sumdb.go

v1.4.2
Doc Versions Source
1
package sumdb
2
3
import (
4
	"archive/zip"
5
	"bytes"
6
	"context"
7
	"fmt"
8
	"io"
9
	"strings"
10
	"sync"
11
12
	"golang.org/x/mod/module"
13
	"golang.org/x/mod/sumdb/dirhash"
14
	"golang.org/x/mod/sumdb/note"
15
	"golang.org/x/mod/sumdb/tlog"
16
17
	"github.com/go-git/go-git/v6/plumbing/transport"
18
19
	"go.bigb.es/curator/internal/build"
20
	"go.bigb.es/curator/internal/config"
21
	"go.bigb.es/curator/internal/git"
22
	"go.bigb.es/curator/internal/metrics"
23
	"go.bigb.es/curator/internal/store"
24
)
25
26
// Ops implements sumdb.ServerOps backed by database persistence.
27
// Records and hashes are read from the database on demand.
28
// A mutex serializes writes to ensure tlog consistency.
29
type Ops struct {
30
	host     string
31
	resolver store.ModuleResolver
32
	creds    store.CredentialResolver
33
	gitc     *git.Cache
34
	db       store.SumdbStore
35
	signer   string
36
37
	// mu serializes writes (new record creation).
38
	mu sync.Mutex
39
40
	hashCache   *lruCache // caches hash blobs by ID
41
	recordCache *lruCache // caches record data by ID
42
}
43
44
const (
45
	defaultHashCacheSize   = 10000 // ~320 KB for 32-byte hashes
46
	defaultRecordCacheSize = 1000
47
)
48
49
func NewOps(host string, resolver store.ModuleResolver, creds store.CredentialResolver, gitc *git.Cache, db store.SumdbStore, signer string) *Ops {
50
	return &Ops{
51
		host:        host,
52
		resolver:    resolver,
53
		creds:       creds,
54
		gitc:        gitc,
55
		db:          db,
56
		signer:      signer,
57
		hashCache:   newLRUCache(defaultHashCacheSize),
58
		recordCache: newLRUCache(defaultRecordCacheSize),
59
	}
60
}
61
62
// authForModule returns a transport.AuthMethod for the given module's credential,
63
// or nil if none is configured.
64
func (s *Ops) authForModule(mod config.Module) transport.AuthMethod {
65
	if mod.CredentialName == "" || s.creds == nil {
66
		return nil
67
	}
68
	cred, err := s.creds.GetCredential(mod.CredentialName)
69
	if err != nil {
70
		return nil
71
	}
72
	auth, err := git.AuthMethodFromCredential(cred.Type, cred.Data)
73
	if err != nil {
74
		return nil
75
	}
76
	return auth
77
}
78
79
// RecordCount returns the current number of records (for metrics).
80
func (s *Ops) RecordCount() int64 {
81
	n, _ := s.db.RecordCount(context.Background())
82
	return n
83
}
84
85
// Signed returns the signed hash of the latest tree.
86
func (s *Ops) Signed(ctx context.Context) ([]byte, error) {
87
	size, err := s.db.RecordCount(ctx)
88
	if err != nil {
89
		return nil, err
90
	}
91
92
	h, err := tlog.TreeHash(size, &dbHashReader{db: s.db, ctx: ctx, cache: s.hashCache})
93
	if err != nil {
94
		return nil, err
95
	}
96
97
	text := tlog.FormatTree(tlog.Tree{N: size, Hash: h})
98
99
	signer, err := note.NewSigner(s.signer)
100
	if err != nil {
101
		return nil, err
102
	}
103
104
	return note.Sign(&note.Note{Text: string(text)}, signer)
105
}
106
107
// ReadRecords returns the content for records id through id+n-1.
108
func (s *Ops) ReadRecords(ctx context.Context, id, n int64) ([][]byte, error) {
109
	result := make([][]byte, 0, n)
110
	var missingStart, missingCount int64
111
112
	flush := func() error {
113
		if missingCount == 0 {
114
			return nil
115
		}
116
		recs, err := s.db.GetRecords(ctx, missingStart, missingCount)
117
		if err != nil {
118
			return err
119
		}
120
		for i, rec := range recs {
121
			s.recordCache.put(missingStart+int64(i), rec)
122
			result = append(result, rec)
123
		}
124
		missingCount = 0
125
		return nil
126
	}
127
128
	for i := range n {
129
		rid := id + i
130
		if cached, ok := s.recordCache.get(rid); ok {
131
			if err := flush(); err != nil {
132
				return nil, err
133
			}
134
			result = append(result, cached)
135
		} else {
136
			if missingCount == 0 {
137
				missingStart = rid
138
			}
139
			missingCount++
140
		}
141
	}
142
	if err := flush(); err != nil {
143
		return nil, err
144
	}
145
	return result, nil
146
}
147
148
// Lookup looks up a record for the given module, creating it if needed.
149
func (s *Ops) Lookup(ctx context.Context, m module.Version) (int64, error) {
150
	key := m.String()
151
152
	// Fast path: check DB without lock.
153
	id, found, err := s.db.LookupRecord(ctx, key)
154
	if err != nil {
155
		return 0, err
156
	}
157
	if found {
158
		metrics.SumdbLookupsTotal.WithLabelValues("hit").Inc()
159
		return id, nil
160
	}
161
162
	metrics.SumdbLookupsTotal.WithLabelValues("miss").Inc()
163
164
	// Build go.sum record outside the lock.
165
	data, err := s.gosum(m.Path, m.Version)
166
	if err != nil {
167
		return 0, err
168
	}
169
170
	s.mu.Lock()
171
	defer s.mu.Unlock()
172
173
	// Double-check after acquiring lock.
174
	id, found, err = s.db.LookupRecord(ctx, key)
175
	if err != nil {
176
		return 0, err
177
	}
178
	if found {
179
		return id, nil
180
	}
181
182
	// Get current counts for the new record.
183
	id, err = s.db.RecordCount(ctx)
184
	if err != nil {
185
		return 0, fmt.Errorf("record count: %w", err)
186
	}
187
188
	hashReader := &dbHashReader{db: s.db, ctx: ctx, cache: s.hashCache}
189
	hashes, err := tlog.StoredHashesForRecordHash(id, tlog.RecordHash(data), hashReader)
190
	if err != nil {
191
		return 0, fmt.Errorf("stored hashes: %w", err)
192
	}
193
194
	hashStartID := tlog.StoredHashCount(id)
195
	rawHashes := make([][]byte, len(hashes))
196
	for i, h := range hashes {
197
		rawHashes[i] = h[:]
198
	}
199
200
	if err := s.db.SaveRecord(ctx, id, key, data, hashStartID, rawHashes); err != nil {
201
		return 0, fmt.Errorf("persist sumdb record: %w", err)
202
	}
203
204
	// Warm caches with the freshly written data.
205
	s.recordCache.put(id, data)
206
	for i, h := range rawHashes {
207
		s.hashCache.put(hashStartID+int64(i), h)
208
	}
209
210
	return id, nil
211
}
212
213
// ReadTileData reads the content of tile t.
214
func (s *Ops) ReadTileData(ctx context.Context, t tlog.Tile) ([]byte, error) {
215
	return tlog.ReadTileData(t, &dbHashReader{db: s.db, ctx: ctx, cache: s.hashCache})
216
}
217
218
// dbHashReader adapts SumdbStore to tlog.HashReader with an LRU cache.
219
type dbHashReader struct {
220
	db    store.SumdbStore
221
	ctx   context.Context
222
	cache *lruCache
223
}
224
225
func (r *dbHashReader) ReadHashes(indexes []int64) ([]tlog.Hash, error) {
226
	result := make([]tlog.Hash, len(indexes))
227
	var missing []int64
228
	missingIdx := make(map[int64][]int, len(indexes)) // hash ID → positions in result
229
230
	for i, id := range indexes {
231
		if cached, ok := r.cache.get(id); ok {
232
			copy(result[i][:], cached)
233
		} else {
234
			missingIdx[id] = append(missingIdx[id], i)
235
			if len(missingIdx[id]) == 1 {
236
				missing = append(missing, id)
237
			}
238
		}
239
	}
240
241
	if len(missing) == 0 {
242
		return result, nil
243
	}
244
245
	raw, err := r.db.GetHashes(r.ctx, missing)
246
	if err != nil {
247
		return nil, err
248
	}
249
250
	for i, id := range missing {
251
		r.cache.put(id, raw[i])
252
		for _, pos := range missingIdx[id] {
253
			copy(result[pos][:], raw[i])
254
		}
255
	}
256
257
	return result, nil
258
}
259
260
func (s *Ops) gosum(path, version string) ([]byte, error) {
261
	modName := strings.TrimPrefix(path, s.host+"/")
262
	if modName == path {
263
		return nil, fmt.Errorf("module %s not under host %s", path, s.host)
264
	}
265
266
	mod, ok := s.resolver.ResolveModule(modName)
267
	if !ok {
268
		return nil, fmt.Errorf("unknown module: %s", modName)
269
	}
270
271
	auth := s.authForModule(mod)
272
	repoPath, err := s.gitc.CloneOrFetch(modName, mod.Repo, auth)
273
	if err != nil {
274
		return nil, fmt.Errorf("git: %w", err)
275
	}
276
277
	modulePath := s.host + "/" + modName
278
279
	rv, err := git.ResolveVersion(repoPath, version)
280
	if err != nil {
281
		return nil, fmt.Errorf("resolve version: %w", err)
282
	}
283
284
	modData, _, err := build.Mod(repoPath, rv, modulePath)
285
	if err != nil {
286
		return nil, fmt.Errorf("build mod: %w", err)
287
	}
288
289
	zipData, _, err := build.Zip(repoPath, rv, modulePath)
290
	if err != nil {
291
		return nil, fmt.Errorf("build zip: %w", err)
292
	}
293
294
	zipHash, err := hashZipBytes(zipData)
295
	if err != nil {
296
		return nil, fmt.Errorf("hash zip: %w", err)
297
	}
298
299
	modHash, err := hashModBytes(modulePath, version, modData)
300
	if err != nil {
301
		return nil, fmt.Errorf("hash mod: %w", err)
302
	}
303
304
	record := fmt.Sprintf("%s %s %s\n%s %s/go.mod %s\n",
305
		path, version, zipHash,
306
		path, version, modHash)
307
308
	return []byte(record), nil
309
}
310
311
func hashZipBytes(data []byte) (string, error) {
312
	r, err := zip.NewReader(bytes.NewReader(data), int64(len(data)))
313
	if err != nil {
314
		return "", err
315
	}
316
317
	var files []string
318
	zfiles := make(map[string]*zip.File)
319
	for _, f := range r.File {
320
		files = append(files, f.Name)
321
		zfiles[f.Name] = f
322
	}
323
324
	return dirhash.Hash1(files, func(name string) (io.ReadCloser, error) {
325
		f := zfiles[name]
326
		if f == nil {
327
			return nil, fmt.Errorf("file %q not found in zip", name)
328
		}
329
		return f.Open()
330
	})
331
}
332
333
func hashModBytes(modulePath, version string, data []byte) (string, error) {
334
	gomodFile := fmt.Sprintf("%s@%s/go.mod", modulePath, version)
335
	return dirhash.Hash1([]string{gomodFile}, func(string) (io.ReadCloser, error) {
336
		return io.NopCloser(bytes.NewReader(data)), nil
337
	})
338
}
339

Source Files