transport.go

v1.0.0
Doc Versions Source
1
package api
2
3
import (
4
	"bytes"
5
	"context"
6
	"encoding/json"
7
	"fmt"
8
	"io"
9
	"net/http"
10
	"net/url"
11
	"sync"
12
	"time"
13
)
14
15
// Transport provides authenticated HTTP access to the Bitwarden/Vaultwarden server API.
16
type Transport struct {
17
	baseURL    string
18
	httpClient *http.Client
19
20
	mu           sync.RWMutex
21
	accessToken  string
22
	refreshToken string
23
	tokenExpiry  time.Time
24
25
	// refreshFunc is called when the access token has expired.
26
	// It receives the current refresh token and returns a new TokenResponse.
27
	refreshFunc func(ctx context.Context, refreshToken string) (*TokenResponse, error)
28
}
29
30
// NewTransport creates a new authenticated transport.
31
func NewTransport(baseURL string, httpClient *http.Client) *Transport {
32
	if httpClient == nil {
33
		httpClient = &http.Client{Timeout: 30 * time.Second}
34
	}
35
	return &Transport{
36
		baseURL:    baseURL,
37
		httpClient: httpClient,
38
	}
39
}
40
41
// SetTokens stores the access and refresh tokens from a login response.
42
func (t *Transport) SetTokens(accessToken, refreshToken string, expiresIn int) {
43
	t.mu.Lock()
44
	defer t.mu.Unlock()
45
	t.accessToken = accessToken
46
	t.refreshToken = refreshToken
47
	// Expire 30 seconds early to avoid edge cases.
48
	t.tokenExpiry = time.Now().Add(time.Duration(expiresIn)*time.Second - 30*time.Second)
49
}
50
51
// SetRefreshFunc sets the callback used to refresh expired tokens.
52
func (t *Transport) SetRefreshFunc(fn func(ctx context.Context, refreshToken string) (*TokenResponse, error)) {
53
	t.mu.Lock()
54
	defer t.mu.Unlock()
55
	t.refreshFunc = fn
56
}
57
58
// AccessToken returns the current access token, refreshing if expired.
59
func (t *Transport) AccessToken(ctx context.Context) (string, error) {
60
	t.mu.RLock()
61
	token := t.accessToken
62
	expiry := t.tokenExpiry
63
	t.mu.RUnlock()
64
65
	if token != "" && time.Now().Before(expiry) {
66
		return token, nil
67
	}
68
69
	return t.refreshAccessToken(ctx)
70
}
71
72
func (t *Transport) refreshAccessToken(ctx context.Context) (string, error) {
73
	t.mu.Lock()
74
	defer t.mu.Unlock()
75
76
	// Double-check after acquiring write lock.
77
	if t.accessToken != "" && time.Now().Before(t.tokenExpiry) {
78
		return t.accessToken, nil
79
	}
80
81
	if t.refreshToken == "" || t.refreshFunc == nil {
82
		return "", fmt.Errorf("no refresh token or refresh function available")
83
	}
84
85
	resp, err := t.refreshFunc(ctx, t.refreshToken)
86
	if err != nil {
87
		return "", fmt.Errorf("refresh token: %w", err)
88
	}
89
90
	t.accessToken = resp.AccessToken
91
	t.refreshToken = resp.RefreshToken
92
	t.tokenExpiry = time.Now().Add(time.Duration(resp.ExpiresIn)*time.Second - 30*time.Second)
93
94
	return t.accessToken, nil
95
}
96
97
// ClearTokens removes all stored tokens.
98
func (t *Transport) ClearTokens() {
99
	t.mu.Lock()
100
	defer t.mu.Unlock()
101
	t.accessToken = ""
102
	t.refreshToken = ""
103
	t.tokenExpiry = time.Time{}
104
}
105
106
// BaseURL returns the configured base URL.
107
func (t *Transport) BaseURL() string {
108
	return t.baseURL
109
}
110
111
// Get performs an authenticated GET request.
112
func (t *Transport) Get(ctx context.Context, path string, query url.Values) ([]byte, error) {
113
	return t.do(ctx, http.MethodGet, path, query, nil)
114
}
115
116
// Post performs an authenticated POST request with a JSON body.
117
func (t *Transport) Post(ctx context.Context, path string, body any) ([]byte, error) {
118
	return t.doJSON(ctx, http.MethodPost, path, body)
119
}
120
121
// Put performs an authenticated PUT request with a JSON body.
122
func (t *Transport) Put(ctx context.Context, path string, body any) ([]byte, error) {
123
	return t.doJSON(ctx, http.MethodPut, path, body)
124
}
125
126
// Delete performs an authenticated DELETE request.
127
func (t *Transport) Delete(ctx context.Context, path string) ([]byte, error) {
128
	return t.do(ctx, http.MethodDelete, path, nil, nil)
129
}
130
131
// PostForm performs an authenticated POST request with form-encoded body.
132
// Used for the token endpoint.
133
func (t *Transport) PostForm(ctx context.Context, fullURL string, form url.Values) ([]byte, error) {
134
	req, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader([]byte(form.Encode())))
