webpush-client-go/rfc8291/rfc8291.go

225 lines
5.7 KiB
Go

package rfc8291
import (
"crypto/aes"
"crypto/cipher"
"crypto/ecdh"
"crypto/rand"
"crypto/sha256"
"fmt"
"hash"
"io"
"log"
)
const (
AUTH_SECRET_LEN = 16
SALT_LEN = 16
AES_GCM_OVERHEAD = 16
HKDF_IKM_LEN = 32
HKDF_CEK_LEN = 16
HKDF_NONCE_LEN = 12
)
// Encoding represents the Content-Encoding type for WebPush messages.
type Encoding string
const (
EncodingAes128gcm Encoding = "aes128gcm"
EncodingAesgcm Encoding = "aesgcm"
)
// EncodingScheme defines the encoding-specific operations for WebPush encryption.
type EncodingScheme interface {
DeriveIKM(hash func() hash.Hash, authSecret, ecdhSecret []byte, uaKey, asKey *ecdh.PublicKey) ([]byte, error)
DeriveCEKAndNonce(hash func() hash.Hash, ikm, salt []byte, uaKey, asKey *ecdh.PublicKey) (cek, nonce []byte, err error)
Pad(plaintext []byte) []byte
Unpad(data []byte) ([]byte, error)
}
// Scheme returns the EncodingScheme implementation for the given encoding type.
func Scheme(encoding Encoding) (EncodingScheme, error) {
switch encoding {
case EncodingAes128gcm:
return Aes128gcmScheme{}, nil
case EncodingAesgcm:
return AesgcmScheme{}, nil
default:
return nil, fmt.Errorf("unsupported encoding: %s", encoding)
}
}
// RFC8291 implements WebPush message encryption and decryption.
type RFC8291 struct {
hash func() hash.Hash
}
// NewRFC8291 creates a new RFC8291 instance. Default hash is SHA256.
func NewRFC8291(hash func() hash.Hash) *RFC8291 {
if hash == nil {
hash = sha256.New
}
return &RFC8291{hash: hash}
}
// NewSecrets generates new random auth secret, salt, and ECDH private key.
func NewSecrets(curve ecdh.Curve) (auth, salt []byte, key *ecdh.PrivateKey) {
auth = make([]byte, AUTH_SECRET_LEN)
salt = make([]byte, SALT_LEN)
for _, b := range [][]byte{auth, salt} {
_, err := io.ReadFull(rand.Reader, b)
if err != nil {
log.Panicln("failed to generate random secret", err)
}
}
key, err := curve.GenerateKey(rand.Reader)
if err != nil {
log.Panicln("failed to generate ecdh key", err)
}
return auth, salt, key
}
// encrypt performs encryption using the specified encoding scheme.
func (c *RFC8291) encrypt(
scheme EncodingScheme,
plaintext []byte,
salt []byte,
authSecret []byte,
receiverPublicKey *ecdh.PublicKey,
senderPrivateKey *ecdh.PrivateKey,
) ([]byte, error) {
if len(authSecret) != AUTH_SECRET_LEN {
return nil, fmt.Errorf("auth_secret must be %d bytes", AUTH_SECRET_LEN)
}
if len(salt) != SALT_LEN {
return nil, fmt.Errorf("salt must be %d bytes", SALT_LEN)
}
ecdhSecret, err := senderPrivateKey.ECDH(receiverPublicKey)
if err != nil {
return nil, fmt.Errorf("calculate ecdh_secret failed: %v", err)
}
ikm, err := scheme.DeriveIKM(c.hash, authSecret, ecdhSecret, receiverPublicKey, senderPrivateKey.PublicKey())
if err != nil {
return nil, err
}
cek, nonce, err := scheme.DeriveCEKAndNonce(c.hash, ikm, salt, receiverPublicKey, senderPrivateKey.PublicKey())
if err != nil {
return nil, err
}
gcm, err := c.gcm(cek)
if err != nil {
return nil, err
}
paddedPlaintext := scheme.Pad(plaintext)
ciphertext := gcm.Seal(nil, nonce, paddedPlaintext, nil)
return ciphertext, nil
}
// decrypt performs decryption using the specified encoding scheme.
func (c *RFC8291) decrypt(
scheme EncodingScheme,
ciphertext []byte,
salt []byte,
authSecret []byte,
receiverPrivateKey *ecdh.PrivateKey,
senderPublicKey *ecdh.PublicKey,
) ([]byte, error) {
if len(authSecret) != AUTH_SECRET_LEN {
return nil, fmt.Errorf("auth_secret must be %d bytes", AUTH_SECRET_LEN)
}
if len(salt) != SALT_LEN {
return nil, fmt.Errorf("salt must be %d bytes", SALT_LEN)
}
ecdhSecret, err := receiverPrivateKey.ECDH(senderPublicKey)
if err != nil {
return nil, fmt.Errorf("calculate ecdh_secret failed: %v", err)
}
ikm, err := scheme.DeriveIKM(c.hash, authSecret, ecdhSecret, receiverPrivateKey.PublicKey(), senderPublicKey)
if err != nil {
return nil, err
}
cek, nonce, err := scheme.DeriveCEKAndNonce(c.hash, ikm, salt, receiverPrivateKey.PublicKey(), senderPublicKey)
if err != nil {
return nil, err
}
gcm, err := c.gcm(cek)
if err != nil {
return nil, err
}
plaintext, err := gcm.Open(nil, nonce, ciphertext, nil)
if err != nil {
return nil, err
}
return scheme.Unpad(plaintext)
}
func (c *RFC8291) gcm(cek []byte) (cipher.AEAD, error) {
block, err := aes.NewCipher(cek)
if err != nil {
return nil, fmt.Errorf("create cipher block failed: %v", err)
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, fmt.Errorf("create GCM failed: %v", err)
}
return gcm, nil
}
// Decrypt decrypts a push notification, automatically selecting the correct
// encoding scheme based on the encoding parameter.
//
// For aes128gcm: crypto params are extracted from the data payload.
// For aesgcm: crypto params are extracted from the HTTP headers.
func (c *RFC8291) Decrypt(
data []byte,
encoding Encoding,
encryptionHeader string,
cryptoKeyHeader string,
authSecret []byte,
receiverPrivateKey *ecdh.PrivateKey,
) ([]byte, error) {
switch encoding {
case EncodingAes128gcm:
payload, err := Unmarshal(data)
if err != nil {
return nil, fmt.Errorf("unmarshal aes128gcm payload: %v", err)
}
senderPublicKey, err := receiverPrivateKey.Curve().NewPublicKey(payload.KeyId)
if err != nil {
return nil, fmt.Errorf("parse sender public key: %v", err)
}
return c.DecryptAes128gcm(payload.CipherText, payload.Salt, authSecret, receiverPrivateKey, senderPublicKey)
case EncodingAesgcm:
params, err := ParseAesgcmHeaders(encryptionHeader, cryptoKeyHeader, receiverPrivateKey.Curve())
if err != nil {
return nil, fmt.Errorf("parse aesgcm headers: %v", err)
}
return c.DecryptAesgcm(data, params.Salt, authSecret, receiverPrivateKey, params.SenderPublicKey)
default:
return nil, fmt.Errorf("unsupported encoding: %s", encoding)
}
}