package rfc8291 import ( "bytes" "crypto/ecdh" "encoding/binary" "errors" "fmt" "hash" "io" "golang.org/x/crypto/hkdf" ) // Aes128gcmScheme implements EncodingScheme for the aes128gcm encoding. type Aes128gcmScheme struct{} func (s Aes128gcmScheme) DeriveIKM(hash func() hash.Hash, authSecret, ecdhSecret []byte, uaKey, asKey *ecdh.PublicKey) ([]byte, error) { prkKey := hkdf.Extract(hash, ecdhSecret, authSecret) // aes128gcm: "WebPush: info\0" + receiver key + sender key keyInfo := bytes.Join([][]byte{ []byte("WebPush: info\000"), uaKey.Bytes(), asKey.Bytes(), }, nil) ikm := make([]byte, HKDF_IKM_LEN) if _, err := io.ReadFull(hkdf.Expand(hash, prkKey, keyInfo), ikm); err != nil { return nil, fmt.Errorf("derive IKM failed: %v", err) } return ikm, nil } func (s Aes128gcmScheme) DeriveCEKAndNonce(hash func() hash.Hash, ikm, salt []byte, uaKey, asKey *ecdh.PublicKey) (cek, nonce []byte, err error) { prk := hkdf.Extract(hash, ikm, salt) // aes128gcm: simple info strings without keys cekInfo := []byte("Content-Encoding: aes128gcm\000") nonceInfo := []byte("Content-Encoding: nonce\000") cek = make([]byte, HKDF_CEK_LEN) if _, err := io.ReadFull(hkdf.Expand(hash, prk, cekInfo), cek); err != nil { return nil, nil, fmt.Errorf("derive CEK failed: %v", err) } nonce = make([]byte, HKDF_NONCE_LEN) if _, err := io.ReadFull(hkdf.Expand(hash, prk, nonceInfo), nonce); err != nil { return nil, nil, fmt.Errorf("derive nonce failed: %v", err) } return cek, nonce, nil } func (s Aes128gcmScheme) Pad(plaintext []byte) []byte { // aes128gcm: append 0x02 delimiter for final record return append(plaintext, 0x02) } func (s Aes128gcmScheme) Unpad(data []byte) ([]byte, error) { if len(data) == 0 { return nil, fmt.Errorf("data is empty") } // aes128gcm: remove trailing 0x01 or 0x02 delimiter last := data[len(data)-1] if last == 0x01 || last == 0x02 { return data[:len(data)-1], nil } return data, nil } // Payload represents the aes128gcm message format with embedded crypto parameters. type Payload struct { RS uint32 Salt []byte KeyId []byte CipherText []byte } const ( baseHeaderLen = 21 ) // Marshal serializes a Payload into the aes128gcm binary format. func Marshal(p Payload) (data []byte) { rs := make([]byte, 4) binary.BigEndian.PutUint32(rs, p.RS) return bytes.Join([][]byte{ p.Salt, rs, {uint8(len(p.KeyId))}, p.KeyId, p.CipherText, }, nil) } // Unmarshal parses the aes128gcm binary format into a Payload. func Unmarshal(data []byte) (p Payload, err error) { if len(data) < baseHeaderLen { return p, errors.New("data is too short") } p.Salt = data[:16] p.RS = binary.BigEndian.Uint32(data[16:20]) idlen := int(data[20]) if len(data) < baseHeaderLen+idlen { return p, errors.New("data is too short") } if idlen > 0 { p.KeyId = data[baseHeaderLen : baseHeaderLen+idlen] } p.CipherText = data[baseHeaderLen+idlen:] return p, nil } // EncryptAes128gcm encrypts a message using the aes128gcm encoding scheme. // Returns the complete payload with embedded crypto parameters. func (c *RFC8291) EncryptAes128gcm( plaintext []byte, salt []byte, authSecret []byte, receiverPublicKey *ecdh.PublicKey, senderPrivateKey *ecdh.PrivateKey, ) ([]byte, error) { ciphertext, err := c.encrypt(Aes128gcmScheme{}, plaintext, salt, authSecret, receiverPublicKey, senderPrivateKey) if err != nil { return nil, err } rs := uint32(len(plaintext) + 1 + AES_GCM_OVERHEAD) return Marshal(Payload{ RS: rs, Salt: salt, KeyId: senderPrivateKey.PublicKey().Bytes(), CipherText: ciphertext, }), nil } // DecryptAes128gcm decrypts a message encrypted with the aes128gcm encoding scheme. func (c *RFC8291) DecryptAes128gcm( ciphertext []byte, salt []byte, authSecret []byte, receiverPrivateKey *ecdh.PrivateKey, senderPublicKey *ecdh.PublicKey, ) ([]byte, error) { return c.decrypt(Aes128gcmScheme{}, ciphertext, salt, authSecret, receiverPrivateKey, senderPublicKey) }