models.go

v0.7.0
Doc Versions Source
1
package nvidia
2
3
import (
4
	"encoding/json"
5
	"fmt"
6
	"net/http"
7
	"slices"
8
	"sort"
9
	"strings"
10
)
11
12
const modelsURL = "https://integrate.api.nvidia.com/v1/models"
13
14
// EnterByHandCategory is the special category label for manual model ID entry.
15
const EnterByHandCategory = "✏ Enter model ID"
16
17
// NIMModel represents a model available via the NVIDIA NIM API.
18
type NIMModel struct {
19
	ID      string  `json:"id"`
20
	Object  string  `json:"object"`
21
	Created float64 `json:"created"`
22
	OwnedBy string  `json:"owned_by"`
23
}
24
25
// DisplayName returns a human-friendly name for the model.
26
func (m NIMModel) DisplayName() string {
27
	// Strip the owner prefix for display (e.g. "nvidia/llama-3.1-nemotron-ultra-253b-v1" → "llama-3.1-nemotron-ultra-253b-v1").
28
	if i := strings.Index(m.ID, "/"); i >= 0 {
29
		return m.ID[i+1:]
30
	}
31
	return m.ID
32
}
33
34
// nonChatPatterns lists substrings that indicate non-chat models (embedding, reward, safety, parsing, etc.).
35
var nonChatPatterns = []string{
36
	"embed", "reward", "guard", "safety", "parse", "clip",
37
	"deplot", "neva", "streampetr", "paligemma", "kosmos",
38
	"vila",
39
}
40
41
// isChatModel returns true if the model is likely a chat/instruct model.
42
func (m NIMModel) isChatModel() bool {
43
	lower := strings.ToLower(m.ID)
44
	for _, pat := range nonChatPatterns {
45
		if strings.Contains(lower, pat) {
46
			return false
47
		}
48
	}
49
	// Exclude base models (no "instruct", "chat", or "it" suffix pattern).
50
	// But only if it ends with "-base".
51
	if strings.HasSuffix(lower, "-base") {
52
		return false
53
	}
54
	return true
55
}
56
57
var featuredIDs = []string{
58
	"nvidia/llama-3.1-nemotron-ultra-253b-v1",
59
	"nvidia/llama-3.3-nemotron-super-49b-v1.5",
60
	"nvidia/nemotron-3-nano-30b-a3b",
61
	"nvidia/nvidia-nemotron-nano-9b-v2",
62
	"deepseek-ai/deepseek-v3.2",
63
	"deepseek-ai/deepseek-v3.1",
64
	"qwen/qwen3-coder-480b-a35b-instruct",
65
	"qwen/qwen3.5-397b-a17b",
66
	"meta/llama-4-maverick-17b-128e-instruct",
67
	"meta/llama-3.3-70b-instruct",
68
	"mistralai/mistral-large-3-675b-instruct-2512",
69
	"mistralai/devstral-2-123b-instruct-2512",
70
	"minimaxai/minimax-m2.5",
71
	"moonshotai/kimi-k2.5",
72
}
73
74
func (m NIMModel) isFeatured() bool {
75
	return slices.Contains(featuredIDs, m.ID)
76
}
77
78
var popularPrefixes = []string{
79
	"nvidia/",
80
	"meta/",
81
	"deepseek-ai/",
82
	"qwen/",
83
	"mistralai/",
84
	"google/",
85
	"microsoft/",
86
	"minimaxai/",
87
	"moonshotai/",
88
}
89
90
func (m NIMModel) isPopular() bool {
91
	for _, prefix := range popularPrefixes {
92
		if strings.HasPrefix(m.ID, prefix) {
93
			return true
94
		}
95
	}
96
	return false
97
}
98
99
// ListModels fetches available models from the NVIDIA NIM API.
100
func ListModels(apiKey string) ([]NIMModel, error) {
101
	req, err := http.NewRequest("GET", modelsURL, nil)
102
	if err != nil {
103
		return nil, err
104
	}
105
	if apiKey != "" {
106
		req.Header.Set("Authorization", "Bearer "+apiKey)
107
	}
108
	req.Header.Set("Accept", "application/json")
109
110
	resp, err := http.DefaultClient.Do(req)
111
	if err != nil {
112
		return nil, fmt.Errorf("fetching models: %w", err)
113
	}
114
	defer resp.Body.Close()
115
116
	if resp.StatusCode != http.StatusOK {
117
		return nil, fmt.Errorf("models endpoint returned %s", resp.Status)
118
	}
119
120
	var result struct {
121
		Data []NIMModel `json:"data"`
122
	}
123
	if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
124
		return nil, fmt.Errorf("parsing models: %w", err)
125
	}
