sumdb.go

v1.0.1
Doc Versions Source
1
package sumdb
2
3
import (
4
	"archive/zip"
5
	"bytes"
6
	"context"
7
	"fmt"
8
	"io"
9
	"strings"
10
	"sync"
11
12
	"golang.org/x/mod/module"
13
	"golang.org/x/mod/sumdb/dirhash"
14
	"golang.org/x/mod/sumdb/note"
15
	"golang.org/x/mod/sumdb/tlog"
16
17
	"example.com/curator/internal/build"
18
	"example.com/curator/internal/git"
19
	"example.com/curator/internal/metrics"
20
	"example.com/curator/internal/store"
21
)
22
23
// Ops implements sumdb.ServerOps backed by database persistence.
24
// Records and hashes are read from the database on demand.
25
// A mutex serializes writes to ensure tlog consistency.
26
type Ops struct {
27
	host     string
28
	resolver store.ModuleResolver
29
	gitc     *git.Cache
30
	db       store.SumdbStore
31
	signer   string
32
33
	// mu serializes writes (new record creation).
34
	mu sync.Mutex
35
36
	hashCache   *lruCache // caches hash blobs by ID
37
	recordCache *lruCache // caches record data by ID
38
}
39
40
const (
41
	defaultHashCacheSize   = 10000 // ~320 KB for 32-byte hashes
42
	defaultRecordCacheSize = 1000
43
)
44
45
func NewOps(host string, resolver store.ModuleResolver, gitc *git.Cache, db store.SumdbStore, signer string) *Ops {
46
	return &Ops{
47
		host:        host,
48
		resolver:    resolver,
49
		gitc:        gitc,
50
		db:          db,
51
		signer:      signer,
52
		hashCache:   newLRUCache(defaultHashCacheSize),
53
		recordCache: newLRUCache(defaultRecordCacheSize),
54
	}
55
}
56
57
// RecordCount returns the current number of records (for metrics).
58
func (s *Ops) RecordCount() int64 {
59
	n, _ := s.db.RecordCount(context.Background())
60
	return n
61
}
62
63
// Signed returns the signed hash of the latest tree.
64
func (s *Ops) Signed(ctx context.Context) ([]byte, error) {
65
	size, err := s.db.RecordCount(ctx)
66
	if err != nil {
67
		return nil, err
68
	}
69
70
	h, err := tlog.TreeHash(size, &dbHashReader{db: s.db, ctx: ctx, cache: s.hashCache})
71
	if err != nil {
72
		return nil, err
73
	}
74
75
	text := tlog.FormatTree(tlog.Tree{N: size, Hash: h})
76
77
	signer, err := note.NewSigner(s.signer)
78
	if err != nil {
79
		return nil, err
80
	}
81
82
	return note.Sign(&note.Note{Text: string(text)}, signer)
83
}
84
85
// ReadRecords returns the content for records id through id+n-1.
86
func (s *Ops) ReadRecords(ctx context.Context, id, n int64) ([][]byte, error) {
87
	result := make([][]byte, 0, n)
88
	var missingStart, missingCount int64
89
90
	flush := func() error {
91
		if missingCount == 0 {
92
			return nil
93
		}
94
		recs, err := s.db.GetRecords(ctx, missingStart, missingCount)
95
		if err != nil {
96
			return err
97
		}
98
		for i, rec := range recs {
99
			s.recordCache.put(missingStart+int64(i), rec)
100
			result = append(result, rec)
101
		}
102
		missingCount = 0
103
		return nil
104
	}
105
106
	for i := range n {
107
		rid := id + i
108
		if cached, ok := s.recordCache.get(rid); ok {
109
			if err := flush(); err != nil {
110
				return nil, err
111
			}
112
			result = append(result, cached)
113
		} else {
114
			if missingCount == 0 {
115
				missingStart = rid
116
			}
117
			missingCount++
118
		}
119
	}
120
	if err := flush(); err != nil {
121
		return nil, err
122
	}
123
	return result, nil
124
}
125
126
// Lookup looks up a record for the given module, creating it if needed.
127
func (s *Ops) Lookup(ctx context.Context, m module.Version) (int64, error) {
128
	key := m.String()
129
130
	// Fast path: check DB without lock.
131
	id, found, err := s.db.LookupRecord(ctx, key)
132
	if err != nil {
133
		return 0, err
134
	}
135
	if found {
136
		metrics.SumdbLookupsTotal.WithLabelValues("hit").Inc()
137
		return id, nil
138
	}
139
140
	metrics.SumdbLookupsTotal.WithLabelValues("miss").Inc()
141
142
	// Build go.sum record outside the lock.
143
	data, err := s.gosum(m.Path, m.Version)
144
	if err != nil {
145
		return 0, err
146
	}
147
148
	s.mu.Lock()
149
	defer s.mu.Unlock()
150
151
	// Double-check after acquiring lock.
152
	id, found, err = s.db.LookupRecord(ctx, key)
153
	if err != nil {
154
		return 0, err
155
	}
156
	if found {
157
		return id, nil
158
	}
159
160
	// Get current counts for the new record.
161
	id, err = s.db.RecordCount(ctx)
162
	if err != nil {
163
		return 0, fmt.Errorf("record count: %w", err)
164
	}
165
166
	hashReader := &dbHashReader{db: s.db, ctx: ctx, cache: s.hashCache}
167
	hashes, err := tlog.StoredHashesForRecordHash(id, tlog.RecordHash(data), hashReader)
168
	if err != nil {
169
		return 0, fmt.Errorf("stored hashes: %w", err)
170
	}
171
172
	hashStartID := tlog.StoredHashCount(id)
173
	rawHashes := make([][]byte, len(hashes))
174
	for i, h := range hashes {
175
		rawHashes[i] = h[:]
176
	}
177
178
	if err := s.db.SaveRecord(ctx, id, key, data, hashStartID, rawHashes); err != nil {
179
		return 0, fmt.Errorf("persist sumdb record: %w", err)
180
	}
181
182
	// Warm caches with the freshly written data.
183
	s.recordCache.put(id, data)
184
	for i, h := range rawHashes {
185
		s.hashCache.put(hashStartID+int64(i), h)
186
	}
187
188
	return id, nil
189
}
190
191
// ReadTileData reads the content of tile t.
192
func (s *Ops) ReadTileData(ctx context.Context, t tlog.Tile) ([]byte, error) {
193
	return tlog.ReadTileData(t, &dbHashReader{db: s.db, ctx: ctx, cache: s.hashCache})
194
}
195
196
// dbHashReader adapts SumdbStore to tlog.HashReader with an LRU cache.
197
type dbHashReader struct {
198
	db    store.SumdbStore
199
	ctx   context.Context
200
	cache *lruCache
201
}
202
203
func (r *dbHashReader) ReadHashes(indexes []int64) ([]tlog.Hash, error) {
204
	result := make([]tlog.Hash, len(indexes))
205
	var missing []int64
206
	missingIdx := make(map[int64][]int, len(indexes)) // hash ID → positions in result
207
208
	for i, id := range indexes {
209
		if cached, ok := r.cache.get(id); ok {
210
			copy(result[i][:], cached)
211
		} else {
212
			missingIdx[id] = append(missingIdx[id], i)
213
			if len(missingIdx[id]) == 1 {
214
				missing = append(missing, id)
215
			}
216
		}
217
	}
218
219
	if len(missing) == 0 {
220
		return result, nil
221
	}
222
223
	raw, err := r.db.GetHashes(r.ctx, missing)
224
	if err != nil {
225
		return nil, err
226
	}
227
228
	for i, id := range missing {
229
		r.cache.put(id, raw[i])
230
		for _, pos := range missingIdx[id] {
231
			copy(result[pos][:], raw[i])
232
		}
233
	}
234
235
	return result, nil
236
}
237
238
func (s *Ops) gosum(path, version string) ([]byte, error) {
239
	modName := strings.TrimPrefix(path, s.host+"/")
240
	if modName == path {
241
		return nil, fmt.Errorf("module %s not under host %s", path, s.host)
242
	}
243
244
	mod, ok := s.resolver.ResolveModule(modName)
245
	if !ok {
246
		return nil, fmt.Errorf("unknown module: %s", modName)
247
	}
248
249
	repoPath, err := s.gitc.CloneOrFetch(modName, mod.Repo)
250
	if err != nil {
251
		return nil, fmt.Errorf("git: %w", err)
252
	}
253
254
	modulePath := s.host + "/" + modName
255
256
	rv, err := git.ResolveVersion(repoPath, version)
257
	if err != nil {
258
		return nil, fmt.Errorf("resolve version: %w", err)
259
	}
260
261
	modData, _, err := build.Mod(repoPath, rv, modulePath)
262
	if err != nil {
263
		return nil, fmt.Errorf("build mod: %w", err)
264
	}
265
266
	zipData, _, err := build.Zip(repoPath, rv, modulePath)
267
	if err != nil {
268
		return nil, fmt.Errorf("build zip: %w", err)
269
	}
270
271
	zipHash, err := hashZipBytes(zipData)
272
	if err != nil {
273
		return nil, fmt.Errorf("hash zip: %w", err)
274
	}
275
276
	modHash, err := hashModBytes(modulePath, version, modData)
277
	if err != nil {
278
		return nil, fmt.Errorf("hash mod: %w", err)
279
	}
280
281
	record := fmt.Sprintf("%s %s %s\n%s %s/go.mod %s\n",
282
		path, version, zipHash,
283
		path, version, modHash)
284
285
	return []byte(record), nil
286
}
287
288
func hashZipBytes(data []byte) (string, error) {
289
	r, err := zip.NewReader(bytes.NewReader(data), int64(len(data)))
290
	if err != nil {
291
		return "", err
292
	}
293
294
	var files []string
295
	zfiles := make(map[string]*zip.File)
296
	for _, f := range r.File {
297
		files = append(files, f.Name)
298
		zfiles[f.Name] = f
299
	}
300
301
	return dirhash.Hash1(files, func(name string) (io.ReadCloser, error) {
302
		f := zfiles[name]
303
		if f == nil {
304
			return nil, fmt.Errorf("file %q not found in zip", name)
305
		}
306
		return f.Open()
307
	})
308
}
309
310
func hashModBytes(modulePath, version string, data []byte) (string, error) {
311
	gomodFile := fmt.Sprintf("%s@%s/go.mod", modulePath, version)
312
	return dirhash.Hash1([]string{gomodFile}, func(string) (io.ReadCloser, error) {
313
		return io.NopCloser(bytes.NewReader(data)), nil
314
	})
315
}
316

Source Files