package rfc8291 import ( "crypto/aes" "crypto/cipher" "crypto/ecdh" "crypto/rand" "crypto/sha256" "fmt" "hash" "io" "log" ) const ( AUTH_SECRET_LEN = 16 SALT_LEN = 16 AES_GCM_OVERHEAD = 16 HKDF_IKM_LEN = 32 HKDF_CEK_LEN = 16 HKDF_NONCE_LEN = 12 ) // Encoding represents the Content-Encoding type for WebPush messages. type Encoding string const ( EncodingAes128gcm Encoding = "aes128gcm" EncodingAesgcm Encoding = "aesgcm" ) // EncodingScheme defines the encoding-specific operations for WebPush encryption. type EncodingScheme interface { DeriveIKM(hash func() hash.Hash, authSecret, ecdhSecret []byte, uaKey, asKey *ecdh.PublicKey) ([]byte, error) DeriveCEKAndNonce(hash func() hash.Hash, ikm, salt []byte, uaKey, asKey *ecdh.PublicKey) (cek, nonce []byte, err error) Pad(plaintext []byte) []byte Unpad(data []byte) ([]byte, error) } // Scheme returns the EncodingScheme implementation for the given encoding type. func Scheme(encoding Encoding) (EncodingScheme, error) { switch encoding { case EncodingAes128gcm: return Aes128gcmScheme{}, nil case EncodingAesgcm: return AesgcmScheme{}, nil default: return nil, fmt.Errorf("unsupported encoding: %s", encoding) } } // RFC8291 implements WebPush message encryption and decryption. type RFC8291 struct { hash func() hash.Hash } // NewRFC8291 creates a new RFC8291 instance. Default hash is SHA256. func NewRFC8291(hash func() hash.Hash) *RFC8291 { if hash == nil { hash = sha256.New } return &RFC8291{hash: hash} } // NewSecrets generates new random auth secret, salt, and ECDH private key. func NewSecrets(curve ecdh.Curve) (auth, salt []byte, key *ecdh.PrivateKey) { auth = make([]byte, AUTH_SECRET_LEN) salt = make([]byte, SALT_LEN) for _, b := range [][]byte{auth, salt} { _, err := io.ReadFull(rand.Reader, b) if err != nil { log.Panicln("failed to generate random secret", err) } } key, err := curve.GenerateKey(rand.Reader) if err != nil { log.Panicln("failed to generate ecdh key", err) } return auth, salt, key } // encrypt performs encryption using the specified encoding scheme. func (c *RFC8291) encrypt( scheme EncodingScheme, plaintext []byte, salt []byte, authSecret []byte, receiverPublicKey *ecdh.PublicKey, senderPrivateKey *ecdh.PrivateKey, ) ([]byte, error) { if len(authSecret) != AUTH_SECRET_LEN { return nil, fmt.Errorf("auth_secret must be %d bytes", AUTH_SECRET_LEN) } if len(salt) != SALT_LEN { return nil, fmt.Errorf("salt must be %d bytes", SALT_LEN) } ecdhSecret, err := senderPrivateKey.ECDH(receiverPublicKey) if err != nil { return nil, fmt.Errorf("calculate ecdh_secret failed: %v", err) } ikm, err := scheme.DeriveIKM(c.hash, authSecret, ecdhSecret, receiverPublicKey, senderPrivateKey.PublicKey()) if err != nil { return nil, err } cek, nonce, err := scheme.DeriveCEKAndNonce(c.hash, ikm, salt, receiverPublicKey, senderPrivateKey.PublicKey()) if err != nil { return nil, err } gcm, err := c.gcm(cek) if err != nil { return nil, err } paddedPlaintext := scheme.Pad(plaintext) ciphertext := gcm.Seal(nil, nonce, paddedPlaintext, nil) return ciphertext, nil } // decrypt performs decryption using the specified encoding scheme. func (c *RFC8291) decrypt( scheme EncodingScheme, ciphertext []byte, salt []byte, authSecret []byte, receiverPrivateKey *ecdh.PrivateKey, senderPublicKey *ecdh.PublicKey, ) ([]byte, error) { if len(authSecret) != AUTH_SECRET_LEN { return nil, fmt.Errorf("auth_secret must be %d bytes", AUTH_SECRET_LEN) } if len(salt) != SALT_LEN { return nil, fmt.Errorf("salt must be %d bytes", SALT_LEN) } ecdhSecret, err := receiverPrivateKey.ECDH(senderPublicKey) if err != nil { return nil, fmt.Errorf("calculate ecdh_secret failed: %v", err) } ikm, err := scheme.DeriveIKM(c.hash, authSecret, ecdhSecret, receiverPrivateKey.PublicKey(), senderPublicKey) if err != nil { return nil, err } cek, nonce, err := scheme.DeriveCEKAndNonce(c.hash, ikm, salt, receiverPrivateKey.PublicKey(), senderPublicKey) if err != nil { return nil, err } gcm, err := c.gcm(cek) if err != nil { return nil, err } plaintext, err := gcm.Open(nil, nonce, ciphertext, nil) if err != nil { return nil, err } return scheme.Unpad(plaintext) } func (c *RFC8291) gcm(cek []byte) (cipher.AEAD, error) { block, err := aes.NewCipher(cek) if err != nil { return nil, fmt.Errorf("create cipher block failed: %v", err) } gcm, err := cipher.NewGCM(block) if err != nil { return nil, fmt.Errorf("create GCM failed: %v", err) } return gcm, nil } // Decrypt decrypts a push notification, automatically selecting the correct // encoding scheme based on the encoding parameter. // // For aes128gcm: crypto params are extracted from the data payload. // For aesgcm: crypto params are extracted from the HTTP headers. func (c *RFC8291) Decrypt( data []byte, encoding Encoding, encryptionHeader string, cryptoKeyHeader string, authSecret []byte, receiverPrivateKey *ecdh.PrivateKey, ) ([]byte, error) { switch encoding { case EncodingAes128gcm: payload, err := Unmarshal(data) if err != nil { return nil, fmt.Errorf("unmarshal aes128gcm payload: %v", err) } senderPublicKey, err := receiverPrivateKey.Curve().NewPublicKey(payload.KeyId) if err != nil { return nil, fmt.Errorf("parse sender public key: %v", err) } return c.DecryptAes128gcm(payload.CipherText, payload.Salt, authSecret, receiverPrivateKey, senderPublicKey) case EncodingAesgcm: params, err := ParseAesgcmHeaders(encryptionHeader, cryptoKeyHeader, receiverPrivateKey.Curve()) if err != nil { return nil, fmt.Errorf("parse aesgcm headers: %v", err) } return c.DecryptAesgcm(data, params.Salt, authSecret, receiverPrivateKey, params.SenderPublicKey) default: return nil, fmt.Errorf("unsupported encoding: %s", encoding) } }