package rfc8291 import ( "bytes" "crypto/ecdh" "encoding/base64" "fmt" "hash" "io" "strings" "golang.org/x/crypto/hkdf" ) // AesgcmScheme implements EncodingScheme for the aesgcm encoding. type AesgcmScheme struct{} func (s AesgcmScheme) DeriveIKM(hash func() hash.Hash, authSecret, ecdhSecret []byte, uaKey, asKey *ecdh.PublicKey) ([]byte, error) { prkKey := hkdf.Extract(hash, ecdhSecret, authSecret) // aesgcm: just "Content-Encoding: auth\0" without keys keyInfo := []byte("Content-Encoding: auth\000") ikm := make([]byte, HKDF_IKM_LEN) if _, err := io.ReadFull(hkdf.Expand(hash, prkKey, keyInfo), ikm); err != nil { return nil, fmt.Errorf("derive IKM failed: %v", err) } return ikm, nil } func (s AesgcmScheme) DeriveCEKAndNonce(hash func() hash.Hash, ikm, salt []byte, uaKey, asKey *ecdh.PublicKey) (cek, nonce []byte, err error) { prk := hkdf.Extract(hash, ikm, salt) // aesgcm: info includes "P-256\0" and length-prefixed public keys uaKeyBytes := uaKey.Bytes() asKeyBytes := asKey.Bytes() context := bytes.Join([][]byte{ []byte("P-256\000"), {0, byte(len(uaKeyBytes))}, uaKeyBytes, {0, byte(len(asKeyBytes))}, asKeyBytes, }, nil) cekInfo := append([]byte("Content-Encoding: aesgcm\000"), context...) nonceInfo := append([]byte("Content-Encoding: nonce\000"), context...) cek = make([]byte, HKDF_CEK_LEN) if _, err := io.ReadFull(hkdf.Expand(hash, prk, cekInfo), cek); err != nil { return nil, nil, fmt.Errorf("derive CEK failed: %v", err) } nonce = make([]byte, HKDF_NONCE_LEN) if _, err := io.ReadFull(hkdf.Expand(hash, prk, nonceInfo), nonce); err != nil { return nil, nil, fmt.Errorf("derive nonce failed: %v", err) } return cek, nonce, nil } func (s AesgcmScheme) Pad(plaintext []byte) []byte { // aesgcm: 2-byte big-endian padding length prefix (0 padding) result := make([]byte, 2+len(plaintext)) // First two bytes are 0 (no padding), already zero-initialized copy(result[2:], plaintext) return result } func (s AesgcmScheme) Unpad(data []byte) ([]byte, error) { if len(data) < 2 { return nil, fmt.Errorf("data too short for aesgcm padding") } padLen := int(data[0])<<8 | int(data[1]) if 2+padLen > len(data) { return nil, fmt.Errorf("invalid padding length: %d (data length: %d)", padLen, len(data)) } // Verify padding bytes are all zeros for i := 2; i < 2+padLen; i++ { if data[i] != 0 { return nil, fmt.Errorf("invalid padding: non-zero byte at position %d", i) } } return data[2+padLen:], nil } // EncryptResult holds the result of aesgcm encryption. // Unlike aes128gcm which embeds crypto params in the payload, // aesgcm requires these to be sent as HTTP headers. type EncryptResult struct { Ciphertext []byte // The encrypted data (for request body) Salt []byte // For Encryption header: salt= SenderPublicKey []byte // For Crypto-Key header: dh= } // CryptoParams holds the extracted cryptographic parameters for decryption. type CryptoParams struct { Salt []byte SenderPublicKey *ecdh.PublicKey } // ParseAesgcmHeaders extracts salt and sender public key from aesgcm HTTP headers. // encryptionHeader: e.g., "salt=FiyMDLvlVl678odI9AWL3A" // cryptoKeyHeader: e.g., "dh=BMLYo...;p256ecdsa=BF5o..." func ParseAesgcmHeaders(encryptionHeader, cryptoKeyHeader string, curve ecdh.Curve) (*CryptoParams, error) { salt, err := parseHeaderParam(encryptionHeader, "salt") if err != nil { return nil, fmt.Errorf("failed to parse salt: %v", err) } if len(salt) != SALT_LEN { return nil, fmt.Errorf("salt must be %d bytes, got %d", SALT_LEN, len(salt)) } dhBytes, err := parseHeaderParam(cryptoKeyHeader, "dh") if err != nil { return nil, fmt.Errorf("failed to parse dh: %v", err) } senderPublicKey, err := curve.NewPublicKey(dhBytes) if err != nil { return nil, fmt.Errorf("failed to parse sender public key: %v", err) } return &CryptoParams{ Salt: salt, SenderPublicKey: senderPublicKey, }, nil } // parseHeaderParam extracts a base64url-encoded parameter value from a header string. func parseHeaderParam(header, paramName string) ([]byte, error) { parts := strings.FieldsFunc(header, func(r rune) bool { return r == ';' || r == ',' }) prefix := paramName + "=" for _, part := range parts { part = strings.TrimSpace(part) if strings.HasPrefix(part, prefix) { value := strings.TrimPrefix(part, prefix) value = strings.Trim(value, "\"") return base64.RawURLEncoding.DecodeString(value) } } return nil, fmt.Errorf("parameter %q not found in header", paramName) } // EncryptAesgcm encrypts a message using the aesgcm encoding scheme. // Returns the ciphertext and crypto parameters needed for HTTP headers. func (c *RFC8291) EncryptAesgcm( plaintext []byte, salt []byte, authSecret []byte, receiverPublicKey *ecdh.PublicKey, senderPrivateKey *ecdh.PrivateKey, ) (*EncryptResult, error) { ciphertext, err := c.encrypt(AesgcmScheme{}, plaintext, salt, authSecret, receiverPublicKey, senderPrivateKey) if err != nil { return nil, err } return &EncryptResult{ Ciphertext: ciphertext, Salt: salt, SenderPublicKey: senderPrivateKey.PublicKey().Bytes(), }, nil } // DecryptAesgcm decrypts a message encrypted with the aesgcm encoding scheme. func (c *RFC8291) DecryptAesgcm( ciphertext []byte, salt []byte, authSecret []byte, receiverPrivateKey *ecdh.PrivateKey, senderPublicKey *ecdh.PublicKey, ) ([]byte, error) { return c.decrypt(AesgcmScheme{}, ciphertext, salt, authSecret, receiverPrivateKey, senderPublicKey) }