webpush-client-go/rfc8291/aes128gcm.go

155 lines
3.9 KiB
Go

package rfc8291
import (
"bytes"
"crypto/ecdh"
"encoding/binary"
"errors"
"fmt"
"hash"
"io"
"golang.org/x/crypto/hkdf"
)
// Aes128gcmScheme implements EncodingScheme for the aes128gcm encoding.
type Aes128gcmScheme struct{}
func (s Aes128gcmScheme) DeriveIKM(hash func() hash.Hash, authSecret, ecdhSecret []byte, uaKey, asKey *ecdh.PublicKey) ([]byte, error) {
prkKey := hkdf.Extract(hash, ecdhSecret, authSecret)
// aes128gcm: "WebPush: info\0" + receiver key + sender key
keyInfo := bytes.Join([][]byte{
[]byte("WebPush: info\000"),
uaKey.Bytes(),
asKey.Bytes(),
}, nil)
ikm := make([]byte, HKDF_IKM_LEN)
if _, err := io.ReadFull(hkdf.Expand(hash, prkKey, keyInfo), ikm); err != nil {
return nil, fmt.Errorf("derive IKM failed: %v", err)
}
return ikm, nil
}
func (s Aes128gcmScheme) DeriveCEKAndNonce(hash func() hash.Hash, ikm, salt []byte, uaKey, asKey *ecdh.PublicKey) (cek, nonce []byte, err error) {
prk := hkdf.Extract(hash, ikm, salt)
// aes128gcm: simple info strings without keys
cekInfo := []byte("Content-Encoding: aes128gcm\000")
nonceInfo := []byte("Content-Encoding: nonce\000")
cek = make([]byte, HKDF_CEK_LEN)
if _, err := io.ReadFull(hkdf.Expand(hash, prk, cekInfo), cek); err != nil {
return nil, nil, fmt.Errorf("derive CEK failed: %v", err)
}
nonce = make([]byte, HKDF_NONCE_LEN)
if _, err := io.ReadFull(hkdf.Expand(hash, prk, nonceInfo), nonce); err != nil {
return nil, nil, fmt.Errorf("derive nonce failed: %v", err)
}
return cek, nonce, nil
}
func (s Aes128gcmScheme) Pad(plaintext []byte) []byte {
// aes128gcm: append 0x02 delimiter for final record
return append(plaintext, 0x02)
}
func (s Aes128gcmScheme) Unpad(data []byte) ([]byte, error) {
if len(data) == 0 {
return nil, fmt.Errorf("data is empty")
}
// aes128gcm: remove trailing 0x01 or 0x02 delimiter
last := data[len(data)-1]
if last == 0x01 || last == 0x02 {
return data[:len(data)-1], nil
}
return data, nil
}
// Payload represents the aes128gcm message format with embedded crypto parameters.
type Payload struct {
RS uint32
Salt []byte
KeyId []byte
CipherText []byte
}
const (
baseHeaderLen = 21
)
// Marshal serializes a Payload into the aes128gcm binary format.
func Marshal(p Payload) (data []byte) {
rs := make([]byte, 4)
binary.BigEndian.PutUint32(rs, p.RS)
return bytes.Join([][]byte{
p.Salt,
rs,
{uint8(len(p.KeyId))},
p.KeyId,
p.CipherText,
}, nil)
}
// Unmarshal parses the aes128gcm binary format into a Payload.
func Unmarshal(data []byte) (p Payload, err error) {
if len(data) < baseHeaderLen {
return p, errors.New("data is too short")
}
p.Salt = data[:16]
p.RS = binary.BigEndian.Uint32(data[16:20])
idlen := int(data[20])
if len(data) < baseHeaderLen+idlen {
return p, errors.New("data is too short")
}
if idlen > 0 {
p.KeyId = data[baseHeaderLen : baseHeaderLen+idlen]
}
p.CipherText = data[baseHeaderLen+idlen:]
return p, nil
}
// EncryptAes128gcm encrypts a message using the aes128gcm encoding scheme.
// Returns the complete payload with embedded crypto parameters.
func (c *RFC8291) EncryptAes128gcm(
plaintext []byte,
salt []byte,
authSecret []byte,
receiverPublicKey *ecdh.PublicKey,
senderPrivateKey *ecdh.PrivateKey,
) ([]byte, error) {
ciphertext, err := c.encrypt(Aes128gcmScheme{}, plaintext, salt, authSecret, receiverPublicKey, senderPrivateKey)
if err != nil {
return nil, err
}
rs := uint32(len(plaintext) + 1 + AES_GCM_OVERHEAD)
return Marshal(Payload{
RS: rs,
Salt: salt,
KeyId: senderPrivateKey.PublicKey().Bytes(),
CipherText: ciphertext,
}), nil
}
// DecryptAes128gcm decrypts a message encrypted with the aes128gcm encoding scheme.
func (c *RFC8291) DecryptAes128gcm(
ciphertext []byte,
salt []byte,
authSecret []byte,
receiverPrivateKey *ecdh.PrivateKey,
senderPublicKey *ecdh.PublicKey,
) ([]byte, error) {
return c.decrypt(Aes128gcmScheme{}, ciphertext, salt, authSecret, receiverPrivateKey, senderPublicKey)
}