Skip to content

Commit

Permalink
Merge pull request #674 from tranchitella/men-6775
Browse files Browse the repository at this point in the history
feat: support for Ed25519 server keys for signing the JWT tokens
  • Loading branch information
tranchitella authored Oct 25, 2023
2 parents 592f6b5 + 91cc9a1 commit 45655b9
Show file tree
Hide file tree
Showing 17 changed files with 1,010 additions and 799 deletions.
2 changes: 1 addition & 1 deletion config/config.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2022 Northern.tech AS
// Copyright 2023 Northern.tech AS
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down
31 changes: 24 additions & 7 deletions devauth/devauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ type DevAuth struct {
cOrch orchestrator.ClientRunner
cTenant tenant.ClientRunner
jwt jwt.Handler
jwtFallback jwt.Handler
verifyTenant bool
config Config
cache cache.Cache
Expand Down Expand Up @@ -1212,6 +1213,21 @@ func verifyTenantClaim(ctx context.Context, verifyTenant bool, tenant string) er
return nil
}

func (d *DevAuth) validateJWTToken(ctx context.Context, jti oid.ObjectID, raw string) error {
err := d.jwt.Validate(raw)
if err != nil && d.jwtFallback != nil {
err = d.jwtFallback.Validate(raw)
}
if err == jwt.ErrTokenExpired && jti.String() != "" {
log.FromContext(ctx).Errorf("Token %s expired: %v", jti.String(), err)
return d.handleExpiredToken(ctx, jti)
} else if err != nil {
log.FromContext(ctx).Errorf("Token %s invalid: %v", jti.String(), err)
return jwt.ErrTokenInvalid
}
return nil
}

