diff --git a/internal/publickey/ecdsa.go b/internal/publickey/ecdsa.go index 595a1e0..29e10a2 100644 --- a/internal/publickey/ecdsa.go +++ b/internal/publickey/ecdsa.go @@ -1,6 +1,7 @@ package publickey import ( + "bytes" "crypto/ecdsa" "crypto/elliptic" "encoding/base64" @@ -22,10 +23,13 @@ var ( ) // Rejects on invalid curve points with a branch -func ParseECDSAPublicKey(data []byte) (*ecdsa.PublicKey, error) { +func ParseECDSAPublicKey(data json.RawMessage) (*ecdsa.PublicKey, error) { var header ECDSAPublicKeyHeader - err := json.Unmarshal(data, &header) + r := bytes.NewReader(data) + dec := json.NewDecoder(r) + + err := dec.Decode(&header) if err != nil { return nil, err } @@ -57,6 +61,7 @@ func ParseECDSAPublicKey(data []byte) (*ecdsa.PublicKey, error) { } // Invalid curve attacks don't exactly apply here? + // We are only verifying signatures but shut up and check points. if !curve.IsOnCurve(x, y) { return nil, ErrInvalidCurvePoint } diff --git a/internal/publickey/eddsa.go b/internal/publickey/eddsa.go index d64c679..87c3530 100644 --- a/internal/publickey/eddsa.go +++ b/internal/publickey/eddsa.go @@ -1,6 +1,7 @@ package publickey import ( + "bytes" "encoding/base64" "encoding/json" "fmt" @@ -12,10 +13,13 @@ type EdDSAPublicKeyHeader struct { X string `json:"x"` } -func ParseEdDSAPublicKey(data []byte) (*ed25519.PublicKey, error) { +func ParseEdDSAPublicKey(data json.RawMessage) (*ed25519.PublicKey, error) { var header EdDSAPublicKeyHeader - err := json.Unmarshal(data, &header) + r := bytes.NewReader(data) + dec := json.NewDecoder(r) + + err := dec.Decode(&header) if err != nil { return nil, err } diff --git a/internal/publickey/rsa.go b/internal/publickey/rsa.go index 748e1d9..b85620a 100644 --- a/internal/publickey/rsa.go +++ b/internal/publickey/rsa.go @@ -1,6 +1,7 @@ package publickey import ( + "bytes" "crypto/rsa" "encoding/base64" "encoding/json" @@ -18,11 +19,13 @@ var ( ErrUnsupportedPublicExponent = errors.New("Public exponent is not 65537") ) -func ParseRSAPublicKey(data []byte) (*rsa.PublicKey, error) { - +func ParseRSAPublicKey(data json.RawMessage) (*rsa.PublicKey, error) { var header RSAPublicKeyHeader - err := json.Unmarshal(data, &header) + r := bytes.NewReader(data) + dec := json.NewDecoder(r) + + err := dec.Decode(&header) if err != nil { return nil, err } diff --git a/jwk/jwk.go b/jwk/jwk.go index ee32f55..41fca6c 100644 --- a/jwk/jwk.go +++ b/jwk/jwk.go @@ -4,6 +4,7 @@ import ( "crypto" "encoding/json" "fmt" + "io" "mercan.dev/dumb-jose/internal/publickey" ) @@ -21,10 +22,12 @@ type jwkHeader struct { Keys []json.RawMessage `json:"keys"` } -func ParseJWKKeysFromSet(data []byte) ([]JWK, error) { +func ParseKeysFromSet(r io.Reader) ([]JWK, error) { + dec := json.NewDecoder(r) + keys := jwkHeader{} - err := json.Unmarshal(data, &keys) + err := dec.Decode(&keys) if err != nil { return nil, err } diff --git a/jwk/jwk_test.go b/jwk/jwk_test.go index 36745c9..847f0ef 100644 --- a/jwk/jwk_test.go +++ b/jwk/jwk_test.go @@ -1,6 +1,7 @@ package jwk_test import ( + "bytes" "mercan.dev/dumb-jose/jwk" "testing" ) @@ -33,7 +34,8 @@ const ( ) func TestCorrectJwkKeySet(t *testing.T) { - set, err := jwk.ParseJWKKeysFromSet([]byte(GoogleJwkKeySet)) + r := bytes.NewReader([]byte(GoogleJwkKeySet)) + set, err := jwk.ParseKeysFromSet(r) if err != nil { t.Errorf("Error while parsing JWK Keyset: %v", err) } @@ -45,7 +47,8 @@ func TestCorrectJwkKeySet(t *testing.T) { } } - set, err = jwk.ParseJWKKeysFromSet([]byte(ValidEd25519KeySet)) + r = bytes.NewReader([]byte(ValidEd25519KeySet)) + set, err = jwk.ParseKeysFromSet(r) if err != nil { t.Errorf("%v", err) } @@ -53,7 +56,8 @@ func TestCorrectJwkKeySet(t *testing.T) { } func TestInvalidRSAExponent(t *testing.T) { - _, err := jwk.ParseJWKKeysFromSet([]byte(InvalidExponent)) + r := bytes.NewReader([]byte(InvalidExponent)) + _, err := jwk.ParseKeysFromSet(r) if err == nil { t.Errorf("Expected error not returned for unsupported public exponent, found \"%v\"", err) }