| 1 | package api |
| 2 | |
| 3 | import ( |
| 4 | "bytes" |
| 5 | "context" |
| 6 | "encoding/json" |
| 7 | "fmt" |
| 8 | "io" |
| 9 | "log" |
| 10 | "strings" |
| 11 | "sync" |
| 12 | "time" |
| 13 | |
| 14 | "github.com/coder/websocket" |
| 15 | "github.com/vmihailenco/msgpack/v5" |
| 16 | ) |
| 17 | |
| 18 | // NotificationType represents the type of server notification. |
| 19 | type NotificationType int |
| 20 | |
| 21 | const ( |
| 22 | NotifSyncCipherUpdate NotificationType = 0 |
| 23 | NotifSyncCipherCreate NotificationType = 1 |
| 24 | NotifSyncLoginDelete NotificationType = 2 |
| 25 | NotifSyncFolderDelete NotificationType = 3 |
| 26 | NotifSyncCiphers NotificationType = 4 |
| 27 | NotifSyncVault NotificationType = 5 |
| 28 | NotifSyncOrgKeys NotificationType = 6 |
| 29 | NotifSyncFolderCreate NotificationType = 7 |
| 30 | NotifSyncFolderUpdate NotificationType = 8 |
| 31 | NotifSyncCipherDelete NotificationType = 9 |
| 32 | NotifSyncSettings NotificationType = 10 |
| 33 | NotifLogOut NotificationType = 11 |
| 34 | NotifSyncSendCreate NotificationType = 12 |
| 35 | NotifSyncSendUpdate NotificationType = 13 |
| 36 | NotifSyncSendDelete NotificationType = 14 |
| 37 | NotifAuthRequest NotificationType = 15 |
| 38 | NotifAuthRequestResponse NotificationType = 16 |
| 39 | ) |
| 40 | |
| 41 | // NotificationMessage is the decoded notification from the server. |
| 42 | type NotificationMessage struct { |
| 43 | Type NotificationType |
| 44 | ContextID string |
| 45 | Payload NotificationPayload |
| 46 | } |
| 47 | |
| 48 | // NotificationPayload holds the payload fields of a notification. |
| 49 | type NotificationPayload struct { |
| 50 | ID string |
| 51 | UserID string |
| 52 | OrganizationID string |
| 53 | CollectionIDs []string |
| 54 | RevisionDate string |
| 55 | Date string |
| 56 | } |
| 57 | |
| 58 | // NotificationHandler is called when a notification is received. |
| 59 | type NotificationHandler func(msg NotificationMessage) |
| 60 | |
| 61 | // NotificationsClient manages a WebSocket connection to the notifications hub. |
| 62 | type NotificationsClient struct { |
| 63 | transport *Transport |
| 64 | handler NotificationHandler |
| 65 | logger *log.Logger |
| 66 | |
| 67 | mu sync.Mutex |
| 68 | cancel context.CancelFunc |
| 69 | done chan struct{} |
| 70 | } |
| 71 | |
| 72 | // NewNotificationsClient creates a new notifications client. |
| 73 | func NewNotificationsClient(transport *Transport, handler NotificationHandler) *NotificationsClient { |
| 74 | return &NotificationsClient{ |
| 75 | transport: transport, |
| 76 | handler: handler, |
| 77 | logger: log.New(io.Discard, "", 0), |
| 78 | } |
| 79 | } |
| 80 | |
| 81 | // SetLogger sets a logger for debug output. By default, logging is disabled. |
| 82 | func (nc *NotificationsClient) SetLogger(l *log.Logger) { |
| 83 | nc.logger = l |
| 84 | } |
| 85 | |
| 86 | // Connect starts the background WebSocket connection with automatic reconnection. |
| 87 | // The provided context controls the lifetime of the connection — cancel it to stop. |
| 88 | func (nc *NotificationsClient) Connect(ctx context.Context) error { |
| 89 | nc.mu.Lock() |
| 90 | defer nc.mu.Unlock() |
| 91 | |
| 92 | if nc.done != nil { |
| 93 | return fmt.Errorf("notifications already connected") |
| 94 | } |
| 95 | |
| 96 | ctx, cancel := context.WithCancel(ctx) |
| 97 | nc.cancel = cancel |
| 98 | nc.done = make(chan struct{}) |
| 99 | |
| 100 | go nc.run(ctx) |
| 101 | return nil |
| 102 | } |
| 103 | |
| 104 | // Close stops the WebSocket connection and waits for the background goroutine to exit. |
| 105 | func (nc *NotificationsClient) Close() error { |
| 106 | nc.mu.Lock() |
| 107 | cancel := nc.cancel |
| 108 | done := nc.done |
| 109 | nc.mu.Unlock() |
| 110 | |
| 111 | if cancel == nil { |
| 112 | return nil |
| 113 | } |
| 114 | |
| 115 | cancel() |
| 116 | if done != nil { |
| 117 | <-done |
| 118 | } |
| 119 | |
| 120 | nc.mu.Lock() |
| 121 | nc.cancel = nil |
| 122 | nc.done = nil |
| 123 | nc.mu.Unlock() |
| 124 | |
| 125 | return nil |
| 126 | } |
| 127 | |
| 128 | func (nc *NotificationsClient) run(ctx context.Context) { |
| 129 | defer close(nc.done) |
| 130 | |
| 131 | backoff := 5 * time.Second |
| 132 | const maxBackoff = 60 * time.Second |
| 133 | |
| 134 | for { |
| 135 | err := nc.connectAndListen(ctx) |
| 136 | if ctx.Err() != nil { |
| 137 | return |
| 138 | } |
| 139 | nc.logger.Printf("notifications disconnected: %v; reconnecting in %v", err, backoff) |
| 140 | |
| 141 | select { |
| 142 | case <-ctx.Done(): |
| 143 | return |
| 144 | case <-time.After(backoff): |
| 145 | } |
| 146 | |
| 147 | backoff *= 2 |
| 148 | if backoff > maxBackoff { |
| 149 | backoff = maxBackoff |
| 150 | } |
| 151 | } |
| 152 | } |
| 153 | |
| 154 | func (nc *NotificationsClient) connectAndListen(ctx context.Context) error { |
| 155 | token, err := nc.transport.AccessToken(ctx) |
| 156 | if err != nil { |
| 157 | return fmt.Errorf("get access token: %w", err) |
| 158 | } |
| 159 | |
| 160 | wsURL := httpToWS(nc.transport.BaseURL()) + "/notifications/hub?access_token=" + token |
| 161 | |
| 162 | conn, _, err := websocket.Dial(ctx, wsURL, nil) |
| 163 | if err != nil { |
| 164 | return fmt.Errorf("dial: %w", err) |
| 165 | } |
| 166 | defer conn.Close(websocket.StatusNormalClosure, "") |
| 167 | |
| 168 | // SignalR handshake — always JSON, terminated by 0x1e |
| 169 | if err := nc.doHandshake(ctx, conn); err != nil { |
| 170 | return fmt.Errorf("handshake: %w", err) |
| 171 | } |
| 172 | |
| 173 | nc.logger.Printf("notifications connected") |
| 174 | |
| 175 | // Reset backoff on successful connection (done by caller seeing nil vs error) |
| 176 | errCh := make(chan error, 1) |
| 177 | |
| 178 | // Reader goroutine |
| 179 | go func() { |
| 180 | errCh <- nc.readMessages(ctx, conn) |
| 181 | }() |
| 182 | |
| 183 | // Ping ticker |
| 184 | ticker := time.NewTicker(15 * time.Second) |
| 185 | defer ticker.Stop() |
| 186 | |
| 187 | for { |
| 188 | select { |
| 189 | case <-ctx.Done(): |
| 190 | return ctx.Err() |
| 191 | case err := <-errCh: |
| 192 | return err |
| 193 | case <-ticker.C: |
| 194 | if err := nc.sendPing(ctx, conn); err != nil { |
| 195 | return fmt.Errorf("send ping: %w", err) |
| 196 | } |
| 197 | } |
| 198 | } |
| 199 | } |
| 200 | |
| 201 | const recordSeparator = 0x1e |
| 202 | |
| 203 | func (nc *NotificationsClient) doHandshake(ctx context.Context, conn *websocket.Conn) error { |
| 204 | // Send handshake: {"protocol":"messagepack","version":1}\x1e |
| 205 | handshake := []byte(`{"protocol":"messagepack","version":1}`) |
| 206 | handshake = append(handshake, recordSeparator) |
| 207 | |
| 208 | if err := conn.Write(ctx, websocket.MessageText, handshake); err != nil { |
| 209 | return fmt.Errorf("write handshake: %w", err) |
| 210 | } |
| 211 | |
| 212 | // Read handshake response |
| 213 | _, data, err := conn.Read(ctx) |
| 214 | if err != nil { |
| 215 | return fmt.Errorf("read handshake response: %w", err) |
| 216 | } |
| 217 | |
| 218 | // Strip record separator |
| 219 | data = bytes.TrimRight(data, string([]byte{recordSeparator})) |
| 220 | |
| 221 | var resp map[string]any |
| 222 | if err := json.Unmarshal(data, &resp); err != nil { |
| 223 | return fmt.Errorf("parse handshake response: %w", err) |
| 224 | } |
| 225 | |
| 226 | // Check for error |
| 227 | if errMsg, ok := resp["error"]; ok { |
| 228 | return fmt.Errorf("handshake error: %v", errMsg) |
| 229 | } |
| 230 | |
| 231 | return nil |
| 232 | } |
| 233 | |
| 234 | func (nc *NotificationsClient) readMessages(ctx context.Context, conn *websocket.Conn) error { |
| 235 | for { |
| 236 | _, data, err := conn.Read(ctx) |
| 237 | if err != nil { |
| 238 | return fmt.Errorf("read: %w", err) |
| 239 | } |
| 240 | |
| 241 | // Binary frames contain one or more length-prefixed MessagePack messages |
| 242 | reader := bytes.NewReader(data) |
| 243 | for reader.Len() > 0 { |
| 244 | msgLen, err := readVarInt(reader) |
| 245 | if err != nil { |
| 246 | return fmt.Errorf("read varint: %w", err) |
| 247 | } |
| 248 | |
| 249 | if msgLen <= 0 || msgLen > reader.Len() { |
| 250 | return fmt.Errorf("invalid message length: %d (remaining: %d)", msgLen, reader.Len()) |
| 251 | } |
| 252 | |
| 253 | msgData := make([]byte, msgLen) |
| 254 | if _, err := io.ReadFull(reader, msgData); err != nil { |
| 255 | return fmt.Errorf("read message: %w", err) |
| 256 | } |
| 257 | |
| 258 | nc.handleMessage(msgData) |
| 259 | } |
| 260 | } |
| 261 | } |
| 262 | |
| 263 | func (nc *NotificationsClient) handleMessage(data []byte) { |
| 264 | // Decode as generic array |
| 265 | var msg []any |
| 266 | if err := msgpack.Unmarshal(data, &msg); err != nil { |
| 267 | nc.logger.Printf("failed to decode message: %v", err) |
| 268 | return |
| 269 | } |
| 270 | |
| 271 | if len(msg) == 0 { |
| 272 | return |
| 273 | } |
| 274 | |
| 275 | msgType, ok := asInt(msg[0]) |
| 276 | if !ok { |
| 277 | return |
| 278 | } |
| 279 | |
| 280 | switch msgType { |
| 281 | case 1: // Invocation |
| 282 | nc.handleInvocation(msg) |
| 283 | case 6: // Ping — no action needed, server pinging us |
| 284 | nc.logger.Printf("received server ping") |
| 285 | case 7: // Close |
| 286 | nc.logger.Printf("received close message") |
| 287 | } |
| 288 | } |
| 289 | |
| 290 | func (nc *NotificationsClient) handleInvocation(msg []any) { |
| 291 | // Format: [1, headers, invocationId, target, arguments] |
| 292 | if len(msg) < 5 { |
| 293 | return |
| 294 | } |
| 295 | |
| 296 | target, _ := msg[3].(string) |
| 297 | if target != "ReceiveMessage" { |
| 298 | nc.logger.Printf("unknown invocation target: %s", target) |
| 299 | return |
| 300 | } |
| 301 | |
| 302 | args, ok := msg[4].([]any) |
| 303 | if !ok || len(args) == 0 { |
| 304 | return |
| 305 | } |
| 306 | |
| 307 | // The first argument is the notification object |
| 308 | argMap, ok := args[0].(map[string]any) |
| 309 | if !ok { |
| 310 | return |
| 311 | } |
| 312 | |
| 313 | notifMsg := NotificationMessage{} |
| 314 | if ct, ok := argMap["ContextId"]; ok { |
| 315 | notifMsg.ContextID, _ = ct.(string) |
| 316 | } |
| 317 | if t, ok := asInt(argMap["Type"]); ok { |
| 318 | notifMsg.Type = NotificationType(t) |
| 319 | } |
| 320 | |
| 321 | // Payload can be a string (JSON) or a map |
| 322 | if payloadStr, ok := argMap["Payload"].(string); ok && payloadStr != "" { |
| 323 | var p struct { |
| 324 | ID string `json:"id"` |
| 325 | UserID string `json:"userId"` |
| 326 | OrganizationID string `json:"organizationId"` |
| 327 | CollectionIDs []string `json:"collectionIds"` |
| 328 | RevisionDate string `json:"revisionDate"` |
| 329 | Date string `json:"date"` |
| 330 | } |
| 331 | if err := json.Unmarshal([]byte(payloadStr), &p); err == nil { |
| 332 | notifMsg.Payload = NotificationPayload{ |
| 333 | ID: p.ID, |
| 334 | UserID: p.UserID, |
| 335 | OrganizationID: p.OrganizationID, |
| 336 | CollectionIDs: p.CollectionIDs, |
| 337 | RevisionDate: p.RevisionDate, |
| 338 | Date: p.Date, |
| 339 | } |
| 340 | } |
| 341 | } |
| 342 | |
| 343 | nc.handler(notifMsg) |
| 344 | } |
| 345 | |
| 346 | func (nc *NotificationsClient) sendPing(ctx context.Context, conn *websocket.Conn) error { |
| 347 | // MessagePack [6] = 0x91 0x06, with VarInt length prefix 0x02 |
| 348 | ping := []byte{0x02, 0x91, 0x06} |
| 349 | return conn.Write(ctx, websocket.MessageBinary, ping) |
| 350 | } |
| 351 | |
| 352 | // httpToWS converts an HTTP(S) URL to a WS(S) URL. |
| 353 | func httpToWS(u string) string { |
| 354 | if strings.HasPrefix(u, "https://") { |
| 355 | return "wss://" + u[len("https://"):] |
| 356 | } |
| 357 | if strings.HasPrefix(u, "http://") { |
| 358 | return "ws://" + u[len("http://"):] |
| 359 | } |
| 360 | return u |
| 361 | } |
| 362 | |
| 363 | // readVarInt reads a variable-length integer (SignalR MessagePack framing). |
| 364 | // Uses 7 bits per byte with the high bit as continuation flag. |
| 365 | func readVarInt(r io.Reader) (int, error) { |
| 366 | var result int |
| 367 | var shift uint |
| 368 | buf := make([]byte, 1) |
| 369 | for { |
| 370 | if _, err := io.ReadFull(r, buf); err != nil { |
| 371 | return 0, err |
| 372 | } |
| 373 | b := buf[0] |
| 374 | result |= int(b&0x7F) << shift |
| 375 | if b&0x80 == 0 { |
| 376 | return result, nil |
| 377 | } |
| 378 | shift += 7 |
| 379 | if shift >= 35 { |
| 380 | return 0, fmt.Errorf("varint too long") |
| 381 | } |
| 382 | } |
| 383 | } |
| 384 | |
| 385 | // writeVarInt writes a variable-length integer to a buffer. |
| 386 | func writeVarInt(buf *bytes.Buffer, n int) { |
| 387 | for { |
| 388 | b := byte(n & 0x7F) |
| 389 | n >>= 7 |
| 390 | if n > 0 { |
| 391 | b |= 0x80 |
| 392 | } |
| 393 | buf.WriteByte(b) |
| 394 | if n == 0 { |
| 395 | break |
| 396 | } |
| 397 | } |
| 398 | } |
| 399 | |
| 400 | // asInt extracts an integer from a msgpack-decoded interface{} value, |
| 401 | // which can be int8, int16, int32, int64, uint8, uint16, uint32, uint64. |
| 402 | func asInt(v any) (int, bool) { |
| 403 | switch n := v.(type) { |
| 404 | case int: |
| 405 | return n, true |
| 406 | case int8: |
| 407 | return int(n), true |
| 408 | case int16: |
| 409 | return int(n), true |
| 410 | case int32: |
| 411 | return int(n), true |
| 412 | case int64: |
| 413 | return int(n), true |
| 414 | case uint8: |
| 415 | return int(n), true |
| 416 | case uint16: |
| 417 | return int(n), true |
| 418 | case uint32: |
| 419 | return int(n), true |
| 420 | case uint64: |
| 421 | return int(n), true |
| 422 | default: |
| 423 | return 0, false |
| 424 | } |
| 425 | } |
| 426 | |