135
	if err != nil {
136
		return nil, fmt.Errorf("create request: %w", err)
137
	}
138
	req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
139
140
	return t.executeRequest(req)
141
}
142
143
// PostFormNoAuth performs an unauthenticated POST with form-encoded body.
144
func (t *Transport) PostFormNoAuth(ctx context.Context, fullURL string, form url.Values) ([]byte, error) {
145
	req, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader([]byte(form.Encode())))
146
	if err != nil {
147
		return nil, fmt.Errorf("create request: %w", err)
148
	}
149
	req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
150
151
	return t.executeRequestRaw(req)
152
}
153
154
// PostNoAuth performs an unauthenticated POST request with a JSON body.
155
func (t *Transport) PostNoAuth(ctx context.Context, path string, body any) ([]byte, error) {
156
	var reader io.Reader
157
	if body != nil {
158
		data, err := json.Marshal(body)
159
		if err != nil {
160
			return nil, fmt.Errorf("marshal request body: %w", err)
161
		}
162
		reader = bytes.NewReader(data)
163
	}
164
165
	u := t.baseURL + path
166
	req, err := http.NewRequestWithContext(ctx, http.MethodPost, u, reader)
167
	if err != nil {
168
		return nil, fmt.Errorf("create request: %w", err)
169
	}
170
	if body != nil {
171
		req.Header.Set("Content-Type", "application/json")
172
	}
173
174
	return t.executeRequestRaw(req)
175
}
176
177
func (t *Transport) doJSON(ctx context.Context, method, path string, body any) ([]byte, error) {
178
	var reader io.Reader
179
	if body != nil {
180
		data, err := json.Marshal(body)
181
		if err != nil {
182
			return nil, fmt.Errorf("marshal request body: %w", err)
183
		}
184
		reader = bytes.NewReader(data)
185
	}
186
	return t.do(ctx, method, path, nil, reader)
187
}
188
189
func (t *Transport) do(ctx context.Context, method, path string, query url.Values, body io.Reader) ([]byte, error) {
190
	u := t.baseURL + path
191
	if len(query) > 0 {
192
		u += "?" + query.Encode()
193
	}
194
195
	req, err := http.NewRequestWithContext(ctx, method, u, body)
196
	if err != nil {
197
		return nil, fmt.Errorf("create request: %w", err)
198
	}
199
	if body != nil {
200
		req.Header.Set("Content-Type", "application/json")
201
	}
202
203
	// Add auth header.
204
	token, err := t.AccessToken(ctx)
205
	if err != nil {
206
		return nil, fmt.Errorf("get access token: %w", err)
207
	}
208
	req.Header.Set("Authorization", "Bearer "+token)
209
210
	return t.executeRequest(req)
211
}
212
213
func (t *Transport) executeRequest(req *http.Request) ([]byte, error) {
214
	respBody, err := t.executeRequestRaw(req)
215
	if err != nil {
216
		return nil, err
217
	}
218
	return respBody, nil
219
}
220
221
func (t *Transport) executeRequestRaw(req *http.Request) ([]byte, error) {
222
	resp, err := t.httpClient.Do(req)
223
	if err != nil {
224
		return nil, fmt.Errorf("execute request: %w", err)
225
	}
226
	defer resp.Body.Close()
227
228
	respBody, err := io.ReadAll(resp.Body)
229
	if err != nil {
230
		return nil, fmt.Errorf("read response: %w", err)
231
	}
232
233
	if resp.StatusCode >= 400 {
234
		return respBody, parseServerError(resp.StatusCode, respBody)
235
	}
236
237
	return respBody, nil
238
}
239
240
// parseServerError extracts a meaningful error from a server error response.
241
func parseServerError(statusCode int, body []byte) error {
242
	var errResp ErrorResponse
243
	if err := json.Unmarshal(body, &errResp); err == nil {
244
		msg := errResp.Message
245
		if msg == "" && errResp.ErrorModel != nil {
246
			msg = errResp.ErrorModel.Message
247
		}
248
		if msg == "" {
249
			msg = errResp.ErrorDescription
250
		}
251
		return &ServerError{StatusCode: statusCode, Msg: msg, Raw: string(body)}
252
	}
253
	return &ServerError{StatusCode: statusCode, Raw: string(body)}
254
}
255
256
// ServerError represents an error response from the Bitwarden/Vaultwarden server.
257
type ServerError struct {
258
	StatusCode int
259
	Msg        string
260
	Raw        string
261
}
262
263
func (e *ServerError) Error() string {
264
	if e.Msg != "" {
265
		return fmt.Sprintf("bitwarden server: HTTP %d: %s", e.StatusCode, e.Msg)
266
	}
267
	return fmt.Sprintf("bitwarden server: HTTP %d", e.StatusCode)
268
}
269

Source Files