webpush-client-go/rfc8291/aesgcm.go

185 lines
5.4 KiB
Go

package rfc8291
import (
"bytes"
"crypto/ecdh"
"encoding/base64"
"fmt"
"hash"
"io"
"strings"
"golang.org/x/crypto/hkdf"
)
// AesgcmScheme implements EncodingScheme for the aesgcm encoding.
type AesgcmScheme struct{}
func (s AesgcmScheme) DeriveIKM(hash func() hash.Hash, authSecret, ecdhSecret []byte, uaKey, asKey *ecdh.PublicKey) ([]byte, error) {
prkKey := hkdf.Extract(hash, ecdhSecret, authSecret)
// aesgcm: just "Content-Encoding: auth\0" without keys
keyInfo := []byte("Content-Encoding: auth\000")
ikm := make([]byte, HKDF_IKM_LEN)
if _, err := io.ReadFull(hkdf.Expand(hash, prkKey, keyInfo), ikm); err != nil {
return nil, fmt.Errorf("derive IKM failed: %v", err)
}
return ikm, nil
}
func (s AesgcmScheme) DeriveCEKAndNonce(hash func() hash.Hash, ikm, salt []byte, uaKey, asKey *ecdh.PublicKey) (cek, nonce []byte, err error) {
prk := hkdf.Extract(hash, ikm, salt)
// aesgcm: info includes "P-256\0" and length-prefixed public keys
uaKeyBytes := uaKey.Bytes()
asKeyBytes := asKey.Bytes()
context := bytes.Join([][]byte{
[]byte("P-256\000"),
{0, byte(len(uaKeyBytes))},
uaKeyBytes,
{0, byte(len(asKeyBytes))},
asKeyBytes,
}, nil)
cekInfo := append([]byte("Content-Encoding: aesgcm\000"), context...)
nonceInfo := append([]byte("Content-Encoding: nonce\000"), context...)
cek = make([]byte, HKDF_CEK_LEN)
if _, err := io.ReadFull(hkdf.Expand(hash, prk, cekInfo), cek); err != nil {
return nil, nil, fmt.Errorf("derive CEK failed: %v", err)
}
nonce = make([]byte, HKDF_NONCE_LEN)
if _, err := io.ReadFull(hkdf.Expand(hash, prk, nonceInfo), nonce); err != nil {
return nil, nil, fmt.Errorf("derive nonce failed: %v", err)
}
return cek, nonce, nil
}
func (s AesgcmScheme) Pad(plaintext []byte) []byte {
// aesgcm: 2-byte big-endian padding length prefix (0 padding)
result := make([]byte, 2+len(plaintext))
// First two bytes are 0 (no padding), already zero-initialized
copy(result[2:], plaintext)
return result
}
func (s AesgcmScheme) Unpad(data []byte) ([]byte, error) {
if len(data) < 2 {
return nil, fmt.Errorf("data too short for aesgcm padding")
}
padLen := int(data[0])<<8 | int(data[1])
if 2+padLen > len(data) {
return nil, fmt.Errorf("invalid padding length: %d (data length: %d)", padLen, len(data))
}
// Verify padding bytes are all zeros
for i := 2; i < 2+padLen; i++ {
if data[i] != 0 {
return nil, fmt.Errorf("invalid padding: non-zero byte at position %d", i)
}
}
return data[2+padLen:], nil
}
// EncryptResult holds the result of aesgcm encryption.
// Unlike aes128gcm which embeds crypto params in the payload,
// aesgcm requires these to be sent as HTTP headers.
type EncryptResult struct {
Ciphertext []byte // The encrypted data (for request body)
Salt []byte // For Encryption header: salt=<base64url>
SenderPublicKey []byte // For Crypto-Key header: dh=<base64url>
}
// CryptoParams holds the extracted cryptographic parameters for decryption.
type CryptoParams struct {
Salt []byte
SenderPublicKey *ecdh.PublicKey
}
// ParseAesgcmHeaders extracts salt and sender public key from aesgcm HTTP headers.
// encryptionHeader: e.g., "salt=FiyMDLvlVl678odI9AWL3A"
// cryptoKeyHeader: e.g., "dh=BMLYo...;p256ecdsa=BF5o..."
func ParseAesgcmHeaders(encryptionHeader, cryptoKeyHeader string, curve ecdh.Curve) (*CryptoParams, error) {
salt, err := parseHeaderParam(encryptionHeader, "salt")
if err != nil {
return nil, fmt.Errorf("failed to parse salt: %v", err)
}
if len(salt) != SALT_LEN {
return nil, fmt.Errorf("salt must be %d bytes, got %d", SALT_LEN, len(salt))
}
dhBytes, err := parseHeaderParam(cryptoKeyHeader, "dh")
if err != nil {
return nil, fmt.Errorf("failed to parse dh: %v", err)
}
senderPublicKey, err := curve.NewPublicKey(dhBytes)
if err != nil {
return nil, fmt.Errorf("failed to parse sender public key: %v", err)
}
return &CryptoParams{
Salt: salt,
SenderPublicKey: senderPublicKey,
}, nil
}
// parseHeaderParam extracts a base64url-encoded parameter value from a header string.
func parseHeaderParam(header, paramName string) ([]byte, error) {
parts := strings.FieldsFunc(header, func(r rune) bool {
return r == ';' || r == ','
})
prefix := paramName + "="
for _, part := range parts {
part = strings.TrimSpace(part)
if strings.HasPrefix(part, prefix) {
value := strings.TrimPrefix(part, prefix)
value = strings.Trim(value, "\"")
return base64.RawURLEncoding.DecodeString(value)
}
}
return nil, fmt.Errorf("parameter %q not found in header", paramName)
}
// EncryptAesgcm encrypts a message using the aesgcm encoding scheme.
// Returns the ciphertext and crypto parameters needed for HTTP headers.
func (c *RFC8291) EncryptAesgcm(
plaintext []byte,
salt []byte,
authSecret []byte,
receiverPublicKey *ecdh.PublicKey,
senderPrivateKey *ecdh.PrivateKey,
) (*EncryptResult, error) {
ciphertext, err := c.encrypt(AesgcmScheme{}, plaintext, salt, authSecret, receiverPublicKey, senderPrivateKey)
if err != nil {
return nil, err
}
return &EncryptResult{
Ciphertext: ciphertext,
Salt: salt,
SenderPublicKey: senderPrivateKey.PublicKey().Bytes(),
}, nil
}
// DecryptAesgcm decrypts a message encrypted with the aesgcm encoding scheme.
func (c *RFC8291) DecryptAesgcm(
ciphertext []byte,
salt []byte,
authSecret []byte,
receiverPrivateKey *ecdh.PrivateKey,
senderPublicKey *ecdh.PublicKey,
) ([]byte, error) {
return c.decrypt(AesgcmScheme{}, ciphertext, salt, authSecret, receiverPrivateKey, senderPublicKey)
}