server.go

v0.6.1
Doc Versions Source
1
package proxy
2
3
import (
4
	"bufio"
5
	"bytes"
6
	"encoding/json"
7
	"errors"
8
	"fmt"
9
	"io"
10
	"maps"
11
	"net"
12
	"net/http"
13
	"net/url"
14
	"strings"
15
	"time"
16
)
17
18
func extractDomain(rawURL string) string {
19
	if u, err := url.Parse(rawURL); err == nil {
20
		return u.Hostname()
21
	}
22
	return rawURL
23
}
24
25
// Start launches a local Anthropic-to-OpenAI translation proxy.
26
// upstreamURL is the OpenAI-compatible base URL (e.g. "https://api.openai.com/v1").
27
// Returns the proxy address (host:port), a shutdown function, and any error.
28
func Start(upstreamURL string, tokenFunc func() (string, error), extraHeaders map[string]string, proxyURL string, providerName string) (addr string, shutdown func(), err error) {
29
	upstream := strings.TrimRight(upstreamURL, "/")
30
	domain := extractDomain(upstream)
31
32
	mux := http.NewServeMux()
33
	mux.HandleFunc("/v1/messages", handler(upstream, tokenFunc, extraHeaders, proxyURL, providerName, domain))
34
	mux.HandleFunc("/v1/messages/count_tokens", countTokensHandler())
35
36
	ln, err := net.Listen("tcp", "127.0.0.1:0")
37
	if err != nil {
38
		return "", nil, err
39
	}
40
41
	srv := &http.Server{Handler: mux}
42
	go func() {
43
		if err := srv.Serve(ln); err != nil && !errors.Is(err, http.ErrServerClosed) {
44
			plog().Error("proxy serve failed", "error", err)
45
		}
46
	}()
47
48
	return ln.Addr().String(), func() { srv.Close() }, nil
49
}
50
51
func handler(upstream string, tokenFunc func() (string, error), extraHeaders map[string]string, proxyURL string, providerName, domain string) http.HandlerFunc {
52
	transport := http.DefaultTransport.(*http.Transport).Clone()
53
	if proxyURL != "" {
54
		if u, err := url.Parse(proxyURL); err == nil {
55
			transport.Proxy = http.ProxyURL(u)
56
		}
57
	} else {
58
		transport.Proxy = nil
59
	}
60
	client := &http.Client{Transport: transport}
61
62
	return func(w http.ResponseWriter, r *http.Request) {
63
		plog().Info("request", "method", r.Method, "path", r.URL.Path, "from", r.RemoteAddr)
64
65
		body, err := io.ReadAll(r.Body)
66
		if err != nil {
67
			plog().Error("reading body", "error", err)
68
			httpError(w, http.StatusBadRequest, "reading body: "+err.Error())
69
			return
70
		}
71
		r.Body = io.NopCloser(bytes.NewReader(body))
72
73
		rawLogRequest(r, body, "INCOMING REQUEST from Claude Code")
74
75
		var req anthropicRequest
76
		if err := json.Unmarshal(body, &req); err != nil {
77
			plog().Error("parsing request", "error", err)
78
			httpError(w, http.StatusBadRequest, "parsing request: "+err.Error())
79
			return
80
		}
81
82
		plog().Info("request details", "model", req.Model, "stream", req.Stream, "max_tokens", req.MaxTokens, "messages", len(req.Messages))
83
84
		// Intercept quota-probe requests (max_tokens=1) — return a mock
85
		// response instead of hitting the upstream API.
86
		if isQuotaProbe(&req) {
87
			plog().Info("intercepting quota probe")
88
			DefaultMetrics.Record(CallRecord{
89
				Provider:       providerName,
90
				UpstreamDomain: domain,
91
				Endpoint:       r.URL.Path,
92
				Model:          req.Model,
93
				QuotaProbe:     true,
94
			})
95
			antResp := &anthropicResponse{
96
				ID:      "msg_" + genHex(12),
97
				Type:    "message",
98
				Role:    "assistant",
99
				Content: []contentBlock{{Type: "text", Text: "ok"}},
100
				Model:   req.Model,
101
				Usage:   anthropicUsage{InputTokens: 1, OutputTokens: 1},
102
			}
103
			stopReason := "end_turn"
104
			antResp.StopReason = &stopReason
105
			w.Header().Set("Content-Type", "application/json")
106
			json.NewEncoder(w).Encode(antResp)
107
			return
108
		}
109
110
		oaiReq := convertRequest(&req)
111
		oaiBody, _ := json.Marshal(oaiReq)
112
113
		upReq, _ := http.NewRequestWithContext(r.Context(), "POST",
114
			upstream+"/chat/completions", bytes.NewReader(oaiBody))
115
		upReq.Header.Set("Content-Type", "application/json")
116
		token, err := tokenFunc()
117
		if err != nil {
118
			plog().Error("getting token", "error", err)
119
			httpError(w, http.StatusBadGateway, "token: "+err.Error())
120
			return
121
		}
122
		upReq.Header.Set("Authorization", "Bearer "+token)
123
		for k, v := range extraHeaders {
124
			upReq.Header.Set(k, v)
125
		}
126
127
		plog().Info("upstream request", "url", upstream+"/chat/completions")
128
		rawLogRequest(upReq, oaiBody, "OUTGOING REQUEST to upstream API")
129
130
		callStart := time.Now()
131
		resp, err := client.Do(upReq)
132
		if err != nil {
133
			plog().Error("upstream request failed", "error", err)
134
			DefaultMetrics.Record(CallRecord{
135
				Provider:       providerName,
136
				UpstreamDomain: domain,
137
				Endpoint:       r.URL.Path,
138
				Model:          req.Model,
139
				Stream:         req.Stream,
140
				Duration:       time.Since(callStart),
141
				Error:          true,
142
			})
143
			httpError(w, http.StatusBadGateway, "upstream: "+err.Error())
144
			return
145
		}
146
		defer resp.Body.Close()
147
148
		plog().Info("upstream response", "status", resp.StatusCode)
149
150
		// Read the response body for logging.
151
		respBody, err := io.ReadAll(resp.Body)
152
		if err != nil {
153
			plog().Error("reading response body", "error", err)
154
		}
155
		resp.Body = io.NopCloser(bytes.NewReader(respBody))
156
157
		rawLogResponse(resp, respBody, "UPSTREAM RESPONSE")
158
159
		if resp.StatusCode != http.StatusOK {
160
			plog().Error("upstream error", "status", resp.Status, "body", string(respBody))
161
			DefaultMetrics.Record(CallRecord{
162
				Provider:       providerName,
163
				UpstreamDomain: domain,
164
				Endpoint:       r.URL.Path,
165
				Model:          req.Model,
166
				StatusCode:     resp.StatusCode,
167
				Stream:         req.Stream,
168
				Duration:       time.Since(callStart),
169
				Error:          true,
170
			})
171
			// Translate to Anthropic error format so Claude Code understands it.
172
			w.Header().Set("Content-Type", "application/json")
173
			w.WriteHeader(resp.StatusCode)
174
			json.NewEncoder(w).Encode(map[string]any{
175
				"type": "error",
176
				"error": map[string]any{
177
					"type":    "api_error",
178
					"message": fmt.Sprintf("upstream %s: %s", resp.Status, respBody),
179
				},
180
			})
181
			return
182
		}
183
184
		plog().Info("processing response", "model", req.Model, "stream", req.Stream)
185
186
		rec := CallRecord{
187
			Provider:       providerName,
188
			UpstreamDomain: domain,
189
			Endpoint:       r.URL.Path,
190
			Model:          req.Model,
191
			StatusCode:     resp.StatusCode,
192
			Stream:         req.Stream,
193
		}
194
195
		if req.Stream {
196
			usage := handleStream(w, resp, req.Model)
197
			if usage != nil {
198
				rec.InputTokens = usage.PromptTokens
199
				rec.OutputTokens = usage.CompletionTokens
200
			}
201
		} else {
202
			usage := handleNonStream(w, resp)
203
			if usage != nil {
204
				rec.InputTokens = usage.PromptTokens
205
				rec.OutputTokens = usage.CompletionTokens
206
			}
207
		}
208
209
		rec.Duration = time.Since(callStart)
210
		DefaultMetrics.Record(rec)
211
	}
212
}
213
214
func handleNonStream(w http.ResponseWriter, resp *http.Response) *openaiUsage {
215
	var oaiResp openaiResponse
216
	if err := json.NewDecoder(resp.Body).Decode(&oaiResp); err != nil {
217
		plog().Error("decoding upstream response", "error", err)
218
		httpError(w, http.StatusBadGateway, "decoding upstream: "+err.Error())
219
		return nil
220
	}
221
	plog().Info("non-stream response", "input_tokens", oaiResp.Usage.PromptTokens, "output_tokens", oaiResp.Usage.CompletionTokens)
222
223
	antResp := convertResponse(&oaiResp)
224
	antBody, _ := json.Marshal(antResp)
225
226
	rawLogResponse(&http.Response{
227
		Status:     "200 OK",
228
		StatusCode: 200,
229
		Header:     http.Header{"Content-Type": []string{"application/json"}},
230
	}, antBody, "RESPONSE TO CLAUDE CODE")
231
232
	w.Header().Set("Content-Type", "application/json")
233
	json.NewEncoder(w).Encode(antResp)
234
	return oaiResp.Usage
235
}
236
237
func handleStream(w http.ResponseWriter, resp *http.Response, model string) *openaiUsage {
238
	flusher, ok := w.(http.Flusher)
239
	if !ok {
240
		httpError(w, http.StatusInternalServerError, "streaming not supported")
241
		return nil
242
	}
243
244
	w.Header().Set("Content-Type", "text/event-stream")
245
	w.Header().Set("Cache-Control", "no-cache")
246
	w.Header().Set("Connection", "keep-alive")
247
248
	sw := &streamWriter{
249
		w:          w,
250
		flusher:    flusher,
251
		msgID:      "msg_" + genHex(12),
252
		model:      model,
253
		toolStarts: make(map[int]int),
254
	}
255
256
	plog().Info("stream start", "model", model)
257
258
	var streamLog strings.Builder
259
	scanner := bufio.NewScanner(resp.Body)
260
	scanner.Buffer(make([]byte, 0, 256*1024), 256*1024)
261
262
	chunkCount := 0
263
	for scanner.Scan() {
264
		line := scanner.Text()
265
		streamLog.WriteString(line + "\n")
266
		if !strings.HasPrefix(line, "data: ") {
267
			continue
268
		}
269
		data := strings.TrimPrefix(line, "data: ")
270
		if data == "[DONE]" {
271
			plog().Debug("stream DONE signal")
272
			streamLog.WriteString("\n[END OF STREAM]\n")
273
			break
274
		}
275
		var chunk openaiStreamChunk
276
		if json.Unmarshal([]byte(data), &chunk) != nil {
277
			plog().Error("unmarshal stream chunk", "data", data[:min(len(data), 200)])
278
			continue
279
		}
280
		sw.processChunk(chunk)
281
		chunkCount++
282
	}
283
284
	if err := scanner.Err(); err != nil {
285
		plog().Error("scanner error", "error", err)
286
	}
287
288
	rawLogStream("STREAM RESPONSE FROM UPSTREAM", streamLog.String(), chunkCount)
289
290
	plog().Info("stream finished", "chunks", chunkCount)
291
	sw.finish()
292
	return sw.usage
293
}
294
295
// streamWriter translates OpenAI SSE chunks into Anthropic SSE events.
296
type streamWriter struct {
297
	w          http.ResponseWriter
298
	flusher    http.Flusher
299
	msgID      string
300
	model      string
301
	started    bool
302
	blockIndex int
303
	inText     bool
304
	inTool     bool
305
	toolStarts map[int]int // openai tool index → anthropic block index
306
	usage      *openaiUsage
307
}
308
309
func (s *streamWriter) processChunk(chunk openaiStreamChunk) {
310
	if chunk.Usage != nil {
311
		s.usage = chunk.Usage
312
	}
313
	if len(chunk.Choices) == 0 {
314
		return
315
	}
316
317
	if !s.started {
318
		s.emitMessageStart()
319
		s.started = true
320
	}
321
322
	delta := chunk.Choices[0].Delta
323
324
	// Text content.
325
	if delta.Content != nil && *delta.Content != "" {
326
		if !s.inText {
327
			s.closeCurrentBlock()
328
			s.emitContentBlockStart("text", "", "")
329
			s.inText = true
330
		}
331
		s.emit("content_block_delta", fmt.Sprintf(
332
			`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":%s}}`,
333
			s.blockIndex, mustJSON(*delta.Content)))
334
	}
335
336
	// Tool calls.
337
	for _, tc := range delta.ToolCalls {
338
		if tc.ID != "" {
339
			// New tool call starting.
340
			s.closeCurrentBlock()
341
			s.emitContentBlockStart("tool_use", tc.ID, tc.Function.Name)
342
			s.toolStarts[tc.Index] = s.blockIndex
343
			s.inTool = true
344
		}
345
		if tc.Function != nil && tc.Function.Arguments != "" {
346
			idx, ok := s.toolStarts[tc.Index]
347
			if !ok {
348
				idx = s.blockIndex
349
			}
350
			s.emit("content_block_delta", fmt.Sprintf(
351
				`{"type":"content_block_delta","index":%d,"delta":{"type":"input_json_delta","partial_json":%s}}`,
352
				idx, mustJSON(tc.Function.Arguments)))
353
		}
354
	}
355
356
	// Finish reason.
357
	if chunk.Choices[0].FinishReason != nil {
358
		s.closeCurrentBlock()
359
	}
360
}
361
362
func (s *streamWriter) finish() {
363
	if !s.started {
364
		s.emitMessageStart()
365
	}
366
	s.closeCurrentBlock()
367
368
	stopReason := "end_turn"
369
	outTokens := 0
370
	if s.usage != nil {
371
		outTokens = s.usage.CompletionTokens
372
	}
373
374
	s.emit("message_delta", fmt.Sprintf(
375
		`{"type":"message_delta","delta":{"stop_reason":"%s","stop_sequence":null},"usage":{"output_tokens":%d}}`,
376
		stopReason, outTokens))
377
	s.emit("message_stop", `{"type":"message_stop"}`)
378
}
379
380
func (s *streamWriter) emitMessageStart() {
381
	inTokens := 0
382
	if s.usage != nil {
383
		inTokens = s.usage.PromptTokens
384
	}
385
	s.emit("message_start", fmt.Sprintf(
386
		`{"type":"message_start","message":{"id":"%s","type":"message","role":"assistant","content":[],"model":"%s","stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":%d,"output_tokens":0}}}`,
387
		s.msgID, s.model, inTokens))
388
}
389
390
func (s *streamWriter) emitContentBlockStart(typ, id, name string) {
391
	switch typ {
392
	case "text":
393
		s.emit("content_block_start", fmt.Sprintf(
394
			`{"type":"content_block_start","index":%d,"content_block":{"type":"text","text":""}}`,
395
			s.blockIndex))
396
	case "tool_use":
397
		s.emit("content_block_start", fmt.Sprintf(
398
			`{"type":"content_block_start","index":%d,"content_block":{"type":"tool_use","id":"%s","name":"%s","input":{}}}`,
399
			s.blockIndex, id, name))
400
	}
401
}
402
403
func (s *streamWriter) closeCurrentBlock() {
404
	if s.inText || s.inTool {
405
		s.emit("content_block_stop", fmt.Sprintf(
406
			`{"type":"content_block_stop","index":%d}`, s.blockIndex))
407
		s.blockIndex++
408
		s.inText = false
409
		s.inTool = false
410
	}
411
}
412
413
func (s *streamWriter) emit(event, data string) {
414
	fmt.Fprintf(s.w, "event: %s\ndata: %s\n\n", event, data)
415
	s.flusher.Flush()
416
}
417
418
func mustJSON(s string) string {
419
	b, _ := json.Marshal(s)
420
	return string(b)
421
}
422
423
// StaticToken returns a token function that always returns the given key.
424
func StaticToken(key string) func() (string, error) {
425
	return func() (string, error) { return key, nil }
426
}
427
428
// StartPassthrough launches a logging-only proxy that forwards requests unchanged.
429
// Used for Anthropic-compatible providers to enable logging without translation.
430
func StartPassthrough(upstreamURL string, tokenFunc func() (string, error), proxyURL string, providerName string) (addr string, shutdown func(), err error) {
431
	upstream := strings.TrimRight(upstreamURL, "/")
432
	domain := extractDomain(upstream)
433
434
	transport := http.DefaultTransport.(*http.Transport).Clone()
435
	if proxyURL != "" {
436
		if u, err := url.Parse(proxyURL); err == nil {
437
			transport.Proxy = http.ProxyURL(u)
438
		}
439
	} else {
440
		transport.Proxy = nil
441
	}
442
	client := &http.Client{Transport: transport}
443
444
	handler := func(w http.ResponseWriter, r *http.Request) {
445
		plog().Info("passthrough request", "method", r.Method, "path", r.URL.Path, "from", r.RemoteAddr)
446
447
		body, err := io.ReadAll(r.Body)
448
		if err != nil {
449
			plog().Error("reading body", "error", err)
450
			httpError(w, http.StatusBadRequest, "reading body: "+err.Error())
451
			return
452
		}
453
		r.Body = io.NopCloser(bytes.NewReader(body))
454
455
		rawLogRequest(r, body, "INCOMING REQUEST (passthrough)")
456
457
		targetURL := upstream + r.URL.Path
458
		if r.URL.RawQuery != "" {
459
			targetURL += "?" + r.URL.RawQuery
460
		}
461
462
		upReq, err := http.NewRequestWithContext(r.Context(), r.Method, targetURL, bytes.NewReader(body))
463
		if err != nil {
464
			plog().Error("creating upstream request", "error", err)
465
			httpError(w, http.StatusInternalServerError, "creating request: "+err.Error())
466
			return
467
		}
468
469
		maps.Copy(upReq.Header, r.Header)
470
471
		if tokenFunc != nil {
472
			token, err := tokenFunc()
473
			if err != nil {
474
				plog().Error("getting token", "error", err)
475
				httpError(w, http.StatusBadGateway, "token: "+err.Error())
476
				return
477
			}
478
			if token != "" {
479
				upReq.Header.Set("Authorization", "Bearer "+token)
480
			}
481
		}
482
483
		plog().Info("upstream request", "method", r.Method, "url", targetURL)
484
		rawLogRequest(upReq, body, "OUTGOING REQUEST (passthrough)")
485
486
		callStart := time.Now()
487
		resp, err := client.Do(upReq)
488
		if err != nil {
489
			plog().Error("upstream request failed", "error", err)
490
			DefaultMetrics.Record(CallRecord{
491
				Provider:       providerName,
492
				UpstreamDomain: domain,
493
				Endpoint:       r.URL.Path,
494
				Duration:       time.Since(callStart),
495
				Error:          true,
496
			})
497
			httpError(w, http.StatusBadGateway, "upstream: "+err.Error())
498
			return
499
		}
500
		defer resp.Body.Close()
501
502
		plog().Info("upstream response", "status", resp.StatusCode)
503
504
		respBody, err := io.ReadAll(resp.Body)
505
		if err != nil {
506
			plog().Error("reading response body", "error", err)
507
			httpError(w, http.StatusBadGateway, "reading response: "+err.Error())
508
			return
509
		}
510
511
		rawLogResponse(resp, respBody, "UPSTREAM RESPONSE (passthrough)")
512
513
		maps.Copy(w.Header(), resp.Header)
514
		w.WriteHeader(resp.StatusCode)
515
		w.Write(respBody)
516
517
		rec := CallRecord{
518
			Provider:       providerName,
519
			UpstreamDomain: domain,
520
			Endpoint:       r.URL.Path,
521
			StatusCode:     resp.StatusCode,
522
			Duration:       time.Since(callStart),
523
			Error:          resp.StatusCode != http.StatusOK,
524
		}
525
526
		// Best-effort token extraction from Anthropic response.
527
		if resp.StatusCode == http.StatusOK {
528
			var usage struct {
529
				Usage struct {
530
					InputTokens  int `json:"input_tokens"`
531
					OutputTokens int `json:"output_tokens"`
532
				} `json:"usage"`
533
				Model string `json:"model"`
534
			}
535
			if json.Unmarshal(respBody, &usage) == nil {
536
				rec.InputTokens = usage.Usage.InputTokens
537
				rec.OutputTokens = usage.Usage.OutputTokens
538
				rec.Model = usage.Model
539
			}
540
		}
541
542
		DefaultMetrics.Record(rec)
543
		plog().Info("passthrough complete")
544
	}
545
546
	mux := http.NewServeMux()
547
	mux.HandleFunc("/v1/messages", handler)
548
	mux.HandleFunc("/v1/messages/count_tokens", countTokensHandler())
549
550
	ln, err := net.Listen("tcp", "127.0.0.1:0")
551
	if err != nil {
552
		return "", nil, err
553
	}
554
555
	srv := &http.Server{Handler: mux}
556
	go func() {
557
		if err := srv.Serve(ln); err != nil && !errors.Is(err, http.ErrServerClosed) {
558
			plog().Error("passthrough proxy serve failed", "error", err)
559
		}
560
	}()
561
562
	return ln.Addr().String(), func() { srv.Close() }, nil
563
}
564
565
// countTokensHandler returns a mock response for /v1/messages/count_tokens.
566
// Claude Code calls this endpoint; without it the proxy returns 404 which may
567
// cause "unable to connect to api" errors.
568
func countTokensHandler() http.HandlerFunc {
569
	return func(w http.ResponseWriter, r *http.Request) {
570
		w.Header().Set("Content-Type", "application/json")
571
		json.NewEncoder(w).Encode(map[string]any{
572
			"input_tokens": 0,
573
		})
574
	}
575
}
576
577
// isQuotaProbe detects Claude Code's quota-check requests (max_tokens=1 with
578
// short messages) and returns a mock response to avoid hitting the upstream.
579
func isQuotaProbe(req *anthropicRequest) bool {
580
	return req.MaxTokens == 1
581
}
582
583
func httpError(w http.ResponseWriter, code int, msg string) {
584
	w.Header().Set("Content-Type", "application/json")
585
	w.WriteHeader(code)
586
	json.NewEncoder(w).Encode(map[string]any{
587
		"type": "error",
588
		"error": map[string]any{
589
			"type":    "proxy_error",
590
			"message": msg,
591
		},
592
	})
593
}
594

Source Files