server.go

v0.5.1
Doc Versions Source
1
package proxy
2
3
import (
4
	"bufio"
5
	"bytes"
6
	"encoding/json"
7
	"fmt"
8
	"io"
9
	"maps"
10
	"net"
11
	"net/http"
12
	"net/url"
13
	"strings"
14
	"time"
15
)
16
17
func extractDomain(rawURL string) string {
18
	if u, err := url.Parse(rawURL); err == nil {
19
		return u.Hostname()
20
	}
21
	return rawURL
22
}
23
24
// Start launches a local Anthropic-to-OpenAI translation proxy.
25
// upstreamURL is the OpenAI-compatible base URL (e.g. "https://api.openai.com/v1").
26
// Returns the proxy address (host:port), a shutdown function, and any error.
27
func Start(upstreamURL string, tokenFunc func() (string, error), extraHeaders map[string]string, proxyURL string, providerName string) (addr string, shutdown func(), err error) {
28
	upstream := strings.TrimRight(upstreamURL, "/")
29
	domain := extractDomain(upstream)
30
31
	mux := http.NewServeMux()
32
	mux.HandleFunc("/v1/messages", handler(upstream, tokenFunc, extraHeaders, proxyURL, providerName, domain))
33
	mux.HandleFunc("/v1/messages/count_tokens", countTokensHandler())
34
35
	ln, err := net.Listen("tcp", "127.0.0.1:0")
36
	if err != nil {
37
		return "", nil, err
38
	}
39
40
	srv := &http.Server{Handler: mux}
41
	go func() {
42
		if err := srv.Serve(ln); err != nil && err != http.ErrServerClosed {
43
			plog().Error("proxy serve failed", "error", err)
44
		}
45
	}()
46
47
	return ln.Addr().String(), func() { srv.Close() }, nil
48
}
49
50
func handler(upstream string, tokenFunc func() (string, error), extraHeaders map[string]string, proxyURL string, providerName, domain string) http.HandlerFunc {
51
	transport := http.DefaultTransport.(*http.Transport).Clone()
52
	if proxyURL != "" {
53
		if u, err := url.Parse(proxyURL); err == nil {
54
			transport.Proxy = http.ProxyURL(u)
55
		}
56
	} else {
57
		transport.Proxy = nil
58
	}
59
	client := &http.Client{Transport: transport}
60
61
	return func(w http.ResponseWriter, r *http.Request) {
62
		plog().Info("request", "method", r.Method, "path", r.URL.Path, "from", r.RemoteAddr)
63
64
		body, err := io.ReadAll(r.Body)
65
		if err != nil {
66
			plog().Error("reading body", "error", err)
67
			httpError(w, http.StatusBadRequest, "reading body: "+err.Error())
68
			return
69
		}
70
		r.Body = io.NopCloser(bytes.NewReader(body))
71
72
		rawLogRequest(r, body, "INCOMING REQUEST from Claude Code")
73
74
		var req anthropicRequest
75
		if err := json.Unmarshal(body, &req); err != nil {
76
			plog().Error("parsing request", "error", err)
77
			httpError(w, http.StatusBadRequest, "parsing request: "+err.Error())
78
			return
79
		}
80
81
		plog().Info("request details", "model", req.Model, "stream", req.Stream, "max_tokens", req.MaxTokens, "messages", len(req.Messages))
82
83
		// Intercept quota-probe requests (max_tokens=1) — return a mock
84
		// response instead of hitting the upstream API.
85
		if isQuotaProbe(&req) {
86
			plog().Info("intercepting quota probe")
87
			DefaultMetrics.Record(CallRecord{
88
				Provider:       providerName,
89
				UpstreamDomain: domain,
90
				Endpoint:       r.URL.Path,
91
				Model:          req.Model,
92
				QuotaProbe:     true,
93
			})
94
			antResp := &anthropicResponse{
95
				ID:      "msg_" + genHex(12),
96
				Type:    "message",
97
				Role:    "assistant",
98
				Content: []contentBlock{{Type: "text", Text: "ok"}},
99
				Model:   req.Model,
100
				Usage:   anthropicUsage{InputTokens: 1, OutputTokens: 1},
101
			}
102
			stopReason := "end_turn"
103
			antResp.StopReason = &stopReason
104
			w.Header().Set("Content-Type", "application/json")
105
			json.NewEncoder(w).Encode(antResp)
106
			return
107
		}
108
109
		oaiReq := convertRequest(&req)
110
		oaiBody, _ := json.Marshal(oaiReq)
111
112
		upReq, _ := http.NewRequestWithContext(r.Context(), "POST",
113
			upstream+"/chat/completions", bytes.NewReader(oaiBody))
114
		upReq.Header.Set("Content-Type", "application/json")
115
		token, err := tokenFunc()
116
		if err != nil {
117
			plog().Error("getting token", "error", err)
118
			httpError(w, http.StatusBadGateway, "token: "+err.Error())
119
			return
120
		}
121
		upReq.Header.Set("Authorization", "Bearer "+token)
122
		for k, v := range extraHeaders {
123
			upReq.Header.Set(k, v)
124
		}
125
126
		plog().Info("upstream request", "url", upstream+"/chat/completions")
127
		rawLogRequest(upReq, oaiBody, "OUTGOING REQUEST to upstream API")
128
129
		callStart := time.Now()
130
		resp, err := client.Do(upReq)
131
		if err != nil {
132
			plog().Error("upstream request failed", "error", err)
133
			DefaultMetrics.Record(CallRecord{
134
				Provider:       providerName,
135
				UpstreamDomain: domain,
136
				Endpoint:       r.URL.Path,
137
				Model:          req.Model,
138
				Stream:         req.Stream,
139
				Duration:       time.Since(callStart),
140
				Error:          true,
141
			})
142
			httpError(w, http.StatusBadGateway, "upstream: "+err.Error())
143
			return
144
		}
145
		defer resp.Body.Close()
146
147
		plog().Info("upstream response", "status", resp.StatusCode)
148
149
		// Read the response body for logging.
150
		respBody, err := io.ReadAll(resp.Body)
151
		if err != nil {
152
			plog().Error("reading response body", "error", err)
153
		}
154
		resp.Body = io.NopCloser(bytes.NewReader(respBody))
155
156
		rawLogResponse(resp, respBody, "UPSTREAM RESPONSE")
157
158
		if resp.StatusCode != http.StatusOK {
159
			plog().Error("upstream error", "status", resp.Status, "body", string(respBody))
160
			DefaultMetrics.Record(CallRecord{
161
				Provider:       providerName,
162
				UpstreamDomain: domain,
163
				Endpoint:       r.URL.Path,
164
				Model:          req.Model,
165
				StatusCode:     resp.StatusCode,
166
				Stream:         req.Stream,
167
				Duration:       time.Since(callStart),
168
				Error:          true,
169
			})
170
			// Translate to Anthropic error format so Claude Code understands it.
171
			w.Header().Set("Content-Type", "application/json")
172
			w.WriteHeader(resp.StatusCode)
173
			json.NewEncoder(w).Encode(map[string]any{
174
				"type": "error",
175
				"error": map[string]any{
176
					"type":    "api_error",
177
					"message": fmt.Sprintf("upstream %s: %s", resp.Status, respBody),
178
				},
179
			})
180
			return
181
		}
182
183
		plog().Info("processing response", "model", req.Model, "stream", req.Stream)
184
185
		rec := CallRecord{
186
			Provider:       providerName,
187
			UpstreamDomain: domain,
188
			Endpoint:       r.URL.Path,
189
			Model:          req.Model,
190
			StatusCode:     resp.StatusCode,
191
			Stream:         req.Stream,
192
		}
193
194
		if req.Stream {
195
			usage := handleStream(w, resp, req.Model)
196
			if usage != nil {
197
				rec.InputTokens = usage.PromptTokens
198
				rec.OutputTokens = usage.CompletionTokens
199
			}
200
		} else {
201
			usage := handleNonStream(w, resp)
202
			if usage != nil {
203
				rec.InputTokens = usage.PromptTokens
204
				rec.OutputTokens = usage.CompletionTokens
205
			}
206
		}
207
208
		rec.Duration = time.Since(callStart)
209
		DefaultMetrics.Record(rec)
210
	}
211
}
212
213
func handleNonStream(w http.ResponseWriter, resp *http.Response) *openaiUsage {
214
	var oaiResp openaiResponse
215
	if err := json.NewDecoder(resp.Body).Decode(&oaiResp); err != nil {
216
		plog().Error("decoding upstream response", "error", err)
217
		httpError(w, http.StatusBadGateway, "decoding upstream: "+err.Error())
218
		return nil
219
	}
220
	plog().Info("non-stream response", "input_tokens", oaiResp.Usage.PromptTokens, "output_tokens", oaiResp.Usage.CompletionTokens)
221
222
	antResp := convertResponse(&oaiResp)
223
	antBody, _ := json.Marshal(antResp)
224
225
	rawLogResponse(&http.Response{
226
		Status:     "200 OK",
227
		StatusCode: 200,
228
		Header:     http.Header{"Content-Type": []string{"application/json"}},
229
	}, antBody, "RESPONSE TO CLAUDE CODE")
230
231
	w.Header().Set("Content-Type", "application/json")
232
	json.NewEncoder(w).Encode(antResp)
233
	return oaiResp.Usage
234
}
235
236
func handleStream(w http.ResponseWriter, resp *http.Response, model string) *openaiUsage {
237
	flusher, ok := w.(http.Flusher)
238
	if !ok {
239
		httpError(w, http.StatusInternalServerError, "streaming not supported")
240
		return nil
241
	}
242
243
	w.Header().Set("Content-Type", "text/event-stream")
244
	w.Header().Set("Cache-Control", "no-cache")
245
	w.Header().Set("Connection", "keep-alive")
246
247
	sw := &streamWriter{
248
		w:          w,
249
		flusher:    flusher,
250
		msgID:      "msg_" + genHex(12),
251
		model:      model,
252
		toolStarts: make(map[int]int),
253
	}
254
255
	plog().Info("stream start", "model", model)
256
257
	var streamLog strings.Builder
258
	scanner := bufio.NewScanner(resp.Body)
259
	scanner.Buffer(make([]byte, 0, 256*1024), 256*1024)
260
261
	chunkCount := 0
262
	for scanner.Scan() {
263
		line := scanner.Text()
264
		streamLog.WriteString(line + "\n")
265
		if !strings.HasPrefix(line, "data: ") {
266
			continue
267
		}
268
		data := strings.TrimPrefix(line, "data: ")
269
		if data == "[DONE]" {
270
			plog().Debug("stream DONE signal")
271
			streamLog.WriteString("\n[END OF STREAM]\n")
272
			break
273
		}
274
		var chunk openaiStreamChunk
275
		if json.Unmarshal([]byte(data), &chunk) != nil {
276
			plog().Error("unmarshal stream chunk", "data", data[:min(len(data), 200)])
277
			continue
278
		}
279
		sw.processChunk(chunk)
280
		chunkCount++
281
	}
282
283
	if err := scanner.Err(); err != nil {
284
		plog().Error("scanner error", "error", err)
285
	}
286
287
	rawLogStream("STREAM RESPONSE FROM UPSTREAM", streamLog.String(), chunkCount)
288
289
	plog().Info("stream finished", "chunks", chunkCount)
290
	sw.finish()
291
	return sw.usage
292
}
293
294
// streamWriter translates OpenAI SSE chunks into Anthropic SSE events.
295
type streamWriter struct {
296
	w          http.ResponseWriter
297
	flusher    http.Flusher
298
	msgID      string
299
	model      string
300
	started    bool
301
	blockIndex int
302
	inText     bool
303
	inTool     bool
304
	toolStarts map[int]int // openai tool index → anthropic block index
305
	usage      *openaiUsage
306
}
307
308
func (s *streamWriter) processChunk(chunk openaiStreamChunk) {
309
	if chunk.Usage != nil {
310
		s.usage = chunk.Usage
311
	}
312
	if len(chunk.Choices) == 0 {
313
		return
314
	}
315
316
	if !s.started {
317
		s.emitMessageStart()
318
		s.started = true
319
	}
320
321
	delta := chunk.Choices[0].Delta
322
323
	// Text content.
324
	if delta.Content != nil && *delta.Content != "" {
325
		if !s.inText {
326
			s.closeCurrentBlock()
327
			s.emitContentBlockStart("text", "", "")
328
			s.inText = true
329
		}
330
		s.emit("content_block_delta", fmt.Sprintf(
331
			`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":%s}}`,
332
			s.blockIndex, mustJSON(*delta.Content)))
333
	}
334
335
	// Tool calls.
336
	for _, tc := range delta.ToolCalls {
337
		if tc.ID != "" {
338
			// New tool call starting.
339
			s.closeCurrentBlock()
340
			s.emitContentBlockStart("tool_use", tc.ID, tc.Function.Name)
341
			s.toolStarts[tc.Index] = s.blockIndex
342
			s.inTool = true
343
		}
344
		if tc.Function != nil && tc.Function.Arguments != "" {
345
			idx, ok := s.toolStarts[tc.Index]
346
			if !ok {
347
				idx = s.blockIndex
348
			}
349
			s.emit("content_block_delta", fmt.Sprintf(
350
				`{"type":"content_block_delta","index":%d,"delta":{"type":"input_json_delta","partial_json":%s}}`,
351
				idx, mustJSON(tc.Function.Arguments)))
352
		}
353
	}
354
355
	// Finish reason.
356
	if chunk.Choices[0].FinishReason != nil {
357
		s.closeCurrentBlock()
358
	}
359
}
360
361
func (s *streamWriter) finish() {
362
	if !s.started {
363
		s.emitMessageStart()
364
	}
365
	s.closeCurrentBlock()
366
367
	stopReason := "end_turn"
368
	outTokens := 0
369
	if s.usage != nil {
370
		outTokens = s.usage.CompletionTokens
371
	}
372
373
	s.emit("message_delta", fmt.Sprintf(
374
		`{"type":"message_delta","delta":{"stop_reason":"%s","stop_sequence":null},"usage":{"output_tokens":%d}}`,
375
		stopReason, outTokens))
376
	s.emit("message_stop", `{"type":"message_stop"}`)
377
}
378
379
func (s *streamWriter) emitMessageStart() {
380
	inTokens := 0
381
	if s.usage != nil {
382
		inTokens = s.usage.PromptTokens
383
	}
384
	s.emit("message_start", fmt.Sprintf(
385
		`{"type":"message_start","message":{"id":"%s","type":"message","role":"assistant","content":[],"model":"%s","stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":%d,"output_tokens":0}}}`,
386
		s.msgID, s.model, inTokens))
387
}
388
389
func (s *streamWriter) emitContentBlockStart(typ, id, name string) {
390
	switch typ {
391
	case "text":
392
		s.emit("content_block_start", fmt.Sprintf(
393
			`{"type":"content_block_start","index":%d,"content_block":{"type":"text","text":""}}`,
394
			s.blockIndex))
395
	case "tool_use":
396
		s.emit("content_block_start", fmt.Sprintf(
397
			`{"type":"content_block_start","index":%d,"content_block":{"type":"tool_use","id":"%s","name":"%s","input":{}}}`,
398
			s.blockIndex, id, name))
399
	}
400
}
401
402
func (s *streamWriter) closeCurrentBlock() {
403
	if s.inText || s.inTool {
404
		s.emit("content_block_stop", fmt.Sprintf(
405
			`{"type":"content_block_stop","index":%d}`, s.blockIndex))
406
		s.blockIndex++
407
		s.inText = false
408
		s.inTool = false
409
	}
410
}
411
412
func (s *streamWriter) emit(event, data string) {
413
	fmt.Fprintf(s.w, "event: %s\ndata: %s\n\n", event, data)
414
	s.flusher.Flush()
415
}
416
417
func mustJSON(s string) string {
418
	b, _ := json.Marshal(s)
419
	return string(b)
420
}
421
422
// StaticToken returns a token function that always returns the given key.
423
func StaticToken(key string) func() (string, error) {
424
	return func() (string, error) { return key, nil }
425
}
426
427
// StartPassthrough launches a logging-only proxy that forwards requests unchanged.
428
// Used for Anthropic-compatible providers to enable logging without translation.
429
func StartPassthrough(upstreamURL string, tokenFunc func() (string, error), proxyURL string, providerName string) (addr string, shutdown func(), err error) {
430
	upstream := strings.TrimRight(upstreamURL, "/")
431
	domain := extractDomain(upstream)
432
433
	transport := http.DefaultTransport.(*http.Transport).Clone()
434
	if proxyURL != "" {
435
		if u, err := url.Parse(proxyURL); err == nil {
436
			transport.Proxy = http.ProxyURL(u)
437
		}
438
	} else {
439
		transport.Proxy = nil
440
	}
441
	client := &http.Client{Transport: transport}
442
443
	handler := func(w http.ResponseWriter, r *http.Request) {
444
		plog().Info("passthrough request", "method", r.Method, "path", r.URL.Path, "from", r.RemoteAddr)
445
446
		body, err := io.ReadAll(r.Body)
447
		if err != nil {
448
			plog().Error("reading body", "error", err)
449
			httpError(w, http.StatusBadRequest, "reading body: "+err.Error())
450
			return
451
		}
452
		r.Body = io.NopCloser(bytes.NewReader(body))
453
454
		rawLogRequest(r, body, "INCOMING REQUEST (passthrough)")
455
456
		targetURL := upstream + r.URL.Path
457
		if r.URL.RawQuery != "" {
458
			targetURL += "?" + r.URL.RawQuery
459
		}
460
461
		upReq, err := http.NewRequestWithContext(r.Context(), r.Method, targetURL, bytes.NewReader(body))
462
		if err != nil {
463
			plog().Error("creating upstream request", "error", err)
464
			httpError(w, http.StatusInternalServerError, "creating request: "+err.Error())
465
			return
466
		}
467
468
		maps.Copy(upReq.Header, r.Header)
469
470
		if tokenFunc != nil {
471
			token, err := tokenFunc()
472
			if err != nil {
473
				plog().Error("getting token", "error", err)
474
				httpError(w, http.StatusBadGateway, "token: "+err.Error())
475
				return
476
			}
477
			if token != "" {
478
				upReq.Header.Set("Authorization", "Bearer "+token)
479
			}
480
		}
481
482
		plog().Info("upstream request", "method", r.Method, "url", targetURL)
483
		rawLogRequest(upReq, body, "OUTGOING REQUEST (passthrough)")
484
485
		callStart := time.Now()
486
		resp, err := client.Do(upReq)
487
		if err != nil {
488
			plog().Error("upstream request failed", "error", err)
489
			DefaultMetrics.Record(CallRecord{
490
				Provider:       providerName,
491
				UpstreamDomain: domain,
492
				Endpoint:       r.URL.Path,
493
				Duration:       time.Since(callStart),
494
				Error:          true,
495
			})
496
			httpError(w, http.StatusBadGateway, "upstream: "+err.Error())
497
			return
498
		}
499
		defer resp.Body.Close()
500
501
		plog().Info("upstream response", "status", resp.StatusCode)
502
503
		respBody, err := io.ReadAll(resp.Body)
504
		if err != nil {
505
			plog().Error("reading response body", "error", err)
506
			httpError(w, http.StatusBadGateway, "reading response: "+err.Error())
507
			return
508
		}
509
510
		rawLogResponse(resp, respBody, "UPSTREAM RESPONSE (passthrough)")
511
512
		maps.Copy(w.Header(), resp.Header)
513
		w.WriteHeader(resp.StatusCode)
514
		w.Write(respBody)
515
516
		rec := CallRecord{
517
			Provider:       providerName,
518
			UpstreamDomain: domain,
519
			Endpoint:       r.URL.Path,
520
			StatusCode:     resp.StatusCode,
521
			Duration:       time.Since(callStart),
522
			Error:          resp.StatusCode != http.StatusOK,
523
		}
524
525
		// Best-effort token extraction from Anthropic response.
526
		if resp.StatusCode == http.StatusOK {
527
			var usage struct {
528
				Usage struct {
529
					InputTokens  int `json:"input_tokens"`
530
					OutputTokens int `json:"output_tokens"`
531
				} `json:"usage"`
532
				Model string `json:"model"`
533
			}
534
			if json.Unmarshal(respBody, &usage) == nil {
535
				rec.InputTokens = usage.Usage.InputTokens
536
				rec.OutputTokens = usage.Usage.OutputTokens
537
				rec.Model = usage.Model
538
			}
539
		}
540
541
		DefaultMetrics.Record(rec)
542
		plog().Info("passthrough complete")
543
	}
544
545
	mux := http.NewServeMux()
546
	mux.HandleFunc("/v1/messages", handler)
547
	mux.HandleFunc("/v1/messages/count_tokens", countTokensHandler())
548
549
	ln, err := net.Listen("tcp", "127.0.0.1:0")
550
	if err != nil {
551
		return "", nil, err
552
	}
553
554
	srv := &http.Server{Handler: mux}
555
	go func() {
556
		if err := srv.Serve(ln); err != nil && err != http.ErrServerClosed {
557
			plog().Error("passthrough proxy serve failed", "error", err)
558
		}
559
	}()
560
561
	return ln.Addr().String(), func() { srv.Close() }, nil
562
}
563
564
// countTokensHandler returns a mock response for /v1/messages/count_tokens.
565
// Claude Code calls this endpoint; without it the proxy returns 404 which may
566
// cause "unable to connect to api" errors.
567
func countTokensHandler() http.HandlerFunc {
568
	return func(w http.ResponseWriter, r *http.Request) {
569
		w.Header().Set("Content-Type", "application/json")
570
		json.NewEncoder(w).Encode(map[string]any{
571
			"input_tokens": 0,
572
		})
573
	}
574
}
575
576
// isQuotaProbe detects Claude Code's quota-check requests (max_tokens=1 with
577
// short messages) and returns a mock response to avoid hitting the upstream.
578
func isQuotaProbe(req *anthropicRequest) bool {
579
	return req.MaxTokens == 1
580
}
581
582
func httpError(w http.ResponseWriter, code int, msg string) {
583
	w.Header().Set("Content-Type", "application/json")
584
	w.WriteHeader(code)
585
	json.NewEncoder(w).Encode(map[string]any{
586
		"type": "error",
587
		"error": map[string]any{
588
			"type":    "proxy_error",
589
			"message": msg,
590
		},
591
	})
592
}
593

Source Files