webpush-client-go/rfc8291/rfc8291.go

210 lines
4.9 KiB
Go

package rfc8291
import (
"bytes"
"crypto/aes"
"crypto/cipher"
"crypto/ecdh"
"crypto/rand"
"crypto/sha256"
"fmt"
"hash"
"io"
"log"
"golang.org/x/crypto/hkdf"
)
const (
AUTH_SECRET_LEN = 16
SALT_LEN = 16
AES_GCM_OVERHEAD = 16
HKDF_IKM_LEN = 32
HKDF_CEK_LEN = 16
HKDF_NONCE_LEN = 12
)
type RFC8291 struct {
hash func() hash.Hash
}
// Default Hash is SHA256
func NewRFC8291(hash func() hash.Hash) *RFC8291 {
if hash == nil {
hash = sha256.New
}
return &RFC8291{hash: hash}
}
func NewSecrets(curve ecdh.Curve) (auth, salt []byte, key *ecdh.PrivateKey) {
auth = make([]byte, AUTH_SECRET_LEN)
salt = make([]byte, SALT_LEN)
for _, b := range [][]byte{auth, salt} {
_, err := io.ReadFull(rand.Reader, b)
if err != nil {
log.Panicln("failed to generate random secret", err)
}
}
key, err := curve.GenerateKey(rand.Reader)
if err != nil {
log.Panicln("failed to generate ecdh key", err)
}
return auth, salt, key
}
func (c *RFC8291) Encrypt(
plaintext []byte,
salt []byte,
authSecret []byte,
useragentPublicKey *ecdh.PublicKey,
appserverPrivateKey *ecdh.PrivateKey,
) ([]byte, error) {
if len(authSecret) != AUTH_SECRET_LEN {
return nil, fmt.Errorf("auth_secret must be %d bytes", AUTH_SECRET_LEN)
}
if len(salt) != SALT_LEN {
return nil, fmt.Errorf("salt must be %d bytes", SALT_LEN)
}
ecdhSecret, err := appserverPrivateKey.ECDH(useragentPublicKey)
if err != nil {
return nil, fmt.Errorf("calculate ecdh_secret failed: %v", err)
}
ikm, err := c.ikm(authSecret, ecdhSecret, useragentPublicKey, appserverPrivateKey.PublicKey())
if err != nil {
return nil, err
}
cek, nonce, err := c.cekAndNonce(ikm, salt)
if err != nil {
return nil, err
}
gcm, err := c.gcm(cek)
if err != nil {
return nil, err
}
ciphertext := gcm.Seal(nil, nonce, plaintext, nil)
rs := uint32(len(plaintext) + 1 + AES_GCM_OVERHEAD)
// RFC8188: 0x01 or 0x02 in tail of plaintext data
// RFC8291: The push message plaintext has the padding delimiter octet (0x02) appended to produce
// ciphertext = bytes.Join([][]byte{ciphertext, {0x02}}, []byte{})
return Marshal(Payload{
RS: rs,
Salt: salt,
KeyId: appserverPrivateKey.PublicKey().Bytes(),
CipherText: ciphertext,
}), nil
}
func (c *RFC8291) Decrypt(
ciphertext []byte,
salt []byte,
authSecret []byte,
useragentPrivateKey *ecdh.PrivateKey,
appserverPublicKey *ecdh.PublicKey,
) ([]byte, error) {
if len(authSecret) != AUTH_SECRET_LEN {
return nil, fmt.Errorf("auth_secret must be %d bytes", AUTH_SECRET_LEN)
}
if len(salt) != SALT_LEN {
return nil, fmt.Errorf("salt must be %d bytes", SALT_LEN)
}
ecdhSecret, err := useragentPrivateKey.ECDH(appserverPublicKey)
if err != nil {
return nil, fmt.Errorf("calculate ecdh_secret failed: %v", err)
}
ikm, err := c.ikm(authSecret, ecdhSecret, useragentPrivateKey.PublicKey(), appserverPublicKey)
if err != nil {
return nil, err
}
cek, nonce, err := c.cekAndNonce(ikm, salt)
if err != nil {
return nil, err
}
gcm, err := c.gcm(cek)
if err != nil {
return nil, err
}
plaintext, err := gcm.Open(nil, nonce, ciphertext, nil)
if err != nil {
return nil, err
}
// RFC8188: 0x01 or 0x02 in tail of plaintext data
// RFC8291: The push message plaintext has the padding delimiter octet (0x02) appended to produce
if plaintext[len(plaintext)-1] == 0x01 || plaintext[len(plaintext)-1] == 0x02 {
plaintext = plaintext[:len(plaintext)-1]
}
return plaintext, err
}
func (c *RFC8291) gcm(cek []byte) (cipher.AEAD, error) {
block, err := aes.NewCipher(cek)
if err != nil {
return nil, fmt.Errorf("create cipher block failed: %v", err)
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, fmt.Errorf("create GCM failed: %v", err)
}
return gcm, nil
}
func (c *RFC8291) ikm(
authSecret []byte,
ecdhSecret []byte,
useragentPublicKey *ecdh.PublicKey,
appserverPublicKey *ecdh.PublicKey,
) (ikm []byte, err error) {
prkKey := hkdf.Extract(c.hash, ecdhSecret, authSecret)
keyInfo := bytes.Join([][]byte{
[]byte("WebPush: info\000"),
useragentPublicKey.Bytes(),
appserverPublicKey.Bytes(),
}, []byte{})
ikm = make([]byte, HKDF_IKM_LEN)
if _, err := io.ReadFull(hkdf.Expand(c.hash, prkKey, keyInfo), ikm); err != nil {
return nil, fmt.Errorf("read IKM failed: %v", err)
}
return ikm, nil
}
func (c *RFC8291) cekAndNonce(ikm []byte, salt []byte) (cek, nonce []byte, err error) {
prk := hkdf.Extract(c.hash, ikm, salt)
cekInfo := []byte("Content-Encoding: aes128gcm\000")
nonceInfo := []byte("Content-Encoding: nonce\000")
cek = make([]byte, HKDF_CEK_LEN)
if _, err := io.ReadFull(hkdf.Expand(c.hash, prk, cekInfo), cek); err != nil {
return nil, nil, fmt.Errorf("read CEK failed: %v", err)
}
nonce = make([]byte, HKDF_NONCE_LEN)
if _, err := io.ReadFull(hkdf.Expand(c.hash, prk, nonceInfo), nonce); err != nil {
return nil, nil, fmt.Errorf("read Nonce failed: %v", err)
}
return cek, nonce, nil
}