| 1 | package store |
| 2 | |
| 3 | import ( |
| 4 | "database/sql" |
| 5 | "fmt" |
| 6 | "regexp" |
| 7 | "strings" |
| 8 | "sync" |
| 9 | |
| 10 | _ "modernc.org/sqlite" |
| 11 | |
| 12 | "go.bigb.es/curator/internal/config" |
| 13 | ) |
| 14 | |
| 15 | // Store is a database-backed module store implementing ModuleResolver. |
| 16 | // It supports both SQLite and PostgreSQL backends. |
| 17 | type Store struct { |
| 18 | db *sql.DB |
| 19 | dialect dialect |
| 20 | |
| 21 | // Compiled pattern cache for fast resolution. |
| 22 | mu sync.RWMutex |
| 23 | patterns []compiledPattern |
| 24 | } |
| 25 | |
| 26 | type compiledPattern struct { |
| 27 | PatternRow |
| 28 | re *regexp.Regexp |
| 29 | } |
| 30 | |
| 31 | // Open opens a database with the given driver and DSN, then runs migrations. |
| 32 | // Supported drivers: "sqlite", "postgres". |
| 33 | // |
| 34 | // For SQLite the DSN is a file path (WAL mode and busy timeout are set automatically). |
| 35 | // For PostgreSQL the DSN is a connection string (e.g. "postgres://user:pass@host/db?sslmode=disable"). |
| 36 | func Open(driver, dsn string) (*Store, error) { |
| 37 | d, err := dialectFor(driver) |
| 38 | if err != nil { |
| 39 | return nil, err |
| 40 | } |
| 41 | |
| 42 | // SQLite: append pragmas if not already present. |
| 43 | if driver == "sqlite" && !strings.Contains(dsn, "_pragma") { |
| 44 | dsn = dsn + "?_pragma=journal_mode(wal)&_pragma=busy_timeout(5000)" |
| 45 | } |
| 46 | |
| 47 | db, err := sql.Open(d.driverName, dsn) |
| 48 | if err != nil { |
| 49 | return nil, fmt.Errorf("open database: %w", err) |
| 50 | } |
| 51 | |
| 52 | if err := db.Ping(); err != nil { |
| 53 | db.Close() |
| 54 | return nil, fmt.Errorf("ping database: %w", err) |
| 55 | } |
| 56 | |
| 57 | if _, err := db.Exec(d.migrationSQL); err != nil { |
| 58 | db.Close() |
| 59 | return nil, fmt.Errorf("migrate database: %w", err) |
| 60 | } |
| 61 | |
| 62 | // Add credential_name column to existing tables (idempotent). |
| 63 | addColumnIfNotExists(db, "modules", "credential_name", "TEXT NOT NULL DEFAULT ''") |
| 64 | addColumnIfNotExists(db, "module_patterns", "credential_name", "TEXT NOT NULL DEFAULT ''") |
| 65 | |
| 66 | st := &Store{db: db, dialect: d} |
| 67 | if err := st.refreshPatterns(); err != nil { |
| 68 | db.Close() |
| 69 | return nil, fmt.Errorf("load patterns: %w", err) |
| 70 | } |
| 71 | |
| 72 | return st, nil |
| 73 | } |
| 74 | |
| 75 | // Close closes the database connection. |
| 76 | func (st *Store) Close() error { |
| 77 | return st.db.Close() |
| 78 | } |
| 79 | |
| 80 | // ResolveModule looks up a module by name: exact match first, then patterns. |
| 81 | func (st *Store) ResolveModule(name string) (config.Module, bool) { |
| 82 | // Try exact match from database. |
| 83 | var mod config.Module |
| 84 | var private int |
| 85 | var credName string |
| 86 | err := st.db.QueryRow(st.dialect.resolveModuleSQL, name).Scan(&mod.VCS, &mod.Repo, &mod.Web, &private, &credName) |
| 87 | if err == nil { |
| 88 | mod.Private = private != 0 |
| 89 | mod.CredentialName = credName |
| 90 | return mod, true |
| 91 | } |
| 92 | |
| 93 | // Try compiled patterns. |
| 94 | st.mu.RLock() |
| 95 | defer st.mu.RUnlock() |
| 96 | |
| 97 | for _, cp := range st.patterns { |
| 98 | matches := cp.re.FindStringSubmatch(name) |
| 99 | if matches == nil { |
| 100 | continue |
| 101 | } |
| 102 | |
| 103 | groups := matches[1:] |
| 104 | return config.Module{ |
| 105 | VCS: expandTemplate(cp.VCS, name, groups), |
| 106 | Repo: expandTemplate(cp.Repo, name, groups), |
| 107 | Web: expandTemplate(cp.Web, name, groups), |
| 108 | Private: cp.Private, |
| 109 | CredentialName: cp.CredentialName, |
| 110 | }, true |
| 111 | } |
| 112 | |
| 113 | return config.Module{}, false |
| 114 | } |
| 115 | |
| 116 | // ListModules returns all modules. |
| 117 | func (st *Store) ListModules() ([]ModuleRow, error) { |
| 118 | rows, err := st.db.Query("SELECT name, vcs, repo, web, private, credential_name, created_at FROM modules ORDER BY name") |
| 119 | if err != nil { |
| 120 | return nil, err |
| 121 | } |
| 122 | defer rows.Close() |
| 123 | |
| 124 | var modules []ModuleRow |
| 125 | for rows.Next() { |
| 126 | var m ModuleRow |
| 127 | var private int |
| 128 | if err := rows.Scan(&m.Name, &m.VCS, &m.Repo, &m.Web, &private, &m.CredentialName, &m.CreatedAt); err != nil { |
| 129 | return nil, err |
| 130 | } |
| 131 | m.Private = private != 0 |
| 132 | modules = append(modules, m) |
| 133 | } |
| 134 | return modules, rows.Err() |
| 135 | } |
| 136 | |
| 137 | // AddModule adds a module to the database. |
| 138 | func (st *Store) AddModule(m ModuleRow) error { |
| 139 | private := 0 |
| 140 | if m.Private { |
| 141 | private = 1 |
| 142 | } |
| 143 | |
| 144 | _, err := st.db.Exec(st.dialect.addModuleSQL, m.Name, m.VCS, m.Repo, m.Web, private, m.CredentialName) |
| 145 | return err |
| 146 | } |
| 147 | |
| 148 | // UpdateModule updates a module's fields (except name, which is the PK). |
| 149 | func (st *Store) UpdateModule(m ModuleRow) error { |
| 150 | private := 0 |
| 151 | if m.Private { |
| 152 | private = 1 |
| 153 | } |
| 154 | |
| 155 | result, err := st.db.Exec(st.dialect.updateModuleSQL, m.VCS, m.Repo, m.Web, private, m.CredentialName, m.Name) |
| 156 | if err != nil { |
| 157 | return err |
| 158 | } |
| 159 | |
| 160 | n, err := result.RowsAffected() |
| 161 | if err != nil { |
| 162 | return err |
| 163 | } |
| 164 | if n == 0 { |
| 165 | return sql.ErrNoRows |
| 166 | } |
| 167 | return nil |
| 168 | } |
| 169 | |
| 170 | // DeleteModule removes a module from the database. |
| 171 | func (st *Store) DeleteModule(name string) error { |
| 172 | result, err := st.db.Exec(st.dialect.deleteModuleSQL, name) |
| 173 | if err != nil { |
| 174 | return err |
| 175 | } |
| 176 | |
| 177 | n, err := result.RowsAffected() |
| 178 | if err != nil { |
| 179 | return err |
| 180 | } |
| 181 | if n == 0 { |
| 182 | return sql.ErrNoRows |
| 183 | } |
| 184 | return nil |
| 185 | } |
| 186 | |
| 187 | // ListPatterns returns all module patterns ordered by priority. |
| 188 | func (st *Store) ListPatterns() ([]PatternRow, error) { |
| 189 | rows, err := st.db.Query("SELECT id, pattern, vcs, repo, web, private, priority, credential_name, created_at FROM module_patterns ORDER BY priority, id") |
| 190 | if err != nil { |
| 191 | return nil, err |
| 192 | } |
| 193 | defer rows.Close() |
| 194 | |
| 195 | var patterns []PatternRow |
| 196 | for rows.Next() { |
| 197 | var p PatternRow |
| 198 | var private int |
| 199 | if err := rows.Scan(&p.ID, &p.Pattern, &p.VCS, &p.Repo, &p.Web, &private, &p.Priority, &p.CredentialName, &p.CreatedAt); err != nil { |
| 200 | return nil, err |
| 201 | } |
| 202 | p.Private = private != 0 |
| 203 | patterns = append(patterns, p) |
| 204 | } |
| 205 | return patterns, rows.Err() |
| 206 | } |
| 207 | |
| 208 | // AddPattern adds a module pattern to the database and refreshes the cache. |
| 209 | func (st *Store) AddPattern(p PatternRow) error { |
| 210 | private := 0 |
| 211 | if p.Private { |
| 212 | private = 1 |
| 213 | } |
| 214 | |
| 215 | _, err := st.db.Exec(st.dialect.addPatternSQL, p.Pattern, p.VCS, p.Repo, p.Web, private, p.Priority, p.CredentialName) |
| 216 | if err != nil { |
| 217 | return err |
| 218 | } |
| 219 | |
| 220 | return st.refreshPatterns() |
| 221 | } |
| 222 | |
| 223 | // UpdatePattern updates a module pattern and refreshes the cache. |
| 224 | func (st *Store) UpdatePattern(p PatternRow) error { |
| 225 | private := 0 |
| 226 | if p.Private { |
| 227 | private = 1 |
| 228 | } |
| 229 | |
| 230 | result, err := st.db.Exec(st.dialect.updatePatternSQL, p.Pattern, p.VCS, p.Repo, p.Web, private, p.Priority, p.CredentialName, p.ID) |
| 231 | if err != nil { |
| 232 | return err |
| 233 | } |
| 234 | |
| 235 | n, err := result.RowsAffected() |
| 236 | if err != nil { |
| 237 | return err |
| 238 | } |
| 239 | if n == 0 { |
| 240 | return sql.ErrNoRows |
| 241 | } |
| 242 | |
| 243 | return st.refreshPatterns() |
| 244 | } |
| 245 | |
| 246 | // DeletePattern removes a module pattern and refreshes the cache. |
| 247 | func (st *Store) DeletePattern(id int64) error { |
| 248 | result, err := st.db.Exec(st.dialect.deletePatternSQL, id) |
| 249 | if err != nil { |
| 250 | return err |
| 251 | } |
| 252 | |
| 253 | n, err := result.RowsAffected() |
| 254 | if err != nil { |
| 255 | return err |
| 256 | } |
| 257 | if n == 0 { |
| 258 | return sql.ErrNoRows |
| 259 | } |
| 260 | |
| 261 | return st.refreshPatterns() |
| 262 | } |
| 263 | |
| 264 | // ModuleCount returns the total number of configured modules (exact + patterns). |
| 265 | func (st *Store) ModuleCount() int { |
| 266 | var count int |
| 267 | st.db.QueryRow("SELECT COUNT(*) FROM modules").Scan(&count) |
| 268 | |
| 269 | st.mu.RLock() |
| 270 | count += len(st.patterns) |
| 271 | st.mu.RUnlock() |
| 272 | |
| 273 | return count |
| 274 | } |
| 275 | |
| 276 | func (st *Store) refreshPatterns() error { |
| 277 | rows, err := st.db.Query("SELECT id, pattern, vcs, repo, web, private, priority, credential_name FROM module_patterns ORDER BY priority, id") |
| 278 | if err != nil { |
| 279 | return err |
| 280 | } |
| 281 | defer rows.Close() |
| 282 | |
| 283 | var patterns []compiledPattern |
| 284 | for rows.Next() { |
| 285 | var cp compiledPattern |
| 286 | var private int |
| 287 | if err := rows.Scan(&cp.ID, &cp.Pattern, &cp.VCS, &cp.Repo, &cp.Web, &private, &cp.Priority, &cp.CredentialName); err != nil { |
| 288 | return err |
| 289 | } |
| 290 | cp.Private = private != 0 |
| 291 | |
| 292 | re, err := regexp.Compile("^" + cp.Pattern + "$") |
| 293 | if err != nil { |
| 294 | return fmt.Errorf("compile pattern %q: %w", cp.Pattern, err) |
| 295 | } |
| 296 | cp.re = re |
| 297 | |
| 298 | patterns = append(patterns, cp) |
| 299 | } |
| 300 | |
| 301 | st.mu.Lock() |
| 302 | st.patterns = patterns |
| 303 | st.mu.Unlock() |
| 304 | |
| 305 | return rows.Err() |
| 306 | } |
| 307 | |
| 308 | func expandTemplate(tmpl, name string, groups []string) string { |
| 309 | s := strings.ReplaceAll(tmpl, "{name}", name) |
| 310 | for i, g := range groups { |
| 311 | s = strings.ReplaceAll(s, fmt.Sprintf("{%d}", i+1), g) |
| 312 | } |
| 313 | return s |
| 314 | } |
| 315 | |
| 316 | // addColumnIfNotExists runs ALTER TABLE ADD COLUMN and ignores errors |
| 317 | // indicating the column already exists (for idempotent migrations). |
| 318 | func addColumnIfNotExists(db *sql.DB, table, column, colDef string) { |
| 319 | _, _ = db.Exec(fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s %s", table, column, colDef)) |
| 320 | } |
| 321 | |
| 322 | // ListCredentials returns all credentials ordered by name. |
| 323 | func (st *Store) ListCredentials() ([]CredentialRow, error) { |
| 324 | rows, err := st.db.Query(st.dialect.listCredentialsSQL) |
| 325 | if err != nil { |
| 326 | return nil, err |
| 327 | } |
| 328 | defer rows.Close() |
| 329 | |
| 330 | var creds []CredentialRow |
| 331 | for rows.Next() { |
| 332 | var c CredentialRow |
| 333 | if err := rows.Scan(&c.Name, &c.Type, &c.Data, &c.CreatedAt); err != nil { |
| 334 | return nil, err |
| 335 | } |
| 336 | creds = append(creds, c) |
| 337 | } |
| 338 | return creds, rows.Err() |
| 339 | } |
| 340 | |
| 341 | // GetCredential returns a credential by name. |
| 342 | func (st *Store) GetCredential(name string) (CredentialRow, error) { |
| 343 | var c CredentialRow |
| 344 | err := st.db.QueryRow(st.dialect.getCredentialSQL, name).Scan(&c.Name, &c.Type, &c.Data, &c.CreatedAt) |
| 345 | return c, err |
| 346 | } |
| 347 | |
| 348 | // AddCredential inserts a new credential. |
| 349 | func (st *Store) AddCredential(c CredentialRow) error { |
| 350 | _, err := st.db.Exec(st.dialect.addCredentialSQL, c.Name, c.Type, c.Data) |
| 351 | return err |
| 352 | } |
| 353 | |
| 354 | // UpdateCredential updates an existing credential's type and data. |
| 355 | func (st *Store) UpdateCredential(c CredentialRow) error { |
| 356 | result, err := st.db.Exec(st.dialect.updateCredentialSQL, c.Type, c.Data, c.Name) |
| 357 | if err != nil { |
| 358 | return err |
| 359 | } |
| 360 | n, err := result.RowsAffected() |
| 361 | if err != nil { |
| 362 | return err |
| 363 | } |
| 364 | if n == 0 { |
| 365 | return sql.ErrNoRows |
| 366 | } |
| 367 | return nil |
| 368 | } |
| 369 | |
| 370 | // DeleteCredential removes a credential by name. |
| 371 | // It returns an error if the credential is referenced by any module or pattern. |
| 372 | func (st *Store) DeleteCredential(name string) error { |
| 373 | // Check for references in modules. |
| 374 | var count int |
| 375 | st.db.QueryRow("SELECT COUNT(*) FROM modules WHERE credential_name = ?", name).Scan(&count) |
| 376 | if count > 0 { |
| 377 | return fmt.Errorf("credential %q is referenced by %d module(s)", name, count) |
| 378 | } |
| 379 | |
| 380 | // Check for references in patterns. |
| 381 | st.db.QueryRow("SELECT COUNT(*) FROM module_patterns WHERE credential_name = ?", name).Scan(&count) |
| 382 | if count > 0 { |
| 383 | return fmt.Errorf("credential %q is referenced by %d pattern(s)", name, count) |
| 384 | } |
| 385 | |
| 386 | result, err := st.db.Exec(st.dialect.deleteCredentialSQL, name) |
| 387 | if err != nil { |
| 388 | return err |
| 389 | } |
| 390 | n, err := result.RowsAffected() |
| 391 | if err != nil { |
| 392 | return err |
| 393 | } |
| 394 | if n == 0 { |
| 395 | return sql.ErrNoRows |
| 396 | } |
| 397 | return nil |
| 398 | } |
| 399 | |
| 400 | // CredentialCount returns the total number of credentials. |
| 401 | func (st *Store) CredentialCount() int { |
| 402 | var count int |
| 403 | st.db.QueryRow("SELECT COUNT(*) FROM credentials").Scan(&count) |
| 404 | return count |
| 405 | } |
| 406 | |