diff --git a/internal/auth/bearer.go b/internal/auth/bearer.go index 923ebd4..31188b4 100644 --- a/internal/auth/bearer.go +++ b/internal/auth/bearer.go @@ -1,86 +1,34 @@ package auth import ( - "fmt" - "net/http" + "github.com/cerberauth/vulnapi/jwt" ) -type BearerSecurityScheme struct { - Type Type `json:"type" yaml:"type"` - Scheme SchemeName `json:"scheme" yaml:"scheme"` - In SchemeIn `json:"in" yaml:"in"` - Name string `json:"name" yaml:"name"` - ValidValue *string `json:"-" yaml:"-"` - AttackValue string `json:"-" yaml:"-"` -} - -var _ SecurityScheme = (*BearerSecurityScheme)(nil) - -func NewAuthorizationBearerSecurityScheme(name string, value *string) *BearerSecurityScheme { - return &BearerSecurityScheme{ - Type: HttpType, - Scheme: BearerScheme, - In: InHeader, - Name: name, - ValidValue: value, - AttackValue: "", +func NewAuthorizationBearerSecurityScheme(name string, value *string) (*SecurityScheme, error) { + in := InHeader + tokenFormat := NoneTokenFormat + if value != nil && *value != "" && jwt.IsJWT(*value) { + tokenFormat = JWTTokenFormat } -} - -func (ss *BearerSecurityScheme) GetType() Type { - return ss.Type -} - -func (ss *BearerSecurityScheme) GetScheme() SchemeName { - return ss.Scheme -} - -func (ss *BearerSecurityScheme) GetIn() *SchemeIn { - return &ss.In -} - -func (ss *BearerSecurityScheme) GetName() string { - return ss.Name -} -func (ss *BearerSecurityScheme) GetHeaders() http.Header { - header := http.Header{} - attackValue := ss.GetAttackValue().(string) - if attackValue == "" && ss.HasValidValue() { - attackValue = ss.GetValidValue().(string) + securityScheme, err := NewSecurityScheme(name, nil, HttpType, BearerScheme, &in, &tokenFormat) + if err != nil { + return nil, err } - if attackValue != "" { - header.Set(AuthorizationHeader, fmt.Sprintf("%s %s", BearerPrefix, attackValue)) + err = securityScheme.SetValidValue(value) + if err != nil { + return nil, err } - return header -} - -func (ss *BearerSecurityScheme) GetCookies() []*http.Cookie { - return []*http.Cookie{} + return securityScheme, nil } -func (ss *BearerSecurityScheme) HasValidValue() bool { - return ss.ValidValue != nil && *ss.ValidValue != "" -} - -func (ss *BearerSecurityScheme) GetValidValue() interface{} { - if !ss.HasValidValue() { - return nil +func MustNewAuthorizationBearerSecurityScheme(name string, value *string) *SecurityScheme { + securityScheme, err := NewAuthorizationBearerSecurityScheme(name, value) + if err != nil { + panic(err) } - return *ss.ValidValue -} - -func (ss *BearerSecurityScheme) GetValidValueWriter() interface{} { - return nil -} - -func (ss *BearerSecurityScheme) SetAttackValue(v interface{}) { - ss.AttackValue = v.(string) -} - -func (ss *BearerSecurityScheme) GetAttackValue() interface{} { - return ss.AttackValue + return securityScheme } diff --git a/internal/auth/bearer_test.go b/internal/auth/bearer_test.go index d664a7b..09c3469 100644 --- a/internal/auth/bearer_test.go +++ b/internal/auth/bearer_test.go @@ -1,7 +1,6 @@ package auth_test import ( - "net/http" "testing" "github.com/cerberauth/vulnapi/internal/auth" @@ -12,180 +11,51 @@ func TestNewAuthorizationBearerSecurityScheme(t *testing.T) { name := "token" value := "abc123" - ss := auth.NewAuthorizationBearerSecurityScheme(name, &value) + securityScheme, err := auth.NewAuthorizationBearerSecurityScheme(name, &value) - assert.Equal(t, auth.HttpType, ss.Type) - assert.Equal(t, auth.BearerScheme, ss.Scheme) - assert.Equal(t, auth.InHeader, ss.In) - assert.Equal(t, name, ss.Name) - assert.Equal(t, &value, ss.ValidValue) - assert.Equal(t, "", ss.AttackValue) + assert.NoError(t, err) + assert.Equal(t, auth.HttpType, securityScheme.GetType()) + assert.Equal(t, auth.BearerScheme, securityScheme.GetScheme()) + assert.Equal(t, auth.InHeader, *securityScheme.GetIn()) + assert.Equal(t, name, securityScheme.GetName()) + assert.Equal(t, &value, securityScheme.GetValidValue()) + assert.Equal(t, nil, securityScheme.GetAttackValue().(string)) } -func TestBearerSecurityScheme_GetScheme(t *testing.T) { +func TestNewAuthorizationBearerSecurityScheme_WhenNilValue(t *testing.T) { name := "token" - value := "abc123" - ss := auth.NewAuthorizationBearerSecurityScheme(name, &value) - - scheme := ss.GetScheme() - - assert.Equal(t, auth.BearerScheme, scheme) -} - -func TestBearerSecurityScheme_GetType(t *testing.T) { - name := "token" - value := "abc123" - ss := auth.NewAuthorizationBearerSecurityScheme(name, &value) - - scheme := ss.GetType() - - assert.Equal(t, auth.HttpType, scheme) -} - -func TestBearerSecurityScheme_GetIn(t *testing.T) { - name := "token" - value := "abc123" - ss := auth.NewAuthorizationBearerSecurityScheme(name, &value) - - scheme := ss.GetIn() - - assert.Equal(t, auth.InHeader, *scheme) -} - -func TestBearerSecurityScheme_GetName(t *testing.T) { - name := "token" - value := "abc123" - ss := auth.NewAuthorizationBearerSecurityScheme(name, &value) - - scheme := ss.GetName() - assert.Equal(t, name, scheme) -} - -func TestBearerSecurityScheme_GetHeaders(t *testing.T) { - name := "token" - value := "abc123" - attackValue := "xyz789" - ss := auth.NewAuthorizationBearerSecurityScheme(name, &value) - ss.SetAttackValue(attackValue) - - headers := ss.GetHeaders() - - assert.Equal(t, http.Header{ - "Authorization": []string{"Bearer xyz789"}, - }, headers) -} - -func TestBearerSecurityScheme_GetHeaders_WhenNoAttackValue(t *testing.T) { - name := "token" - value := "abc123" - ss := auth.NewAuthorizationBearerSecurityScheme(name, &value) + securityScheme, err := auth.NewAuthorizationBearerSecurityScheme(name, nil) - headers := ss.GetHeaders() - - assert.Equal(t, http.Header{ - "Authorization": []string{"Bearer abc123"}, - }, headers) -} - -func TestBearerSecurityScheme_GetHeaders_WhenNoAttackAndValidValue(t *testing.T) { - name := "token" - ss := auth.NewAuthorizationBearerSecurityScheme(name, nil) - - headers := ss.GetHeaders() - - assert.Equal(t, http.Header{}, headers) + assert.NoError(t, err) + assert.Equal(t, auth.HttpType, securityScheme.GetType()) + assert.Equal(t, auth.BearerScheme, securityScheme.GetScheme()) + assert.Equal(t, auth.InHeader, *securityScheme.GetIn()) + assert.Equal(t, name, securityScheme.GetName()) + assert.Equal(t, nil, securityScheme.GetValidValue()) + assert.Equal(t, nil, securityScheme.GetAttackValue().(string)) } -func TestBearerSecurityScheme_GetCookies(t *testing.T) { - name := "token" - value := "abc123" - - ss := auth.NewAuthorizationBearerSecurityScheme(name, &value) - - cookies := ss.GetCookies() - - assert.Empty(t, cookies) -} - -func TestBearerSecurityScheme_HasValidValue_WhenValueIsNil(t *testing.T) { - name := "token" - value := "abc123" - ss := auth.NewAuthorizationBearerSecurityScheme(name, &value) - - result := ss.HasValidValue() - - assert.True(t, result) -} - -func TestBearerSecurityScheme_HasValidValueFalse_WhenValueIsEmptyString(t *testing.T) { +func TestNewAuthorizationBearerSecurityScheme_WhenEmptyValue(t *testing.T) { name := "token" value := "" - ss := auth.NewAuthorizationBearerSecurityScheme(name, &value) - - result := ss.HasValidValue() - assert.False(t, result) -} - -func TestBearerSecurityScheme_GetValidValueNil(t *testing.T) { - name := "token" - ss := auth.NewAuthorizationBearerSecurityScheme(name, nil) - - validValue := ss.GetValidValue() - - assert.Equal(t, nil, validValue) -} - -func TestBearerSecurityScheme_HasValidValueFalse(t *testing.T) { - name := "token" - ss := auth.NewAuthorizationBearerSecurityScheme(name, nil) + _, err := auth.NewAuthorizationBearerSecurityScheme(name, &value) - result := ss.HasValidValue() - - assert.False(t, result) + assert.Error(t, err) } -func TestBearerSecurityScheme_GetValidValue(t *testing.T) { +func TestNewAuthorizationBearerSecurityScheme_WhenJWTFormatValue(t *testing.T) { name := "token" - value := "abc123" - ss := auth.NewAuthorizationBearerSecurityScheme(name, &value) - - validValue := ss.GetValidValue() - - assert.Equal(t, value, validValue) -} - -func TestBearerSecurityScheme_GetValidValueWriter(t *testing.T) { - name := "token" - value := "abc123" - ss := auth.NewAuthorizationBearerSecurityScheme(name, &value) - writer := ss.GetValidValueWriter() - - assert.Equal(t, nil, writer) -} - -func TestBearerSecurityScheme_SetAttackValue(t *testing.T) { - name := "token" - value := "abc123" - - ss := auth.NewAuthorizationBearerSecurityScheme(name, &value) - - attackValue := "xyz789" - ss.SetAttackValue(attackValue) - - assert.Equal(t, attackValue, ss.AttackValue) -} - -func TestBearerSecurityScheme_GetAttackValue(t *testing.T) { - name := "token" - value := "abc123" - ss := auth.NewAuthorizationBearerSecurityScheme(name, &value) - - attackValue := "xyz789" - ss.SetAttackValue(attackValue) + value := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.e30.ufhxDTmrs4T5MSsvT6lsb3OpdWi5q8O31VX7TgrVamA" - result := ss.GetAttackValue() + securityScheme, err := auth.NewAuthorizationBearerSecurityScheme(name, &value) - assert.Equal(t, attackValue, result) + assert.NoError(t, err) + assert.Equal(t, auth.HttpType, securityScheme.GetType()) + assert.Equal(t, auth.BearerScheme, securityScheme.GetScheme()) + assert.Equal(t, auth.InHeader, *securityScheme.GetIn()) + assert.Equal(t, name, securityScheme.GetName()) + assert.Equal(t, value, securityScheme.GetValidValue()) + assert.Equal(t, nil, securityScheme.GetAttackValue().(string)) } diff --git a/internal/auth/jwt_bearer.go b/internal/auth/jwt_bearer.go deleted file mode 100644 index 82fe6c2..0000000 --- a/internal/auth/jwt_bearer.go +++ /dev/null @@ -1,108 +0,0 @@ -package auth - -import ( - "fmt" - "net/http" - - "github.com/cerberauth/vulnapi/jwt" -) - -type JWTBearerSecurityScheme struct { - Type Type `json:"type" yaml:"type"` - Scheme SchemeName `json:"scheme" yaml:"scheme"` - In SchemeIn `json:"in" yaml:"in"` - Name string `json:"name" yaml:"name"` - ValidValue *string `json:"-" yaml:"-"` - AttackValue string `json:"-" yaml:"-"` - - JWTWriter *jwt.JWTWriter `json:"-" yaml:"-"` -} - -var _ SecurityScheme = (*JWTBearerSecurityScheme)(nil) - -func NewAuthorizationJWTBearerSecurityScheme(name string, value *string) (*JWTBearerSecurityScheme, error) { - var jwtWriter *jwt.JWTWriter - if value != nil { - var err error - if jwtWriter, err = jwt.NewJWTWriter(*value); err != nil { - return nil, err - } - } - - return &JWTBearerSecurityScheme{ - Type: HttpType, - Scheme: BearerScheme, - In: InHeader, - Name: name, - ValidValue: value, - AttackValue: "", - - JWTWriter: jwtWriter, - }, nil -} - -func MustNewAuthorizationJWTBearerSecurityScheme(name string, value *string) *JWTBearerSecurityScheme { - scheme, err := NewAuthorizationJWTBearerSecurityScheme(name, value) - if err != nil { - panic(err) - } - return scheme -} - -func (ss *JWTBearerSecurityScheme) GetType() Type { - return ss.Type -} - -func (ss *JWTBearerSecurityScheme) GetScheme() SchemeName { - return ss.Scheme -} - -func (ss *JWTBearerSecurityScheme) GetIn() *SchemeIn { - return &ss.In -} - -func (ss *JWTBearerSecurityScheme) GetName() string { - return ss.Name -} - -func (ss *JWTBearerSecurityScheme) GetHeaders() http.Header { - header := http.Header{} - attackValue := ss.GetAttackValue().(string) - if attackValue == "" && ss.HasValidValue() { - attackValue = ss.GetValidValue().(string) - } - - if attackValue != "" { - header.Set(AuthorizationHeader, fmt.Sprintf("%s %s", BearerPrefix, attackValue)) - } - - return header -} - -func (ss *JWTBearerSecurityScheme) GetCookies() []*http.Cookie { - return []*http.Cookie{} -} - -func (ss *JWTBearerSecurityScheme) HasValidValue() bool { - return ss.ValidValue != nil -} - -func (ss *JWTBearerSecurityScheme) GetValidValue() interface{} { - if !ss.HasValidValue() { - return nil - } - - return *ss.ValidValue -} - -func (ss *JWTBearerSecurityScheme) GetValidValueWriter() interface{} { - return ss.JWTWriter -} - -func (ss *JWTBearerSecurityScheme) SetAttackValue(v interface{}) { - ss.AttackValue = v.(string) -} - -func (ss *JWTBearerSecurityScheme) GetAttackValue() interface{} { - return ss.AttackValue -} diff --git a/internal/auth/jwt_bearer_test.go b/internal/auth/jwt_bearer_test.go deleted file mode 100644 index 12be661..0000000 --- a/internal/auth/jwt_bearer_test.go +++ /dev/null @@ -1,233 +0,0 @@ -package auth_test - -import ( - "net/http" - "testing" - - "github.com/cerberauth/vulnapi/internal/auth" - "github.com/cerberauth/vulnapi/jwt" - "github.com/stretchr/testify/assert" -) - -func TestNewAuthorizationJWTBearerSecurityScheme(t *testing.T) { - name := "token" - value := jwt.FakeJWT - ss, err := auth.NewAuthorizationJWTBearerSecurityScheme(name, &value) - - assert.NoError(t, err) - assert.Equal(t, auth.HttpType, ss.Type) - assert.Equal(t, auth.BearerScheme, ss.Scheme) - assert.Equal(t, auth.InHeader, ss.In) - assert.Equal(t, name, ss.Name) - assert.Equal(t, &value, ss.ValidValue) - assert.Equal(t, "", ss.AttackValue) -} - -func TestNewAuthorizationJWTBearerSecuritySchemeWithInvalidJWT(t *testing.T) { - name := "token" - value := "abc123" - _, err := auth.NewAuthorizationJWTBearerSecurityScheme(name, &value) - - assert.Error(t, err) -} - -func TestMustNewAuthorizationJWTBearerSecurityScheme(t *testing.T) { - t.Run("ValidJWT", func(t *testing.T) { - name := "token" - value := jwt.FakeJWT - ss := auth.MustNewAuthorizationJWTBearerSecurityScheme(name, &value) - - assert.NotNil(t, ss) - assert.Equal(t, auth.HttpType, ss.Type) - assert.Equal(t, auth.BearerScheme, ss.Scheme) - assert.Equal(t, auth.InHeader, ss.In) - assert.Equal(t, name, ss.Name) - assert.Equal(t, &value, ss.ValidValue) - assert.Equal(t, "", ss.AttackValue) - }) - - t.Run("InvalidJWT", func(t *testing.T) { - name := "token" - value := "abc123" - assert.Panics(t, func() { - auth.MustNewAuthorizationJWTBearerSecurityScheme(name, &value) - }) - }) - - t.Run("NilValue", func(t *testing.T) { - name := "token" - ss := auth.MustNewAuthorizationJWTBearerSecurityScheme(name, nil) - - assert.NotNil(t, ss) - assert.Equal(t, auth.HttpType, ss.Type) - assert.Equal(t, auth.BearerScheme, ss.Scheme) - assert.Equal(t, auth.InHeader, ss.In) - assert.Equal(t, name, ss.Name) - assert.Nil(t, ss.ValidValue) - assert.Equal(t, "", ss.AttackValue) - }) -} - -func TestAuthorizationJWTBearerSecurityScheme_GetScheme(t *testing.T) { - name := "token" - value := jwt.FakeJWT - ss, err := auth.NewAuthorizationJWTBearerSecurityScheme(name, &value) - - scheme := ss.GetScheme() - - assert.NoError(t, err) - assert.Equal(t, auth.BearerScheme, scheme) -} - -func TestAuthorizationJWTBearerSecurityScheme_GetType(t *testing.T) { - name := "token" - value := jwt.FakeJWT - ss, err := auth.NewAuthorizationJWTBearerSecurityScheme(name, &value) - - scheme := ss.GetType() - - assert.NoError(t, err) - assert.Equal(t, auth.HttpType, scheme) -} - -func TestAuthorizationJWTBearerSecurityScheme_GetIn(t *testing.T) { - name := "token" - value := jwt.FakeJWT - ss, err := auth.NewAuthorizationJWTBearerSecurityScheme(name, &value) - - scheme := ss.GetIn() - - assert.NoError(t, err) - assert.Equal(t, auth.InHeader, *scheme) -} - -func TestAuthorizationJWTBearerSecurityScheme_GetName(t *testing.T) { - name := "token" - value := jwt.FakeJWT - ss, err := auth.NewAuthorizationJWTBearerSecurityScheme(name, &value) - - scheme := ss.GetName() - - assert.NoError(t, err) - assert.Equal(t, name, scheme) -} - -func TestJWTBearerSecurityScheme_GetHeaders(t *testing.T) { - name := "token" - value := jwt.FakeJWT - attackValue := "xyz789" - ss, err := auth.NewAuthorizationJWTBearerSecurityScheme(name, &value) - ss.SetAttackValue(attackValue) - - headers := ss.GetHeaders() - - assert.NoError(t, err) - assert.Equal(t, http.Header{ - "Authorization": []string{"Bearer xyz789"}, - }, headers) -} - -func TestJWTBearerSecurityScheme_GetHeaders_WhenNoAttackValue(t *testing.T) { - name := "token" - value := jwt.FakeJWT - ss, err := auth.NewAuthorizationJWTBearerSecurityScheme(name, &value) - - headers := ss.GetHeaders() - - assert.NoError(t, err) - assert.Equal(t, http.Header{ - "Authorization": []string{"Bearer " + jwt.FakeJWT}, - }, headers) -} - -func TestJWTBearerSecurityScheme_GetHeaders_WhenNoAttackAndValidValue(t *testing.T) { - name := "token" - ss, err := auth.NewAuthorizationJWTBearerSecurityScheme(name, nil) - - headers := ss.GetHeaders() - - assert.NoError(t, err) - assert.Equal(t, http.Header{}, headers) -} - -func TestJWTBearerSecurityScheme_GetCookies(t *testing.T) { - name := "token" - value := jwt.FakeJWT - ss, err := auth.NewAuthorizationJWTBearerSecurityScheme(name, &value) - cookies := ss.GetCookies() - - assert.NoError(t, err) - assert.Empty(t, cookies) -} - -func TestJWTBearerSecurityScheme_HasValidValue(t *testing.T) { - name := "token" - value := jwt.FakeJWT - ss, err := auth.NewAuthorizationJWTBearerSecurityScheme(name, &value) - hasValidValue := ss.HasValidValue() - - assert.NoError(t, err) - assert.True(t, hasValidValue) -} - -func TestJWTBearerSecurityScheme_HasValidValue_WhenNoValue(t *testing.T) { - name := "token" - ss, err := auth.NewAuthorizationJWTBearerSecurityScheme(name, nil) - hasValidValue := ss.HasValidValue() - - assert.NoError(t, err) - assert.False(t, hasValidValue) -} - -func TestJWTBearerSecurityScheme_GetValidValue(t *testing.T) { - name := "token" - value := jwt.FakeJWT - ss, err := auth.NewAuthorizationJWTBearerSecurityScheme(name, &value) - validValue := ss.GetValidValue() - - assert.NoError(t, err) - assert.Equal(t, value, validValue) -} - -func TestJWTBearerSecurityScheme_GetValidValue_WhenNoValue(t *testing.T) { - name := "token" - ss, err := auth.NewAuthorizationJWTBearerSecurityScheme(name, nil) - validValue := ss.GetValidValue() - - assert.NoError(t, err) - assert.Nil(t, validValue) -} - -func TestJWTBearerSecurityScheme_GetValidValueWriter(t *testing.T) { - name := "token" - value := jwt.FakeJWT - ss, err := auth.NewAuthorizationJWTBearerSecurityScheme(name, &value) - writer := ss.GetValidValueWriter() - - assert.NoError(t, err) - assert.Equal(t, ss.JWTWriter, writer) -} - -func TestJWTBearerSecurityScheme_SetAttackValue(t *testing.T) { - name := "token" - value := jwt.FakeJWT - ss, err := auth.NewAuthorizationJWTBearerSecurityScheme(name, &value) - attackValue := "xyz789" - ss.SetAttackValue(attackValue) - - assert.NoError(t, err) - assert.Equal(t, attackValue, ss.AttackValue) -} - -func TestJWTBearerSecurityScheme_GetAttackValue(t *testing.T) { - name := "token" - value := jwt.FakeJWT - ss, err := auth.NewAuthorizationJWTBearerSecurityScheme(name, &value) - attackValue := "xyz789" - ss.SetAttackValue(attackValue) - - result := ss.GetAttackValue() - - assert.NoError(t, err) - assert.Equal(t, attackValue, result) -} diff --git a/internal/auth/no_auth.go b/internal/auth/no_auth.go index 364d594..42c9634 100644 --- a/internal/auth/no_auth.go +++ b/internal/auth/no_auth.go @@ -1,61 +1,16 @@ package auth -import "net/http" +var defaultName = "no_auth" -type NoAuthSecurityScheme struct { - Name string `json:"name" yaml:"name"` - Type Type `json:"type" yaml:"type"` - Scheme SchemeName `json:"scheme" yaml:"scheme"` +func NewNoAuthSecurityScheme() (*SecurityScheme, error) { + return NewSecurityScheme(defaultName, nil, None, NoneScheme, nil, nil) } -var _ SecurityScheme = (*NoAuthSecurityScheme)(nil) - -func NewNoAuthSecurityScheme() *NoAuthSecurityScheme { - return &NoAuthSecurityScheme{ - Name: "", - Type: None, - Scheme: NoneScheme, +func MustNewNoAuthSecurityScheme() *SecurityScheme { + scheme, err := NewNoAuthSecurityScheme() + if err != nil { + panic(err) } -} - -func (ss *NoAuthSecurityScheme) GetType() Type { - return ss.Type -} - -func (ss *NoAuthSecurityScheme) GetScheme() SchemeName { - return ss.Scheme -} - -func (ss *NoAuthSecurityScheme) GetIn() *SchemeIn { - return nil -} - -func (ss *NoAuthSecurityScheme) GetName() string { - return ss.Name -} - -func (ss *NoAuthSecurityScheme) GetHeaders() http.Header { - return http.Header{} -} - -func (ss *NoAuthSecurityScheme) GetCookies() []*http.Cookie { - return []*http.Cookie{} -} - -func (ss *NoAuthSecurityScheme) HasValidValue() bool { - return false -} - -func (ss *NoAuthSecurityScheme) GetValidValue() interface{} { - return "" -} - -func (ss *NoAuthSecurityScheme) GetValidValueWriter() interface{} { - return "" -} - -func (ss *NoAuthSecurityScheme) SetAttackValue(v interface{}) {} -func (ss *NoAuthSecurityScheme) GetAttackValue() interface{} { - return nil + return scheme } diff --git a/internal/auth/no_auth_test.go b/internal/auth/no_auth_test.go index ed0abda..ffcafba 100644 --- a/internal/auth/no_auth_test.go +++ b/internal/auth/no_auth_test.go @@ -8,68 +8,22 @@ import ( ) func TestNewNoAuthSecurityScheme(t *testing.T) { - ss := auth.NewNoAuthSecurityScheme() - assert.NotNil(t, ss) -} - -func TestNoAuthSecurityScheme_GetScheme(t *testing.T) { - ss := auth.NewNoAuthSecurityScheme() - - scheme := ss.GetScheme() - - assert.Equal(t, auth.NoneScheme, scheme) -} - -func TestNoAuthSecurityScheme_GetType(t *testing.T) { - ss := auth.NewNoAuthSecurityScheme() - - scheme := ss.GetType() - - assert.Equal(t, auth.None, scheme) -} - -func TestNoAuthSecurityScheme_GetName(t *testing.T) { - ss := auth.NewNoAuthSecurityScheme() + securityScheme, err := auth.NewNoAuthSecurityScheme() - scheme := ss.GetName() - - assert.Equal(t, "", scheme) -} - -func TestNoAuthSecurityScheme_GetHeaders(t *testing.T) { - ss := &auth.NoAuthSecurityScheme{} - headers := ss.GetHeaders() - assert.NotNil(t, headers) - assert.Empty(t, headers) -} - -func TestNoAuthSecurityScheme_GetCookies(t *testing.T) { - ss := &auth.NoAuthSecurityScheme{} - cookies := ss.GetCookies() - assert.NotNil(t, cookies) - assert.Empty(t, cookies) + assert.NoError(t, err) + assert.NotNil(t, securityScheme) + assert.Equal(t, "no_auth", securityScheme.GetName()) + assert.Equal(t, auth.None, securityScheme.GetType()) + assert.Equal(t, auth.NoneScheme, securityScheme.GetScheme()) + assert.Nil(t, securityScheme.GetIn()) } -func TestNoAuthSecurityScheme_GetValidValue(t *testing.T) { - ss := &auth.NoAuthSecurityScheme{} - validValue := ss.GetValidValue() - assert.Equal(t, "", validValue) -} - -func TestNoAuthSecurityScheme_GetValidValueWriter(t *testing.T) { - ss := &auth.NoAuthSecurityScheme{} - validValueWriter := ss.GetValidValueWriter() - assert.Equal(t, "", validValueWriter) -} - -func TestNoAuthSecurityScheme_SetAttackValue(t *testing.T) { - ss := &auth.NoAuthSecurityScheme{} - ss.SetAttackValue("attack value") - // No assertion as this method does not return anything -} +func TestMustNewNoAuthSecurityScheme(t *testing.T) { + securityScheme := auth.MustNewNoAuthSecurityScheme() -func TestNoAuthSecurityScheme_GetAttackValue(t *testing.T) { - ss := &auth.NoAuthSecurityScheme{} - attackValue := ss.GetAttackValue() - assert.Nil(t, attackValue) + assert.NotNil(t, securityScheme) + assert.Equal(t, "no_auth", securityScheme.GetName()) + assert.Equal(t, auth.None, securityScheme.GetType()) + assert.Equal(t, auth.NoneScheme, securityScheme.GetScheme()) + assert.Nil(t, securityScheme.GetIn()) } diff --git a/internal/auth/oauth.go b/internal/auth/oauth.go index 90798ad..92f0165 100644 --- a/internal/auth/oauth.go +++ b/internal/auth/oauth.go @@ -1,9 +1,6 @@ package auth import ( - "fmt" - "net/http" - "github.com/cerberauth/vulnapi/jwt" ) @@ -23,98 +20,30 @@ type OAuthConfig struct { RefreshURL string } -type OAuthSecurityScheme struct { - Type Type `json:"type" yaml:"type"` - Scheme SchemeName `json:"scheme" yaml:"scheme"` - In SchemeIn `json:"in" yaml:"in"` - Name string `json:"name" yaml:"name"` - ValidValue *string `json:"-" yaml:"-"` - AttackValue string `json:"-" yaml:"-"` - - Config *OAuthConfig `json:"config" yaml:"config"` - JWTWriter *jwt.JWTWriter `json:"-" yaml:"-"` -} - -var _ SecurityScheme = (*OAuthSecurityScheme)(nil) - -func NewOAuthSecurityScheme(name string, value *string, cfg *OAuthConfig) *OAuthSecurityScheme { - var jwtWriter *jwt.JWTWriter - if value != nil { - jwtWriter, _ = jwt.NewJWTWriter(*value) +func NewOAuthSecurityScheme(name string, value *OAuthValue, config *OAuthConfig) (*SecurityScheme, error) { + tokenFormat := NoneTokenFormat + if value != nil && value.AccessToken != "" && jwt.IsJWT(value.AccessToken) { + tokenFormat = JWTTokenFormat } - return &OAuthSecurityScheme{ - Type: OAuth2, - Scheme: BearerScheme, - In: InHeader, - Name: name, - ValidValue: value, - JWTWriter: jwtWriter, - AttackValue: "", - - Config: cfg, + securityScheme, err := NewSecurityScheme(name, config, OAuth2, NoneScheme, nil, &tokenFormat) + if err != nil { + return nil, err } -} - -func (ss *OAuthSecurityScheme) GetType() Type { - return ss.Type -} - -func (ss *OAuthSecurityScheme) GetScheme() SchemeName { - return ss.Scheme -} - -func (ss *OAuthSecurityScheme) GetIn() *SchemeIn { - return &ss.In -} - -func (ss *OAuthSecurityScheme) GetName() string { - return ss.Name -} -func (ss *OAuthSecurityScheme) GetHeaders() http.Header { - header := http.Header{} - attackValue := ss.GetAttackValue().(string) - if attackValue == "" && ss.HasValidValue() { - attackValue = ss.GetValidValue().(string) + err = securityScheme.SetValidValue(value) + if err != nil { + return nil, err } - if attackValue != "" { - header.Set(AuthorizationHeader, fmt.Sprintf("%s %s", BearerPrefix, attackValue)) - } - - return header -} - -func (ss *OAuthSecurityScheme) GetCookies() []*http.Cookie { - return []*http.Cookie{} -} - -func (ss *OAuthSecurityScheme) HasValidValue() bool { - return ss.ValidValue != nil && *ss.ValidValue != "" -} - -func (ss *OAuthSecurityScheme) GetValidValue() interface{} { - if !ss.HasValidValue() { - return nil - } - - return *ss.ValidValue + return securityScheme, nil } -func (ss *OAuthSecurityScheme) GetValidValueWriter() interface{} { - return ss.JWTWriter -} - -func (ss *OAuthSecurityScheme) SetAttackValue(v interface{}) { - if v == nil { - ss.AttackValue = "" - return +func MustNewOAuthSecurityScheme(name string, value *OAuthValue, config *OAuthConfig) *SecurityScheme { + securityScheme, err := NewOAuthSecurityScheme(name, value, config) + if err != nil { + panic(err) } - ss.AttackValue = v.(string) -} - -func (ss *OAuthSecurityScheme) GetAttackValue() interface{} { - return ss.AttackValue + return securityScheme } diff --git a/internal/auth/oauth_test.go b/internal/auth/oauth_test.go index 851dcee..d130cb6 100644 --- a/internal/auth/oauth_test.go +++ b/internal/auth/oauth_test.go @@ -1,217 +1,25 @@ package auth_test import ( - "net/http" "testing" "github.com/cerberauth/vulnapi/internal/auth" - "github.com/cerberauth/vulnapi/jwt" "github.com/stretchr/testify/assert" ) func TestNewOAuthSecurityScheme(t *testing.T) { name := "token" - value := "abc123" + value := auth.OAuthValue{ + AccessToken: "abc123", + } - ss := auth.NewOAuthSecurityScheme(name, &value, nil) + securityScheme, err := auth.NewOAuthSecurityScheme(name, &value, nil) - assert.Equal(t, auth.OAuth2, ss.Type) - assert.Equal(t, auth.BearerScheme, ss.Scheme) - assert.Equal(t, auth.InHeader, ss.In) - assert.Equal(t, name, ss.Name) - assert.Equal(t, &value, ss.ValidValue) - assert.Equal(t, "", ss.AttackValue) - assert.Nil(t, ss.JWTWriter) -} - -func TestNewOAuthSecurityScheme_WithJWT(t *testing.T) { - name := "token" - value := jwt.FakeJWT - - ss := auth.NewOAuthSecurityScheme(name, &value, nil) - - assert.Equal(t, name, ss.Name) - assert.Equal(t, &value, ss.ValidValue) - assert.Equal(t, "", ss.AttackValue) - assert.NotNil(t, ss.JWTWriter) -} - -func TestOAuthSecurityScheme_GetScheme(t *testing.T) { - name := "token" - value := "abc123" - ss := auth.NewOAuthSecurityScheme(name, &value, nil) - - scheme := ss.GetScheme() - - assert.Equal(t, auth.BearerScheme, scheme) -} - -func TestOAuthSecurityScheme_GetType(t *testing.T) { - name := "token" - value := "abc123" - ss := auth.NewOAuthSecurityScheme(name, &value, nil) - - scheme := ss.GetType() - - assert.Equal(t, auth.OAuth2, scheme) -} - -func TestOAuthSecurityScheme_GetIn(t *testing.T) { - name := "token" - value := "abc123" - ss := auth.NewOAuthSecurityScheme(name, &value, nil) - - scheme := ss.GetIn() - - assert.Equal(t, auth.InHeader, *scheme) -} - -func TestOAuthSecurityScheme_GetName(t *testing.T) { - name := "token" - value := "abc123" - ss := auth.NewOAuthSecurityScheme(name, &value, nil) - - scheme := ss.GetName() - - assert.Equal(t, name, scheme) -} - -func TestNewOAuthSecurityScheme_GetHeaders(t *testing.T) { - name := "token" - value := "abc123" - attackValue := "xyz789" - - ss := auth.NewOAuthSecurityScheme(name, &value, nil) - ss.SetAttackValue(attackValue) - - headers := ss.GetHeaders() - - assert.Equal(t, http.Header{ - "Authorization": []string{"Bearer xyz789"}, - }, headers) -} - -func TestNewOAuthSecurityScheme_GetHeaders_WhenNoAttackValue(t *testing.T) { - name := "token" - value := "abc123" - - ss := auth.NewOAuthSecurityScheme(name, &value, nil) - - headers := ss.GetHeaders() - - assert.Equal(t, http.Header{ - "Authorization": []string{"Bearer abc123"}, - }, headers) -} - -func TestNewOAuthSecurityScheme_GetHeaders_WhenNoAttackAndValidValue(t *testing.T) { - name := "token" - ss := auth.NewOAuthSecurityScheme(name, nil, nil) - - headers := ss.GetHeaders() - - assert.Equal(t, http.Header{}, headers) -} - -func TestNewOAuthSecurityScheme_GetCookies(t *testing.T) { - name := "token" - value := "abc123" - - ss := auth.NewOAuthSecurityScheme(name, &value, nil) - - cookies := ss.GetCookies() - - assert.Empty(t, cookies) -} - -func TestNewOAuthSecurityScheme_HasValidValue(t *testing.T) { - name := "token" - value := "abc123" - ss := auth.NewOAuthSecurityScheme(name, &value, nil) - - result := ss.HasValidValue() - - assert.True(t, result) -} - -func TestNewOAuthSecurityScheme_HasValidValueFalse_WhenValueIsNil(t *testing.T) { - name := "token" - ss := auth.NewOAuthSecurityScheme(name, nil, nil) - - result := ss.HasValidValue() - - assert.False(t, result) -} - -func TestNewOAuthSecurityScheme_HasValidValueFalse_WhenValueIsEmptyString(t *testing.T) { - name := "token" - value := "" - ss := auth.NewOAuthSecurityScheme(name, &value, nil) - - result := ss.HasValidValue() - - assert.False(t, result) -} - -func TestNewOAuthSecurityScheme_GetValidValueNil(t *testing.T) { - name := "token" - ss := auth.NewOAuthSecurityScheme(name, nil, nil) - - validValue := ss.GetValidValue() - - assert.Equal(t, nil, validValue) -} - -func TestNewOAuthSecurityScheme_GetValidValue(t *testing.T) { - name := "token" - value := "abc123" - ss := auth.NewOAuthSecurityScheme(name, &value, nil) - - validValue := ss.GetValidValue() - - assert.Equal(t, value, validValue) -} - -func TestNewOAuthSecurityScheme_GetValidValueWriter(t *testing.T) { - name := "token" - value := "abc123" - ss := auth.NewOAuthSecurityScheme(name, &value, nil) - writer := ss.GetValidValueWriter() - - assert.Nil(t, writer) -} - -func TestNewOAuthSecurityScheme_GetValidValueWriter_WithJWT(t *testing.T) { - name := "token" - value := jwt.FakeJWT - - ss := auth.NewOAuthSecurityScheme(name, &value, nil) - writer := ss.GetValidValueWriter() - - assert.IsType(t, &jwt.JWTWriter{}, writer) -} - -func TestNewOAuthSecurityScheme_SetAttackValue(t *testing.T) { - name := "token" - value := "abc123" - - ss := auth.NewOAuthSecurityScheme(name, &value, nil) - - attackValue := "xyz789" - ss.SetAttackValue(attackValue) - - assert.Equal(t, attackValue, ss.AttackValue) -} - -func TestNewOAuthSecurityScheme_GetAttackValue(t *testing.T) { - name := "token" - value := "abc123" - ss := auth.NewOAuthSecurityScheme(name, &value, nil) - - attackValue := "xyz789" - ss.SetAttackValue(attackValue) - - result := ss.GetAttackValue() - - assert.Equal(t, attackValue, result) + assert.NoError(t, err) + assert.Equal(t, auth.OAuth2, securityScheme.GetType()) + assert.Equal(t, auth.BearerScheme, securityScheme.GetScheme()) + assert.Equal(t, auth.InHeader, *securityScheme.GetIn()) + assert.Equal(t, name, securityScheme.GetName()) + assert.Equal(t, &value, securityScheme.GetValidValue()) + assert.Equal(t, nil, securityScheme.GetAttackValue().(string)) } diff --git a/internal/auth/scheme.go b/internal/auth/scheme.go index 1e6222e..62e6630 100644 --- a/internal/auth/scheme.go +++ b/internal/auth/scheme.go @@ -24,6 +24,14 @@ func (e *SchemeName) Type() string { type SchemeIn string const ( + InQuery SchemeIn = "query" InHeader SchemeIn = "header" InCookie SchemeIn = "cookie" ) + +type TokenFormat string + +const ( + JWTTokenFormat TokenFormat = "jwt" + NoneTokenFormat TokenFormat = "none" +) diff --git a/internal/auth/security_scheme.go b/internal/auth/security_scheme.go index 5d75223..3a4ee96 100644 --- a/internal/auth/security_scheme.go +++ b/internal/auth/security_scheme.go @@ -1,34 +1,203 @@ package auth import ( + "fmt" "net/http" + "time" + + "github.com/cerberauth/vulnapi/jwt" ) -type SecurityScheme interface { - GetType() Type - GetScheme() SchemeName - GetIn() *SchemeIn - GetName() string +type SecurityScheme struct { + Type Type `json:"type" yaml:"type"` + Scheme SchemeName `json:"scheme" yaml:"scheme"` + In *SchemeIn `json:"in" yaml:"in"` + TokenFormat *TokenFormat `json:"token_format" yaml:"token_format"` + + Name string `json:"name" yaml:"name"` + Config interface{} `json:"config" yaml:"config"` - GetHeaders() http.Header - GetCookies() []*http.Cookie - GetValidValue() interface{} - HasValidValue() bool - GetValidValueWriter() interface{} - SetAttackValue(v interface{}) - GetAttackValue() interface{} + ValidValue interface{} `json:"-" yaml:"-"` + AttackValue interface{} `json:"-" yaml:"-"` } type SecuritySchemesMap map[string]SecurityScheme -func GetSecuritySchemeUniqueName(securityScheme SecurityScheme) string { - if securityScheme == nil { - return "" +type InQueryValue = string +type InHeaderValue = string +type InCookieValue = http.Cookie + +type OAuthValue struct { + AccessToken string `json:"access_token" yaml:"access_token"` + RefreshToken *string `json:"refresh_token" yaml:"refresh_token"` + ExpiresIn *time.Time `json:"expires_in" yaml:"expires_in"` + Scope *string `json:"scope" yaml:"scope"` +} + +func NewSecurityScheme(name string, config interface{}, t Type, scheme SchemeName, in *SchemeIn, tokenFormat *TokenFormat) (*SecurityScheme, error) { + if in != nil && name == "" { + return nil, fmt.Errorf("name is required for security scheme with in %s", *in) + } + + if t == ApiKey && in == nil { + return nil, fmt.Errorf("in is required for security scheme with type %s", t) + } + + return &SecurityScheme{ + Name: name, + Config: config, + + Type: t, + Scheme: scheme, + In: in, + TokenFormat: tokenFormat, + }, nil +} + +func (securityScheme *SecurityScheme) GetType() Type { + return securityScheme.Type +} + +func (securityScheme *SecurityScheme) GetScheme() SchemeName { + return securityScheme.Scheme +} + +func (securityScheme *SecurityScheme) GetIn() *SchemeIn { + return securityScheme.In +} + +func (securityScheme *SecurityScheme) GetTokenFormat() *TokenFormat { + return securityScheme.TokenFormat +} + +func (securityScheme *SecurityScheme) GetName() string { + return securityScheme.Name +} + +func (securityScheme *SecurityScheme) GetConfig() interface{} { + return securityScheme.Config +} + +func (securityScheme *SecurityScheme) validateValue(value interface{}) error { + if value == nil { + return fmt.Errorf("value is required") + } + + switch securityScheme.GetType() { + case ApiKey: + if securityScheme.GetIn() == nil { + return fmt.Errorf("in is required for api key security scheme") + } + + var ok bool + switch *securityScheme.GetIn() { + case InQuery: + _, ok = value.(InQueryValue) + case InHeader: + _, ok = value.(InHeaderValue) + case InCookie: + _, ok = value.(InCookieValue) + } + + if !ok { + return fmt.Errorf("invalid value for api key security scheme") + } + return nil + + case HttpType: + val, ok := value.(string) + if !ok { + return fmt.Errorf("invalid value for http security scheme") + } + + if securityScheme.GetTokenFormat() != nil && *securityScheme.GetTokenFormat() == JWTTokenFormat { + if _, err := jwt.NewJWTWriter(val); err != nil { + return err + } + } + return nil + + case OAuth2: + _, ok := value.(*OAuthValue) + if !ok { + return fmt.Errorf("invalid value for oauth2 security scheme") + } + return nil + } + + return nil +} + +func (securityScheme *SecurityScheme) SetValidValue(value interface{}) error { + if value == nil { + securityScheme.ValidValue = nil + return nil + } + + if err := securityScheme.validateValue(value); err != nil { + return err + } + + securityScheme.ValidValue = &value + return nil +} + +func (securityScheme *SecurityScheme) GetValidValue() interface{} { + return securityScheme.ValidValue +} + +func (securityScheme *SecurityScheme) HasValidValue() bool { + return securityScheme.GetValidValue() != nil +} + +func (securityScheme *SecurityScheme) SetAttackValue(value interface{}) error { + if value == nil { + securityScheme.AttackValue = nil + return nil } - uniqueName := string(securityScheme.GetType()) + "-" + string(securityScheme.GetScheme()) - if securityScheme.GetIn() != nil { - uniqueName += "-" + string(*securityScheme.GetIn()) + if err := securityScheme.validateValue(value); err != nil { + return err + } + + securityScheme.AttackValue = &value + return nil +} + +func (securityScheme *SecurityScheme) GetAttackValue() interface{} { + return securityScheme.AttackValue +} + +func (securityScheme *SecurityScheme) GetHeaders() http.Header { + header := http.Header{} + + attackValue := securityScheme.GetAttackValue() + if attackValue == nil && securityScheme.HasValidValue() { + attackValue = securityScheme.GetValidValue().(string) + } + + if attackValue == nil { + return header + } + + if (securityScheme.GetType() == ApiKey || securityScheme.GetType() == HttpType) && *securityScheme.GetIn() == InHeader { + var val string + if securityScheme.GetType() == HttpType && securityScheme.GetScheme() == BearerScheme { + val = fmt.Sprintf("%s %s", BearerPrefix, attackValue) + } else { + val = fmt.Sprintf("%s", attackValue) + } + header.Set(securityScheme.GetName(), val) + } + + return header +} + +func (securityScheme *SecurityScheme) GetCookies() []*http.Cookie { + if securityScheme.GetIn() == nil || *securityScheme.GetIn() != InCookie { + return []*http.Cookie{} } - return uniqueName + cookies := []*http.Cookie{} + // TODO + return cookies } diff --git a/internal/auth/type.go b/internal/auth/type.go index 5c41fa3..20c52db 100644 --- a/internal/auth/type.go +++ b/internal/auth/type.go @@ -7,5 +7,6 @@ const ( OAuth2 Type = "oauth2" OpenIdConnect Type = "openIdConnect" ApiKey Type = "apiKey" + MutualTLS Type = "mutualTLS" None Type = "none" ) diff --git a/internal/auth/uniq_name.go b/internal/auth/uniq_name.go new file mode 100644 index 0000000..fee611a --- /dev/null +++ b/internal/auth/uniq_name.go @@ -0,0 +1,14 @@ +package auth + +func GetSecuritySchemeUniqueName(securityScheme *SecurityScheme) string { + if securityScheme == nil { + return "" + } + + uniqueName := string(securityScheme.GetType()) + "-" + string(securityScheme.GetScheme()) + if securityScheme.GetIn() != nil { + uniqueName += "-" + string(*securityScheme.GetIn()) + } + + return uniqueName +} diff --git a/internal/auth/security_scheme_test.go b/internal/auth/uniq_name_test.go similarity index 56% rename from internal/auth/security_scheme_test.go rename to internal/auth/uniq_name_test.go index 0861e66..b00c123 100644 --- a/internal/auth/security_scheme_test.go +++ b/internal/auth/uniq_name_test.go @@ -8,34 +8,24 @@ import ( ) func TestGetSecuritySchemeUniqueName(t *testing.T) { - noAuthSecurityScheme := auth.NewNoAuthSecurityScheme() - bearerSecurityScheme := auth.NewAuthorizationBearerSecurityScheme("name", nil) - jwtBearerSecurityScheme, _ := auth.NewAuthorizationJWTBearerSecurityScheme("name", nil) - oauthSecurityScheme := auth.NewOAuthSecurityScheme("name", nil, nil) - tests := []struct { name string - securityScheme auth.SecurityScheme + securityScheme *auth.SecurityScheme expected string }{ { name: "no auth security scheme", - securityScheme: noAuthSecurityScheme, + securityScheme: auth.MustNewNoAuthSecurityScheme(), expected: "none-None", }, { name: "bearer security scheme", - securityScheme: bearerSecurityScheme, - expected: "http-Bearer-header", - }, - { - name: "jwt bearer security scheme", - securityScheme: jwtBearerSecurityScheme, + securityScheme: auth.MustNewAuthorizationBearerSecurityScheme("name", nil), expected: "http-Bearer-header", }, { name: "oauth security scheme", - securityScheme: oauthSecurityScheme, + securityScheme: auth.MustNewOAuthSecurityScheme("name", &auth.OAuthValue{}, nil), expected: "oauth2-Bearer-header", }, } diff --git a/internal/operation/operation.go b/internal/operation/operation.go index 8bf9e8b..0bdd8e0 100644 --- a/internal/operation/operation.go +++ b/internal/operation/operation.go @@ -42,12 +42,12 @@ type Operation struct { OpenAPIDocPath *string `json:"-" yaml:"-"` ID string `json:"id" yaml:"id"` - Method string `json:"method" yaml:"method"` - URL url.URL `json:"url" yaml:"url"` - Body []byte `json:"body,omitempty" yaml:"body,omitempty"` - Cookies []*http.Cookie `json:"cookies,omitempty" yaml:"cookies,omitempty"` - Header http.Header `json:"header,omitempty" yaml:"header,omitempty"` - SecuritySchemes []auth.SecurityScheme `json:"securitySchemes" yaml:"securitySchemes"` + Method string `json:"method" yaml:"method"` + URL url.URL `json:"url" yaml:"url"` + Body []byte `json:"body,omitempty" yaml:"body,omitempty"` + Cookies []*http.Cookie `json:"cookies,omitempty" yaml:"cookies,omitempty"` + Header http.Header `json:"header,omitempty" yaml:"header,omitempty"` + SecuritySchemes []*auth.SecurityScheme `json:"securitySchemes" yaml:"securitySchemes"` } func getBody(body io.Reader) ([]byte, error) { @@ -88,7 +88,7 @@ func NewOperation(method string, operationUrl string, body io.Reader, client *re Body: bodyBuffer, Cookies: []*http.Cookie{}, Header: http.Header{}, - SecuritySchemes: []auth.SecurityScheme{auth.NewNoAuthSecurityScheme()}, + SecuritySchemes: []*auth.SecurityScheme{auth.MustNewNoAuthSecurityScheme()}, }, nil } @@ -126,7 +126,7 @@ func NewOperationFromRequest(r *request.Request) (*Operation, error) { Cookies: r.GetCookies(), Body: r.GetBody(), - SecuritySchemes: []auth.SecurityScheme{auth.NewNoAuthSecurityScheme()}, + SecuritySchemes: []*auth.SecurityScheme{auth.MustNewNoAuthSecurityScheme()}, }, nil } @@ -162,21 +162,21 @@ func (operation *Operation) NewRequest() (*request.Request, error) { return req, nil } -func (operation *Operation) GetSecuritySchemes() []auth.SecurityScheme { +func (operation *Operation) GetSecuritySchemes() []*auth.SecurityScheme { if operation.SecuritySchemes == nil { - return []auth.SecurityScheme{auth.NewNoAuthSecurityScheme()} + return []*auth.SecurityScheme{auth.MustNewNoAuthSecurityScheme()} } return operation.SecuritySchemes } -func (operation *Operation) GetSecurityScheme() auth.SecurityScheme { +func (operation *Operation) GetSecurityScheme() *auth.SecurityScheme { if operation.SecuritySchemes == nil { - return auth.NewNoAuthSecurityScheme() + return auth.MustNewNoAuthSecurityScheme() } return operation.SecuritySchemes[0] } -func (operation *Operation) SetSecuritySchemes(securitySchemes []auth.SecurityScheme) *Operation { +func (operation *Operation) SetSecuritySchemes(securitySchemes []*auth.SecurityScheme) *Operation { operation.SecuritySchemes = securitySchemes return operation } @@ -204,9 +204,9 @@ func (operation *Operation) GetID() string { } func (o *Operation) Clone() (*Operation, error) { - var clonedSecuritySchemes []auth.SecurityScheme + var clonedSecuritySchemes []*auth.SecurityScheme if o.SecuritySchemes != nil { - clonedSecuritySchemes = make([]auth.SecurityScheme, len(o.SecuritySchemes)) + clonedSecuritySchemes = make([]*auth.SecurityScheme, len(o.SecuritySchemes)) copy(clonedSecuritySchemes, o.SecuritySchemes) } diff --git a/internal/operation/operation_test.go b/internal/operation/operation_test.go index 3431301..a8bf3d2 100644 --- a/internal/operation/operation_test.go +++ b/internal/operation/operation_test.go @@ -140,7 +140,7 @@ func TestNewOperationFromRequest_WithBody(t *testing.T) { func TestOperation_GetSecurityScheme(t *testing.T) { t.Run("NoSecuritySchemes", func(t *testing.T) { operation := &operation.Operation{} - expectedScheme := auth.NewNoAuthSecurityScheme() + expectedScheme := auth.MustNewNoAuthSecurityScheme() scheme := operation.GetSecurityScheme() @@ -148,9 +148,9 @@ func TestOperation_GetSecurityScheme(t *testing.T) { }) t.Run("WithSecuritySchemes", func(t *testing.T) { - expectedScheme := auth.NewNoAuthSecurityScheme() + expectedScheme := auth.MustNewNoAuthSecurityScheme() operation := &operation.Operation{ - SecuritySchemes: []auth.SecurityScheme{expectedScheme}, + SecuritySchemes: []*auth.SecurityScheme{expectedScheme}, } scheme := operation.GetSecurityScheme() @@ -160,7 +160,7 @@ func TestOperation_GetSecurityScheme(t *testing.T) { } func TestOperationCloneWithSecuritySchemes(t *testing.T) { - securitySchemes := []auth.SecurityScheme{auth.NewNoAuthSecurityScheme()} + securitySchemes := []*auth.SecurityScheme{auth.MustNewNoAuthSecurityScheme()} operation := operation.MustNewOperation(http.MethodGet, "http://example.com", nil, nil) operation.SetSecuritySchemes(securitySchemes) diff --git a/internal/request/request.go b/internal/request/request.go index a9f77f3..902ae81 100644 --- a/internal/request/request.go +++ b/internal/request/request.go @@ -74,7 +74,7 @@ func (r *Request) WithCookies(cookies []*http.Cookie) *Request { return r } -func (r *Request) WithSecurityScheme(securityScheme auth.SecurityScheme) *Request { +func (r *Request) WithSecurityScheme(securityScheme *auth.SecurityScheme) *Request { if securityScheme.GetCookies() != nil { r.WithCookies(securityScheme.GetCookies()) } diff --git a/internal/request/request_test.go b/internal/request/request_test.go index fa1eb14..647943e 100644 --- a/internal/request/request_test.go +++ b/internal/request/request_test.go @@ -59,7 +59,7 @@ func TestWithSecurityScheme(t *testing.T) { method := http.MethodGet url := "http://localhost:8080/" token := "token" - securityScheme := auth.SecurityScheme(auth.NewAuthorizationBearerSecurityScheme("token", &token)) + securityScheme := auth.MustNewAuthorizationBearerSecurityScheme("token", &token) request, err := request.NewRequest(method, url, nil, nil) request = request.WithSecurityScheme(securityScheme) @@ -346,7 +346,7 @@ func TestDoWithSecuritySchemeHeaders(t *testing.T) { method := http.MethodGet url := "http://localhost:8080/" token := "token" - securityScheme := auth.SecurityScheme(auth.NewAuthorizationBearerSecurityScheme("token", &token)) + securityScheme := auth.MustNewAuthorizationBearerSecurityScheme("token", &token) request, _ := request.NewRequest(method, url, nil, client) request.WithSecurityScheme(securityScheme) @@ -376,7 +376,7 @@ func TestDoWithHeadersSecuritySchemeHeaders(t *testing.T) { "Authorization": []string{"Bearer othertoken"}, } token := "token" - securityScheme := auth.SecurityScheme(auth.NewAuthorizationBearerSecurityScheme("token", &token)) + securityScheme := auth.MustNewAuthorizationBearerSecurityScheme("token", &token) request, _ := request.NewRequest(method, url, nil, client) request = request.WithHeader(header) request = request.WithSecurityScheme(securityScheme) @@ -408,7 +408,7 @@ func TestDoWithCookiesSecuritySchemeHeaders(t *testing.T) { Value: "value1", }} token := "token" - securityScheme := auth.SecurityScheme(auth.NewAuthorizationBearerSecurityScheme("token", &token)) + securityScheme := auth.MustNewAuthorizationBearerSecurityScheme("token", &token) request, _ := request.NewRequest(method, url, nil, client) request = request.WithCookies(cookies) request = request.WithSecurityScheme(securityScheme) diff --git a/jwt/jwt.go b/jwt/jwt.go new file mode 100644 index 0000000..41d0b30 --- /dev/null +++ b/jwt/jwt.go @@ -0,0 +1,19 @@ +package jwt + +import ( + "regexp" + + "github.com/golang-jwt/jwt/v5" +) + +var jwtRegexp = `^[A-Za-z0-9-_=]+\.[A-Za-z0-9-_=]+\.?[A-Za-z0-9-_.+/=]*$` + +func IsJWT(token string) bool { + matched, err := regexp.MatchString(jwtRegexp, token) + if err != nil || !matched { + return false + } + + _, _, err = new(jwt.Parser).ParseUnverified(token, jwt.MapClaims{}) + return err == nil +} diff --git a/jwt/jwt_test.go b/jwt/jwt_test.go new file mode 100644 index 0000000..2a33b89 --- /dev/null +++ b/jwt/jwt_test.go @@ -0,0 +1,26 @@ +package jwt_test + +import ( + "testing" + + "github.com/cerberauth/vulnapi/jwt" +) + +func TestIsJWT(t *testing.T) { + tests := []struct { + token string + expected bool + }{ + {"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIn0.Gfx6VO9tcxwk6xqx9yYzSfebfeakZp5JYIgP_edcw_A", true}, + {"invalid.jwt.token", false}, + {"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIn0.", true}, + {"", false}, + } + + for _, test := range tests { + result := jwt.IsJWT(test.token) + if result != test.expected { + t.Errorf("IsJWT(%q) = %v; want %v", test.token, result, test.expected) + } + } +} diff --git a/openapi/openapi.go b/openapi/openapi.go index 55a3228..95fb9bf 100644 --- a/openapi/openapi.go +++ b/openapi/openapi.go @@ -4,8 +4,11 @@ import ( "net/url" "github.com/getkin/kin-openapi/openapi3" + "go.opentelemetry.io/otel" ) +var tracer = otel.Tracer("openapi") + type OpenAPI struct { baseUrl *url.URL diff --git a/openapi/security_scheme.go b/openapi/security_scheme.go index 2dcc29b..4fa067f 100644 --- a/openapi/security_scheme.go +++ b/openapi/security_scheme.go @@ -1,11 +1,13 @@ package openapi import ( + "context" "fmt" "strings" "github.com/cerberauth/vulnapi/internal/auth" "github.com/getkin/kin-openapi/openapi3" + "go.opentelemetry.io/otel/codes" ) const ( @@ -85,7 +87,14 @@ func mapOAuth2SchemeType(name string, scheme *openapi3.SecuritySchemeRef, securi return auth.NewOAuthSecurityScheme(name, securitySchemeValue, cfg), nil } +func mapAPIKeySchemeType(name string, scheme *openapi3.SecuritySchemeRef, securitySchemeValue *string) (auth.SecurityScheme, error) { + return auth.NewApiKeySecurityScheme(name, securitySchemeValue), nil +} + func (openapi *OpenAPI) SecuritySchemeMap(values *auth.SecuritySchemeValues) (auth.SecuritySchemesMap, error) { + _, span := tracer.Start(context.Background(), "SecuritySchemeMap") + defer span.End() + var err error var securitySchemeValue interface{} @@ -108,11 +117,15 @@ func (openapi *OpenAPI) SecuritySchemeMap(values *auth.SecuritySchemeValues) (au securitySchemes[name], err = mapHTTPSchemeType(name, scheme, value) case OAuth2SchemeType, OpenIdConnectSchemeType: securitySchemes[name], err = mapOAuth2SchemeType(name, scheme, value) + case ApiKeySchemeType: + securitySchemes[name], err = mapAPIKeySchemeType(name, scheme, value) default: err = NewErrUnsupportedSecuritySchemeType(schemeType) } if err != nil { + span.RecordError(err) + span.SetStatus(codes.Error, err.Error()) return nil, err } } diff --git a/report/curl_report.go b/report/curl_report.go index c7f0068..72a4eed 100644 --- a/report/curl_report.go +++ b/report/curl_report.go @@ -18,10 +18,10 @@ type CurlReport struct { Issues []*IssueReport `json:"issues" yaml:"issues"` } -func NewCurlReport(method string, url string, data interface{}, header http.Header, cookies []*http.Cookie, securitySchemes []auth.SecurityScheme) *CurlReport { +func NewCurlReport(method string, url string, data interface{}, header http.Header, cookies []*http.Cookie, securitySchemes []*auth.SecurityScheme) *CurlReport { reportSecuritySchemes := []OperationSecurityScheme{} - for _, ss := range securitySchemes { - reportSecuritySchemes = append(reportSecuritySchemes, NewOperationSecurityScheme(ss)) + for _, securityScheme := range securitySchemes { + reportSecuritySchemes = append(reportSecuritySchemes, NewOperationSecurityScheme(securityScheme)) } return &CurlReport{ diff --git a/report/issue_report.go b/report/issue_report.go index 1499c12..8ef51a1 100644 --- a/report/issue_report.go +++ b/report/issue_report.go @@ -35,7 +35,7 @@ type IssueReport struct { Status IssueReportStatus `json:"status" yaml:"status"` Operation *operation.Operation `json:"-" yaml:"-"` - SecurityScheme auth.SecurityScheme `json:"-" yaml:"-"` + SecurityScheme *auth.SecurityScheme `json:"-" yaml:"-"` } func NewIssueReport(issue Issue) *IssueReport { @@ -50,8 +50,8 @@ func (vr *IssueReport) WithOperation(operation *operation.Operation) *IssueRepor return vr } -func (vr *IssueReport) WithSecurityScheme(ss auth.SecurityScheme) *IssueReport { - vr.SecurityScheme = ss +func (vr *IssueReport) WithSecurityScheme(securityScheme *auth.SecurityScheme) *IssueReport { + vr.SecurityScheme = securityScheme return vr } diff --git a/report/report.go b/report/report.go index 871f732..4c39c80 100644 --- a/report/report.go +++ b/report/report.go @@ -11,18 +11,22 @@ import ( ) type OperationSecurityScheme struct { - Type auth.Type `json:"type" yaml:"type"` - Scheme auth.SchemeName `json:"scheme" yaml:"scheme"` - In *auth.SchemeIn `json:"in,omitempty" yaml:"in,omitempty"` - Name string `json:"name" yaml:"name"` + Type auth.Type `json:"type" yaml:"type"` + Scheme auth.SchemeName `json:"scheme" yaml:"scheme"` + In *auth.SchemeIn `json:"in" yaml:"in"` + TokenFormat *auth.TokenFormat `json:"token_format" yaml:"token_format"` + + Name string `json:"name" yaml:"name"` } -func NewOperationSecurityScheme(ss auth.SecurityScheme) OperationSecurityScheme { +func NewOperationSecurityScheme(securityScheme *auth.SecurityScheme) OperationSecurityScheme { return OperationSecurityScheme{ - Type: ss.GetType(), - Scheme: ss.GetScheme(), - In: ss.GetIn(), - Name: ss.GetName(), + Type: securityScheme.GetType(), + Scheme: securityScheme.GetScheme(), + In: securityScheme.GetIn(), + TokenFormat: securityScheme.GetTokenFormat(), + + Name: securityScheme.GetName(), } } diff --git a/report/reporter.go b/report/reporter.go index 7290521..617959e 100644 --- a/report/reporter.go +++ b/report/reporter.go @@ -29,7 +29,7 @@ func NewReporter() *Reporter { } } -func NewReporterWithCurl(method string, url string, data interface{}, header http.Header, cookies []*http.Cookie, securitySchemes []auth.SecurityScheme) *Reporter { +func NewReporterWithCurl(method string, url string, data interface{}, header http.Header, cookies []*http.Cookie, securitySchemes []*auth.SecurityScheme) *Reporter { return &Reporter{ Schema: reporterSchema, diff --git a/report/reporter_test.go b/report/reporter_test.go index 75c6c91..baf3807 100644 --- a/report/reporter_test.go +++ b/report/reporter_test.go @@ -19,8 +19,8 @@ func TestNewReporterWithCurl(t *testing.T) { header := http.Header{"Content-Type": []string{"application/json"}} cookies := []*http.Cookie{{Name: "session_id", Value: "abc123"}} token := "abc123" - securityScheme := auth.SecurityScheme(auth.NewAuthorizationBearerSecurityScheme("token", &token)) - securitySchemes := []auth.SecurityScheme{securityScheme} + securityScheme := auth.MustNewAuthorizationBearerSecurityScheme("token", &token) + securitySchemes := []*auth.SecurityScheme{securityScheme} reportSecuritySchemes := []report.OperationSecurityScheme{ { @@ -51,8 +51,8 @@ func TestNewReporterWithCurl_AddReport(t *testing.T) { header := http.Header{"Content-Type": []string{"application/json"}} cookies := []*http.Cookie{{Name: "session_id", Value: "abc123"}} token := "abc123" - securityScheme := auth.SecurityScheme(auth.NewAuthorizationBearerSecurityScheme("token", &token)) - securitySchemes := []auth.SecurityScheme{securityScheme} + securityScheme := auth.MustNewAuthorizationBearerSecurityScheme("token", &token) + securitySchemes := []*auth.SecurityScheme{securityScheme} reporter := report.NewReporterWithCurl(method, url, data, header, cookies, securitySchemes) diff --git a/scan/graphql/introspection_enabled/introspection_enabled.go b/scan/graphql/introspection_enabled/introspection_enabled.go index 42a82f7..3f0e62a 100644 --- a/scan/graphql/introspection_enabled/introspection_enabled.go +++ b/scan/graphql/introspection_enabled/introspection_enabled.go @@ -60,8 +60,8 @@ func newGetGraphqlIntrospectionRequest(client *request.Client, endpoint url.URL) return req, nil } -func ScanHandler(op *operation.Operation, securityScheme auth.SecurityScheme) (*report.ScanReport, error) { - securitySchemes := []auth.SecurityScheme{securityScheme} +func ScanHandler(op *operation.Operation, securityScheme *auth.SecurityScheme) (*report.ScanReport, error) { + securitySchemes := []*auth.SecurityScheme{securityScheme} vulnReport := report.NewIssueReport(issue).WithOperation(op).WithSecurityScheme(securityScheme) r := report.NewScanReport(GraphqlIntrospectionScanID, GraphqlIntrospectionScanName, op)