| 1 | package server |
| 2 | |
| 3 | import ( |
| 4 | "net/http" |
| 5 | "strings" |
| 6 | ) |
| 7 | |
| 8 | // responseWriter wraps http.ResponseWriter to capture the status code. |
| 9 | type responseWriter struct { |
| 10 | http.ResponseWriter |
| 11 | code int |
| 12 | } |
| 13 | |
| 14 | func (rw *responseWriter) WriteHeader(code int) { |
| 15 | rw.code = code |
| 16 | rw.ResponseWriter.WriteHeader(code) |
| 17 | } |
| 18 | |
| 19 | // SumdbAuthMiddleware wraps the sumdb handler to enforce access control on lookups. |
| 20 | func (s *Server) SumdbAuthMiddleware(next http.Handler) http.Handler { |
| 21 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| 22 | if strings.HasPrefix(r.URL.Path, "/lookup/") { |
| 23 | lookupPath := strings.TrimPrefix(r.URL.Path, "/lookup/") |
| 24 | modVer := strings.TrimPrefix(lookupPath, s.Cfg.Host+"/") |
| 25 | modName, _, _ := strings.Cut(modVer, "@") |
| 26 | |
| 27 | if !s.CanAccessModule(modName, r) { |
| 28 | http.NotFound(w, r) |
| 29 | return |
| 30 | } |
| 31 | } |
| 32 | |
| 33 | next.ServeHTTP(w, r) |
| 34 | }) |
| 35 | } |
| 36 | |