oidcauth.go

v1.3.6
Doc Versions Source
1
package oidcauth
2
3
import (
4
	"context"
5
	"crypto/hmac"
6
	"crypto/rand"
7
	"crypto/sha256"
8
	"encoding/base64"
9
	"encoding/json"
10
	"fmt"
11
	"net/http"
12
	"time"
13
14
	"github.com/coreos/go-oidc/v3/oidc"
15
	"golang.org/x/oauth2"
16
17
	"go.bigb.es/curator/internal/config"
18
)
19
20
const (
21
	stateCookieName   = "curator_oidc_state"
22
	sessionCookieName = "curator_session"
23
	sessionDuration   = 8 * time.Hour
24
)
25
26
// Provider handles OIDC authentication for the admin UI.
27
type Provider struct {
28
	oauth2Cfg  oauth2.Config
29
	verifier   *oidc.IDTokenVerifier
30
	sessionKey []byte // HMAC key derived from client secret
31
}
32
33
// New creates an OIDC provider by performing discovery on the issuer.
34
func New(ctx context.Context, cfg *config.OIDCConfig, clientSecret string) (*Provider, error) {
35
	provider, err := oidc.NewProvider(ctx, cfg.Issuer)
36
	if err != nil {
37
		return nil, fmt.Errorf("oidc discovery: %w", err)
38
	}
39
40
	oauth2Cfg := oauth2.Config{
41
		ClientID:     cfg.ClientID,
42
		ClientSecret: clientSecret,
43
		RedirectURL:  cfg.RedirectURL,
44
		Endpoint:     provider.Endpoint(),
45
		Scopes:       []string{oidc.ScopeOpenID, "profile", "email"},
46
	}
47
48
	verifier := provider.Verifier(&oidc.Config{ClientID: cfg.ClientID})
49
50
	// Derive session signing key from client secret via SHA-256.
51
	key := sha256.Sum256([]byte("curator-session:" + clientSecret))
52
53
	return &Provider{
54
		oauth2Cfg:  oauth2Cfg,
55
		verifier:   verifier,
56
		sessionKey: key[:],
57
	}, nil
58
}
59
60
// LoginHandler redirects to the OIDC provider's authorization endpoint.
61
func (p *Provider) LoginHandler(w http.ResponseWriter, r *http.Request) {
62
	state := randomString(32)
63
64
	http.SetCookie(w, &http.Cookie{
65
		Name:     stateCookieName,
66
		Value:    state,
67
		Path:     "/-/oidc/",
68
		MaxAge:   300, // 5 minutes
69
		HttpOnly: true,
70
		SameSite: http.SameSiteLaxMode,
71
		Secure:   r.TLS != nil,
72
	})
73
74
	http.Redirect(w, r, p.oauth2Cfg.AuthCodeURL(state, oauth2.S256ChallengeOption(state)), http.StatusFound)
75
}
76
77
// CallbackHandler handles the OIDC callback, exchanges the code, and sets a session cookie.
78
func (p *Provider) CallbackHandler(w http.ResponseWriter, r *http.Request) {
79
	// Validate state.
80
	stateCookie, err := r.Cookie(stateCookieName)
81
	if err != nil || stateCookie.Value == "" {
82
		http.Error(w, "missing state cookie", http.StatusBadRequest)
83
		return
84
	}
85
	if r.URL.Query().Get("state") != stateCookie.Value {
86
		http.Error(w, "state mismatch", http.StatusBadRequest)
87
		return
88
	}
89
90
	// Clear state cookie.
91
	http.SetCookie(w, &http.Cookie{
92
		Name:   stateCookieName,
93
		Path:   "/-/oidc/",
94
		MaxAge: -1,
95
	})
96
97
	// Check for error response from provider.
98
	if errCode := r.URL.Query().Get("error"); errCode != "" {
99
		desc := r.URL.Query().Get("error_description")
100
		http.Error(w, fmt.Sprintf("oidc error: %s: %s", errCode, desc), http.StatusForbidden)
101
		return
102
	}
103
104
	// Exchange code for tokens.
105
	token, err := p.oauth2Cfg.Exchange(r.Context(), r.URL.Query().Get("code"),
106
		oauth2.VerifierOption(stateCookie.Value))
107
	if err != nil {
108
		http.Error(w, "token exchange failed", http.StatusInternalServerError)
109
		return
110
	}
111
112
	// Extract and verify ID token.
113
	rawIDToken, ok := token.Extra("id_token").(string)
114
	if !ok {
115
		http.Error(w, "missing id_token", http.StatusInternalServerError)
116
		return
117
	}
118
119
	idToken, err := p.verifier.Verify(r.Context(), rawIDToken)
120
	if err != nil {
121
		http.Error(w, "id_token verification failed", http.StatusForbidden)
122
		return
123
	}
124
125
	var claims struct {
126
		Sub   string `json:"sub"`
127
		Email string `json:"email"`
128
		Name  string `json:"name"`
129
	}
130
	if err := idToken.Claims(&claims); err != nil {
131
		http.Error(w, "failed to parse claims", http.StatusInternalServerError)
132
		return
133
	}
134
135
	// Create session cookie.
136
	sessionValue, err := p.signSession(claims.Sub, claims.Email)
137
	if err != nil {
138
		http.Error(w, "session creation failed", http.StatusInternalServerError)
139
		return
140
	}
141
142
	http.SetCookie(w, &http.Cookie{
143
		Name:     sessionCookieName,
144
		Value:    sessionValue,
145
		Path:     "/-/",
146
		MaxAge:   int(sessionDuration.Seconds()),
147
		HttpOnly: true,
148
		SameSite: http.SameSiteLaxMode,
149
		Secure:   r.TLS != nil,
150
	})
151
152
	http.Redirect(w, r, "/-/admin/", http.StatusFound)
153
}
154
155
// LogoutHandler clears the session cookie and redirects to the admin login.
156
func (p *Provider) LogoutHandler(w http.ResponseWriter, r *http.Request) {
157
	http.SetCookie(w, &http.Cookie{
158
		Name:   sessionCookieName,
159
		Path:   "/-/",
160
		MaxAge: -1,
161
	})
162
	http.Redirect(w, r, "/-/admin/", http.StatusFound)
163
}
164
165
// session is the JSON payload stored in the session cookie.
166
type session struct {
167
	Sub   string `json:"sub"`
168
	Email string `json:"email"`
169
	Exp   int64  `json:"exp"`
170
}
171
172
// signSession creates an HMAC-signed session cookie value.
173
func (p *Provider) signSession(sub, email string) (string, error) {
174
	s := session{
175
		Sub:   sub,
176
		Email: email,
177
		Exp:   time.Now().Add(sessionDuration).Unix(),
178
	}
179
180
	payload, err := json.Marshal(s)
181
	if err != nil {
182
		return "", err
183
	}
184
185
	mac := hmac.New(sha256.New, p.sessionKey)
186
	mac.Write(payload)
187
	sig := mac.Sum(nil)
188
189
	// payload.signature, both base64url encoded.
190
	return base64.RawURLEncoding.EncodeToString(payload) + "." +
191
		base64.RawURLEncoding.EncodeToString(sig), nil
192
}
193
194
// ValidateSession verifies an HMAC-signed session cookie value.
195
// Returns the email on success or an empty string on failure.
196
func (p *Provider) ValidateSession(value string) string {
197
	dot := -1
198
	for i := len(value) - 1; i >= 0; i-- {
199
		if value[i] == '.' {
200
			dot = i
201
			break
202
		}
203
	}
204
	if dot < 0 {
205
		return ""
206
	}
207
208
	payload, err := base64.RawURLEncoding.DecodeString(value[:dot])
209
	if err != nil {
210
		return ""
211
	}
212
	sig, err := base64.RawURLEncoding.DecodeString(value[dot+1:])
213
	if err != nil {
214
		return ""
215
	}
216
217
	mac := hmac.New(sha256.New, p.sessionKey)
218
	mac.Write(payload)
219
	if !hmac.Equal(sig, mac.Sum(nil)) {
220
		return ""
221
	}
222
223
	var s session
224
	if err := json.Unmarshal(payload, &s); err != nil {
225
		return ""
226
	}
227
	if time.Now().Unix() > s.Exp {
228
		return ""
229
	}
230
231
	if s.Email != "" {
232
		return s.Email
233
	}
234
	return s.Sub
235
}
236
237
func randomString(n int) string {
238
	b := make([]byte, n)
239
	rand.Read(b)
240
	return base64.RawURLEncoding.EncodeToString(b)
241
}
242

Source Files