server.go

v0.7.0
Doc Versions Source
1
package proxy
2
3
import (
4
	"bufio"
5
	"bytes"
6
	"encoding/json"
7
	"errors"
8
	"fmt"
9
	"io"
10
	"maps"
11
	"net"
12
	"net/http"
13
	"net/url"
14
	"strings"
15
	"time"
16
)
17
18
func extractDomain(rawURL string) string {
19
	if u, err := url.Parse(rawURL); err == nil {
20
		return u.Hostname()
21
	}
22
	return rawURL
23
}
24
25
// Start launches a local Anthropic-to-OpenAI translation proxy.
26
// upstreamURL is the OpenAI-compatible base URL (e.g. "https://api.openai.com/v1").
27
// Returns the proxy address (host:port), a shutdown function, and any error.
28
func Start(upstreamURL string, tokenFunc func() (string, error), extraHeaders map[string]string, proxyURL string, providerName string) (addr string, shutdown func(), err error) {
29
	upstream := strings.TrimRight(upstreamURL, "/")
30
	domain := extractDomain(upstream)
31
32
	mux := http.NewServeMux()
33
	mux.HandleFunc("/v1/messages", handler(upstream, tokenFunc, extraHeaders, proxyURL, providerName, domain))
34
	mux.HandleFunc("/v1/messages/count_tokens", countTokensHandler())
35
36
	ln, err := net.Listen("tcp", "127.0.0.1:0")
37
	if err != nil {
38
		return "", nil, err
39
	}
40
41
	srv := &http.Server{Handler: mux}
42
	go func() {
43
		if err := srv.Serve(ln); err != nil && !errors.Is(err, http.ErrServerClosed) {
44
			plog().Error("proxy serve failed", "error", err)
45
		}
46
	}()
47
48
	return ln.Addr().String(), func() { srv.Close() }, nil
49
}
50
51
func handler(upstream string, tokenFunc func() (string, error), extraHeaders map[string]string, proxyURL string, providerName, domain string) http.HandlerFunc {
52
	transport := http.DefaultTransport.(*http.Transport).Clone()
53
	if proxyURL != "" {
54
		if u, err := url.Parse(proxyURL); err == nil {
55
			transport.Proxy = http.ProxyURL(u)
56
		}
57
	} else {
58
		transport.Proxy = nil
59
	}
60
	client := &http.Client{Transport: transport}
61
62
	return func(w http.ResponseWriter, r *http.Request) {
63
		plog().Info("request", "method", r.Method, "path", r.URL.Path, "from", r.RemoteAddr)
64
65
		body, err := io.ReadAll(r.Body)
66
		if err != nil {
67
			plog().Error("reading body", "error", err)
68
			httpError(w, http.StatusBadRequest, "reading body: "+err.Error())
69
			return
70
		}
71
		r.Body = io.NopCloser(bytes.NewReader(body))
72
73
		rawLogRequest(r, body, "INCOMING REQUEST from Claude Code")
74
75
		var req anthropicRequest
76
		if err := json.Unmarshal(body, &req); err != nil {
77
			plog().Error("parsing request", "error", err)
78
			httpError(w, http.StatusBadRequest, "parsing request: "+err.Error())
79
			return
80
		}
81
82
		plog().Info("request details", "model", req.Model, "stream", req.Stream, "max_tokens", req.MaxTokens, "messages", len(req.Messages))
83
84
		// Intercept quota-probe requests (max_tokens=1) — return a mock
85
		// response instead of hitting the upstream API.
86
		if isQuotaProbe(&req) {
87
			plog().Info("intercepting quota probe")
88
			DefaultMetrics.Record(CallRecord{
89
				Provider:       providerName,
90
				UpstreamDomain: domain,
91
				Endpoint:       r.URL.Path,
92
				Model:          req.Model,
93
				QuotaProbe:     true,
94
			})
95
			antResp := &anthropicResponse{
96
				ID:      "msg_" + genHex(12),
97
				Type:    "message",
98
				Role:    "assistant",
99
				Content: []contentBlock{{Type: "text", Text: "ok"}},
100
				Model:   req.Model,
101
				Usage:   anthropicUsage{InputTokens: 1, OutputTokens: 1},
102
			}
103
			stopReason := "end_turn"
104
			antResp.StopReason = &stopReason
105
			w.Header().Set("Content-Type", "application/json")
106
			json.NewEncoder(w).Encode(antResp)
107
			return
108
		}
109
110
		oaiReq := convertRequest(&req)
111
		oaiBody, _ := json.Marshal(oaiReq)
112
113
		upReq, _ := http.NewRequestWithContext(r.Context(), "POST",
114
			upstream+"/chat/completions", bytes.NewReader(oaiBody))
115
		upReq.Header.Set("Content-Type", "application/json")
116
		token, err := tokenFunc()
117
		if err != nil {
118
			plog().Error("getting token", "error", err)
119
			httpError(w, http.StatusBadGateway, "token: "+err.Error())
120
			return
121
		}
122
		upReq.Header.Set("Authorization", "Bearer "+token)
123
		for k, v := range extraHeaders {
124
			upReq.Header.Set(k, v)
125
		}
126
127
		plog().Info("upstream request", "url", upstream+"/chat/completions")
128
		rawLogRequest(upReq, oaiBody, "OUTGOING REQUEST to upstream API")
129
130
		callStart := time.Now()
131
		resp, err := client.Do(upReq)
132
		if err != nil {
133
			plog().Error("upstream request failed", "error", err)
134
			DefaultMetrics.Record(CallRecord{
135
				Provider:       providerName,
136
				UpstreamDomain: domain,
137
				Endpoint:       r.URL.Path,
138
				Model:          req.Model,
139
				Stream:         req.Stream,
140
				Duration:       time.Since(callStart),
141
				Error:          true,
142
			})
143
			httpError(w, http.StatusBadGateway, "upstream: "+err.Error())
144
			return
145
		}
146
		defer resp.Body.Close()
147
148
		plog().Info("upstream response", "status", resp.StatusCode)
149
150
		// Read the response body for logging.
151
		respBody, err := io.ReadAll(resp.Body)
152
		if err != nil {
153
			plog().Error("reading response body", "error", err)
154
		}
155
		resp.Body = io.NopCloser(bytes.NewReader(respBody))
156
157
		rawLogResponse(resp, respBody, "UPSTREAM RESPONSE")
158
159
		if resp.StatusCode != http.StatusOK {
160
			plog().Error("upstream error", "status", resp.Status, "body", string(respBody))
161
			DefaultMetrics.Record(CallRecord{
162
				Provider:       providerName,
163
				UpstreamDomain: domain,
164
				Endpoint:       r.URL.Path,
165
				Model:          req.Model,
166
				StatusCode:     resp.StatusCode,
167
				Stream:         req.Stream,
168
				Duration:       time.Since(callStart),
169
				Error:          true,
170
			})
171
			// Translate to Anthropic error format so Claude Code understands it.
172
			w.Header().Set("Content-Type", "application/json")
173
			w.WriteHeader(resp.StatusCode)
174
			json.NewEncoder(w).Encode(map[string]any{
175
				"type": "error",
176
				"error": map[string]any{
177
					"type":    "api_error",
178
					"message": fmt.Sprintf("upstream %s: %s", resp.Status, respBody),
179
				},
180
			})
181
			return
182
		}
183
184
		plog().Info("processing response", "model", req.Model, "stream", req.Stream)
185
186
		rec := CallRecord{
187
			Provider:       providerName,
188
			UpstreamDomain: domain,
189
			Endpoint:       r.URL.Path,
190
			Model:          req.Model,
191
			StatusCode:     resp.StatusCode,
192
			Stream:         req.Stream,
193
		}
194
195
		if req.Stream {
196
			usage := handleStream(w, resp, req.Model)
197
			if usage != nil {
198
				rec.InputTokens = usage.PromptTokens
199
				rec.OutputTokens = usage.CompletionTokens
200
				if usage.PromptTokensDetails != nil {
201
					rec.CacheRead = usage.PromptTokensDetails.CachedTokens
202
				}
203
			}
204
		} else {
205
			usage := handleNonStream(w, resp)
206
			if usage != nil {
207
				rec.InputTokens = usage.PromptTokens
208
				rec.OutputTokens = usage.CompletionTokens
209
				if usage.PromptTokensDetails != nil {
210
					rec.CacheRead = usage.PromptTokensDetails.CachedTokens
211
				}
212
			}
213
		}
214
215
		rec.Duration = time.Since(callStart)
216
		DefaultMetrics.Record(rec)
217
	}
218
}
219
220
func handleNonStream(w http.ResponseWriter, resp *http.Response) *openaiUsage {
221
	var oaiResp openaiResponse
222
	if err := json.NewDecoder(resp.Body).Decode(&oaiResp); err != nil {
223
		plog().Error("decoding upstream response", "error", err)
224
		httpError(w, http.StatusBadGateway, "decoding upstream: "+err.Error())
225
		return nil
226
	}
227
	plog().Info("non-stream response", "input_tokens", oaiResp.Usage.PromptTokens, "output_tokens", oaiResp.Usage.CompletionTokens)
228
229
	antResp := convertResponse(&oaiResp)
230
	antBody, _ := json.Marshal(antResp)
231
232
	rawLogResponse(&http.Response{
233
		Status:     "200 OK",
234
		StatusCode: 200,
235
		Header:     http.Header{"Content-Type": []string{"application/json"}},
236
	}, antBody, "RESPONSE TO CLAUDE CODE")
237
238
	w.Header().Set("Content-Type", "application/json")
239
	json.NewEncoder(w).Encode(antResp)
240
	return oaiResp.Usage
241
}
242
243
func handleStream(w http.ResponseWriter, resp *http.Response, model string) *openaiUsage {
244
	flusher, ok := w.(http.Flusher)
245
	if !ok {
246
		httpError(w, http.StatusInternalServerError, "streaming not supported")
247
		return nil
248
	}
249
250
	w.Header().Set("Content-Type", "text/event-stream")
251
	w.Header().Set("Cache-Control", "no-cache")
252
	w.Header().Set("Connection", "keep-alive")
253
254
	sw := &streamWriter{
255
		w:          w,
256
		flusher:    flusher,
257
		msgID:      "msg_" + genHex(12),
258
		model:      model,
259
		toolStarts: make(map[int]int),
260
	}
261
262
	plog().Info("stream start", "model", model)
263
264
	var streamLog strings.Builder
265
	scanner := bufio.NewScanner(resp.Body)
266
	scanner.Buffer(make([]byte, 0, 256*1024), 256*1024)
267
268
	chunkCount := 0
269
	for scanner.Scan() {
270
		line := scanner.Text()
271
		streamLog.WriteString(line + "\n")
272
		if !strings.HasPrefix(line, "data: ") {
273
			continue
274
		}
275
		data := strings.TrimPrefix(line, "data: ")
276
		if data == "[DONE]" {
277
			plog().Debug("stream DONE signal")
278
			streamLog.WriteString("\n[END OF STREAM]\n")
279
			break
280
		}
281
		var chunk openaiStreamChunk
282
		if json.Unmarshal([]byte(data), &chunk) != nil {
283
			plog().Error("unmarshal stream chunk", "data", data[:min(len(data), 200)])
284
			continue
285
		}
286
		sw.processChunk(chunk)
287
		chunkCount++
288
	}
289
290
	if err := scanner.Err(); err != nil {
291
		plog().Error("scanner error", "error", err)
292
	}
293
294
	rawLogStream("STREAM RESPONSE FROM UPSTREAM", streamLog.String(), chunkCount)
295
296
	plog().Info("stream finished", "chunks", chunkCount)
297
	sw.finish()
298
	return sw.usage
299
}
300
301
// streamWriter translates OpenAI SSE chunks into Anthropic SSE events.
302
type streamWriter struct {
303
	w          http.ResponseWriter
304
	flusher    http.Flusher
305
	msgID      string
306
	model      string
307
	started    bool
308
	blockIndex int
309
	inText     bool
310
	inTool     bool
311
	toolStarts map[int]int // openai tool index → anthropic block index
312
	usage      *openaiUsage
313
}
314
315
func (s *streamWriter) processChunk(chunk openaiStreamChunk) {
316
	if chunk.Usage != nil {
317
		s.usage = chunk.Usage
318
	}
319
	if len(chunk.Choices) == 0 {
320
		return
321
	}
322
323
	if !s.started {
324
		s.emitMessageStart()
325
		s.started = true
326
	}
327
328
	delta := chunk.Choices[0].Delta
329
330
	// Text content.
331
	if delta.Content != nil && *delta.Content != "" {
332
		if !s.inText {
333
			s.closeCurrentBlock()
334
			s.emitContentBlockStart("text", "", "")
335
			s.inText = true
336
		}
337
		s.emit("content_block_delta", fmt.Sprintf(
338
			`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":%s}}`,
339
			s.blockIndex, mustJSON(*delta.Content)))
340
	}
341
342
	// Tool calls.
343
	for _, tc := range delta.ToolCalls {
344
		if tc.ID != "" {
345
			// New tool call starting.
346
			s.closeCurrentBlock()
347
			s.emitContentBlockStart("tool_use", tc.ID, tc.Function.Name)
348
			s.toolStarts[tc.Index] = s.blockIndex
349
			s.inTool = true
350
		}
351
		if tc.Function != nil && tc.Function.Arguments != "" {
352
			idx, ok := s.toolStarts[tc.Index]
353
			if !ok {
354
				idx = s.blockIndex
355
			}
356
			s.emit("content_block_delta", fmt.Sprintf(
357
				`{"type":"content_block_delta","index":%d,"delta":{"type":"input_json_delta","partial_json":%s}}`,
358
				idx, mustJSON(tc.Function.Arguments)))
359
		}
360
	}
361
362
	// Finish reason.
363
	if chunk.Choices[0].FinishReason != nil {
364
		s.closeCurrentBlock()
365
	}
366
}
367
368
func (s *streamWriter) finish() {
369
	if !s.started {
370
		s.emitMessageStart()
371
	}
372
	s.closeCurrentBlock()
373
374
	stopReason := "end_turn"
375
	outTokens := 0
376
	cacheRead := 0
377
	if s.usage != nil {
378
		outTokens = s.usage.CompletionTokens
379
		if s.usage.PromptTokensDetails != nil {
380
			cacheRead = s.usage.PromptTokensDetails.CachedTokens
381
		}
382
	}
383
384
	usage := fmt.Sprintf(`"output_tokens":%d`, outTokens)
385
	if cacheRead > 0 {
386
		usage += fmt.Sprintf(`,"cache_read_input_tokens":%d`, cacheRead)
387
	}
388
	s.emit("message_delta", fmt.Sprintf(
389
		`{"type":"message_delta","delta":{"stop_reason":"%s","stop_sequence":null},"usage":{%s}}`,
390
		stopReason, usage))
391
	s.emit("message_stop", `{"type":"message_stop"}`)
392
}
393
394
func (s *streamWriter) emitMessageStart() {
395
	inTokens := 0
396
	cacheRead := 0
397
	if s.usage != nil {
398
		inTokens = s.usage.PromptTokens
399
		if s.usage.PromptTokensDetails != nil {
400
			cacheRead = s.usage.PromptTokensDetails.CachedTokens
401
		}
402
	}
403
	usage := fmt.Sprintf(`"input_tokens":%d,"output_tokens":0`, inTokens)
404
	if cacheRead > 0 {
405
		usage += fmt.Sprintf(`,"cache_read_input_tokens":%d`, cacheRead)
406
	}
407
	s.emit("message_start", fmt.Sprintf(
408
		`{"type":"message_start","message":{"id":"%s","type":"message","role":"assistant","content":[],"model":"%s","stop_reason":null,"stop_sequence":null,"usage":{%s}}}`,
409
		s.msgID, s.model, usage))
410
}
411
412
func (s *streamWriter) emitContentBlockStart(typ, id, name string) {
413
	switch typ {
414
	case "text":
415
		s.emit("content_block_start", fmt.Sprintf(
416
			`{"type":"content_block_start","index":%d,"content_block":{"type":"text","text":""}}`,
417
			s.blockIndex))
418
	case "tool_use":
419
		s.emit("content_block_start", fmt.Sprintf(
420
			`{"type":"content_block_start","index":%d,"content_block":{"type":"tool_use","id":"%s","name":"%s","input":{}}}`,
421
			s.blockIndex, id, name))
422
	}
423
}
424
425
func (s *streamWriter) closeCurrentBlock() {
426
	if s.inText || s.inTool {
427
		s.emit("content_block_stop", fmt.Sprintf(
428
			`{"type":"content_block_stop","index":%d}`, s.blockIndex))
429
		s.blockIndex++
430
		s.inText = false
431
		s.inTool = false
432
	}
433
}
434
435
func (s *streamWriter) emit(event, data string) {
436
	fmt.Fprintf(s.w, "event: %s\ndata: %s\n\n", event, data)
437
	s.flusher.Flush()
438
}
439
440
func mustJSON(s string) string {
441
	b, _ := json.Marshal(s)
442
	return string(b)
443
}
444
445
// StaticToken returns a token function that always returns the given key.
446
func StaticToken(key string) func() (string, error) {
447
	return func() (string, error) { return key, nil }
448
}
449
450
// StartPassthrough launches a logging-only proxy that forwards requests unchanged.
451
// Used for Anthropic-compatible providers to enable logging without translation.
452
func StartPassthrough(upstreamURL string, tokenFunc func() (string, error), proxyURL string, providerName string) (addr string, shutdown func(), err error) {
453
	upstream := strings.TrimRight(upstreamURL, "/")
454
	domain := extractDomain(upstream)
455
456
	transport := http.DefaultTransport.(*http.Transport).Clone()
457
	if proxyURL != "" {
458
		if u, err := url.Parse(proxyURL); err == nil {
459
			transport.Proxy = http.ProxyURL(u)
460
		}
461
	} else {
462
		transport.Proxy = nil
463
	}
464
	client := &http.Client{Transport: transport}
465
466
	handler := func(w http.ResponseWriter, r *http.Request) {
467
		plog().Info("passthrough request", "method", r.Method, "path", r.URL.Path, "from", r.RemoteAddr)
468
469
		body, err := io.ReadAll(r.Body)
470
		if err != nil {
471
			plog().Error("reading body", "error", err)
472
			httpError(w, http.StatusBadRequest, "reading body: "+err.Error())
473
			return
474
		}
475
		r.Body = io.NopCloser(bytes.NewReader(body))
476
477
		rawLogRequest(r, body, "INCOMING REQUEST (passthrough)")
478
479
		targetURL := upstream + r.URL.Path
480
		if r.URL.RawQuery != "" {
481
			targetURL += "?" + r.URL.RawQuery
482
		}
483
484
		upReq, err := http.NewRequestWithContext(r.Context(), r.Method, targetURL, bytes.NewReader(body))
485
		if err != nil {
486
			plog().Error("creating upstream request", "error", err)
487
			httpError(w, http.StatusInternalServerError, "creating request: "+err.Error())
488
			return
489
		}
490
491
		maps.Copy(upReq.Header, r.Header)
492
493
		if tokenFunc != nil {
494
			token, err := tokenFunc()
495
			if err != nil {
496
				plog().Error("getting token", "error", err)
497
				httpError(w, http.StatusBadGateway, "token: "+err.Error())
498
				return
499
			}
500
			if token != "" {
501
				upReq.Header.Set("Authorization", "Bearer "+token)
502
			}
503
		}
504
505
		plog().Info("upstream request", "method", r.Method, "url", targetURL)
506
		rawLogRequest(upReq, body, "OUTGOING REQUEST (passthrough)")
507
508
		callStart := time.Now()
509
		resp, err := client.Do(upReq)
510
		if err != nil {
511
			plog().Error("upstream request failed", "error", err)
512
			DefaultMetrics.Record(CallRecord{
513
				Provider:       providerName,
514
				UpstreamDomain: domain,
515
				Endpoint:       r.URL.Path,
516
				Duration:       time.Since(callStart),
517
				Error:          true,
518
			})
519
			httpError(w, http.StatusBadGateway, "upstream: "+err.Error())
520
			return
521
		}
522
		defer resp.Body.Close()
523
524
		plog().Info("upstream response", "status", resp.StatusCode)
525
526
		respBody, err := io.ReadAll(resp.Body)
527
		if err != nil {
528
			plog().Error("reading response body", "error", err)
529
			httpError(w, http.StatusBadGateway, "reading response: "+err.Error())
530
			return
531
		}
532
533
		rawLogResponse(resp, respBody, "UPSTREAM RESPONSE (passthrough)")
534
535
		maps.Copy(w.Header(), resp.Header)
536
		w.WriteHeader(resp.StatusCode)
537
		w.Write(respBody)
538
539
		rec := CallRecord{
540
			Provider:       providerName,
541
			UpstreamDomain: domain,
542
			Endpoint:       r.URL.Path,
543
			StatusCode:     resp.StatusCode,
544
			Duration:       time.Since(callStart),
545
			Error:          resp.StatusCode != http.StatusOK,
546
		}
547
548
		// Best-effort token extraction from Anthropic response.
549
		if resp.StatusCode == http.StatusOK {
550
			var usage struct {
551
				Usage struct {
552
					InputTokens              int `json:"input_tokens"`
553
					OutputTokens             int `json:"output_tokens"`
554
					CacheCreationInputTokens int `json:"cache_creation_input_tokens"`
555
					CacheReadInputTokens     int `json:"cache_read_input_tokens"`
556
				} `json:"usage"`
557
				Model string `json:"model"`
558
			}
559
			if json.Unmarshal(respBody, &usage) == nil {
560
				rec.InputTokens = usage.Usage.InputTokens
561
				rec.OutputTokens = usage.Usage.OutputTokens
562
				rec.CacheCreation = usage.Usage.CacheCreationInputTokens
563
				rec.CacheRead = usage.Usage.CacheReadInputTokens
564
				rec.Model = usage.Model
565
			}
566
		}
567
568
		DefaultMetrics.Record(rec)
569
		plog().Info("passthrough complete")
570
	}
571
572
	mux := http.NewServeMux()
573
	mux.HandleFunc("/v1/messages", handler)
574
	mux.HandleFunc("/v1/messages/count_tokens", countTokensHandler())
575
576
	ln, err := net.Listen("tcp", "127.0.0.1:0")
577
	if err != nil {
578
		return "", nil, err
579
	}
580
581
	srv := &http.Server{Handler: mux}
582
	go func() {
583
		if err := srv.Serve(ln); err != nil && !errors.Is(err, http.ErrServerClosed) {
584
			plog().Error("passthrough proxy serve failed", "error", err)
585
		}
586
	}()
587
588
	return ln.Addr().String(), func() { srv.Close() }, nil
589
}
590
591
// countTokensHandler returns a mock response for /v1/messages/count_tokens.
592
// Claude Code calls this endpoint; without it the proxy returns 404 which may
593
// cause "unable to connect to api" errors.
594
func countTokensHandler() http.HandlerFunc {
595
	return func(w http.ResponseWriter, r *http.Request) {
596
		w.Header().Set("Content-Type", "application/json")
597
		json.NewEncoder(w).Encode(map[string]any{
598
			"input_tokens": 0,
599
		})
600
	}
601
}
602
603
// isQuotaProbe detects Claude Code's quota-check requests (max_tokens=1 with
604
// short messages) and returns a mock response to avoid hitting the upstream.
605
func isQuotaProbe(req *anthropicRequest) bool {
606
	return req.MaxTokens == 1
607
}
608
609
func httpError(w http.ResponseWriter, code int, msg string) {
610
	w.Header().Set("Content-Type", "application/json")
611
	w.WriteHeader(code)
612
	json.NewEncoder(w).Encode(map[string]any{
613
		"type": "error",
614
		"error": map[string]any{
615
			"type":    "proxy_error",
616
			"message": msg,
617
		},
618
	})
619
}
620

Source Files