func (d *DevAuth) VerifyToken(ctx context.Context, raw string) error {
l := log.FromContext(ctx)

Expand Down Expand Up @@ -1263,13 +1279,9 @@ func (d *DevAuth) VerifyToken(ctx context.Context, raw string) error {
}

// perform JWT signature and claims validation
if err := d.jwt.Validate(raw); err != nil {
if err == jwt.ErrTokenExpired && jti.String() != "" {
l.Errorf("Token %s expired: %v", jti.String(), err)
return d.handleExpiredToken(ctx, jti)
}
l.Errorf("Token %s invalid: %v", jti.String(), err)
return jwt.ErrTokenInvalid
err = d.validateJWTToken(ctx, jti, raw)
if err != nil {
return err
}

// cache check was a MISS, hit the db for verification
Expand Down Expand Up @@ -1500,6 +1512,11 @@ func (d *DevAuth) GetTenantLimit(
return d.GetLimit(ctx, name)
}

func (d *DevAuth) WithJWTFallbackHandler(handler jwt.Handler) *DevAuth {
d.jwtFallback = handler
return d
}

// WithTenantVerification will force verification of tenant token with tenant
// administrator when processing device authentication requests. Returns an
// updated devauth.
Expand Down
65 changes: 59 additions & 6 deletions devauth/devauth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1809,8 +1809,9 @@ func TestDevAuthVerifyToken(t *testing.T) {
tokenValidateErr error
tokenOtherError error

jwToken *jwt.Token
validateErr error
jwToken *jwt.Token
validateErr error
fallbackValidateErr error

getToken bool
getTokenErr error
Expand All @@ -1823,8 +1824,9 @@ func TestDevAuthVerifyToken(t *testing.T) {

updateDeviceErr error

tenantVerify bool
willUpdateDevice bool
tenantVerify bool
willUpdateDevice bool
jwtHandlerFallback bool
}{
{
tokenString: "expired",
Expand Down Expand Up @@ -1951,6 +1953,51 @@ func TestDevAuthVerifyToken(t *testing.T) {

tenantVerify: true,
},
{
tokenString: "with fallback",
jwToken: &jwt.Token{
Claims: jwt.Claims{
ID: oid.NewUUIDv5("good"),
Subject: oid.NewUUIDv5("bar"),
ExpiresAt: jwt.Time{
Time: time.Now().Add(time.Hour),
},
Issuer: "Tester",
Device: true,
},
},
validateErr: jwt.ErrTokenInvalid,
getToken: true,
auth: &model.AuthSet{
Id: oid.NewUUIDv5("good").String(),
Status: model.DevStatusAccepted,
DeviceId: oid.NewUUIDv5("bar").String(),
},
dev: &model.Device{
Id: oid.NewUUIDv5("bar").String(),
Decommissioning: false,
},
willUpdateDevice: true,
jwtHandlerFallback: true,
},
{
tokenString: "failed-validation-with-fallback",
tokenValidateErr: jwt.ErrTokenInvalid,
jwToken: &jwt.Token{
Claims: jwt.Claims{
ID: oid.NewUUIDv5("good"),
Subject: oid.NewUUIDv5("bar"),
ExpiresAt: jwt.Time{
Time: time.Now().Add(time.Hour),
},
Issuer: "Tester",
Device: true,
},
},
validateErr: jwt.ErrTokenInvalid,
fallbackValidateErr: jwt.ErrTokenInvalid,
jwtHandlerFallback: true,
},
}

for i := range testCases {
Expand All @@ -1962,6 +2009,12 @@ func TestDevAuthVerifyToken(t *testing.T) {
ja := &mjwt.Handler{}

devauth := NewDevAuth(db, nil, ja, Config{})
if tc.jwtHandlerFallback {
jaFallback := &mjwt.Handler{}
jaFallback.On("Validate", tc.tokenString).Return(tc.fallbackValidateErr)

devauth = devauth.WithJWTFallbackHandler(jaFallback)
}
if tc.tenantVerify {
// ok to pass nil tenantadm client here
devauth = devauth.WithTenantVerification(nil)
Expand All @@ -1973,8 +2026,8 @@ func TestDevAuthVerifyToken(t *testing.T) {
return tc.jwToken
}, tc.tokenParseErr)

if tc.tokenParseErr == nil && tc.jwToken != nil &&
tc.tokenOtherError == nil && tc.tokenString != "missing-tenant-claim" {
if tc.tokenParseErr == nil && tc.jwToken != nil && tc.tokenOtherError == nil &&
tc.tokenString != "missing-tenant-claim" {
ja.On("Validate", tc.tokenString).Return(tc.validateErr)
}

Expand Down
98 changes: 30 additions & 68 deletions jwt/jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,12 @@
package jwt

import (
"crypto/ed25519"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"os"

jwtgo "github.com/golang-jwt/jwt/v4"
"github.com/pkg/errors"
)

Expand All @@ -25,6 +28,11 @@ var (
ErrTokenInvalid = errors.New("jwt: token invalid")
)

const (
pemHeaderPKCS1 = "RSA PRIVATE KEY"
pemHeaderPKCS8 = "PRIVATE KEY"
)

// Handler jwt generator/verifier
//
//go:generate ../utils/mockgen.sh
Expand All @@ -41,76 +49,30 @@ type Handler interface {
Validate(string) error
}

// JWTHandlerRS256 is an RS256-specific JWTHandler
type JWTHandlerRS256 struct {
privKey *rsa.PrivateKey
fallbackPrivKey *rsa.PrivateKey
}

func NewJWTHandlerRS256(privKey *rsa.PrivateKey, fallbackPrivKey *rsa.PrivateKey) *JWTHandlerRS256 {
return &JWTHandlerRS256{
privKey: privKey,
fallbackPrivKey: fallbackPrivKey,
func NewJWTHandler(privateKeyPath string) (Handler, error) {
priv, err := os.ReadFile(privateKeyPath)
block, _ := pem.Decode(priv)
if block == nil {
return nil, errors.Wrap(err, "failed to read private key")
}
}

func (j *JWTHandlerRS256) ToJWT(token *Token) (string, error) {
//generate
jt := jwtgo.NewWithClaims(jwtgo.SigningMethodRS256, &token.Claims)

//sign
data, err := jt.SignedString(j.privKey)
return data, err
}

func (j *JWTHandlerRS256) FromJWT(tokstr string) (*Token, error) {
parser := jwtgo.NewParser(jwtgo.WithoutClaimsValidation())
jwttoken, _, err := parser.ParseUnverified(tokstr, &Claims{})
if err == nil {
token := Token{}
if claims, ok := jwttoken.Claims.(*Claims); ok {
token.Claims = *claims
return &token, nil
switch block.Type {
case pemHeaderPKCS1:
privKey, err := x509.ParsePKCS1PrivateKey(block.Bytes)
if err != nil {
return nil, errors.Wrap(err, "failed to read rsa private key")
}
}

return nil, ErrTokenInvalid
}

func (j *JWTHandlerRS256) Validate(tokstr string) error {
var err error
var jwttoken *jwtgo.Token
for _, privKey := range []*rsa.PrivateKey{
j.privKey,
j.fallbackPrivKey,
} {
if privKey != nil {
jwttoken, err = jwtgo.ParseWithClaims(tokstr, &Claims{},
func(token *jwtgo.Token) (interface{}, error) {
if _, ok := token.Method.(*jwtgo.SigningMethodRSA); !ok {
return nil, errors.New("unexpected signing method: " + token.Method.Alg())
}
return &privKey.PublicKey, nil
},
)
if jwttoken != nil && err == nil {
break
}
return NewJWTHandlerRS256(privKey), nil
case pemHeaderPKCS8:
key, err := x509.ParsePKCS8PrivateKey(block.Bytes)
if err != nil {
return nil, errors.Wrap(err, "failed to read private key")
}
}

// our Claims return Mender-specific validation errors
// go-jwt will wrap them in a generic ValidationError - unwrap and return directly
if jwttoken != nil && !jwttoken.Valid {
return ErrTokenInvalid
} else if err != nil {
err, ok := err.(*jwtgo.ValidationError)
if ok && err.Inner != nil {
return err.Inner
} else {
return err
switch v := key.(type) {
case *rsa.PrivateKey:
return NewJWTHandlerRS256(v), nil
case ed25519.PrivateKey:
return NewJWTHandlerEd25519(&v), nil
}
}

return nil
return nil, errors.Errorf("unsupported server private key type")
}
81 changes: 81 additions & 0 deletions jwt/jwt_ed25519.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
// Copyright 2023 Northern.tech AS
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package jwt

import (
"crypto/ed25519"

"github.com/golang-jwt/jwt/v4"
"github.com/pkg/errors"
)

// JWTHandlerEd25519 is an Ed25519-specific JWTHandler
type JWTHandlerEd25519 struct {
privKey *ed25519.PrivateKey
}

func NewJWTHandlerEd25519(privKey *ed25519.PrivateKey) *JWTHandlerEd25519 {
return &JWTHandlerEd25519{
privKey: privKey,
}
}

func (j *JWTHandlerEd25519) ToJWT(token *Token) (string, error) {
//generate
jt := jwt.NewWithClaims(jwt.SigningMethodEdDSA, &token.Claims)

//sign
data, err := jt.SignedString(j.privKey)
return data, err
}

func (j *JWTHandlerEd25519) FromJWT(tokstr string) (*Token, error) {
parser := jwt.NewParser(jwt.WithoutClaimsValidation())
jwttoken, _, err := parser.ParseUnverified(tokstr, &Claims{})
if err == nil {
token := Token{}
if claims, ok := jwttoken.Claims.(*Claims); ok {
token.Claims = *claims
return &token, nil
}
}

return nil, ErrTokenInvalid
}

func (j *JWTHandlerEd25519) Validate(tokstr string) error {
jwttoken, err := jwt.ParseWithClaims(tokstr, &Claims{},
func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodEd25519); !ok {
return nil, errors.New("unexpected signing method: " + token.Method.Alg())
}
return j.privKey.Public(), nil
},
)

// our Claims return Mender-specific validation errors
// go-jwt will wrap them in a generic ValidationError - unwrap and return directly
if jwttoken != nil && !jwttoken.Valid {
return ErrTokenInvalid
} else if err != nil {
err, ok := err.(*jwt.ValidationError)
if ok && err.Inner != nil {
return err.Inner
} else {
return err
}
}

return nil
}
Loading

0 comments on commit 45655b9

Please sign in to comment.