models.go

v0.3.0
Doc Versions Source
1
package nvidia
2
3
import (
4
	"encoding/json"
5
	"fmt"
6
	"net/http"
7
	"slices"
8
	"sort"
9
	"strings"
10
)
11
12
const modelsURL = "https://integrate.api.nvidia.com/v1/models"
13
14
// EnterByHandCategory is the special category label for manual model ID entry.
15
const EnterByHandCategory = "✏ Enter model ID"
16
17
// NIMModel represents a model available via the NVIDIA NIM API.
18
type NIMModel struct {
19
	ID      string  `json:"id"`
20
	Object  string  `json:"object"`
21
	Created float64 `json:"created"`
22
	OwnedBy string  `json:"owned_by"`
23
}
24
25
// DisplayName returns a human-friendly name for the model.
26
func (m NIMModel) DisplayName() string {
27
	// Strip the owner prefix for display (e.g. "nvidia/llama-3.1-nemotron-ultra-253b-v1" → "llama-3.1-nemotron-ultra-253b-v1").
28
	if i := strings.Index(m.ID, "/"); i >= 0 {
29
		return m.ID[i+1:]
30
	}
31
	return m.ID
32
}
33
34
// nonChatPatterns lists substrings that indicate non-chat models (embedding, reward, safety, parsing, etc.).
35
var nonChatPatterns = []string{
36
	"embed", "reward", "guard", "safety", "parse", "clip",
37
	"deplot", "neva", "streampetr", "paligemma", "kosmos",
38
	"vila",
39
}
40
41
// isChatModel returns true if the model is likely a chat/instruct model.
42
func (m NIMModel) isChatModel() bool {
43
	lower := strings.ToLower(m.ID)
44
	for _, pat := range nonChatPatterns {
45
		if strings.Contains(lower, pat) {
46
			return false
47
		}
48
	}
49
	// Exclude base models (no "instruct", "chat", or "it" suffix pattern).
50
	// But only if it ends with "-base".
51
	if strings.HasSuffix(lower, "-base") {
52
		return false
53
	}
54
	return true
55
}
56
57
var featuredIDs = []string{
58
	"nvidia/llama-3.1-nemotron-ultra-253b-v1",
59
	"nvidia/llama-3.3-nemotron-super-49b-v1.5",
60
	"nvidia/nemotron-3-nano-30b-a3b",
61
	"nvidia/nvidia-nemotron-nano-9b-v2",
62
	"deepseek-ai/deepseek-v3.2",
63
	"deepseek-ai/deepseek-v3.1",
64
	"qwen/qwen3-coder-480b-a35b-instruct",
65
	"qwen/qwen3.5-397b-a17b",
66
	"meta/llama-4-maverick-17b-128e-instruct",
67
	"meta/llama-3.3-70b-instruct",
68
	"mistralai/mistral-large-3-675b-instruct-2512",
69
	"mistralai/devstral-2-123b-instruct-2512",
70
	"minimaxai/minimax-m2.5",
71
	"moonshotai/kimi-k2.5",
72
}
73
74
func (m NIMModel) isFeatured() bool {
75
	return slices.Contains(featuredIDs, m.ID)
76
}
77
78
var popularPrefixes = []string{
79
	"nvidia/",
80
	"meta/",
81
	"deepseek-ai/",
82
	"qwen/",
83
	"mistralai/",
84
	"google/",
85
	"microsoft/",
86
	"minimaxai/",
87
	"moonshotai/",
88
}
89
90
func (m NIMModel) isPopular() bool {
91
	for _, prefix := range popularPrefixes {
92
		if strings.HasPrefix(m.ID, prefix) {
93
			return true
94
		}
95
	}
96
	return false
97
}
98
99
// ListModels fetches available models from the NVIDIA NIM API.
100
func ListModels(apiKey string) ([]NIMModel, error) {
101
	req, err := http.NewRequest("GET", modelsURL, nil)
102
	if err != nil {
103
		return nil, err
104
	}
105
	if apiKey != "" {
106
		req.Header.Set("Authorization", "Bearer "+apiKey)
107
	}
108
	req.Header.Set("Accept", "application/json")
109
110
	resp, err := http.DefaultClient.Do(req)
111
	if err != nil {
112
		return nil, fmt.Errorf("fetching models: %w", err)
113
	}
114
	defer resp.Body.Close()
115
116
	if resp.StatusCode != http.StatusOK {
117
		return nil, fmt.Errorf("models endpoint returned %s", resp.Status)
118
	}
119
120
	var result struct {
121
		Data []NIMModel `json:"data"`
122
	}
123
	if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
124
		return nil, fmt.Errorf("parsing models: %w", err)
125
	}
