| 1 | package proxy |
| 2 | |
| 3 | import ( |
| 4 | "crypto/rand" |
| 5 | "encoding/json" |
| 6 | "fmt" |
| 7 | "strings" |
| 8 | ) |
| 9 | |
| 10 | // convertRequest translates an Anthropic Messages request to an OpenAI Chat Completions request. |
| 11 | func convertRequest(req *anthropicRequest) *openaiRequest { |
| 12 | oai := &openaiRequest{ |
| 13 | Model: req.Model, |
| 14 | MaxTokens: &req.MaxTokens, |
| 15 | Temperature: req.Temperature, |
| 16 | TopP: req.TopP, |
| 17 | Stream: req.Stream, |
| 18 | Stop: req.StopSequences, |
| 19 | } |
| 20 | if req.Stream { |
| 21 | oai.StreamOptions = &streamOptions{IncludeUsage: true} |
| 22 | } |
| 23 | |
| 24 | // System message. |
| 25 | if sys := extractSystemText(req.System); sys != "" { |
| 26 | oai.Messages = append(oai.Messages, openaiMessage{Role: "system", Content: sys}) |
| 27 | } |
| 28 | |
| 29 | // Convert messages. |
| 30 | for _, msg := range req.Messages { |
| 31 | oai.Messages = append(oai.Messages, convertMessages(msg)...) |
| 32 | } |
| 33 | |
| 34 | // Tools. |
| 35 | for _, t := range req.Tools { |
| 36 | oai.Tools = append(oai.Tools, openaiTool{ |
| 37 | Type: "function", |
| 38 | Function: openaiFunction{ |
| 39 | Name: t.Name, |
| 40 | Description: t.Description, |
| 41 | Parameters: t.InputSchema, |
| 42 | }, |
| 43 | }) |
| 44 | } |
| 45 | |
| 46 | // Tool choice. |
| 47 | if len(req.ToolChoice) > 0 { |
| 48 | var tc anthropicToolChoice |
| 49 | if json.Unmarshal(req.ToolChoice, &tc) == nil { |
| 50 | switch tc.Type { |
| 51 | case "auto": |
| 52 | oai.ToolChoice = "auto" |
| 53 | case "any": |
| 54 | oai.ToolChoice = "required" |
| 55 | case "none": |
| 56 | oai.ToolChoice = "none" |
| 57 | case "tool": |
| 58 | oai.ToolChoice = map[string]any{ |
| 59 | "type": "function", |
| 60 | "function": map[string]string{"name": tc.Name}, |
| 61 | } |
| 62 | } |
| 63 | } |
| 64 | } |
| 65 | |
| 66 | return oai |
| 67 | } |
| 68 | |
| 69 | // convertResponse translates an OpenAI Chat Completions response to an Anthropic Messages response. |
| 70 | func convertResponse(oai *openaiResponse) *anthropicResponse { |
| 71 | resp := &anthropicResponse{ |
| 72 | ID: "msg_" + genHex(12), |
| 73 | Type: "message", |
| 74 | Role: "assistant", |
| 75 | } |
| 76 | if oai.Model != "" { |
| 77 | resp.Model = oai.Model |
| 78 | } |
| 79 | if oai.Usage != nil { |
| 80 | resp.Usage = anthropicUsage{ |
| 81 | InputTokens: oai.Usage.PromptTokens, |
| 82 | OutputTokens: oai.Usage.CompletionTokens, |
| 83 | } |
| 84 | if oai.Usage.PromptTokensDetails != nil { |
| 85 | resp.Usage.CacheReadInputTokens = oai.Usage.PromptTokensDetails.CachedTokens |
| 86 | } |
| 87 | } |
| 88 | |
| 89 | if len(oai.Choices) > 0 { |
| 90 | ch := oai.Choices[0] |
| 91 | resp.StopReason = mapStopReason(ch.FinishReason) |
| 92 | |
| 93 | // Text content. |
| 94 | if s, ok := ch.Message.Content.(string); ok && s != "" { |
| 95 | resp.Content = append(resp.Content, contentBlock{Type: "text", Text: s}) |
| 96 | } |
| 97 | |
| 98 | // Tool calls. |
| 99 | for _, tc := range ch.Message.ToolCalls { |
| 100 | var input json.RawMessage |
| 101 | if tc.Function.Arguments != "" { |
| 102 | input = json.RawMessage(tc.Function.Arguments) |
| 103 | } else { |
| 104 | input = json.RawMessage(`{}`) |
| 105 | } |
| 106 | resp.Content = append(resp.Content, contentBlock{ |
| 107 | Type: "tool_use", |
| 108 | ID: tc.ID, |
| 109 | Name: tc.Function.Name, |
| 110 | Input: input, |
| 111 | }) |
| 112 | } |
| 113 | } |
| 114 | |
| 115 | if len(resp.Content) == 0 { |
| 116 | resp.Content = []contentBlock{{Type: "text", Text: ""}} |
| 117 | } |
| 118 | return resp |
| 119 | } |
| 120 | |
| 121 | // ---------- helpers ---------- |
| 122 | |
| 123 | func extractSystemText(raw json.RawMessage) string { |
| 124 | if len(raw) == 0 { |
| 125 | return "" |
| 126 | } |
| 127 | // Try string first. |
| 128 | var s string |
| 129 | if json.Unmarshal(raw, &s) == nil { |
| 130 | return s |
| 131 | } |
| 132 | // Try array of {type:"text", text:"..."}. |
| 133 | var blocks []contentBlock |
| 134 | if json.Unmarshal(raw, &blocks) == nil { |
| 135 | var parts []string |
| 136 | for _, b := range blocks { |
| 137 | if b.Type == "text" && b.Text != "" { |
| 138 | parts = append(parts, b.Text) |
| 139 | } |
| 140 | } |
| 141 | return strings.Join(parts, "\n") |
| 142 | } |
| 143 | return "" |
| 144 | } |
| 145 | |
| 146 | // parseContent returns content blocks from the polymorphic content field. |
| 147 | func parseContent(raw json.RawMessage) []contentBlock { |
| 148 | if len(raw) == 0 { |
| 149 | return nil |
| 150 | } |
| 151 | var s string |
| 152 | if json.Unmarshal(raw, &s) == nil { |
| 153 | if s == "" { |
| 154 | return nil |
| 155 | } |
| 156 | return []contentBlock{{Type: "text", Text: s}} |
| 157 | } |
| 158 | var blocks []contentBlock |
| 159 | if err := json.Unmarshal(raw, &blocks); err != nil { |
| 160 | return nil |
| 161 | } |
| 162 | return blocks |
| 163 | } |
| 164 | |
| 165 | // convertMessages converts a single Anthropic message into one or more OpenAI messages. |
| 166 | func convertMessages(msg anthropicMessage) []openaiMessage { |
| 167 | blocks := parseContent(msg.Content) |
| 168 | |
| 169 | switch msg.Role { |
| 170 | case "user": |
| 171 | return convertUserMessage(blocks) |
| 172 | case "assistant": |
| 173 | return convertAssistantMessage(blocks) |
| 174 | default: |
| 175 | // Pass through as-is. |
| 176 | text := extractText(blocks) |
| 177 | return []openaiMessage{{Role: msg.Role, Content: text}} |
| 178 | } |
| 179 | } |
| 180 | |
| 181 | func convertUserMessage(blocks []contentBlock) []openaiMessage { |
| 182 | var msgs []openaiMessage |
| 183 | var textParts []string |
| 184 | |
| 185 | for _, b := range blocks { |
| 186 | switch b.Type { |
| 187 | case "text": |
| 188 | textParts = append(textParts, b.Text) |
| 189 | case "tool_result": |
| 190 | content := extractToolResultContent(b) |
| 191 | msgs = append(msgs, openaiMessage{ |
| 192 | Role: "tool", |
| 193 | Content: content, |
| 194 | ToolCallID: b.ToolUseID, |
| 195 | }) |
| 196 | } |
| 197 | } |
| 198 | |
| 199 | // Put text before tool results so ordering stays sensible. |
| 200 | if len(textParts) > 0 { |
| 201 | out := []openaiMessage{{Role: "user", Content: strings.Join(textParts, "\n")}} |
| 202 | return append(out, msgs...) |
| 203 | } |
| 204 | return msgs |
| 205 | } |
| 206 | |
| 207 | func convertAssistantMessage(blocks []contentBlock) []openaiMessage { |
| 208 | var textParts []string |
| 209 | var toolCalls []openaiToolCall |
| 210 | |
| 211 | for _, b := range blocks { |
| 212 | switch b.Type { |
| 213 | case "text": |
| 214 | textParts = append(textParts, b.Text) |
| 215 | case "tool_use": |
| 216 | args := "{}" |
| 217 | if len(b.Input) > 0 { |
| 218 | args = string(b.Input) |
| 219 | } |
| 220 | toolCalls = append(toolCalls, openaiToolCall{ |
| 221 | ID: b.ID, |
| 222 | Type: "function", |
| 223 | Function: openaiCallFunc{ |
| 224 | Name: b.Name, |
| 225 | Arguments: args, |
| 226 | }, |
| 227 | }) |
| 228 | } |
| 229 | } |
| 230 | |
| 231 | msg := openaiMessage{Role: "assistant"} |
| 232 | if len(textParts) > 0 { |
| 233 | msg.Content = strings.Join(textParts, "\n") |
| 234 | } |
| 235 | if len(toolCalls) > 0 { |
| 236 | msg.ToolCalls = toolCalls |
| 237 | } |
| 238 | // Some providers (e.g. NVIDIA NIM) reject empty content on assistant messages. |
| 239 | if msg.Content == nil && len(msg.ToolCalls) == 0 { |
| 240 | msg.Content = " " |
| 241 | } |
| 242 | return []openaiMessage{msg} |
| 243 | } |
| 244 | |
| 245 | func extractText(blocks []contentBlock) string { |
| 246 | var parts []string |
| 247 | for _, b := range blocks { |
| 248 | if b.Type == "text" { |
| 249 | parts = append(parts, b.Text) |
| 250 | } |
| 251 | } |
| 252 | return strings.Join(parts, "\n") |
| 253 | } |
| 254 | |
| 255 | func extractToolResultContent(b contentBlock) string { |
| 256 | if len(b.Content) == 0 { |
| 257 | return "" |
| 258 | } |
| 259 | var s string |
| 260 | if json.Unmarshal(b.Content, &s) == nil { |
| 261 | return s |
| 262 | } |
| 263 | var blocks []contentBlock |
| 264 | if json.Unmarshal(b.Content, &blocks) == nil { |
| 265 | return extractText(blocks) |
| 266 | } |
| 267 | return string(b.Content) |
| 268 | } |
| 269 | |
| 270 | func mapStopReason(fr *string) *string { |
| 271 | if fr == nil { |
| 272 | return nil |
| 273 | } |
| 274 | var mapped string |
| 275 | switch *fr { |
| 276 | case "stop": |
| 277 | mapped = "end_turn" |
| 278 | case "length": |
| 279 | mapped = "max_tokens" |
| 280 | case "tool_calls": |
| 281 | mapped = "tool_use" |
| 282 | default: |
| 283 | mapped = "end_turn" |
| 284 | } |
| 285 | return &mapped |
| 286 | } |
| 287 | |
| 288 | func genHex(n int) string { |
| 289 | b := make([]byte, n) |
| 290 | _, _ = rand.Read(b) |
| 291 | return fmt.Sprintf("%x", b) |
| 292 | } |
| 293 | |