symmetric.go

v1.0.0
Doc Versions Source
1
package crypto
2
3
import (
4
	"crypto/aes"
5
	"crypto/cipher"
6
	"crypto/hmac"
7
	"crypto/rand"
8
	"crypto/sha256"
9
	"fmt"
10
)
11
12
// SymmetricKey holds a 64-byte key split into encryption and MAC halves.
13
type SymmetricKey struct {
14
	EncKey []byte // 32 bytes
15
	MacKey []byte // 32 bytes
16
}
17
18
// NewSymmetricKey creates a SymmetricKey from a 64-byte key.
19
func NewSymmetricKey(key []byte) (*SymmetricKey, error) {
20
	if len(key) != 64 {
21
		return nil, fmt.Errorf("symmetric key must be 64 bytes, got %d", len(key))
22
	}
23
	sk := &SymmetricKey{
24
		EncKey: make([]byte, 32),
25
		MacKey: make([]byte, 32),
26
	}
27
	copy(sk.EncKey, key[:32])
28
	copy(sk.MacKey, key[32:])
29
	return sk, nil
30
}
31
32
// Encrypt encrypts plaintext with AES-256-CBC + HMAC-SHA256 (type 2).
33
func (sk *SymmetricKey) Encrypt(plaintext []byte) (*CipherString, error) {
34
	block, err := aes.NewCipher(sk.EncKey)
35
	if err != nil {
36
		return nil, fmt.Errorf("aes cipher: %w", err)
37
	}
38
39
	// PKCS7 padding
40
	blockSize := aes.BlockSize
41
	padding := blockSize - (len(plaintext) % blockSize)
42
	padded := make([]byte, len(plaintext)+padding)
43
	copy(padded, plaintext)
44
	for i := len(plaintext); i < len(padded); i++ {
45
		padded[i] = byte(padding)
46
	}
47
48
	// Random IV
49
	iv := make([]byte, blockSize)
50
	if _, err := rand.Read(iv); err != nil {
51
		return nil, fmt.Errorf("generate IV: %w", err)
52
	}
53
54
	// Encrypt
55
	ct := make([]byte, len(padded))
56
	cipher.NewCBCEncrypter(block, iv).CryptBlocks(ct, padded)
57
58
	// HMAC
59
	mac := hmac.New(sha256.New, sk.MacKey)
60
	mac.Write(iv)
61
	mac.Write(ct)
62
	macSum := mac.Sum(nil)
63
64
	return &CipherString{
65
		Type: EncAesCbc256_HmacSha256_B64,
66
		IV:   iv,
67
		CT:   ct,
68
		MAC:  macSum,
69
	}, nil
70
}
71
72
// Decrypt decrypts a CipherString using this symmetric key.
73
func (sk *SymmetricKey) Decrypt(cs *CipherString) ([]byte, error) {
74
	switch cs.Type {
75
	case EncAesCbc256_HmacSha256_B64:
76
		return sk.decryptAesCbc256Hmac(cs)
77
	case EncAesCbc256_B64:
78
		return sk.decryptAesCbc256NoHmac(cs)
79
	default:
80
		return nil, fmt.Errorf("symmetric key cannot decrypt type %d", cs.Type)
81
	}
82
}
83
84
func (sk *SymmetricKey) decryptAesCbc256Hmac(cs *CipherString) ([]byte, error) {
85
	// Verify HMAC first
86
	mac := hmac.New(sha256.New, sk.MacKey)
87
	mac.Write(cs.IV)
88
	mac.Write(cs.CT)
89
	expectedMAC := mac.Sum(nil)
90
	if !hmac.Equal(expectedMAC, cs.MAC) {
91
		return nil, fmt.Errorf("HMAC verification failed")
92
	}
93
94
	return sk.decryptAesCbc(cs.IV, cs.CT)
95
}
96
97
func (sk *SymmetricKey) decryptAesCbc256NoHmac(cs *CipherString) ([]byte, error) {
98
	return sk.decryptAesCbc(cs.IV, cs.CT)
99
}
100
101
func (sk *SymmetricKey) decryptAesCbc(iv, ct []byte) ([]byte, error) {
102
	block, err := aes.NewCipher(sk.EncKey)
103
	if err != nil {
104
		return nil, fmt.Errorf("aes cipher: %w", err)
105
	}
106
107
	if len(ct)%aes.BlockSize != 0 {
108
		return nil, fmt.Errorf("ciphertext not multiple of block size")
109
	}
110
111
	plaintext := make([]byte, len(ct))
112
	cipher.NewCBCDecrypter(block, iv).CryptBlocks(plaintext, ct)
113
114
	// Remove PKCS7 padding
115
	if len(plaintext) == 0 {
116
		return plaintext, nil
117
	}
118
	padding := int(plaintext[len(plaintext)-1])
119
	if padding < 1 || padding > aes.BlockSize {
120
		return nil, fmt.Errorf("invalid PKCS7 padding: %d", padding)
121
	}
122
	for i := len(plaintext) - padding; i < len(plaintext); i++ {
123
		if plaintext[i] != byte(padding) {
124
			return nil, fmt.Errorf("invalid PKCS7 padding bytes")
125
		}
126
	}
127
	return plaintext[:len(plaintext)-padding], nil
128
}
129
130
// EncryptString is a convenience wrapper that encrypts a string and returns the CipherString notation.
131
func (sk *SymmetricKey) EncryptString(s string) (string, error) {
132
	cs, err := sk.Encrypt([]byte(s))
133
	if err != nil {
134
		return "", err
135
	}
136
	return cs.String(), nil
137
}
138
139
// DecryptString is a convenience wrapper that decrypts a CipherString notation to a string.
140
func DecryptString(s string, key *SymmetricKey) (string, error) {
141
	if s == "" {
142
		return "", nil
143
	}
144
	cs, err := ParseCipherString(s)
145
	if err != nil {
146
		return "", err
147
	}
148
	plaintext, err := key.Decrypt(cs)
149
	if err != nil {
150
		return "", err
151
	}
152
	return string(plaintext), nil
153
}
154

Source Files