126
127
	return FilterChatModels(result.Data), nil
128
}
129
130
// FilterChatModels returns only models suitable for chat/coding use.
131
func FilterChatModels(models []NIMModel) []NIMModel {
132
	var out []NIMModel
133
	for _, m := range models {
134
		if m.isChatModel() {
135
			out = append(out, m)
136
		}
137
	}
138
	return out
139
}
140
141
// CategorizeModels groups models into Featured, Popular, Newest, and Enter by hand categories.
142
func CategorizeModels(models []NIMModel) ([]string, map[string][]NIMModel) {
143
	categories := []string{"Featured", "Popular", "Newest", EnterByHandCategory}
144
	grouped := make(map[string][]NIMModel)
145
146
	var featured []NIMModel
147
	for _, m := range models {
148
		if m.isFeatured() {
149
			featured = append(featured, m)
150
		}
151
	}
152
	// Sort featured by the predefined order.
153
	sort.Slice(featured, func(i, j int) bool {
154
		ii := slices.Index(featuredIDs, featured[i].ID)
155
		ij := slices.Index(featuredIDs, featured[j].ID)
156
		return ii < ij
157
	})
158
159
	// Newest: sort by created desc, take top 40.
160
	newest := make([]NIMModel, len(models))
161
	copy(newest, models)
162
	sort.Slice(newest, func(i, j int) bool {
163
		if newest[i].Created != newest[j].Created {
164
			return newest[i].Created > newest[j].Created
165
		}
166
		return newest[i].ID < newest[j].ID
167
	})
168
	if len(newest) > 40 {
169
		newest = newest[:40]
170
	}
171
172
	grouped["Featured"] = featured
173
	grouped["Newest"] = newest
174
175
	return categories, grouped
176
}
177
178
// Provider display names and ordering for the Popular drill-down.
179
var providerDisplayNames = map[string]string{
180
	"nvidia":      "NVIDIA",
181
	"meta":        "Meta",
182
	"deepseek-ai": "DeepSeek",
183
	"qwen":        "Qwen",
184
	"mistralai":   "Mistral",
185
	"google":      "Google",
186
	"microsoft":   "Microsoft",
187
	"minimaxai":   "MiniMax",
188
	"moonshotai":  "Moonshot",
189
}
190
191
var providerOrder = []string{
192
	"nvidia", "meta", "deepseek-ai", "qwen",
193
	"mistralai", "google", "microsoft", "minimaxai", "moonshotai",
194
}
195
196
// GroupPopularByProvider groups popular models by their provider prefix.
197
func GroupPopularByProvider(models []NIMModel) ([]string, map[string][]NIMModel) {
198
	grouped := make(map[string][]NIMModel)
199
	for _, m := range models {
200
		if !m.isPopular() {
201
			continue
202
		}
203
		prefix := providerPrefix(m.ID)
204
		name := providerDisplayName(prefix)
205
		grouped[name] = append(grouped[name], m)
206
	}
207
	for name := range grouped {
208
		sort.Slice(grouped[name], func(i, j int) bool {
209
			return grouped[name][i].Created > grouped[name][j].Created
210
		})
211
	}
212
	var providers []string
213
	for _, prefix := range providerOrder {
214
		name := providerDisplayName(prefix)
215
		if _, ok := grouped[name]; ok {
216
			providers = append(providers, name)
217
		}
218
	}
219
	return providers, grouped
220
}
221
222
func providerPrefix(id string) string {
223
	if i := strings.Index(id, "/"); i > 0 {
224
		return id[:i]
225
	}
226
	return id
227
}
228
229
func providerDisplayName(prefix string) string {
230
	if name, ok := providerDisplayNames[prefix]; ok {
231
		return name
232
	}
233
	return prefix
234
}
235
236
// AllModelIDs returns a set of all model IDs for validation.
237
func AllModelIDs(models []NIMModel) map[string]bool {
238
	ids := make(map[string]bool, len(models))
239
	for _, m := range models {
240
		ids[m.ID] = true
241
	}
242
	return ids
243
}
244

Source Files