| 1 | package copilot |
| 2 | |
| 3 | import ( |
| 4 | "encoding/json" |
| 5 | "fmt" |
| 6 | "net/http" |
| 7 | "net/url" |
| 8 | "strings" |
| 9 | "sync" |
| 10 | "time" |
| 11 | ) |
| 12 | |
| 13 | const ( |
| 14 | clientID = "Iv1.b507a08c87ecfe98" |
| 15 | deviceCodeURL = "https://github.com/login/device/code" |
| 16 | accessTokenURL = "https://github.com/login/oauth/access_token" |
| 17 | copilotTokenURL = "https://api.github.com/copilot_internal/v2/token" |
| 18 | deviceFlowTimeout = 15 * time.Minute |
| 19 | tokenGracePeriod = 60 * time.Second |
| 20 | |
| 21 | editorVersion = "vscode/1.85.0" |
| 22 | pluginVersion = "copilot-chat/0.12.0" |
| 23 | userAgent = "GitHubCopilotChat/0.12.0" |
| 24 | ) |
| 25 | |
| 26 | // Headers returns editor identification headers required by the Copilot API. |
| 27 | func Headers() map[string]string { |
| 28 | return map[string]string{ |
| 29 | "User-Agent": userAgent, |
| 30 | "Editor-Version": editorVersion, |
| 31 | "Editor-Plugin-Version": pluginVersion, |
| 32 | "Copilot-Integration-Id": "vscode-chat", |
| 33 | "Openai-Organization": "github-copilot", |
| 34 | "Openai-Intent": "conversation-agent", |
| 35 | } |
| 36 | } |
| 37 | |
| 38 | // DeviceAuth runs the GitHub OAuth device flow and returns an access token. |
| 39 | func DeviceAuth() (string, error) { |
| 40 | data := url.Values{ |
| 41 | "client_id": {clientID}, |
| 42 | "scope": {"read:user"}, |
| 43 | } |
| 44 | req, err := http.NewRequest("POST", deviceCodeURL, strings.NewReader(data.Encode())) |
| 45 | if err != nil { |
| 46 | return "", err |
| 47 | } |
| 48 | req.Header.Set("Content-Type", "application/x-www-form-urlencoded") |
| 49 | req.Header.Set("Accept", "application/json") |
| 50 | |
| 51 | resp, err := http.DefaultClient.Do(req) |
| 52 | if err != nil { |
| 53 | return "", fmt.Errorf("requesting device code: %w", err) |
| 54 | } |
| 55 | defer resp.Body.Close() |
| 56 | |
| 57 | var dc struct { |
| 58 | DeviceCode string `json:"device_code"` |
| 59 | UserCode string `json:"user_code"` |
| 60 | VerificationURI string `json:"verification_uri"` |
| 61 | Interval int `json:"interval"` |
| 62 | ExpiresIn int `json:"expires_in"` |
| 63 | } |
| 64 | if err := json.NewDecoder(resp.Body).Decode(&dc); err != nil { |
| 65 | return "", fmt.Errorf("parsing device code response: %w", err) |
| 66 | } |
| 67 | |
| 68 | fmt.Printf("Please visit: %s\n", dc.VerificationURI) |
| 69 | fmt.Printf("Enter code: %s\n", dc.UserCode) |
| 70 | |
| 71 | interval := time.Duration(dc.Interval) * time.Second |
| 72 | if interval == 0 { |
| 73 | interval = 5 * time.Second |
| 74 | } |
| 75 | deadline := time.Now().Add(deviceFlowTimeout) |
| 76 | |
| 77 | for time.Now().Before(deadline) { |
| 78 | time.Sleep(interval) |
| 79 | |
| 80 | token, done, err := pollAccessToken(dc.DeviceCode) |
| 81 | if err != nil { |
| 82 | return "", err |
| 83 | } |
| 84 | if done { |
| 85 | return token, nil |
| 86 | } |
| 87 | } |
| 88 | |
| 89 | return "", fmt.Errorf("device flow timed out") |
| 90 | } |
| 91 | |
| 92 | func pollAccessToken(deviceCode string) (token string, done bool, err error) { |
| 93 | data := url.Values{ |
| 94 | "client_id": {clientID}, |
| 95 | "device_code": {deviceCode}, |
| 96 | "grant_type": {"urn:ietf:params:oauth:grant-type:device_code"}, |
| 97 | } |
| 98 | |
| 99 | req, err := http.NewRequest("POST", accessTokenURL, strings.NewReader(data.Encode())) |
| 100 | if err != nil { |
| 101 | return "", false, err |
| 102 | } |
| 103 | req.Header.Set("Content-Type", "application/x-www-form-urlencoded") |
| 104 | req.Header.Set("Accept", "application/json") |
| 105 | |
| 106 | resp, err := http.DefaultClient.Do(req) |
| 107 | if err != nil { |
| 108 | return "", false, fmt.Errorf("polling access token: %w", err) |
| 109 | } |
| 110 | defer resp.Body.Close() |
| 111 | |
| 112 | var result struct { |
| 113 | AccessToken string `json:"access_token"` |
| 114 | Error string `json:"error"` |
| 115 | } |
| 116 | if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { |
| 117 | return "", false, fmt.Errorf("parsing token response: %w", err) |
| 118 | } |
| 119 | |
| 120 | switch result.Error { |
| 121 | case "": |
| 122 | return result.AccessToken, true, nil |
| 123 | case "authorization_pending": |
| 124 | return "", false, nil |
| 125 | case "slow_down": |
| 126 | time.Sleep(5 * time.Second) |
| 127 | return "", false, nil |
| 128 | default: |
| 129 | return "", false, fmt.Errorf("oauth error: %s", result.Error) |
| 130 | } |
| 131 | } |
| 132 | |
| 133 | // TokenSource manages Copilot session tokens, refreshing them as needed. |
| 134 | type TokenSource struct { |
| 135 | oauthToken string |
| 136 | |
| 137 | mu sync.Mutex |
| 138 | token string |
| 139 | expiresAt time.Time |
| 140 | } |
| 141 | |
| 142 | // NewTokenSource creates a TokenSource that uses the given OAuth token to |
| 143 | // obtain Copilot session tokens. |
| 144 | func NewTokenSource(oauthToken string) *TokenSource { |
| 145 | return &TokenSource{oauthToken: oauthToken} |
| 146 | } |
| 147 | |
| 148 | // Token returns a valid Copilot session token, refreshing if needed. |
| 149 | func (ts *TokenSource) Token() (string, error) { |
| 150 | ts.mu.Lock() |
| 151 | defer ts.mu.Unlock() |
| 152 | |
| 153 | if ts.token != "" && time.Now().Before(ts.expiresAt.Add(-tokenGracePeriod)) { |
| 154 | return ts.token, nil |
| 155 | } |
| 156 | |
| 157 | req, err := http.NewRequest("GET", copilotTokenURL, nil) |
| 158 | if err != nil { |
| 159 | return "", err |
| 160 | } |
| 161 | req.Header.Set("Authorization", "token "+ts.oauthToken) |
| 162 | req.Header.Set("Accept", "application/json") |
| 163 | for k, v := range Headers() { |
| 164 | req.Header.Set(k, v) |
| 165 | } |
| 166 | |
| 167 | resp, err := http.DefaultClient.Do(req) |
| 168 | if err != nil { |
| 169 | return "", fmt.Errorf("fetching copilot token: %w", err) |
| 170 | } |
| 171 | defer resp.Body.Close() |
| 172 | |
| 173 | if resp.StatusCode != http.StatusOK { |
| 174 | return "", fmt.Errorf("copilot token endpoint returned %s", resp.Status) |
| 175 | } |
| 176 | |
| 177 | var result struct { |
| 178 | Token string `json:"token"` |
| 179 | ExpiresAt int64 `json:"expires_at"` |
| 180 | } |
| 181 | if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { |
| 182 | return "", fmt.Errorf("parsing copilot token: %w", err) |
| 183 | } |
| 184 | |
| 185 | ts.token = result.Token |
| 186 | ts.expiresAt = time.Unix(result.ExpiresAt, 0) |
| 187 | return ts.token, nil |
| 188 | } |
| 189 | |