225 lines
5.7 KiB
Go
225 lines
5.7 KiB
Go
package rfc8291
|
|
|
|
import (
|
|
"crypto/aes"
|
|
"crypto/cipher"
|
|
"crypto/ecdh"
|
|
"crypto/rand"
|
|
"crypto/sha256"
|
|
"fmt"
|
|
"hash"
|
|
"io"
|
|
"log"
|
|
)
|
|
|
|
const (
|
|
AUTH_SECRET_LEN = 16
|
|
SALT_LEN = 16
|
|
|
|
AES_GCM_OVERHEAD = 16
|
|
|
|
HKDF_IKM_LEN = 32
|
|
HKDF_CEK_LEN = 16
|
|
HKDF_NONCE_LEN = 12
|
|
)
|
|
|
|
// Encoding represents the Content-Encoding type for WebPush messages.
|
|
type Encoding string
|
|
|
|
const (
|
|
EncodingAes128gcm Encoding = "aes128gcm"
|
|
EncodingAesgcm Encoding = "aesgcm"
|
|
)
|
|
|
|
// EncodingScheme defines the encoding-specific operations for WebPush encryption.
|
|
type EncodingScheme interface {
|
|
DeriveIKM(hash func() hash.Hash, authSecret, ecdhSecret []byte, uaKey, asKey *ecdh.PublicKey) ([]byte, error)
|
|
DeriveCEKAndNonce(hash func() hash.Hash, ikm, salt []byte, uaKey, asKey *ecdh.PublicKey) (cek, nonce []byte, err error)
|
|
Pad(plaintext []byte) []byte
|
|
Unpad(data []byte) ([]byte, error)
|
|
}
|
|
|
|
// Scheme returns the EncodingScheme implementation for the given encoding type.
|
|
func Scheme(encoding Encoding) (EncodingScheme, error) {
|
|
switch encoding {
|
|
case EncodingAes128gcm:
|
|
return Aes128gcmScheme{}, nil
|
|
case EncodingAesgcm:
|
|
return AesgcmScheme{}, nil
|
|
default:
|
|
return nil, fmt.Errorf("unsupported encoding: %s", encoding)
|
|
}
|
|
}
|
|
|
|
// RFC8291 implements WebPush message encryption and decryption.
|
|
type RFC8291 struct {
|
|
hash func() hash.Hash
|
|
}
|
|
|
|
// NewRFC8291 creates a new RFC8291 instance. Default hash is SHA256.
|
|
func NewRFC8291(hash func() hash.Hash) *RFC8291 {
|
|
if hash == nil {
|
|
hash = sha256.New
|
|
}
|
|
return &RFC8291{hash: hash}
|
|
}
|
|
|
|
// NewSecrets generates new random auth secret, salt, and ECDH private key.
|
|
func NewSecrets(curve ecdh.Curve) (auth, salt []byte, key *ecdh.PrivateKey) {
|
|
auth = make([]byte, AUTH_SECRET_LEN)
|
|
salt = make([]byte, SALT_LEN)
|
|
for _, b := range [][]byte{auth, salt} {
|
|
_, err := io.ReadFull(rand.Reader, b)
|
|
if err != nil {
|
|
log.Panicln("failed to generate random secret", err)
|
|
}
|
|
}
|
|
|
|
key, err := curve.GenerateKey(rand.Reader)
|
|
if err != nil {
|
|
log.Panicln("failed to generate ecdh key", err)
|
|
}
|
|
|
|
return auth, salt, key
|
|
}
|
|
|
|
// encrypt performs encryption using the specified encoding scheme.
|
|
func (c *RFC8291) encrypt(
|
|
scheme EncodingScheme,
|
|
plaintext []byte,
|
|
salt []byte,
|
|
authSecret []byte,
|
|
receiverPublicKey *ecdh.PublicKey,
|
|
senderPrivateKey *ecdh.PrivateKey,
|
|
) ([]byte, error) {
|
|
if len(authSecret) != AUTH_SECRET_LEN {
|
|
return nil, fmt.Errorf("auth_secret must be %d bytes", AUTH_SECRET_LEN)
|
|
}
|
|
if len(salt) != SALT_LEN {
|
|
return nil, fmt.Errorf("salt must be %d bytes", SALT_LEN)
|
|
}
|
|
|
|
ecdhSecret, err := senderPrivateKey.ECDH(receiverPublicKey)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("calculate ecdh_secret failed: %v", err)
|
|
}
|
|
|
|
ikm, err := scheme.DeriveIKM(c.hash, authSecret, ecdhSecret, receiverPublicKey, senderPrivateKey.PublicKey())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
cek, nonce, err := scheme.DeriveCEKAndNonce(c.hash, ikm, salt, receiverPublicKey, senderPrivateKey.PublicKey())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
gcm, err := c.gcm(cek)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
paddedPlaintext := scheme.Pad(plaintext)
|
|
ciphertext := gcm.Seal(nil, nonce, paddedPlaintext, nil)
|
|
|
|
return ciphertext, nil
|
|
}
|
|
|
|
// decrypt performs decryption using the specified encoding scheme.
|
|
func (c *RFC8291) decrypt(
|
|
scheme EncodingScheme,
|
|
ciphertext []byte,
|
|
salt []byte,
|
|
authSecret []byte,
|
|
receiverPrivateKey *ecdh.PrivateKey,
|
|
senderPublicKey *ecdh.PublicKey,
|
|
) ([]byte, error) {
|
|
if len(authSecret) != AUTH_SECRET_LEN {
|
|
return nil, fmt.Errorf("auth_secret must be %d bytes", AUTH_SECRET_LEN)
|
|
}
|
|
if len(salt) != SALT_LEN {
|
|
return nil, fmt.Errorf("salt must be %d bytes", SALT_LEN)
|
|
}
|
|
|
|
ecdhSecret, err := receiverPrivateKey.ECDH(senderPublicKey)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("calculate ecdh_secret failed: %v", err)
|
|
}
|
|
|
|
ikm, err := scheme.DeriveIKM(c.hash, authSecret, ecdhSecret, receiverPrivateKey.PublicKey(), senderPublicKey)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
cek, nonce, err := scheme.DeriveCEKAndNonce(c.hash, ikm, salt, receiverPrivateKey.PublicKey(), senderPublicKey)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
gcm, err := c.gcm(cek)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
plaintext, err := gcm.Open(nil, nonce, ciphertext, nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return scheme.Unpad(plaintext)
|
|
}
|
|
|
|
func (c *RFC8291) gcm(cek []byte) (cipher.AEAD, error) {
|
|
block, err := aes.NewCipher(cek)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("create cipher block failed: %v", err)
|
|
}
|
|
|
|
gcm, err := cipher.NewGCM(block)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("create GCM failed: %v", err)
|
|
}
|
|
|
|
return gcm, nil
|
|
}
|
|
|
|
// Decrypt decrypts a push notification, automatically selecting the correct
|
|
// encoding scheme based on the encoding parameter.
|
|
//
|
|
// For aes128gcm: crypto params are extracted from the data payload.
|
|
// For aesgcm: crypto params are extracted from the HTTP headers.
|
|
func (c *RFC8291) Decrypt(
|
|
data []byte,
|
|
encoding Encoding,
|
|
encryptionHeader string,
|
|
cryptoKeyHeader string,
|
|
authSecret []byte,
|
|
receiverPrivateKey *ecdh.PrivateKey,
|
|
) ([]byte, error) {
|
|
switch encoding {
|
|
case EncodingAes128gcm:
|
|
payload, err := Unmarshal(data)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("unmarshal aes128gcm payload: %v", err)
|
|
}
|
|
|
|
senderPublicKey, err := receiverPrivateKey.Curve().NewPublicKey(payload.KeyId)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("parse sender public key: %v", err)
|
|
}
|
|
|
|
return c.DecryptAes128gcm(payload.CipherText, payload.Salt, authSecret, receiverPrivateKey, senderPublicKey)
|
|
|
|
case EncodingAesgcm:
|
|
params, err := ParseAesgcmHeaders(encryptionHeader, cryptoKeyHeader, receiverPrivateKey.Curve())
|
|
if err != nil {
|
|
return nil, fmt.Errorf("parse aesgcm headers: %v", err)
|
|
}
|
|
|
|
return c.DecryptAesgcm(data, params.Salt, authSecret, receiverPrivateKey, params.SenderPublicKey)
|
|
|
|
default:
|
|
return nil, fmt.Errorf("unsupported encoding: %s", encoding)
|
|
}
|
|
}
|