126
127
	return FilterChatModels(result.Data), nil
128
}
129
130
// FilterChatModels returns only models suitable for chat/coding use.
131
func FilterChatModels(models []NIMModel) []NIMModel {
132
	var out []NIMModel
133
	for _, m := range models {
134
		if m.isChatModel() {
135
			out = append(out, m)
136
		}
137
	}
138
	return out
139
}
140
141
// CategorizeModels groups models into Featured, Popular, Newest, and Enter by hand categories.
142
func CategorizeModels(models []NIMModel) ([]string, map[string][]NIMModel) {
143
	categories := []string{"Featured", "Popular", "Newest", "All Models", EnterByHandCategory}
144
	grouped := make(map[string][]NIMModel)
145
146
	var featured []NIMModel
147
	for _, m := range models {
148
		if m.isFeatured() {
149
			featured = append(featured, m)
150
		}
151
	}
152
	// Sort featured by the predefined order.
153
	sort.Slice(featured, func(i, j int) bool {
154
		ii := slices.Index(featuredIDs, featured[i].ID)
155
		ij := slices.Index(featuredIDs, featured[j].ID)
156
		return ii < ij
157
	})
158
159
	// Newest: sort by created desc, take top 40.
160
	newest := make([]NIMModel, len(models))
161
	copy(newest, models)
162
	sort.Slice(newest, func(i, j int) bool {
163
		if newest[i].Created != newest[j].Created {
164
			return newest[i].Created > newest[j].Created
165
		}
166
		return newest[i].ID < newest[j].ID
167
	})
168
	if len(newest) > 40 {
169
		newest = newest[:40]
170
	}
171
172
	// All Models: sorted alphabetically by ID.
173
	all := make([]NIMModel, len(models))
174
	copy(all, models)
175
	sort.Slice(all, func(i, j int) bool {
176
		return all[i].ID < all[j].ID
177
	})
178
179
	grouped["Featured"] = featured
180
	grouped["Newest"] = newest
181
	grouped["All Models"] = all
182
183
	return categories, grouped
184
}
185
186
// Provider display names and ordering for the Popular drill-down.
187
var providerDisplayNames = map[string]string{
188
	"nvidia":      "NVIDIA",
189
	"meta":        "Meta",
190
	"deepseek-ai": "DeepSeek",
191
	"qwen":        "Qwen",
192
	"mistralai":   "Mistral",
193
	"google":      "Google",
194
	"microsoft":   "Microsoft",
195
	"minimaxai":   "MiniMax",
196
	"moonshotai":  "Moonshot",
197
}
198
199
var providerOrder = []string{
200
	"nvidia", "meta", "deepseek-ai", "qwen",
201
	"mistralai", "google", "microsoft", "minimaxai", "moonshotai",
202
}
203
204
// GroupPopularByProvider groups popular models by their provider prefix.
205
func GroupPopularByProvider(models []NIMModel) ([]string, map[string][]NIMModel) {
206
	grouped := make(map[string][]NIMModel)
207
	for _, m := range models {
208
		if !m.isPopular() {
209
			continue
210
		}
211
		prefix := providerPrefix(m.ID)
212
		name := providerDisplayName(prefix)
213
		grouped[name] = append(grouped[name], m)
214
	}
215
	for name := range grouped {
216
		sort.Slice(grouped[name], func(i, j int) bool {
217
			return grouped[name][i].Created > grouped[name][j].Created
218
		})
219
	}
220
	var providers []string
221
	for _, prefix := range providerOrder {
222
		name := providerDisplayName(prefix)
223
		if _, ok := grouped[name]; ok {
224
			providers = append(providers, name)
225
		}
226
	}
227
	return providers, grouped
228
}
229
230
func providerPrefix(id string) string {
231
	if i := strings.Index(id, "/"); i > 0 {
232
		return id[:i]
233
	}
234
	return id
235
}
236
237
func providerDisplayName(prefix string) string {
238
	if name, ok := providerDisplayNames[prefix]; ok {
239
		return name
240
	}
241
	return prefix
242
}
243
244
// AllModelIDs returns a set of all model IDs for validation.
245
func AllModelIDs(models []NIMModel) map[string]bool {
246
	ids := make(map[string]bool, len(models))
247
	for _, m := range models {
248
		ids[m.ID] = true
249
	}
250
	return ids
251
}
252

Source Files