jwk, internal/publickey: unmarshal -> decode

Currently encoding/json accepts duplicate fields in the json.
Since https://github.com/golang/go/issues/48298 got accepted,
we should use the decoder interface with Decoder.DisallowDuplicateFields
turned on when available. Its exact behavior will determine whether
json.RawMessage's will be re-unmarshaled or will follow the byte reader
path.
main
Aydin Mercan 2021-12-04 10:55:00 +03:00
parent 8bc5579fad
commit 67699f4a7c
No known key found for this signature in database
5 changed files with 31 additions and 12 deletions

View File

@ -1,6 +1,7 @@
package publickey package publickey
import ( import (
"bytes"
"crypto/ecdsa" "crypto/ecdsa"
"crypto/elliptic" "crypto/elliptic"
"encoding/base64" "encoding/base64"
@ -22,10 +23,13 @@ var (
) )
// Rejects on invalid curve points with a branch // Rejects on invalid curve points with a branch
func ParseECDSAPublicKey(data []byte) (*ecdsa.PublicKey, error) { func ParseECDSAPublicKey(data json.RawMessage) (*ecdsa.PublicKey, error) {
var header ECDSAPublicKeyHeader var header ECDSAPublicKeyHeader
err := json.Unmarshal(data, &header) r := bytes.NewReader(data)
dec := json.NewDecoder(r)
err := dec.Decode(&header)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -57,6 +61,7 @@ func ParseECDSAPublicKey(data []byte) (*ecdsa.PublicKey, error) {
} }
// Invalid curve attacks don't exactly apply here? // Invalid curve attacks don't exactly apply here?
// We are only verifying signatures but shut up and check points.
if !curve.IsOnCurve(x, y) { if !curve.IsOnCurve(x, y) {
return nil, ErrInvalidCurvePoint return nil, ErrInvalidCurvePoint
} }

View File

@ -1,6 +1,7 @@
package publickey package publickey
import ( import (
"bytes"
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
"fmt" "fmt"
@ -12,10 +13,13 @@ type EdDSAPublicKeyHeader struct {
X string `json:"x"` X string `json:"x"`
} }
func ParseEdDSAPublicKey(data []byte) (*ed25519.PublicKey, error) { func ParseEdDSAPublicKey(data json.RawMessage) (*ed25519.PublicKey, error) {
var header EdDSAPublicKeyHeader var header EdDSAPublicKeyHeader
err := json.Unmarshal(data, &header) r := bytes.NewReader(data)
dec := json.NewDecoder(r)
err := dec.Decode(&header)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -1,6 +1,7 @@
package publickey package publickey
import ( import (
"bytes"
"crypto/rsa" "crypto/rsa"
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
@ -18,11 +19,13 @@ var (
ErrUnsupportedPublicExponent = errors.New("Public exponent is not 65537") ErrUnsupportedPublicExponent = errors.New("Public exponent is not 65537")
) )
func ParseRSAPublicKey(data []byte) (*rsa.PublicKey, error) { func ParseRSAPublicKey(data json.RawMessage) (*rsa.PublicKey, error) {
var header RSAPublicKeyHeader var header RSAPublicKeyHeader
err := json.Unmarshal(data, &header) r := bytes.NewReader(data)
dec := json.NewDecoder(r)
err := dec.Decode(&header)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -4,6 +4,7 @@ import (
"crypto" "crypto"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io"
"mercan.dev/dumb-jose/internal/publickey" "mercan.dev/dumb-jose/internal/publickey"
) )
@ -21,10 +22,12 @@ type jwkHeader struct {
Keys []json.RawMessage `json:"keys"` Keys []json.RawMessage `json:"keys"`
} }
func ParseJWKKeysFromSet(data []byte) ([]JWK, error) { func ParseKeysFromSet(r io.Reader) ([]JWK, error) {
dec := json.NewDecoder(r)
keys := jwkHeader{} keys := jwkHeader{}
err := json.Unmarshal(data, &keys) err := dec.Decode(&keys)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -1,6 +1,7 @@
package jwk_test package jwk_test
import ( import (
"bytes"
"mercan.dev/dumb-jose/jwk" "mercan.dev/dumb-jose/jwk"
"testing" "testing"
) )
@ -33,7 +34,8 @@ const (
) )
func TestCorrectJwkKeySet(t *testing.T) { func TestCorrectJwkKeySet(t *testing.T) {
set, err := jwk.ParseJWKKeysFromSet([]byte(GoogleJwkKeySet)) r := bytes.NewReader([]byte(GoogleJwkKeySet))
set, err := jwk.ParseKeysFromSet(r)
if err != nil { if err != nil {
t.Errorf("Error while parsing JWK Keyset: %v", err) t.Errorf("Error while parsing JWK Keyset: %v", err)
} }
@ -45,7 +47,8 @@ func TestCorrectJwkKeySet(t *testing.T) {
} }
} }
set, err = jwk.ParseJWKKeysFromSet([]byte(ValidEd25519KeySet)) r = bytes.NewReader([]byte(ValidEd25519KeySet))
set, err = jwk.ParseKeysFromSet(r)
if err != nil { if err != nil {
t.Errorf("%v", err) t.Errorf("%v", err)
} }
@ -53,7 +56,8 @@ func TestCorrectJwkKeySet(t *testing.T) {
} }
func TestInvalidRSAExponent(t *testing.T) { func TestInvalidRSAExponent(t *testing.T) {
_, err := jwk.ParseJWKKeysFromSet([]byte(InvalidExponent)) r := bytes.NewReader([]byte(InvalidExponent))
_, err := jwk.ParseKeysFromSet(r)
if err == nil { if err == nil {
t.Errorf("Expected error not returned for unsupported public exponent, found \"%v\"", err) t.Errorf("Expected error not returned for unsupported public exponent, found \"%v\"", err)
} }