sumdb.go

v1.0.1
Doc Versions Source
1
package store
2
3
import (
4
	"context"
5
	"database/sql"
6
	"fmt"
7
	"strings"
8
)
9
10
// SumdbStore is the persistence interface for sumdb transparency log state.
11
type SumdbStore interface {
12
	// SaveRecord atomically persists a new record and its associated hashes.
13
	SaveRecord(ctx context.Context, recordID int64, key string, data []byte, hashStartID int64, hashes [][]byte) error
14
15
	// GetRecords returns record data for IDs [startID, startID+count).
16
	GetRecords(ctx context.Context, startID, count int64) ([][]byte, error)
17
18
	// GetHashes returns hashes for the given IDs, in the order requested.
19
	GetHashes(ctx context.Context, ids []int64) ([][]byte, error)
20
21
	// LookupRecord returns the record ID for a key, or found=false.
22
	LookupRecord(ctx context.Context, key string) (id int64, found bool, err error)
23
24
	// RecordCount returns the total number of sumdb records.
25
	RecordCount(ctx context.Context) (int64, error)
26
}
27
28
// SaveRecord inserts a record and its hashes in a single transaction.
29
func (st *Store) SaveRecord(ctx context.Context, recordID int64, key string, data []byte, hashStartID int64, hashes [][]byte) error {
30
	tx, err := st.db.BeginTx(ctx, nil)
31
	if err != nil {
32
		return err
33
	}
34
	defer tx.Rollback()
35
36
	if _, err := tx.ExecContext(ctx, st.dialect.insertSumdbRecordSQL, recordID, key, data); err != nil {
37
		return err
38
	}
39
40
	for i, h := range hashes {
41
		if _, err := tx.ExecContext(ctx, st.dialect.insertSumdbHashSQL, hashStartID+int64(i), h); err != nil {
42
			return err
43
		}
44
	}
45
46
	return tx.Commit()
47
}
48
49
// GetRecords returns record data for IDs [startID, startID+count).
50
func (st *Store) GetRecords(ctx context.Context, startID, count int64) ([][]byte, error) {
51
	rows, err := st.db.QueryContext(ctx, st.dialect.getRecordsSQL, startID, startID+count)
52
	if err != nil {
53
		return nil, err
54
	}
55
	defer rows.Close()
56
57
	var records [][]byte
58
	for rows.Next() {
59
		var data []byte
60
		if err := rows.Scan(&data); err != nil {
61
			return nil, err
62
		}
63
		records = append(records, data)
64
	}
65
	if err := rows.Err(); err != nil {
66
		return nil, err
67
	}
68
69
	if int64(len(records)) != count {
70
		return nil, fmt.Errorf("expected %d records, got %d", count, len(records))
71
	}
72
	return records, nil
73
}
74
75
// GetHashes returns hashes for the given IDs, in the order requested.
76
func (st *Store) GetHashes(ctx context.Context, ids []int64) ([][]byte, error) {
77
	if len(ids) == 0 {
78
		return nil, nil
79
	}
80
81
	// Build query: SELECT id, hash FROM sumdb_hashes WHERE id IN (?, ?, ...)
82
	placeholders := make([]string, len(ids))
83
	args := make([]any, len(ids))
84
	for i, id := range ids {
85
		if st.dialect.driverName == "postgres" {
86
			placeholders[i] = fmt.Sprintf("$%d", i+1)
87
		} else {
88
			placeholders[i] = "?"
89
		}
90
		args[i] = id
91
	}
92
93
	query := fmt.Sprintf("SELECT id, hash FROM sumdb_hashes WHERE id IN (%s)",
94
		strings.Join(placeholders, ","))
95
96
	rows, err := st.db.QueryContext(ctx, query, args...)
97
	if err != nil {
98
		return nil, err
99
	}
100
	defer rows.Close()
101
102
	// Index results by ID for reordering.
103
	byID := make(map[int64][]byte, len(ids))
104
	for rows.Next() {
105
		var id int64
106
		var h []byte
107
		if err := rows.Scan(&id, &h); err != nil {
108
			return nil, err
109
		}
110
		byID[id] = h
111
	}
112
	if err := rows.Err(); err != nil {
113
		return nil, err
114
	}
115
116
	// Return in requested order.
117
	result := make([][]byte, len(ids))
118
	for i, id := range ids {
119
		h, ok := byID[id]
120
		if !ok {
121
			return nil, fmt.Errorf("hash %d not found", id)
122
		}
123
		result[i] = h
124
	}
125
	return result, nil
126
}
127
128
// LookupRecord returns the record ID for a key, or found=false if not present.
129
func (st *Store) LookupRecord(ctx context.Context, key string) (int64, bool, error) {
130
	var id int64
131
	err := st.db.QueryRowContext(ctx, st.dialect.lookupRecordSQL, key).Scan(&id)
132
	if err == sql.ErrNoRows {
133
		return 0, false, nil
134
	}
135
	if err != nil {
136
		return 0, false, err
137
	}
138
	return id, true, nil
139
}
140
141
// RecordCount returns the total number of sumdb records.
142
func (st *Store) RecordCount(ctx context.Context) (int64, error) {
143
	var count int64
144
	err := st.db.QueryRowContext(ctx, st.dialect.recordCountSQL).Scan(&count)
145
	return count, err
146
}
147

Source Files