| 1 | package crypto |
| 2 | |
| 3 | import ( |
| 4 | "crypto/aes" |
| 5 | "crypto/cipher" |
| 6 | "crypto/hmac" |
| 7 | "crypto/rand" |
| 8 | "crypto/sha256" |
| 9 | "fmt" |
| 10 | ) |
| 11 | |
| 12 | // SymmetricKey holds a 64-byte key split into encryption and MAC halves. |
| 13 | type SymmetricKey struct { |
| 14 | EncKey []byte // 32 bytes |
| 15 | MacKey []byte // 32 bytes |
| 16 | } |
| 17 | |
| 18 | // NewSymmetricKey creates a SymmetricKey from a 64-byte key. |
| 19 | func NewSymmetricKey(key []byte) (*SymmetricKey, error) { |
| 20 | if len(key) != 64 { |
| 21 | return nil, fmt.Errorf("symmetric key must be 64 bytes, got %d", len(key)) |
| 22 | } |
| 23 | sk := &SymmetricKey{ |
| 24 | EncKey: make([]byte, 32), |
| 25 | MacKey: make([]byte, 32), |
| 26 | } |
| 27 | copy(sk.EncKey, key[:32]) |
| 28 | copy(sk.MacKey, key[32:]) |
| 29 | return sk, nil |
| 30 | } |
| 31 | |
| 32 | // Encrypt encrypts plaintext with AES-256-CBC + HMAC-SHA256 (type 2). |
| 33 | func (sk *SymmetricKey) Encrypt(plaintext []byte) (*CipherString, error) { |
| 34 | block, err := aes.NewCipher(sk.EncKey) |
| 35 | if err != nil { |
| 36 | return nil, fmt.Errorf("aes cipher: %w", err) |
| 37 | } |
| 38 | |
| 39 | // PKCS7 padding |
| 40 | blockSize := aes.BlockSize |
| 41 | padding := blockSize - (len(plaintext) % blockSize) |
| 42 | padded := make([]byte, len(plaintext)+padding) |
| 43 | copy(padded, plaintext) |
| 44 | for i := len(plaintext); i < len(padded); i++ { |
| 45 | padded[i] = byte(padding) |
| 46 | } |
| 47 | |
| 48 | // Random IV |
| 49 | iv := make([]byte, blockSize) |
| 50 | if _, err := rand.Read(iv); err != nil { |
| 51 | return nil, fmt.Errorf("generate IV: %w", err) |
| 52 | } |
| 53 | |
| 54 | // Encrypt |
| 55 | ct := make([]byte, len(padded)) |
| 56 | cipher.NewCBCEncrypter(block, iv).CryptBlocks(ct, padded) |
| 57 | |
| 58 | // HMAC |
| 59 | mac := hmac.New(sha256.New, sk.MacKey) |
| 60 | mac.Write(iv) |
| 61 | mac.Write(ct) |
| 62 | macSum := mac.Sum(nil) |
| 63 | |
| 64 | return &CipherString{ |
| 65 | Type: EncAesCbc256_HmacSha256_B64, |
| 66 | IV: iv, |
| 67 | CT: ct, |
| 68 | MAC: macSum, |
| 69 | }, nil |
| 70 | } |
| 71 | |
| 72 | // Decrypt decrypts a CipherString using this symmetric key. |
| 73 | func (sk *SymmetricKey) Decrypt(cs *CipherString) ([]byte, error) { |
| 74 | switch cs.Type { |
| 75 | case EncAesCbc256_HmacSha256_B64: |
| 76 | return sk.decryptAesCbc256Hmac(cs) |
| 77 | case EncAesCbc256_B64: |
| 78 | return sk.decryptAesCbc256NoHmac(cs) |
| 79 | default: |
| 80 | return nil, fmt.Errorf("symmetric key cannot decrypt type %d", cs.Type) |
| 81 | } |
| 82 | } |
| 83 | |
| 84 | func (sk *SymmetricKey) decryptAesCbc256Hmac(cs *CipherString) ([]byte, error) { |
| 85 | // Verify HMAC first |
| 86 | mac := hmac.New(sha256.New, sk.MacKey) |
| 87 | mac.Write(cs.IV) |
| 88 | mac.Write(cs.CT) |
| 89 | expectedMAC := mac.Sum(nil) |
| 90 | if !hmac.Equal(expectedMAC, cs.MAC) { |
| 91 | return nil, fmt.Errorf("HMAC verification failed") |
| 92 | } |
| 93 | |
| 94 | return sk.decryptAesCbc(cs.IV, cs.CT) |
| 95 | } |
| 96 | |
| 97 | func (sk *SymmetricKey) decryptAesCbc256NoHmac(cs *CipherString) ([]byte, error) { |
| 98 | return sk.decryptAesCbc(cs.IV, cs.CT) |
| 99 | } |
| 100 | |
| 101 | func (sk *SymmetricKey) decryptAesCbc(iv, ct []byte) ([]byte, error) { |
| 102 | block, err := aes.NewCipher(sk.EncKey) |
| 103 | if err != nil { |
| 104 | return nil, fmt.Errorf("aes cipher: %w", err) |
| 105 | } |
| 106 | |
| 107 | if len(ct)%aes.BlockSize != 0 { |
| 108 | return nil, fmt.Errorf("ciphertext not multiple of block size") |
| 109 | } |
| 110 | |
| 111 | plaintext := make([]byte, len(ct)) |
| 112 | cipher.NewCBCDecrypter(block, iv).CryptBlocks(plaintext, ct) |
| 113 | |
| 114 | // Remove PKCS7 padding |
| 115 | if len(plaintext) == 0 { |
| 116 | return plaintext, nil |
| 117 | } |
| 118 | padding := int(plaintext[len(plaintext)-1]) |
| 119 | if padding < 1 || padding > aes.BlockSize { |
| 120 | return nil, fmt.Errorf("invalid PKCS7 padding: %d", padding) |
| 121 | } |
| 122 | for i := len(plaintext) - padding; i < len(plaintext); i++ { |
| 123 | if plaintext[i] != byte(padding) { |
| 124 | return nil, fmt.Errorf("invalid PKCS7 padding bytes") |
| 125 | } |
| 126 | } |
| 127 | return plaintext[:len(plaintext)-padding], nil |
| 128 | } |
| 129 | |
| 130 | // EncryptString is a convenience wrapper that encrypts a string and returns the CipherString notation. |
| 131 | func (sk *SymmetricKey) EncryptString(s string) (string, error) { |
| 132 | cs, err := sk.Encrypt([]byte(s)) |
| 133 | if err != nil { |
| 134 | return "", err |
| 135 | } |
| 136 | return cs.String(), nil |
| 137 | } |
| 138 | |
| 139 | // DecryptString is a convenience wrapper that decrypts a CipherString notation to a string. |
| 140 | func DecryptString(s string, key *SymmetricKey) (string, error) { |
| 141 | if s == "" { |
| 142 | return "", nil |
| 143 | } |
| 144 | cs, err := ParseCipherString(s) |
| 145 | if err != nil { |
| 146 | return "", err |
| 147 | } |
| 148 | plaintext, err := key.Decrypt(cs) |
| 149 | if err != nil { |
| 150 | return "", err |
| 151 | } |
| 152 | return string(plaintext), nil |
| 153 | } |
| 154 | |