models.go

v0.5.1
Doc Versions Source
1
package copilot
2
3
import (
4
	"encoding/json"
5
	"fmt"
6
	"net/http"
7
	"sort"
8
	"strings"
9
)
10
11
const modelsURL = "https://api.githubcopilot.com/models"
12
13
// CopilotModel represents a model available via the Copilot API.
14
type CopilotModel struct {
15
	ID            string `json:"id"`
16
	Name          string `json:"name"`
17
	Version       string `json:"version"`
18
	Vendor        string `json:"vendor"`
19
	ContextWindow int    `json:"context_window"`
20
}
21
22
// DisplayName returns a human-friendly name for the model.
23
func (m CopilotModel) DisplayName() string {
24
	if m.Name != "" {
25
		return m.Name
26
	}
27
	return m.ID
28
}
29
30
// ListModels fetches available models from the Copilot API using the given OAuth token.
31
func ListModels(oauthToken string) ([]CopilotModel, error) {
32
	ts := NewTokenSource(oauthToken)
33
	token, err := ts.Token()
34
	if err != nil {
35
		return nil, fmt.Errorf("getting copilot token: %w", err)
36
	}
37
38
	req, err := http.NewRequest("GET", modelsURL, nil)
39
	if err != nil {
40
		return nil, err
41
	}
42
	req.Header.Set("Authorization", "Bearer "+token)
43
	req.Header.Set("Accept", "application/json")
44
	for k, v := range Headers() {
45
		req.Header.Set(k, v)
46
	}
47
48
	resp, err := http.DefaultClient.Do(req)
49
	if err != nil {
50
		return nil, fmt.Errorf("fetching models: %w", err)
51
	}
52
	defer resp.Body.Close()
53
54
	if resp.StatusCode != http.StatusOK {
55
		return nil, fmt.Errorf("models endpoint returned %s", resp.Status)
56
	}
57
58
	var result struct {
59
		Data []struct {
60
			ID      string `json:"id"`
61
			Name    string `json:"name"`
62
			Version string `json:"version"`
63
			// ModelPickerEnabled indicates the model is selectable.
64
			ModelPickerEnabled bool `json:"model_picker_enabled"`
65
			// Capabilities lists what the model can do.
66
			Capabilities struct {
67
				Type   string `json:"type"`
68
				Limits struct {
69
					MaxContextWindowTokens int `json:"max_context_window_tokens"`
70
				} `json:"limits"`
71
			} `json:"capabilities"`
72
			Vendor string `json:"vendor"`
73
		} `json:"data"`
74
	}
75
	if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
76
		return nil, fmt.Errorf("parsing models: %w", err)
77
	}
78
79
	var models []CopilotModel
80
	for _, m := range result.Data {
81
		// Only include chat-capable models.
82
		if m.Capabilities.Type != "" && m.Capabilities.Type != "chat" {
83
			continue
84
		}
85
		models = append(models, CopilotModel{
86
			ID:            m.ID,
87
			Name:          m.Name,
88
			Version:       m.Version,
89
			Vendor:        m.Vendor,
90
			ContextWindow: m.Capabilities.Limits.MaxContextWindowTokens,
91
		})
92
	}
93
94
	sort.Slice(models, func(i, j int) bool {
95
		vi, vj := vendorOrder(models[i].Vendor), vendorOrder(models[j].Vendor)
96
		if vi != vj {
97
			return vi < vj
98
		}
99
		return strings.ToLower(models[i].ID) < strings.ToLower(models[j].ID)
100
	})
101
102
	return models, nil
103
}
104
105
// VendorFamilies returns the ordered list of family names and models grouped by family.
106
func VendorFamilies(models []CopilotModel) ([]string, map[string][]CopilotModel) {
107
	grouped := make(map[string][]CopilotModel)
108
	for _, m := range models {
109
		fam := familyName(m.Vendor)
110
		grouped[fam] = append(grouped[fam], m)
111
	}
112
	// Order families deterministically.
113
	order := []string{"Claude", "GPT", "Gemini"}
114
	var extra []string
115
	for fam := range grouped {
116
		if fam != "Claude" && fam != "GPT" && fam != "Gemini" {
117
			extra = append(extra, fam)
118
		}
119
	}
120
	sort.Strings(extra)
121
	var families []string
122
	for _, f := range append(order, extra...) {
123
		if _, ok := grouped[f]; ok {
124
			families = append(families, f)
125
		}
126
	}
127
	return families, grouped
128
}
129
130
func familyName(vendor string) string {
131
	switch strings.ToLower(vendor) {
132
	case "anthropic":
133
		return "Claude"
134
	case "openai":
135
		return "GPT"
136
	case "google":
137
		return "Gemini"
138
	default:
139
		if vendor != "" {
140
			return vendor
141
		}
142
		return "Other"
143
	}
144
}
145
146
func vendorOrder(v string) int {
147
	switch strings.ToLower(v) {
148
	case "anthropic":
149
		return 0
150
	case "openai":
151
		return 1
152
	case "google":
153
		return 2
154
	default:
155
		return 3
156
	}
157
}
158

Source Files