server.go

v0.3.0
Doc Versions Source
1
package proxy
2
3
import (
4
	"bufio"
5
	"bytes"
6
	"encoding/json"
7
	"fmt"
8
	"io"
9
	"log"
10
	"net"
11
	"net/http"
12
	"net/url"
13
	"strings"
14
)
15
16
// Start launches a local Anthropic-to-OpenAI translation proxy.
17
// upstreamURL is the OpenAI-compatible base URL (e.g. "https://api.openai.com/v1").
18
// Returns the proxy address (host:port), a shutdown function, and any error.
19
func Start(upstreamURL string, tokenFunc func() (string, error), extraHeaders map[string]string, proxyURL string) (addr string, shutdown func(), err error) {
20
	upstream := strings.TrimRight(upstreamURL, "/")
21
22
	mux := http.NewServeMux()
23
	mux.HandleFunc("/v1/messages", handler(upstream, tokenFunc, extraHeaders, proxyURL))
24
	mux.HandleFunc("/v1/messages/count_tokens", countTokensHandler())
25
26
	ln, err := net.Listen("tcp", "127.0.0.1:0")
27
	if err != nil {
28
		return "", nil, err
29
	}
30
31
	srv := &http.Server{Handler: mux}
32
	go func() {
33
		if err := srv.Serve(ln); err != nil && err != http.ErrServerClosed {
34
			log.Printf("proxy: %v", err)
35
		}
36
	}()
37
38
	return ln.Addr().String(), func() { srv.Close() }, nil
39
}
40
41
func handler(upstream string, tokenFunc func() (string, error), extraHeaders map[string]string, proxyURL string) http.HandlerFunc {
42
	transport := http.DefaultTransport.(*http.Transport).Clone()
43
	if proxyURL != "" {
44
		if u, err := url.Parse(proxyURL); err == nil {
45
			transport.Proxy = http.ProxyURL(u)
46
		}
47
	} else {
48
		transport.Proxy = nil
49
	}
50
	client := &http.Client{Transport: transport}
51
52
	return func(w http.ResponseWriter, r *http.Request) {
53
		body, err := io.ReadAll(r.Body)
54
		if err != nil {
55
			httpError(w, http.StatusBadRequest, "reading body: "+err.Error())
56
			return
57
		}
58
59
		var req anthropicRequest
60
		if err := json.Unmarshal(body, &req); err != nil {
61
			httpError(w, http.StatusBadRequest, "parsing request: "+err.Error())
62
			return
63
		}
64
65
		// Intercept quota-probe requests (max_tokens=1) — return a mock
66
		// response instead of hitting the upstream API.
67
		if isQuotaProbe(&req) {
68
			antResp := &anthropicResponse{
69
				ID:      "msg_" + genHex(12),
70
				Type:    "message",
71
				Role:    "assistant",
72
				Content: []contentBlock{{Type: "text", Text: "ok"}},
73
				Model:   req.Model,
74
				Usage:   anthropicUsage{InputTokens: 1, OutputTokens: 1},
75
			}
76
			stopReason := "end_turn"
77
			antResp.StopReason = &stopReason
78
			w.Header().Set("Content-Type", "application/json")
79
			json.NewEncoder(w).Encode(antResp)
80
			return
81
		}
82
83
		oaiReq := convertRequest(&req)
84
		oaiBody, _ := json.Marshal(oaiReq)
85
86
		upReq, _ := http.NewRequestWithContext(r.Context(), "POST",
87
			upstream+"/chat/completions", bytes.NewReader(oaiBody))
88
		upReq.Header.Set("Content-Type", "application/json")
89
		token, err := tokenFunc()
90
		if err != nil {
91
			httpError(w, http.StatusBadGateway, "token: "+err.Error())
92
			return
93
		}
94
		upReq.Header.Set("Authorization", "Bearer "+token)
95
		for k, v := range extraHeaders {
96
			upReq.Header.Set(k, v)
97
		}
98
99
		resp, err := client.Do(upReq)
100
		if err != nil {
101
			httpError(w, http.StatusBadGateway, "upstream: "+err.Error())
102
			return
103
		}
104
		defer resp.Body.Close()
105
106
		if resp.StatusCode != http.StatusOK {
107
			errBody, _ := io.ReadAll(resp.Body)
108
			log.Printf("proxy: upstream %s: %s", resp.Status, errBody)
109
			// Translate to Anthropic error format so Claude Code understands it.
110
			w.Header().Set("Content-Type", "application/json")
111
			w.WriteHeader(resp.StatusCode)
112
			json.NewEncoder(w).Encode(map[string]any{
113
				"type": "error",
114
				"error": map[string]any{
115
					"type":    "api_error",
116
					"message": fmt.Sprintf("upstream %s: %s", resp.Status, errBody),
117
				},
118
			})
119
			return
120
		}
121
122
		if req.Stream {
123
			handleStream(w, resp, req.Model)
124
		} else {
125
			handleNonStream(w, resp)
126
		}
127
	}
128
}
129
130
func handleNonStream(w http.ResponseWriter, resp *http.Response) {
131
	var oaiResp openaiResponse
132
	if err := json.NewDecoder(resp.Body).Decode(&oaiResp); err != nil {
133
		httpError(w, http.StatusBadGateway, "decoding upstream: "+err.Error())
134
		return
135
	}
136
	antResp := convertResponse(&oaiResp)
137
	w.Header().Set("Content-Type", "application/json")
138
	json.NewEncoder(w).Encode(antResp)
139
}
140
141
func handleStream(w http.ResponseWriter, resp *http.Response, model string) {
142
	flusher, ok := w.(http.Flusher)
143
	if !ok {
144
		httpError(w, http.StatusInternalServerError, "streaming not supported")
145
		return
146
	}
147
148
	w.Header().Set("Content-Type", "text/event-stream")
149
	w.Header().Set("Cache-Control", "no-cache")
150
	w.Header().Set("Connection", "keep-alive")
151
152
	sw := &streamWriter{
153
		w:          w,
154
		flusher:    flusher,
155
		msgID:      "msg_" + genHex(12),
156
		model:      model,
157
		toolStarts: make(map[int]int),
158
	}
159
160
	scanner := bufio.NewScanner(resp.Body)
161
	scanner.Buffer(make([]byte, 0, 256*1024), 256*1024)
162
163
	for scanner.Scan() {
164
		line := scanner.Text()
165
		if !strings.HasPrefix(line, "data: ") {
166
			continue
167
		}
168
		data := strings.TrimPrefix(line, "data: ")
169
		if data == "[DONE]" {
170
			break
171
		}
172
		var chunk openaiStreamChunk
173
		if json.Unmarshal([]byte(data), &chunk) != nil {
174
			continue
175
		}
176
		sw.processChunk(chunk)
177
	}
178
179
	sw.finish()
180
}
181
182
// streamWriter translates OpenAI SSE chunks into Anthropic SSE events.
183
type streamWriter struct {
184
	w          http.ResponseWriter
185
	flusher    http.Flusher
186
	msgID      string
187
	model      string
188
	started    bool
189
	blockIndex int
190
	inText     bool
191
	inTool     bool
192
	toolStarts map[int]int // openai tool index → anthropic block index
193
	usage      *openaiUsage
194
}
195
196
func (s *streamWriter) processChunk(chunk openaiStreamChunk) {
197
	if chunk.Usage != nil {
198
		s.usage = chunk.Usage
199
	}
200
	if len(chunk.Choices) == 0 {
201
		return
202
	}
203
204
	if !s.started {
205
		s.emitMessageStart()
206
		s.started = true
207
	}
208
209
	delta := chunk.Choices[0].Delta
210
211
	// Text content.
212
	if delta.Content != nil && *delta.Content != "" {
213
		if !s.inText {
214
			s.closeCurrentBlock()
215
			s.emitContentBlockStart("text", "", "")
216
			s.inText = true
217
		}
218
		s.emit("content_block_delta", fmt.Sprintf(
219
			`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":%s}}`,
220
			s.blockIndex, mustJSON(*delta.Content)))
221
	}
222
223
	// Tool calls.
224
	for _, tc := range delta.ToolCalls {
225
		if tc.ID != "" {
226
			// New tool call starting.
227
			s.closeCurrentBlock()
228
			s.emitContentBlockStart("tool_use", tc.ID, tc.Function.Name)
229
			s.toolStarts[tc.Index] = s.blockIndex
230
			s.inTool = true
231
		}
232
		if tc.Function != nil && tc.Function.Arguments != "" {
233
			idx, ok := s.toolStarts[tc.Index]
234
			if !ok {
235
				idx = s.blockIndex
236
			}
237
			s.emit("content_block_delta", fmt.Sprintf(
238
				`{"type":"content_block_delta","index":%d,"delta":{"type":"input_json_delta","partial_json":%s}}`,
239
				idx, mustJSON(tc.Function.Arguments)))
240
		}
241
	}
242
243
	// Finish reason.
244
	if chunk.Choices[0].FinishReason != nil {
245
		s.closeCurrentBlock()
246
	}
247
}
248
249
func (s *streamWriter) finish() {
250
	if !s.started {
251
		s.emitMessageStart()
252
	}
253
	s.closeCurrentBlock()
254
255
	stopReason := "end_turn"
256
	outTokens := 0
257
	if s.usage != nil {
258
		outTokens = s.usage.CompletionTokens
259
	}
260
261
	s.emit("message_delta", fmt.Sprintf(
262
		`{"type":"message_delta","delta":{"stop_reason":"%s","stop_sequence":null},"usage":{"output_tokens":%d}}`,
263
		stopReason, outTokens))
264
	s.emit("message_stop", `{"type":"message_stop"}`)
265
}
266
267
func (s *streamWriter) emitMessageStart() {
268
	inTokens := 0
269
	if s.usage != nil {
270
		inTokens = s.usage.PromptTokens
271
	}
272
	s.emit("message_start", fmt.Sprintf(
273
		`{"type":"message_start","message":{"id":"%s","type":"message","role":"assistant","content":[],"model":"%s","stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":%d,"output_tokens":0}}}`,
274
		s.msgID, s.model, inTokens))
275
}
276
277
func (s *streamWriter) emitContentBlockStart(typ, id, name string) {
278
	switch typ {
279
	case "text":
280
		s.emit("content_block_start", fmt.Sprintf(
281
			`{"type":"content_block_start","index":%d,"content_block":{"type":"text","text":""}}`,
282
			s.blockIndex))
283
	case "tool_use":
284
		s.emit("content_block_start", fmt.Sprintf(
285
			`{"type":"content_block_start","index":%d,"content_block":{"type":"tool_use","id":"%s","name":"%s","input":{}}}`,
286
			s.blockIndex, id, name))
287
	}
288
}
289
290
func (s *streamWriter) closeCurrentBlock() {
291
	if s.inText || s.inTool {
292
		s.emit("content_block_stop", fmt.Sprintf(
293
			`{"type":"content_block_stop","index":%d}`, s.blockIndex))
294
		s.blockIndex++
295
		s.inText = false
296
		s.inTool = false
297
	}
298
}
299
300
func (s *streamWriter) emit(event, data string) {
301
	fmt.Fprintf(s.w, "event: %s\ndata: %s\n\n", event, data)
302
	s.flusher.Flush()
303
}
304
305
func mustJSON(s string) string {
306
	b, _ := json.Marshal(s)
307
	return string(b)
308
}
309
310
// StaticToken returns a token function that always returns the given key.
311
func StaticToken(key string) func() (string, error) {
312
	return func() (string, error) { return key, nil }
313
}
314
315
// countTokensHandler returns a mock response for /v1/messages/count_tokens.
316
// Claude Code calls this endpoint; without it the proxy returns 404 which may
317
// cause "unable to connect to api" errors.
318
func countTokensHandler() http.HandlerFunc {
319
	return func(w http.ResponseWriter, r *http.Request) {
320
		w.Header().Set("Content-Type", "application/json")
321
		json.NewEncoder(w).Encode(map[string]any{
322
			"input_tokens": 0,
323
		})
324
	}
325
}
326
327
// isQuotaProbe detects Claude Code's quota-check requests (max_tokens=1 with
328
// short messages) and returns a mock response to avoid hitting the upstream.
329
func isQuotaProbe(req *anthropicRequest) bool {
330
	return req.MaxTokens == 1
331
}
332
333
func httpError(w http.ResponseWriter, code int, msg string) {
334
	w.Header().Set("Content-Type", "application/json")
335
	w.WriteHeader(code)
336
	json.NewEncoder(w).Encode(map[string]any{
337
		"type": "error",
338
		"error": map[string]any{
339
			"type":    "proxy_error",
340
			"message": msg,
341
		},
342
	})
343
}
344

Source Files