notifications.go

v1.0.0
Doc Versions Source
1
package api
2
3
import (
4
	"bytes"
5
	"context"
6
	"encoding/json"
7
	"fmt"
8
	"io"
9
	"log"
10
	"strings"
11
	"sync"
12
	"time"
13
14
	"github.com/coder/websocket"
15
	"github.com/vmihailenco/msgpack/v5"
16
)
17
18
// NotificationType represents the type of server notification.
19
type NotificationType int
20
21
const (
22
	NotifSyncCipherUpdate    NotificationType = 0
23
	NotifSyncCipherCreate    NotificationType = 1
24
	NotifSyncLoginDelete     NotificationType = 2
25
	NotifSyncFolderDelete    NotificationType = 3
26
	NotifSyncCiphers         NotificationType = 4
27
	NotifSyncVault           NotificationType = 5
28
	NotifSyncOrgKeys         NotificationType = 6
29
	NotifSyncFolderCreate    NotificationType = 7
30
	NotifSyncFolderUpdate    NotificationType = 8
31
	NotifSyncCipherDelete    NotificationType = 9
32
	NotifSyncSettings        NotificationType = 10
33
	NotifLogOut              NotificationType = 11
34
	NotifSyncSendCreate      NotificationType = 12
35
	NotifSyncSendUpdate      NotificationType = 13
36
	NotifSyncSendDelete      NotificationType = 14
37
	NotifAuthRequest         NotificationType = 15
38
	NotifAuthRequestResponse NotificationType = 16
39
)
40
41
// NotificationMessage is the decoded notification from the server.
42
type NotificationMessage struct {
43
	Type      NotificationType
44
	ContextID string
45
	Payload   NotificationPayload
46
}
47
48
// NotificationPayload holds the payload fields of a notification.
49
type NotificationPayload struct {
50
	ID             string
51
	UserID         string
52
	OrganizationID string
53
	CollectionIDs  []string
54
	RevisionDate   string
55
	Date           string
56
}
57
58
// NotificationHandler is called when a notification is received.
59
type NotificationHandler func(msg NotificationMessage)
60
61
// NotificationsClient manages a WebSocket connection to the notifications hub.
62
type NotificationsClient struct {
63
	transport *Transport
64
	handler   NotificationHandler
65
	logger    *log.Logger
66
67
	mu     sync.Mutex
68
	cancel context.CancelFunc
69
	done   chan struct{}
70
}
71
72
// NewNotificationsClient creates a new notifications client.
73
func NewNotificationsClient(transport *Transport, handler NotificationHandler) *NotificationsClient {
74
	return &NotificationsClient{
75
		transport: transport,
76
		handler:   handler,
77
		logger:    log.New(io.Discard, "", 0),
78
	}
79
}
80
81
// SetLogger sets a logger for debug output. By default, logging is disabled.
82
func (nc *NotificationsClient) SetLogger(l *log.Logger) {
83
	nc.logger = l
84
}
85
86
// Connect starts the background WebSocket connection with automatic reconnection.
87
// The provided context controls the lifetime of the connection — cancel it to stop.
88
func (nc *NotificationsClient) Connect(ctx context.Context) error {
89
	nc.mu.Lock()
90
	defer nc.mu.Unlock()
91
92
	if nc.done != nil {
93
		return fmt.Errorf("notifications already connected")
94
	}
95
96
	ctx, cancel := context.WithCancel(ctx)
97
	nc.cancel = cancel
98
	nc.done = make(chan struct{})
99
100
	go nc.run(ctx)
101
	return nil
102
}
103
104
// Close stops the WebSocket connection and waits for the background goroutine to exit.
105
func (nc *NotificationsClient) Close() error {
106
	nc.mu.Lock()
107
	cancel := nc.cancel
108
	done := nc.done
109
	nc.mu.Unlock()
110
111
	if cancel == nil {
112
		return nil
113
	}
114
115
	cancel()
116
	if done != nil {
117
		<-done
118
	}
119
120
	nc.mu.Lock()
121
	nc.cancel = nil
122
	nc.done = nil
123
	nc.mu.Unlock()
124
125
	return nil
126
}
127
128
func (nc *NotificationsClient) run(ctx context.Context) {
129
	defer close(nc.done)
130
131
	backoff := 5 * time.Second
132
	const maxBackoff = 60 * time.Second
133
134
	for {
135
		err := nc.connectAndListen(ctx)
136
		if ctx.Err() != nil {
137
			return
138
		}
139
		nc.logger.Printf("notifications disconnected: %v; reconnecting in %v", err, backoff)
140
141
		select {
142
		case <-ctx.Done():
143
			return
144
		case <-time.After(backoff):
145
		}
146
147
		backoff *= 2
148
		if backoff > maxBackoff {
149
			backoff = maxBackoff
150
		}
151
	}
152
}
153
154
func (nc *NotificationsClient) connectAndListen(ctx context.Context) error {
155
	token, err := nc.transport.AccessToken(ctx)
156
	if err != nil {
157
		return fmt.Errorf("get access token: %w", err)
158
	}
159
160
	wsURL := httpToWS(nc.transport.BaseURL()) + "/notifications/hub?access_token=" + token
161
162
	conn, _, err := websocket.Dial(ctx, wsURL, nil)
163
	if err != nil {
164
		return fmt.Errorf("dial: %w", err)
165
	}
166
	defer conn.Close(websocket.StatusNormalClosure, "")
167
168
	// SignalR handshake — always JSON, terminated by 0x1e
169
	if err := nc.doHandshake(ctx, conn); err != nil {
170
		return fmt.Errorf("handshake: %w", err)
171
	}
172
173
	nc.logger.Printf("notifications connected")
174
175
	// Reset backoff on successful connection (done by caller seeing nil vs error)
176
	errCh := make(chan error, 1)
177
178
	// Reader goroutine
179
	go func() {
180
		errCh <- nc.readMessages(ctx, conn)
181
	}()
182
183
	// Ping ticker
184
	ticker := time.NewTicker(15 * time.Second)
185
	defer ticker.Stop()
186
187
	for {
188
		select {
189
		case <-ctx.Done():
190
			return ctx.Err()
191
		case err := <-errCh:
192
			return err
193
		case <-ticker.C:
194
			if err := nc.sendPing(ctx, conn); err != nil {
195
				return fmt.Errorf("send ping: %w", err)
196
			}
197
		}
198
	}
199
}
200
201
const recordSeparator = 0x1e
202
203
func (nc *NotificationsClient) doHandshake(ctx context.Context, conn *websocket.Conn) error {
204
	// Send handshake: {"protocol":"messagepack","version":1}\x1e
205
	handshake := []byte(`{"protocol":"messagepack","version":1}`)
206
	handshake = append(handshake, recordSeparator)
207
208
	if err := conn.Write(ctx, websocket.MessageText, handshake); err != nil {
209
		return fmt.Errorf("write handshake: %w", err)
210
	}
211
212
	// Read handshake response
213
	_, data, err := conn.Read(ctx)
214
	if err != nil {
215
		return fmt.Errorf("read handshake response: %w", err)
216
	}
217
218
	// Strip record separator
219
	data = bytes.TrimRight(data, string([]byte{recordSeparator}))
220
221
	var resp map[string]any
222
	if err := json.Unmarshal(data, &resp); err != nil {
223
		return fmt.Errorf("parse handshake response: %w", err)
224
	}
225
226
	// Check for error
227
	if errMsg, ok := resp["error"]; ok {
228
		return fmt.Errorf("handshake error: %v", errMsg)
229
	}
230
231
	return nil
232
}
233
234
func (nc *NotificationsClient) readMessages(ctx context.Context, conn *websocket.Conn) error {
235
	for {
236
		_, data, err := conn.Read(ctx)
237
		if err != nil {
238
			return fmt.Errorf("read: %w", err)
239
		}
240
241
		// Binary frames contain one or more length-prefixed MessagePack messages
242
		reader := bytes.NewReader(data)
243
		for reader.Len() > 0 {
244
			msgLen, err := readVarInt(reader)
245
			if err != nil {
246
				return fmt.Errorf("read varint: %w", err)
247
			}
248
249
			if msgLen <= 0 || msgLen > reader.Len() {
250
				return fmt.Errorf("invalid message length: %d (remaining: %d)", msgLen, reader.Len())
251
			}
252
253
			msgData := make([]byte, msgLen)
254
			if _, err := io.ReadFull(reader, msgData); err != nil {
255
				return fmt.Errorf("read message: %w", err)
256
			}
257
258
			nc.handleMessage(msgData)
259
		}
260
	}
261
}
262
263
func (nc *NotificationsClient) handleMessage(data []byte) {
264
	// Decode as generic array
265
	var msg []any
266
	if err := msgpack.Unmarshal(data, &msg); err != nil {
267
		nc.logger.Printf("failed to decode message: %v", err)
268
		return
269
	}
270
271
	if len(msg) == 0 {
272
		return
273
	}
274
275
	msgType, ok := asInt(msg[0])
276
	if !ok {
277
		return
278
	}
279
280
	switch msgType {
281
	case 1: // Invocation
282
		nc.handleInvocation(msg)
283
	case 6: // Ping — no action needed, server pinging us
284
		nc.logger.Printf("received server ping")
285
	case 7: // Close
286
		nc.logger.Printf("received close message")
287
	}
288
}
289
290
func (nc *NotificationsClient) handleInvocation(msg []any) {
291
	// Format: [1, headers, invocationId, target, arguments]
292
	if len(msg) < 5 {
293
		return
294
	}
295
296
	target, _ := msg[3].(string)
297
	if target != "ReceiveMessage" {
298
		nc.logger.Printf("unknown invocation target: %s", target)
299
		return
300
	}
301
302
	args, ok := msg[4].([]any)
303
	if !ok || len(args) == 0 {
304
		return
305
	}
306
307
	// The first argument is the notification object
308
	argMap, ok := args[0].(map[string]any)
309
	if !ok {
310
		return
311
	}
312
313
	notifMsg := NotificationMessage{}
314
	if ct, ok := argMap["ContextId"]; ok {
315
		notifMsg.ContextID, _ = ct.(string)
316
	}
317
	if t, ok := asInt(argMap["Type"]); ok {
318
		notifMsg.Type = NotificationType(t)
319
	}
320
321
	// Payload can be a string (JSON) or a map
322
	if payloadStr, ok := argMap["Payload"].(string); ok && payloadStr != "" {
323
		var p struct {
324
			ID             string   `json:"id"`
325
			UserID         string   `json:"userId"`
326
			OrganizationID string   `json:"organizationId"`
327
			CollectionIDs  []string `json:"collectionIds"`
328
			RevisionDate   string   `json:"revisionDate"`
329
			Date           string   `json:"date"`
330
		}
331
		if err := json.Unmarshal([]byte(payloadStr), &p); err == nil {
332
			notifMsg.Payload = NotificationPayload{
333
				ID:             p.ID,
334
				UserID:         p.UserID,
335
				OrganizationID: p.OrganizationID,
336
				CollectionIDs:  p.CollectionIDs,
337
				RevisionDate:   p.RevisionDate,
338
				Date:           p.Date,
339
			}
340
		}
341
	}
342
343
	nc.handler(notifMsg)
344
}
345
346
func (nc *NotificationsClient) sendPing(ctx context.Context, conn *websocket.Conn) error {
347
	// MessagePack [6] = 0x91 0x06, with VarInt length prefix 0x02
348
	ping := []byte{0x02, 0x91, 0x06}
349
	return conn.Write(ctx, websocket.MessageBinary, ping)
350
}
351
352
// httpToWS converts an HTTP(S) URL to a WS(S) URL.
353
func httpToWS(u string) string {
354
	if strings.HasPrefix(u, "https://") {
355
		return "wss://" + u[len("https://"):]
356
	}
357
	if strings.HasPrefix(u, "http://") {
358
		return "ws://" + u[len("http://"):]
359
	}
360
	return u
361
}
362
363
// readVarInt reads a variable-length integer (SignalR MessagePack framing).
364
// Uses 7 bits per byte with the high bit as continuation flag.
365
func readVarInt(r io.Reader) (int, error) {
366
	var result int
367
	var shift uint
368
	buf := make([]byte, 1)
369
	for {
370
		if _, err := io.ReadFull(r, buf); err != nil {
371
			return 0, err
372
		}
373
		b := buf[0]
374
		result |= int(b&0x7F) << shift
375
		if b&0x80 == 0 {
376
			return result, nil
377
		}
378
		shift += 7
379
		if shift >= 35 {
380
			return 0, fmt.Errorf("varint too long")
381
		}
382
	}
383
}
384
385
// writeVarInt writes a variable-length integer to a buffer.
386
func writeVarInt(buf *bytes.Buffer, n int) {
387
	for {
388
		b := byte(n & 0x7F)
389
		n >>= 7
390
		if n > 0 {
391
			b |= 0x80
392
		}
393
		buf.WriteByte(b)
394
		if n == 0 {
395
			break
396
		}
397
	}
398
}
399
400
// asInt extracts an integer from a msgpack-decoded interface{} value,
401
// which can be int8, int16, int32, int64, uint8, uint16, uint32, uint64.
402
func asInt(v any) (int, bool) {
403
	switch n := v.(type) {
404
	case int:
405
		return n, true
406
	case int8:
407
		return int(n), true
408
	case int16:
409
		return int(n), true
410
	case int32:
411
		return int(n), true
412
	case int64:
413
		return int(n), true
414
	case uint8:
415
		return int(n), true
416
	case uint16:
417
		return int(n), true
418
	case uint32:
419
		return int(n), true
420
	case uint64:
421
		return int(n), true
422
	default:
423
		return 0, false
424
	}
425
}
426

Source Files