From 209e04b4d00113e1f7f7f4c39aa35d9bd0bb4e17 Mon Sep 17 00:00:00 2001 From: Emmanuel Gautier Date: Mon, 11 Nov 2024 19:46:18 +0100 Subject: [PATCH] feat: improve and add support for more openapi security schemes --- api/openapi.go | 9 +- cmd/scan/openapi.go | 9 +- internal/auth/api_key.go | 27 ++ internal/auth/api_key_test.go | 61 +++ internal/auth/bearer.go | 99 ++-- internal/auth/bearer_test.go | 196 ++------ internal/auth/jwt_bearer.go | 108 ----- internal/auth/jwt_bearer_test.go | 233 --------- internal/auth/no_auth.go | 61 +-- internal/auth/no_auth_test.go | 74 +-- internal/auth/oauth.go | 136 +++--- internal/auth/oauth_test.go | 238 +++------ internal/auth/scheme.go | 8 + internal/auth/security_scheme.go | 225 ++++++++- internal/auth/security_scheme_test.go | 454 +++++++++++++++++- internal/auth/type.go | 1 + internal/auth/uniq_name.go | 14 + internal/auth/uniq_name_test.go | 40 ++ internal/operation/operation.go | 30 +- internal/operation/operation_test.go | 8 +- internal/request/request.go | 10 +- internal/request/request_test.go | 8 +- internal/scan/scan_url.go | 2 +- jwt/jwt.go | 23 + jwt/jwt_test.go | 26 + jwt/jwt_writer.go | 7 +- openapi/openapi.go | 3 + openapi/operation.go | 4 +- openapi/param_test.go | 91 ++-- openapi/security_scheme.go | 41 +- openapi/security_scheme_test.go | 101 ++-- .../security_scheme_values.go | 2 +- .../security_scheme_values_test.go | 18 +- report.json | 1 + report/curl_report.go | 6 +- report/curl_report_test.go | 8 +- report/graphql_report.go | 2 +- report/issue_report.go | 6 +- report/issue_report_test.go | 2 +- report/openapi_report_test.go | 9 +- report/report.go | 22 +- report/report_test.go | 29 +- report/reporter.go | 4 +- report/reporter_test.go | 34 +- .../authentication_bypass.go | 8 +- .../authentication_bypass_test.go | 6 +- .../jwt/alg_none/alg_none.go | 50 +- .../jwt/alg_none/alg_none_test.go | 26 +- .../jwt/blank_secret/blank_secret.go | 27 +- .../jwt/blank_secret/blank_secret_test.go | 10 +- .../jwt/not_verified/not_verified.go | 33 +- .../jwt/not_verified/not_verified_test.go | 12 +- .../jwt/null_signature/null_signature.go | 27 +- .../jwt/null_signature/null_signature_test.go | 8 +- .../jwt/weak_secret/weak_secret.go | 22 +- .../jwt/weak_secret/weak_secret_test.go | 14 +- .../accept_unauthenticated_operation.go | 5 +- .../accept_unauthenticated_operation_test.go | 4 +- .../discoverable_graphql.go | 2 +- .../discoverable_graphql_test.go | 4 +- .../discoverable_openapi.go | 2 +- .../discoverable_openapi_test.go | 4 +- scan/discover/fingerprint/fingerprint.go | 4 +- scan/discover/fingerprint/fingerprint_test.go | 18 +- scan/discover/utils.go | 10 +- scan/discover/utils_test.go | 6 +- .../introspection_enabled.go | 8 +- .../introspection_enabled_test.go | 8 +- .../http_cookies/http_cookies.go | 4 +- .../http_cookies/http_cookies_test.go | 14 +- .../http_headers/http_headers.go | 4 +- .../http_headers/http_headers_test.go | 18 +- .../http_method_override.go | 15 +- .../http_method_override_test.go | 20 +- .../http_trace/http_trace_method.go | 4 +- .../http_trace/http_trace_method_test.go | 4 +- .../http_track/http_track_method.go | 4 +- .../http_track/http_track_method_test.go | 4 +- scan/operation_scan.go | 2 +- scan/operation_scan_test.go | 2 +- scan/scan_test.go | 18 +- scenario/graphql.go | 6 +- scenario/graphql_test.go | 10 +- scenario/openapi.go | 3 +- scenario/openapi_test.go | 18 +- scenario/url.go | 6 +- scenario/url_test.go | 8 +- scenario/utils.go | 10 +- 88 files changed, 1542 insertions(+), 1440 deletions(-) create mode 100644 internal/auth/api_key.go create mode 100644 internal/auth/api_key_test.go delete mode 100644 internal/auth/jwt_bearer.go delete mode 100644 internal/auth/jwt_bearer_test.go create mode 100644 internal/auth/uniq_name.go create mode 100644 internal/auth/uniq_name_test.go create mode 100644 jwt/jwt.go create mode 100644 jwt/jwt_test.go rename {internal/auth => openapi}/security_scheme_values.go (98%) rename {internal/auth => openapi}/security_scheme_values_test.go (72%) create mode 100644 report.json diff --git a/api/openapi.go b/api/openapi.go index a23e9232..c500d3cc 100644 --- a/api/openapi.go +++ b/api/openapi.go @@ -5,7 +5,6 @@ import ( "net/http" "github.com/cerberauth/vulnapi/internal/analytics" - "github.com/cerberauth/vulnapi/internal/auth" "github.com/cerberauth/vulnapi/internal/request" "github.com/cerberauth/vulnapi/openapi" "github.com/cerberauth/vulnapi/scan" @@ -33,7 +32,7 @@ func (h *Handler) ScanOpenAPI(ctx *gin.Context) { traceCtx, span := tracer.Start(ctx, "Scan OpenAPI") defer span.End() - openapi, err := openapi.LoadFromData(traceCtx, []byte(form.Schema)) + doc, err := openapi.LoadFromData(traceCtx, []byte(form.Schema)) if err != nil { span.RecordError(err) span.SetStatus(codes.Error, err.Error()) @@ -41,7 +40,7 @@ func (h *Handler) ScanOpenAPI(ctx *gin.Context) { return } - if err := openapi.Validate(ctx); err != nil { + if err := doc.Validate(ctx); err != nil { span.RecordError(err) span.SetStatus(codes.Error, err.Error()) ctx.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) @@ -59,8 +58,8 @@ func (h *Handler) ScanOpenAPI(ctx *gin.Context) { values[key] = &value.Value } } - securitySchemesValues := auth.NewSecuritySchemeValues(values) - s, err := scenario.NewOpenAPIScan(openapi, securitySchemesValues, client, &scan.ScanOptions{ + securitySchemesValues := openapi.NewSecuritySchemeValues(values) + s, err := scenario.NewOpenAPIScan(doc, securitySchemesValues, client, &scan.ScanOptions{ IncludeScans: form.Opts.Scans, ExcludeScans: form.Opts.ExcludeScans, }) diff --git a/cmd/scan/openapi.go b/cmd/scan/openapi.go index 8c78fdd3..bfc15e7b 100644 --- a/cmd/scan/openapi.go +++ b/cmd/scan/openapi.go @@ -6,7 +6,6 @@ import ( "os" "github.com/cerberauth/vulnapi/internal/analytics" - "github.com/cerberauth/vulnapi/internal/auth" internalCmd "github.com/cerberauth/vulnapi/internal/cmd" "github.com/cerberauth/vulnapi/internal/request" "github.com/cerberauth/vulnapi/openapi" @@ -47,14 +46,14 @@ func NewOpenAPIScanCmd() (scanCmd *cobra.Command) { ctx, span := tracer.Start(cmd.Context(), "Scan OpenAPI") defer span.End() - openapi, err := openapi.LoadOpenAPI(ctx, openapiUrlOrPath) + doc, err := openapi.LoadOpenAPI(ctx, openapiUrlOrPath) if err != nil { span.RecordError(err) span.SetStatus(codes.Error, err.Error()) log.Fatal(err) } - if err := openapi.Validate(ctx); err != nil { + if err := doc.Validate(ctx); err != nil { span.RecordError(err) span.SetStatus(codes.Error, err.Error()) log.Fatal(err) @@ -69,7 +68,7 @@ func NewOpenAPIScanCmd() (scanCmd *cobra.Command) { for key, value := range securitySchemesValueArg { values[key] = &value } - securitySchemesValues := auth.NewSecuritySchemeValues(values).WithDefault(validToken) + securitySchemesValues := openapi.NewSecuritySchemeValues(values).WithDefault(validToken) client, err := internalCmd.NewHTTPClientFromArgs(internalCmd.GetRateLimit(), internalCmd.GetProxy(), internalCmd.GetHeaders(), internalCmd.GetCookies()) if err != nil { @@ -79,7 +78,7 @@ func NewOpenAPIScanCmd() (scanCmd *cobra.Command) { } request.SetDefaultClient(client) - s, err := scenario.NewOpenAPIScan(openapi, securitySchemesValues, client, &scan.ScanOptions{ + s, err := scenario.NewOpenAPIScan(doc, securitySchemesValues, client, &scan.ScanOptions{ IncludeScans: internalCmd.GetIncludeScans(), ExcludeScans: internalCmd.GetExcludeScans(), }) diff --git a/internal/auth/api_key.go b/internal/auth/api_key.go new file mode 100644 index 00000000..21fdf734 --- /dev/null +++ b/internal/auth/api_key.go @@ -0,0 +1,27 @@ +package auth + +func NewAPIKeySecurityScheme(name string, in SchemeIn, value *string) (*SecurityScheme, error) { + tokenFormat := NoneTokenFormat + securityScheme, err := NewSecurityScheme(name, nil, ApiKey, NoneScheme, &in, &tokenFormat) + if err != nil { + return nil, err + } + + if value != nil && *value != "" { + err = securityScheme.SetValidValue(*value) + if err != nil { + return nil, err + } + } + + return securityScheme, nil +} + +func MustNewAPIKeySecurityScheme(name string, in SchemeIn, value *string) *SecurityScheme { + securityScheme, err := NewAPIKeySecurityScheme(name, in, value) + if err != nil { + panic(err) + } + + return securityScheme +} diff --git a/internal/auth/api_key_test.go b/internal/auth/api_key_test.go new file mode 100644 index 00000000..d8f270c6 --- /dev/null +++ b/internal/auth/api_key_test.go @@ -0,0 +1,61 @@ +package auth_test + +import ( + "testing" + + "github.com/cerberauth/vulnapi/internal/auth" + "github.com/stretchr/testify/assert" +) + +func TestNewAPIKeySecurityScheme(t *testing.T) { + name := "token" + value := "abc123" + tokenFormat := auth.NoneTokenFormat + + securityScheme, err := auth.NewAPIKeySecurityScheme(name, auth.InHeader, &value) + + assert.NoError(t, err) + assert.Equal(t, auth.ApiKey, securityScheme.GetType()) + assert.Equal(t, auth.NoneScheme, securityScheme.GetScheme()) + assert.Equal(t, auth.InHeader, *securityScheme.GetIn()) + assert.Equal(t, &tokenFormat, securityScheme.GetTokenFormat()) + assert.Equal(t, name, securityScheme.GetName()) + assert.Equal(t, value, securityScheme.GetValidValue().(string)) + assert.Equal(t, nil, securityScheme.GetAttackValue()) +} + +func TestTestNewAPIKeySecurityScheme_WhenNilValue(t *testing.T) { + name := "token" + + securityScheme, err := auth.NewAPIKeySecurityScheme(name, auth.InHeader, nil) + + assert.NoError(t, err) + assert.Equal(t, nil, securityScheme.GetValidValue()) + assert.Equal(t, nil, securityScheme.GetAttackValue()) +} + +func TestNewAuthorizationBearerSecurityScheme_WhenInCooke(t *testing.T) { + name := "token" + value := "abc123" + + securityScheme, err := auth.NewAPIKeySecurityScheme(name, auth.InQuery, &value) + + assert.NoError(t, err) + assert.Equal(t, auth.InQuery, *securityScheme.GetIn()) +} + +func TestMustNewAPIKeySecurityScheme(t *testing.T) { + name := "token" + value := "abc123" + tokenFormat := auth.NoneTokenFormat + + securityScheme := auth.MustNewAPIKeySecurityScheme(name, auth.InHeader, &value) + + assert.Equal(t, auth.ApiKey, securityScheme.GetType()) + assert.Equal(t, auth.NoneScheme, securityScheme.GetScheme()) + assert.Equal(t, auth.InHeader, *securityScheme.GetIn()) + assert.Equal(t, &tokenFormat, securityScheme.GetTokenFormat()) + assert.Equal(t, name, securityScheme.GetName()) + assert.Equal(t, value, securityScheme.GetValidValue().(string)) + assert.Equal(t, nil, securityScheme.GetAttackValue()) +} diff --git a/internal/auth/bearer.go b/internal/auth/bearer.go index 923ebd45..06206f37 100644 --- a/internal/auth/bearer.go +++ b/internal/auth/bearer.go @@ -1,86 +1,41 @@ package auth import ( - "fmt" - "net/http" + "github.com/cerberauth/vulnapi/jwt" ) -type BearerSecurityScheme struct { - Type Type `json:"type" yaml:"type"` - Scheme SchemeName `json:"scheme" yaml:"scheme"` - In SchemeIn `json:"in" yaml:"in"` - Name string `json:"name" yaml:"name"` - ValidValue *string `json:"-" yaml:"-"` - AttackValue string `json:"-" yaml:"-"` -} - -var _ SecurityScheme = (*BearerSecurityScheme)(nil) - -func NewAuthorizationBearerSecurityScheme(name string, value *string) *BearerSecurityScheme { - return &BearerSecurityScheme{ - Type: HttpType, - Scheme: BearerScheme, - In: InHeader, - Name: name, - ValidValue: value, - AttackValue: "", - } -} - -func (ss *BearerSecurityScheme) GetType() Type { - return ss.Type -} - -func (ss *BearerSecurityScheme) GetScheme() SchemeName { - return ss.Scheme -} - -func (ss *BearerSecurityScheme) GetIn() *SchemeIn { - return &ss.In -} - -func (ss *BearerSecurityScheme) GetName() string { - return ss.Name -} - -func (ss *BearerSecurityScheme) GetHeaders() http.Header { - header := http.Header{} - attackValue := ss.GetAttackValue().(string) - if attackValue == "" && ss.HasValidValue() { - attackValue = ss.GetValidValue().(string) +func NewAuthorizationBearerSecurityScheme(name string, value *string) (*SecurityScheme, error) { + in := InHeader + securityScheme, err := NewSecurityScheme(name, nil, HttpType, BearerScheme, &in, nil) + if err != nil { + return nil, err } - if attackValue != "" { - header.Set(AuthorizationHeader, fmt.Sprintf("%s %s", BearerPrefix, attackValue)) + if value != nil && *value != "" { + err = securityScheme.SetValidValue(*value) + if err != nil { + return nil, err + } + + var tokenFormat TokenFormat + if jwt.IsJWT(*value) { + tokenFormat = JWTTokenFormat + } else { + tokenFormat = NoneTokenFormat + } + if err = securityScheme.SetTokenFormat(tokenFormat); err != nil { + return nil, err + } } - return header -} - -func (ss *BearerSecurityScheme) GetCookies() []*http.Cookie { - return []*http.Cookie{} -} - -func (ss *BearerSecurityScheme) HasValidValue() bool { - return ss.ValidValue != nil && *ss.ValidValue != "" + return securityScheme, nil } -func (ss *BearerSecurityScheme) GetValidValue() interface{} { - if !ss.HasValidValue() { - return nil +func MustNewAuthorizationBearerSecurityScheme(name string, value *string) *SecurityScheme { + securityScheme, err := NewAuthorizationBearerSecurityScheme(name, value) + if err != nil { + panic(err) } - return *ss.ValidValue -} - -func (ss *BearerSecurityScheme) GetValidValueWriter() interface{} { - return nil -} - -func (ss *BearerSecurityScheme) SetAttackValue(v interface{}) { - ss.AttackValue = v.(string) -} - -func (ss *BearerSecurityScheme) GetAttackValue() interface{} { - return ss.AttackValue + return securityScheme } diff --git a/internal/auth/bearer_test.go b/internal/auth/bearer_test.go index d664a7b5..32cd8e35 100644 --- a/internal/auth/bearer_test.go +++ b/internal/auth/bearer_test.go @@ -1,7 +1,6 @@ package auth_test import ( - "net/http" "testing" "github.com/cerberauth/vulnapi/internal/auth" @@ -11,181 +10,58 @@ import ( func TestNewAuthorizationBearerSecurityScheme(t *testing.T) { name := "token" value := "abc123" + tokenFormat := auth.NoneTokenFormat - ss := auth.NewAuthorizationBearerSecurityScheme(name, &value) + securityScheme, err := auth.NewAuthorizationBearerSecurityScheme(name, &value) - assert.Equal(t, auth.HttpType, ss.Type) - assert.Equal(t, auth.BearerScheme, ss.Scheme) - assert.Equal(t, auth.InHeader, ss.In) - assert.Equal(t, name, ss.Name) - assert.Equal(t, &value, ss.ValidValue) - assert.Equal(t, "", ss.AttackValue) + assert.NoError(t, err) + assert.Equal(t, auth.HttpType, securityScheme.GetType()) + assert.Equal(t, auth.BearerScheme, securityScheme.GetScheme()) + assert.Equal(t, auth.InHeader, *securityScheme.GetIn()) + assert.Equal(t, &tokenFormat, securityScheme.GetTokenFormat()) + assert.Equal(t, name, securityScheme.GetName()) + assert.Equal(t, value, securityScheme.GetValidValue()) + assert.Equal(t, nil, securityScheme.GetAttackValue()) } -func TestBearerSecurityScheme_GetScheme(t *testing.T) { +func TestNewAuthorizationBearerSecurityScheme_WhenNilValue(t *testing.T) { name := "token" - value := "abc123" - ss := auth.NewAuthorizationBearerSecurityScheme(name, &value) - - scheme := ss.GetScheme() - - assert.Equal(t, auth.BearerScheme, scheme) -} - -func TestBearerSecurityScheme_GetType(t *testing.T) { - name := "token" - value := "abc123" - ss := auth.NewAuthorizationBearerSecurityScheme(name, &value) - - scheme := ss.GetType() - - assert.Equal(t, auth.HttpType, scheme) -} - -func TestBearerSecurityScheme_GetIn(t *testing.T) { - name := "token" - value := "abc123" - ss := auth.NewAuthorizationBearerSecurityScheme(name, &value) - - scheme := ss.GetIn() - - assert.Equal(t, auth.InHeader, *scheme) -} - -func TestBearerSecurityScheme_GetName(t *testing.T) { - name := "token" - value := "abc123" - ss := auth.NewAuthorizationBearerSecurityScheme(name, &value) - - scheme := ss.GetName() - - assert.Equal(t, name, scheme) -} - -func TestBearerSecurityScheme_GetHeaders(t *testing.T) { - name := "token" - value := "abc123" - attackValue := "xyz789" - ss := auth.NewAuthorizationBearerSecurityScheme(name, &value) - ss.SetAttackValue(attackValue) - - headers := ss.GetHeaders() - - assert.Equal(t, http.Header{ - "Authorization": []string{"Bearer xyz789"}, - }, headers) -} - -func TestBearerSecurityScheme_GetHeaders_WhenNoAttackValue(t *testing.T) { - name := "token" - value := "abc123" - ss := auth.NewAuthorizationBearerSecurityScheme(name, &value) - - headers := ss.GetHeaders() - - assert.Equal(t, http.Header{ - "Authorization": []string{"Bearer abc123"}, - }, headers) -} - -func TestBearerSecurityScheme_GetHeaders_WhenNoAttackAndValidValue(t *testing.T) { - name := "token" - ss := auth.NewAuthorizationBearerSecurityScheme(name, nil) - - headers := ss.GetHeaders() - - assert.Equal(t, http.Header{}, headers) -} -func TestBearerSecurityScheme_GetCookies(t *testing.T) { - name := "token" - value := "abc123" - - ss := auth.NewAuthorizationBearerSecurityScheme(name, &value) - - cookies := ss.GetCookies() - - assert.Empty(t, cookies) -} - -func TestBearerSecurityScheme_HasValidValue_WhenValueIsNil(t *testing.T) { - name := "token" - value := "abc123" - ss := auth.NewAuthorizationBearerSecurityScheme(name, &value) - - result := ss.HasValidValue() - - assert.True(t, result) -} - -func TestBearerSecurityScheme_HasValidValueFalse_WhenValueIsEmptyString(t *testing.T) { - name := "token" - value := "" - ss := auth.NewAuthorizationBearerSecurityScheme(name, &value) - - result := ss.HasValidValue() - - assert.False(t, result) -} - -func TestBearerSecurityScheme_GetValidValueNil(t *testing.T) { - name := "token" - ss := auth.NewAuthorizationBearerSecurityScheme(name, nil) - - validValue := ss.GetValidValue() + securityScheme, err := auth.NewAuthorizationBearerSecurityScheme(name, nil) - assert.Equal(t, nil, validValue) + assert.NoError(t, err) + assert.Nil(t, securityScheme.GetTokenFormat()) + assert.Equal(t, nil, securityScheme.GetValidValue()) + assert.Equal(t, nil, securityScheme.GetAttackValue()) } -func TestBearerSecurityScheme_HasValidValueFalse(t *testing.T) { +func TestNewAuthorizationBearerSecurityScheme_WhenJWTFormatValue(t *testing.T) { name := "token" - ss := auth.NewAuthorizationBearerSecurityScheme(name, nil) + value := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.e30.ufhxDTmrs4T5MSsvT6lsb3OpdWi5q8O31VX7TgrVamA" + tokenFormat := auth.JWTTokenFormat - result := ss.HasValidValue() + securityScheme, err := auth.NewAuthorizationBearerSecurityScheme(name, &value) - assert.False(t, result) + assert.NoError(t, err) + assert.Equal(t, auth.HttpType, securityScheme.GetType()) + assert.Equal(t, auth.BearerScheme, securityScheme.GetScheme()) + assert.Equal(t, &tokenFormat, securityScheme.GetTokenFormat()) + assert.Equal(t, value, securityScheme.GetValidValue()) + assert.Equal(t, nil, securityScheme.GetAttackValue()) } -func TestBearerSecurityScheme_GetValidValue(t *testing.T) { +func TestMustNewAuthorizationBearerSecurityScheme(t *testing.T) { name := "token" value := "abc123" - ss := auth.NewAuthorizationBearerSecurityScheme(name, &value) - - validValue := ss.GetValidValue() - - assert.Equal(t, value, validValue) -} - -func TestBearerSecurityScheme_GetValidValueWriter(t *testing.T) { - name := "token" - value := "abc123" - ss := auth.NewAuthorizationBearerSecurityScheme(name, &value) - writer := ss.GetValidValueWriter() - - assert.Equal(t, nil, writer) -} - -func TestBearerSecurityScheme_SetAttackValue(t *testing.T) { - name := "token" - value := "abc123" - - ss := auth.NewAuthorizationBearerSecurityScheme(name, &value) - - attackValue := "xyz789" - ss.SetAttackValue(attackValue) - - assert.Equal(t, attackValue, ss.AttackValue) -} - -func TestBearerSecurityScheme_GetAttackValue(t *testing.T) { - name := "token" - value := "abc123" - ss := auth.NewAuthorizationBearerSecurityScheme(name, &value) - - attackValue := "xyz789" - ss.SetAttackValue(attackValue) + tokenFormat := auth.NoneTokenFormat - result := ss.GetAttackValue() + securityScheme := auth.MustNewAuthorizationBearerSecurityScheme(name, &value) - assert.Equal(t, attackValue, result) + assert.Equal(t, auth.HttpType, securityScheme.GetType()) + assert.Equal(t, auth.BearerScheme, securityScheme.GetScheme()) + assert.Equal(t, auth.InHeader, *securityScheme.GetIn()) + assert.Equal(t, &tokenFormat, securityScheme.GetTokenFormat()) + assert.Equal(t, name, securityScheme.GetName()) + assert.Equal(t, value, securityScheme.GetValidValue()) + assert.Equal(t, nil, securityScheme.GetAttackValue()) } diff --git a/internal/auth/jwt_bearer.go b/internal/auth/jwt_bearer.go deleted file mode 100644 index 82fe6c28..00000000 --- a/internal/auth/jwt_bearer.go +++ /dev/null @@ -1,108 +0,0 @@ -package auth - -import ( - "fmt" - "net/http" - - "github.com/cerberauth/vulnapi/jwt" -) - -type JWTBearerSecurityScheme struct { - Type Type `json:"type" yaml:"type"` - Scheme SchemeName `json:"scheme" yaml:"scheme"` - In SchemeIn `json:"in" yaml:"in"` - Name string `json:"name" yaml:"name"` - ValidValue *string `json:"-" yaml:"-"` - AttackValue string `json:"-" yaml:"-"` - - JWTWriter *jwt.JWTWriter `json:"-" yaml:"-"` -} - -var _ SecurityScheme = (*JWTBearerSecurityScheme)(nil) - -func NewAuthorizationJWTBearerSecurityScheme(name string, value *string) (*JWTBearerSecurityScheme, error) { - var jwtWriter *jwt.JWTWriter - if value != nil { - var err error - if jwtWriter, err = jwt.NewJWTWriter(*value); err != nil { - return nil, err - } - } - - return &JWTBearerSecurityScheme{ - Type: HttpType, - Scheme: BearerScheme, - In: InHeader, - Name: name, - ValidValue: value, - AttackValue: "", - - JWTWriter: jwtWriter, - }, nil -} - -func MustNewAuthorizationJWTBearerSecurityScheme(name string, value *string) *JWTBearerSecurityScheme { - scheme, err := NewAuthorizationJWTBearerSecurityScheme(name, value) - if err != nil { - panic(err) - } - return scheme -} - -func (ss *JWTBearerSecurityScheme) GetType() Type { - return ss.Type -} - -func (ss *JWTBearerSecurityScheme) GetScheme() SchemeName { - return ss.Scheme -} - -func (ss *JWTBearerSecurityScheme) GetIn() *SchemeIn { - return &ss.In -} - -func (ss *JWTBearerSecurityScheme) GetName() string { - return ss.Name -} - -func (ss *JWTBearerSecurityScheme) GetHeaders() http.Header { - header := http.Header{} - attackValue := ss.GetAttackValue().(string) - if attackValue == "" && ss.HasValidValue() { - attackValue = ss.GetValidValue().(string) - } - - if attackValue != "" { - header.Set(AuthorizationHeader, fmt.Sprintf("%s %s", BearerPrefix, attackValue)) - } - - return header -} - -func (ss *JWTBearerSecurityScheme) GetCookies() []*http.Cookie { - return []*http.Cookie{} -} - -func (ss *JWTBearerSecurityScheme) HasValidValue() bool { - return ss.ValidValue != nil -} - -func (ss *JWTBearerSecurityScheme) GetValidValue() interface{} { - if !ss.HasValidValue() { - return nil - } - - return *ss.ValidValue -} - -func (ss *JWTBearerSecurityScheme) GetValidValueWriter() interface{} { - return ss.JWTWriter -} - -func (ss *JWTBearerSecurityScheme) SetAttackValue(v interface{}) { - ss.AttackValue = v.(string) -} - -func (ss *JWTBearerSecurityScheme) GetAttackValue() interface{} { - return ss.AttackValue -} diff --git a/internal/auth/jwt_bearer_test.go b/internal/auth/jwt_bearer_test.go deleted file mode 100644 index 12be6610..00000000 --- a/internal/auth/jwt_bearer_test.go +++ /dev/null @@ -1,233 +0,0 @@ -package auth_test - -import ( - "net/http" - "testing" - - "github.com/cerberauth/vulnapi/internal/auth" - "github.com/cerberauth/vulnapi/jwt" - "github.com/stretchr/testify/assert" -) - -func TestNewAuthorizationJWTBearerSecurityScheme(t *testing.T) { - name := "token" - value := jwt.FakeJWT - ss, err := auth.NewAuthorizationJWTBearerSecurityScheme(name, &value) - - assert.NoError(t, err) - assert.Equal(t, auth.HttpType, ss.Type) - assert.Equal(t, auth.BearerScheme, ss.Scheme) - assert.Equal(t, auth.InHeader, ss.In) - assert.Equal(t, name, ss.Name) - assert.Equal(t, &value, ss.ValidValue) - assert.Equal(t, "", ss.AttackValue) -} - -func TestNewAuthorizationJWTBearerSecuritySchemeWithInvalidJWT(t *testing.T) { - name := "token" - value := "abc123" - _, err := auth.NewAuthorizationJWTBearerSecurityScheme(name, &value) - - assert.Error(t, err) -} - -func TestMustNewAuthorizationJWTBearerSecurityScheme(t *testing.T) { - t.Run("ValidJWT", func(t *testing.T) { - name := "token" - value := jwt.FakeJWT - ss := auth.MustNewAuthorizationJWTBearerSecurityScheme(name, &value) - - assert.NotNil(t, ss) - assert.Equal(t, auth.HttpType, ss.Type) - assert.Equal(t, auth.BearerScheme, ss.Scheme) - assert.Equal(t, auth.InHeader, ss.In) - assert.Equal(t, name, ss.Name) - assert.Equal(t, &value, ss.ValidValue) - assert.Equal(t, "", ss.AttackValue) - }) - - t.Run("InvalidJWT", func(t *testing.T) { - name := "token" - value := "abc123" - assert.Panics(t, func() { - auth.MustNewAuthorizationJWTBearerSecurityScheme(name, &value) - }) - }) - - t.Run("NilValue", func(t *testing.T) { - name := "token" - ss := auth.MustNewAuthorizationJWTBearerSecurityScheme(name, nil) - - assert.NotNil(t, ss) - assert.Equal(t, auth.HttpType, ss.Type) - assert.Equal(t, auth.BearerScheme, ss.Scheme) - assert.Equal(t, auth.InHeader, ss.In) - assert.Equal(t, name, ss.Name) - assert.Nil(t, ss.ValidValue) - assert.Equal(t, "", ss.AttackValue) - }) -} - -func TestAuthorizationJWTBearerSecurityScheme_GetScheme(t *testing.T) { - name := "token" - value := jwt.FakeJWT - ss, err := auth.NewAuthorizationJWTBearerSecurityScheme(name, &value) - - scheme := ss.GetScheme() - - assert.NoError(t, err) - assert.Equal(t, auth.BearerScheme, scheme) -} - -func TestAuthorizationJWTBearerSecurityScheme_GetType(t *testing.T) { - name := "token" - value := jwt.FakeJWT - ss, err := auth.NewAuthorizationJWTBearerSecurityScheme(name, &value) - - scheme := ss.GetType() - - assert.NoError(t, err) - assert.Equal(t, auth.HttpType, scheme) -} - -func TestAuthorizationJWTBearerSecurityScheme_GetIn(t *testing.T) { - name := "token" - value := jwt.FakeJWT - ss, err := auth.NewAuthorizationJWTBearerSecurityScheme(name, &value) - - scheme := ss.GetIn() - - assert.NoError(t, err) - assert.Equal(t, auth.InHeader, *scheme) -} - -func TestAuthorizationJWTBearerSecurityScheme_GetName(t *testing.T) { - name := "token" - value := jwt.FakeJWT - ss, err := auth.NewAuthorizationJWTBearerSecurityScheme(name, &value) - - scheme := ss.GetName() - - assert.NoError(t, err) - assert.Equal(t, name, scheme) -} - -func TestJWTBearerSecurityScheme_GetHeaders(t *testing.T) { - name := "token" - value := jwt.FakeJWT - attackValue := "xyz789" - ss, err := auth.NewAuthorizationJWTBearerSecurityScheme(name, &value) - ss.SetAttackValue(attackValue) - - headers := ss.GetHeaders() - - assert.NoError(t, err) - assert.Equal(t, http.Header{ - "Authorization": []string{"Bearer xyz789"}, - }, headers) -} - -func TestJWTBearerSecurityScheme_GetHeaders_WhenNoAttackValue(t *testing.T) { - name := "token" - value := jwt.FakeJWT - ss, err := auth.NewAuthorizationJWTBearerSecurityScheme(name, &value) - - headers := ss.GetHeaders() - - assert.NoError(t, err) - assert.Equal(t, http.Header{ - "Authorization": []string{"Bearer " + jwt.FakeJWT}, - }, headers) -} - -func TestJWTBearerSecurityScheme_GetHeaders_WhenNoAttackAndValidValue(t *testing.T) { - name := "token" - ss, err := auth.NewAuthorizationJWTBearerSecurityScheme(name, nil) - - headers := ss.GetHeaders() - - assert.NoError(t, err) - assert.Equal(t, http.Header{}, headers) -} - -func TestJWTBearerSecurityScheme_GetCookies(t *testing.T) { - name := "token" - value := jwt.FakeJWT - ss, err := auth.NewAuthorizationJWTBearerSecurityScheme(name, &value) - cookies := ss.GetCookies() - - assert.NoError(t, err) - assert.Empty(t, cookies) -} - -func TestJWTBearerSecurityScheme_HasValidValue(t *testing.T) { - name := "token" - value := jwt.FakeJWT - ss, err := auth.NewAuthorizationJWTBearerSecurityScheme(name, &value) - hasValidValue := ss.HasValidValue() - - assert.NoError(t, err) - assert.True(t, hasValidValue) -} - -func TestJWTBearerSecurityScheme_HasValidValue_WhenNoValue(t *testing.T) { - name := "token" - ss, err := auth.NewAuthorizationJWTBearerSecurityScheme(name, nil) - hasValidValue := ss.HasValidValue() - - assert.NoError(t, err) - assert.False(t, hasValidValue) -} - -func TestJWTBearerSecurityScheme_GetValidValue(t *testing.T) { - name := "token" - value := jwt.FakeJWT - ss, err := auth.NewAuthorizationJWTBearerSecurityScheme(name, &value) - validValue := ss.GetValidValue() - - assert.NoError(t, err) - assert.Equal(t, value, validValue) -} - -func TestJWTBearerSecurityScheme_GetValidValue_WhenNoValue(t *testing.T) { - name := "token" - ss, err := auth.NewAuthorizationJWTBearerSecurityScheme(name, nil) - validValue := ss.GetValidValue() - - assert.NoError(t, err) - assert.Nil(t, validValue) -} - -func TestJWTBearerSecurityScheme_GetValidValueWriter(t *testing.T) { - name := "token" - value := jwt.FakeJWT - ss, err := auth.NewAuthorizationJWTBearerSecurityScheme(name, &value) - writer := ss.GetValidValueWriter() - - assert.NoError(t, err) - assert.Equal(t, ss.JWTWriter, writer) -} - -func TestJWTBearerSecurityScheme_SetAttackValue(t *testing.T) { - name := "token" - value := jwt.FakeJWT - ss, err := auth.NewAuthorizationJWTBearerSecurityScheme(name, &value) - attackValue := "xyz789" - ss.SetAttackValue(attackValue) - - assert.NoError(t, err) - assert.Equal(t, attackValue, ss.AttackValue) -} - -func TestJWTBearerSecurityScheme_GetAttackValue(t *testing.T) { - name := "token" - value := jwt.FakeJWT - ss, err := auth.NewAuthorizationJWTBearerSecurityScheme(name, &value) - attackValue := "xyz789" - ss.SetAttackValue(attackValue) - - result := ss.GetAttackValue() - - assert.NoError(t, err) - assert.Equal(t, attackValue, result) -} diff --git a/internal/auth/no_auth.go b/internal/auth/no_auth.go index 364d594d..42c9634d 100644 --- a/internal/auth/no_auth.go +++ b/internal/auth/no_auth.go @@ -1,61 +1,16 @@ package auth -import "net/http" +var defaultName = "no_auth" -type NoAuthSecurityScheme struct { - Name string `json:"name" yaml:"name"` - Type Type `json:"type" yaml:"type"` - Scheme SchemeName `json:"scheme" yaml:"scheme"` +func NewNoAuthSecurityScheme() (*SecurityScheme, error) { + return NewSecurityScheme(defaultName, nil, None, NoneScheme, nil, nil) } -var _ SecurityScheme = (*NoAuthSecurityScheme)(nil) - -func NewNoAuthSecurityScheme() *NoAuthSecurityScheme { - return &NoAuthSecurityScheme{ - Name: "", - Type: None, - Scheme: NoneScheme, +func MustNewNoAuthSecurityScheme() *SecurityScheme { + scheme, err := NewNoAuthSecurityScheme() + if err != nil { + panic(err) } -} - -func (ss *NoAuthSecurityScheme) GetType() Type { - return ss.Type -} - -func (ss *NoAuthSecurityScheme) GetScheme() SchemeName { - return ss.Scheme -} - -func (ss *NoAuthSecurityScheme) GetIn() *SchemeIn { - return nil -} - -func (ss *NoAuthSecurityScheme) GetName() string { - return ss.Name -} - -func (ss *NoAuthSecurityScheme) GetHeaders() http.Header { - return http.Header{} -} - -func (ss *NoAuthSecurityScheme) GetCookies() []*http.Cookie { - return []*http.Cookie{} -} - -func (ss *NoAuthSecurityScheme) HasValidValue() bool { - return false -} - -func (ss *NoAuthSecurityScheme) GetValidValue() interface{} { - return "" -} - -func (ss *NoAuthSecurityScheme) GetValidValueWriter() interface{} { - return "" -} - -func (ss *NoAuthSecurityScheme) SetAttackValue(v interface{}) {} -func (ss *NoAuthSecurityScheme) GetAttackValue() interface{} { - return nil + return scheme } diff --git a/internal/auth/no_auth_test.go b/internal/auth/no_auth_test.go index ed0abdad..ffcafba0 100644 --- a/internal/auth/no_auth_test.go +++ b/internal/auth/no_auth_test.go @@ -8,68 +8,22 @@ import ( ) func TestNewNoAuthSecurityScheme(t *testing.T) { - ss := auth.NewNoAuthSecurityScheme() - assert.NotNil(t, ss) -} - -func TestNoAuthSecurityScheme_GetScheme(t *testing.T) { - ss := auth.NewNoAuthSecurityScheme() - - scheme := ss.GetScheme() - - assert.Equal(t, auth.NoneScheme, scheme) -} - -func TestNoAuthSecurityScheme_GetType(t *testing.T) { - ss := auth.NewNoAuthSecurityScheme() - - scheme := ss.GetType() - - assert.Equal(t, auth.None, scheme) -} - -func TestNoAuthSecurityScheme_GetName(t *testing.T) { - ss := auth.NewNoAuthSecurityScheme() + securityScheme, err := auth.NewNoAuthSecurityScheme() - scheme := ss.GetName() - - assert.Equal(t, "", scheme) -} - -func TestNoAuthSecurityScheme_GetHeaders(t *testing.T) { - ss := &auth.NoAuthSecurityScheme{} - headers := ss.GetHeaders() - assert.NotNil(t, headers) - assert.Empty(t, headers) -} - -func TestNoAuthSecurityScheme_GetCookies(t *testing.T) { - ss := &auth.NoAuthSecurityScheme{} - cookies := ss.GetCookies() - assert.NotNil(t, cookies) - assert.Empty(t, cookies) + assert.NoError(t, err) + assert.NotNil(t, securityScheme) + assert.Equal(t, "no_auth", securityScheme.GetName()) + assert.Equal(t, auth.None, securityScheme.GetType()) + assert.Equal(t, auth.NoneScheme, securityScheme.GetScheme()) + assert.Nil(t, securityScheme.GetIn()) } -func TestNoAuthSecurityScheme_GetValidValue(t *testing.T) { - ss := &auth.NoAuthSecurityScheme{} - validValue := ss.GetValidValue() - assert.Equal(t, "", validValue) -} - -func TestNoAuthSecurityScheme_GetValidValueWriter(t *testing.T) { - ss := &auth.NoAuthSecurityScheme{} - validValueWriter := ss.GetValidValueWriter() - assert.Equal(t, "", validValueWriter) -} - -func TestNoAuthSecurityScheme_SetAttackValue(t *testing.T) { - ss := &auth.NoAuthSecurityScheme{} - ss.SetAttackValue("attack value") - // No assertion as this method does not return anything -} +func TestMustNewNoAuthSecurityScheme(t *testing.T) { + securityScheme := auth.MustNewNoAuthSecurityScheme() -func TestNoAuthSecurityScheme_GetAttackValue(t *testing.T) { - ss := &auth.NoAuthSecurityScheme{} - attackValue := ss.GetAttackValue() - assert.Nil(t, attackValue) + assert.NotNil(t, securityScheme) + assert.Equal(t, "no_auth", securityScheme.GetName()) + assert.Equal(t, auth.None, securityScheme.GetType()) + assert.Equal(t, auth.NoneScheme, securityScheme.GetScheme()) + assert.Nil(t, securityScheme.GetIn()) } diff --git a/internal/auth/oauth.go b/internal/auth/oauth.go index 90798add..f443e86f 100644 --- a/internal/auth/oauth.go +++ b/internal/auth/oauth.go @@ -1,8 +1,7 @@ package auth import ( - "fmt" - "net/http" + "time" "github.com/cerberauth/vulnapi/jwt" ) @@ -15,106 +14,87 @@ const ( ClientCredentials OAuthFlow = "client_credentials" ) -type OAuthConfig struct { - ClientID string - ClientSecret string - - TokenURL string - RefreshURL string -} - -type OAuthSecurityScheme struct { - Type Type `json:"type" yaml:"type"` - Scheme SchemeName `json:"scheme" yaml:"scheme"` - In SchemeIn `json:"in" yaml:"in"` - Name string `json:"name" yaml:"name"` - ValidValue *string `json:"-" yaml:"-"` - AttackValue string `json:"-" yaml:"-"` - - Config *OAuthConfig `json:"config" yaml:"config"` - JWTWriter *jwt.JWTWriter `json:"-" yaml:"-"` +type OAuthValue struct { + AccessToken string `json:"access_token" yaml:"access_token"` + RefreshToken *string `json:"refresh_token" yaml:"refresh_token"` + ExpiresIn *time.Time `json:"expires_in" yaml:"expires_in"` + Scope *string `json:"scope" yaml:"scope"` } -var _ SecurityScheme = (*OAuthSecurityScheme)(nil) - -func NewOAuthSecurityScheme(name string, value *string, cfg *OAuthConfig) *OAuthSecurityScheme { - var jwtWriter *jwt.JWTWriter - if value != nil { - jwtWriter, _ = jwt.NewJWTWriter(*value) - } - - return &OAuthSecurityScheme{ - Type: OAuth2, - Scheme: BearerScheme, - In: InHeader, - Name: name, - ValidValue: value, - JWTWriter: jwtWriter, - AttackValue: "", - - Config: cfg, +func NewOAuthValue(accessToken string, refreshToken *string, expiresIn *time.Time, scope *string) *OAuthValue { + return &OAuthValue{ + AccessToken: accessToken, + RefreshToken: refreshToken, + ExpiresIn: expiresIn, + Scope: scope, } } -func (ss *OAuthSecurityScheme) GetType() Type { - return ss.Type +func (value *OAuthValue) SetAccessToken(accessToken string) { + value.AccessToken = accessToken } -func (ss *OAuthSecurityScheme) GetScheme() SchemeName { - return ss.Scheme +func (value *OAuthValue) GetAccessToken() string { + return value.AccessToken } -func (ss *OAuthSecurityScheme) GetIn() *SchemeIn { - return &ss.In +func (value *OAuthValue) SetRefreshToken(refreshToken *string) { + value.RefreshToken = refreshToken } -func (ss *OAuthSecurityScheme) GetName() string { - return ss.Name +func (value *OAuthValue) GetRefreshToken() *string { + return value.RefreshToken } -func (ss *OAuthSecurityScheme) GetHeaders() http.Header { - header := http.Header{} - attackValue := ss.GetAttackValue().(string) - if attackValue == "" && ss.HasValidValue() { - attackValue = ss.GetValidValue().(string) - } +func (value *OAuthValue) IsValid() bool { + return value.AccessToken != "" +} - if attackValue != "" { - header.Set(AuthorizationHeader, fmt.Sprintf("%s %s", BearerPrefix, attackValue)) - } +type OAuthConfig struct { + ClientID string + ClientSecret string - return header + TokenURL string + RefreshURL string } -func (ss *OAuthSecurityScheme) GetCookies() []*http.Cookie { - return []*http.Cookie{} -} +var defaultIn = InHeader -func (ss *OAuthSecurityScheme) HasValidValue() bool { - return ss.ValidValue != nil && *ss.ValidValue != "" -} +func NewOAuthSecurityScheme(name string, in *SchemeIn, value *OAuthValue, config *OAuthConfig) (*SecurityScheme, error) { + if in == nil { + in = &defaultIn + } -func (ss *OAuthSecurityScheme) GetValidValue() interface{} { - if !ss.HasValidValue() { - return nil + securityScheme, err := NewSecurityScheme(name, config, OAuth2, OAuthScheme, in, nil) + if err != nil { + return nil, err } - return *ss.ValidValue -} + if value != nil && value.AccessToken != "" { + err = securityScheme.SetValidValue(value) + if err != nil { + return nil, err + } + + var tokenFormat TokenFormat + if jwt.IsJWT(value.AccessToken) { + tokenFormat = JWTTokenFormat + } else { + tokenFormat = NoneTokenFormat + } + if err = securityScheme.SetTokenFormat(tokenFormat); err != nil { + return nil, err + } + } -func (ss *OAuthSecurityScheme) GetValidValueWriter() interface{} { - return ss.JWTWriter + return securityScheme, nil } -func (ss *OAuthSecurityScheme) SetAttackValue(v interface{}) { - if v == nil { - ss.AttackValue = "" - return +func MustNewOAuthSecurityScheme(name string, in *SchemeIn, value *OAuthValue, config *OAuthConfig) *SecurityScheme { + securityScheme, err := NewOAuthSecurityScheme(name, in, value, config) + if err != nil { + panic(err) } - ss.AttackValue = v.(string) -} - -func (ss *OAuthSecurityScheme) GetAttackValue() interface{} { - return ss.AttackValue + return securityScheme } diff --git a/internal/auth/oauth_test.go b/internal/auth/oauth_test.go index 851dceec..b2b58adf 100644 --- a/internal/auth/oauth_test.go +++ b/internal/auth/oauth_test.go @@ -1,217 +1,99 @@ package auth_test import ( - "net/http" "testing" "github.com/cerberauth/vulnapi/internal/auth" - "github.com/cerberauth/vulnapi/jwt" "github.com/stretchr/testify/assert" ) func TestNewOAuthSecurityScheme(t *testing.T) { name := "token" - value := "abc123" + accessToken := "abc123" + in := auth.InHeader + tokenFormat := auth.NoneTokenFormat + oauthValue := auth.NewOAuthValue(accessToken, nil, nil, nil) + oauthConfig := auth.OAuthConfig{} - ss := auth.NewOAuthSecurityScheme(name, &value, nil) + securityScheme, err := auth.NewOAuthSecurityScheme(name, &in, oauthValue, &oauthConfig) - assert.Equal(t, auth.OAuth2, ss.Type) - assert.Equal(t, auth.BearerScheme, ss.Scheme) - assert.Equal(t, auth.InHeader, ss.In) - assert.Equal(t, name, ss.Name) - assert.Equal(t, &value, ss.ValidValue) - assert.Equal(t, "", ss.AttackValue) - assert.Nil(t, ss.JWTWriter) + assert.NoError(t, err) + assert.Equal(t, auth.OAuth2, securityScheme.GetType()) + assert.Equal(t, auth.OAuthScheme, securityScheme.GetScheme()) + assert.Equal(t, auth.InHeader, *securityScheme.GetIn()) + assert.Equal(t, &tokenFormat, securityScheme.GetTokenFormat()) + assert.Equal(t, name, securityScheme.GetName()) + assert.Equal(t, oauthValue, securityScheme.GetValidValue()) + assert.Equal(t, nil, securityScheme.GetAttackValue()) } -func TestNewOAuthSecurityScheme_WithJWT(t *testing.T) { +func TestNewOAuthSecurityScheme_WhenNilIn(t *testing.T) { name := "token" - value := jwt.FakeJWT + accessToken := "abc123" + oauthValue := auth.NewOAuthValue(accessToken, nil, nil, nil) + oauthConfig := auth.OAuthConfig{} - ss := auth.NewOAuthSecurityScheme(name, &value, nil) + securityScheme, err := auth.NewOAuthSecurityScheme(name, nil, oauthValue, &oauthConfig) - assert.Equal(t, name, ss.Name) - assert.Equal(t, &value, ss.ValidValue) - assert.Equal(t, "", ss.AttackValue) - assert.NotNil(t, ss.JWTWriter) + assert.NoError(t, err) + assert.Equal(t, auth.InHeader, *securityScheme.GetIn()) } -func TestOAuthSecurityScheme_GetScheme(t *testing.T) { +func TestNewOAuthSecurityScheme_WhenQueryIn(t *testing.T) { name := "token" - value := "abc123" - ss := auth.NewOAuthSecurityScheme(name, &value, nil) + accessToken := "abc123" + in := auth.InQuery + oauthValue := auth.NewOAuthValue(accessToken, nil, nil, nil) + oauthConfig := auth.OAuthConfig{} - scheme := ss.GetScheme() + securityScheme, err := auth.NewOAuthSecurityScheme(name, &in, oauthValue, &oauthConfig) - assert.Equal(t, auth.BearerScheme, scheme) + assert.NoError(t, err) + assert.Equal(t, auth.InQuery, *securityScheme.GetIn()) } -func TestOAuthSecurityScheme_GetType(t *testing.T) { +func TestNewOAuthSecurityScheme_WhenNilValue(t *testing.T) { name := "token" - value := "abc123" - ss := auth.NewOAuthSecurityScheme(name, &value, nil) + oauthConfig := auth.OAuthConfig{} - scheme := ss.GetType() + securityScheme, err := auth.NewOAuthSecurityScheme(name, nil, nil, &oauthConfig) - assert.Equal(t, auth.OAuth2, scheme) + assert.NoError(t, err) + assert.Equal(t, nil, securityScheme.GetValidValue()) + assert.Equal(t, nil, securityScheme.GetAttackValue()) } -func TestOAuthSecurityScheme_GetIn(t *testing.T) { +func TestNewOAuthSecurityScheme_WhenJWTFormatValue(t *testing.T) { name := "token" - value := "abc123" - ss := auth.NewOAuthSecurityScheme(name, &value, nil) + accessToken := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.e30.ufhxDTmrs4T5MSsvT6lsb3OpdWi5q8O31VX7TgrVamA" + in := auth.InHeader + tokenFormat := auth.JWTTokenFormat + oauthValue := auth.NewOAuthValue(accessToken, nil, nil, nil) + oauthConfig := auth.OAuthConfig{} - scheme := ss.GetIn() + securityScheme, err := auth.NewOAuthSecurityScheme(name, &in, oauthValue, &oauthConfig) - assert.Equal(t, auth.InHeader, *scheme) + assert.NoError(t, err) + assert.Equal(t, &tokenFormat, securityScheme.GetTokenFormat()) + assert.Equal(t, oauthValue, securityScheme.GetValidValue()) + assert.Equal(t, nil, securityScheme.GetAttackValue()) } -func TestOAuthSecurityScheme_GetName(t *testing.T) { +func TestMustNewOAuthSecurityScheme(t *testing.T) { name := "token" - value := "abc123" - ss := auth.NewOAuthSecurityScheme(name, &value, nil) + accessToken := "abc123" + in := auth.InHeader + tokenFormat := auth.NoneTokenFormat + oauthValue := auth.NewOAuthValue(accessToken, nil, nil, nil) + oauthConfig := auth.OAuthConfig{} - scheme := ss.GetName() + securityScheme := auth.MustNewOAuthSecurityScheme(name, &in, oauthValue, &oauthConfig) - assert.Equal(t, name, scheme) -} - -func TestNewOAuthSecurityScheme_GetHeaders(t *testing.T) { - name := "token" - value := "abc123" - attackValue := "xyz789" - - ss := auth.NewOAuthSecurityScheme(name, &value, nil) - ss.SetAttackValue(attackValue) - - headers := ss.GetHeaders() - - assert.Equal(t, http.Header{ - "Authorization": []string{"Bearer xyz789"}, - }, headers) -} - -func TestNewOAuthSecurityScheme_GetHeaders_WhenNoAttackValue(t *testing.T) { - name := "token" - value := "abc123" - - ss := auth.NewOAuthSecurityScheme(name, &value, nil) - - headers := ss.GetHeaders() - - assert.Equal(t, http.Header{ - "Authorization": []string{"Bearer abc123"}, - }, headers) -} - -func TestNewOAuthSecurityScheme_GetHeaders_WhenNoAttackAndValidValue(t *testing.T) { - name := "token" - ss := auth.NewOAuthSecurityScheme(name, nil, nil) - - headers := ss.GetHeaders() - - assert.Equal(t, http.Header{}, headers) -} - -func TestNewOAuthSecurityScheme_GetCookies(t *testing.T) { - name := "token" - value := "abc123" - - ss := auth.NewOAuthSecurityScheme(name, &value, nil) - - cookies := ss.GetCookies() - - assert.Empty(t, cookies) -} - -func TestNewOAuthSecurityScheme_HasValidValue(t *testing.T) { - name := "token" - value := "abc123" - ss := auth.NewOAuthSecurityScheme(name, &value, nil) - - result := ss.HasValidValue() - - assert.True(t, result) -} - -func TestNewOAuthSecurityScheme_HasValidValueFalse_WhenValueIsNil(t *testing.T) { - name := "token" - ss := auth.NewOAuthSecurityScheme(name, nil, nil) - - result := ss.HasValidValue() - - assert.False(t, result) -} - -func TestNewOAuthSecurityScheme_HasValidValueFalse_WhenValueIsEmptyString(t *testing.T) { - name := "token" - value := "" - ss := auth.NewOAuthSecurityScheme(name, &value, nil) - - result := ss.HasValidValue() - - assert.False(t, result) -} - -func TestNewOAuthSecurityScheme_GetValidValueNil(t *testing.T) { - name := "token" - ss := auth.NewOAuthSecurityScheme(name, nil, nil) - - validValue := ss.GetValidValue() - - assert.Equal(t, nil, validValue) -} - -func TestNewOAuthSecurityScheme_GetValidValue(t *testing.T) { - name := "token" - value := "abc123" - ss := auth.NewOAuthSecurityScheme(name, &value, nil) - - validValue := ss.GetValidValue() - - assert.Equal(t, value, validValue) -} - -func TestNewOAuthSecurityScheme_GetValidValueWriter(t *testing.T) { - name := "token" - value := "abc123" - ss := auth.NewOAuthSecurityScheme(name, &value, nil) - writer := ss.GetValidValueWriter() - - assert.Nil(t, writer) -} - -func TestNewOAuthSecurityScheme_GetValidValueWriter_WithJWT(t *testing.T) { - name := "token" - value := jwt.FakeJWT - - ss := auth.NewOAuthSecurityScheme(name, &value, nil) - writer := ss.GetValidValueWriter() - - assert.IsType(t, &jwt.JWTWriter{}, writer) -} - -func TestNewOAuthSecurityScheme_SetAttackValue(t *testing.T) { - name := "token" - value := "abc123" - - ss := auth.NewOAuthSecurityScheme(name, &value, nil) - - attackValue := "xyz789" - ss.SetAttackValue(attackValue) - - assert.Equal(t, attackValue, ss.AttackValue) -} - -func TestNewOAuthSecurityScheme_GetAttackValue(t *testing.T) { - name := "token" - value := "abc123" - ss := auth.NewOAuthSecurityScheme(name, &value, nil) - - attackValue := "xyz789" - ss.SetAttackValue(attackValue) - - result := ss.GetAttackValue() - - assert.Equal(t, attackValue, result) + assert.Equal(t, auth.OAuth2, securityScheme.GetType()) + assert.Equal(t, auth.OAuthScheme, securityScheme.GetScheme()) + assert.Equal(t, auth.InHeader, *securityScheme.GetIn()) + assert.Equal(t, &tokenFormat, securityScheme.GetTokenFormat()) + assert.Equal(t, name, securityScheme.GetName()) + assert.Equal(t, oauthValue, securityScheme.GetValidValue()) + assert.Equal(t, nil, securityScheme.GetAttackValue()) } diff --git a/internal/auth/scheme.go b/internal/auth/scheme.go index 1e6222e0..62e66305 100644 --- a/internal/auth/scheme.go +++ b/internal/auth/scheme.go @@ -24,6 +24,14 @@ func (e *SchemeName) Type() string { type SchemeIn string const ( + InQuery SchemeIn = "query" InHeader SchemeIn = "header" InCookie SchemeIn = "cookie" ) + +type TokenFormat string + +const ( + JWTTokenFormat TokenFormat = "jwt" + NoneTokenFormat TokenFormat = "none" +) diff --git a/internal/auth/security_scheme.go b/internal/auth/security_scheme.go index 5d752232..79e7c0bd 100644 --- a/internal/auth/security_scheme.go +++ b/internal/auth/security_scheme.go @@ -1,34 +1,221 @@ package auth import ( + "fmt" "net/http" + + "github.com/cerberauth/vulnapi/jwt" ) -type SecurityScheme interface { - GetType() Type - GetScheme() SchemeName - GetIn() *SchemeIn - GetName() string +func NewErrTokenFormatShouldBeJWT() error { + return fmt.Errorf("token format should be jwt") +} + +type SecurityScheme struct { + Type Type `json:"type" yaml:"type"` + Scheme SchemeName `json:"scheme" yaml:"scheme"` + In *SchemeIn `json:"in" yaml:"in"` + TokenFormat *TokenFormat `json:"token_format" yaml:"token_format"` + + Name string `json:"name" yaml:"name"` + Config interface{} `json:"config" yaml:"config"` + + ValidValue interface{} `json:"-" yaml:"-"` + AttackValue interface{} `json:"-" yaml:"-"` +} +type SecuritySchemesMap map[string]*SecurityScheme + +type InQueryValue = string +type InHeaderValue = string +type InCookieValue = http.Cookie + +func NewSecurityScheme(name string, config interface{}, t Type, scheme SchemeName, in *SchemeIn, tokenFormat *TokenFormat) (*SecurityScheme, error) { + if in != nil && name == "" { + return nil, fmt.Errorf("name is required for security scheme with in %s", *in) + } + + if t == ApiKey && in == nil { + return nil, fmt.Errorf("in is required for security scheme with type %s", t) + } + + return &SecurityScheme{ + Name: name, + Config: config, + + Type: t, + Scheme: scheme, + In: in, + TokenFormat: tokenFormat, + }, nil +} + +func (securityScheme *SecurityScheme) GetType() Type { + return securityScheme.Type +} + +func (securityScheme *SecurityScheme) GetScheme() SchemeName { + return securityScheme.Scheme +} - GetHeaders() http.Header - GetCookies() []*http.Cookie - GetValidValue() interface{} - HasValidValue() bool - GetValidValueWriter() interface{} - SetAttackValue(v interface{}) - GetAttackValue() interface{} +func (securityScheme *SecurityScheme) GetIn() *SchemeIn { + return securityScheme.In } -type SecuritySchemesMap map[string]SecurityScheme -func GetSecuritySchemeUniqueName(securityScheme SecurityScheme) string { - if securityScheme == nil { +func (securityScheme *SecurityScheme) GetToken() string { + if !securityScheme.HasValidValue() { return "" } - uniqueName := string(securityScheme.GetType()) + "-" + string(securityScheme.GetScheme()) - if securityScheme.GetIn() != nil { - uniqueName += "-" + string(*securityScheme.GetIn()) + switch securityScheme.GetType() { + case OAuth2: + return securityScheme.GetValidValue().(*OAuthValue).GetAccessToken() + default: + return securityScheme.GetValidValue().(string) + } +} + +func (securityScheme *SecurityScheme) SetTokenFormat(tokenFormat TokenFormat) error { + if tokenFormat == JWTTokenFormat && securityScheme.HasValidValue() && !jwt.IsJWT(securityScheme.GetToken()) { + return NewErrTokenFormatShouldBeJWT() + } + + securityScheme.TokenFormat = &tokenFormat + return nil +} + +func (securityScheme *SecurityScheme) GetTokenFormat() *TokenFormat { + return securityScheme.TokenFormat +} + +func (securityScheme *SecurityScheme) GetName() string { + return securityScheme.Name +} + +func (securityScheme *SecurityScheme) GetConfig() interface{} { + return securityScheme.Config +} + +func (securityScheme *SecurityScheme) validateValue(value interface{}) error { + if value == nil { + return fmt.Errorf("value is required") + } + + switch securityScheme.GetType() { + case ApiKey: + if securityScheme.GetIn() == nil { + return fmt.Errorf("in is required for api key security scheme") + } + + var ok bool + switch *securityScheme.GetIn() { + case InQuery: + _, ok = value.(InQueryValue) + case InHeader: + _, ok = value.(InHeaderValue) + case InCookie: + _, ok = value.(InCookieValue) + } + if !ok { + return fmt.Errorf("invalid value for api key security scheme") + } + return nil + + case HttpType: + val, ok := value.(string) + if !ok { + return fmt.Errorf("invalid value for http security scheme") + } + + if securityScheme.GetTokenFormat() != nil && *securityScheme.GetTokenFormat() == JWTTokenFormat { + if _, err := jwt.NewJWTWriter(val); err != nil { + return err + } + } + return nil + + case OAuth2: + _, ok := value.(*OAuthValue) + if !ok { + return fmt.Errorf("invalid value for oauth2 security scheme") + } + return nil + } + + return nil +} + +func (securityScheme *SecurityScheme) SetValidValue(value interface{}) error { + if value == nil { + securityScheme.ValidValue = nil + return nil + } + + if err := securityScheme.validateValue(value); err != nil { + return err + } + + securityScheme.ValidValue = value + return nil +} + +func (securityScheme *SecurityScheme) GetValidValue() interface{} { + return securityScheme.ValidValue +} + +func (securityScheme *SecurityScheme) HasValidValue() bool { + return securityScheme.GetValidValue() != nil +} + +func (securityScheme *SecurityScheme) SetAttackValue(value interface{}) error { + if value == nil { + securityScheme.AttackValue = nil + return nil + } + + if err := securityScheme.validateValue(value); err != nil { + return err + } + + securityScheme.AttackValue = value + return nil +} + +func (securityScheme *SecurityScheme) GetAttackValue() interface{} { + return securityScheme.AttackValue +} + +func (securityScheme *SecurityScheme) GetHeaders() http.Header { + header := http.Header{} + + attackValue := securityScheme.GetAttackValue() + if attackValue == nil && securityScheme.HasValidValue() { + attackValue = securityScheme.GetValidValue() + } + + if attackValue == nil { + return header + } + + if (securityScheme.GetType() == ApiKey || securityScheme.GetType() == HttpType) && *securityScheme.GetIn() == InHeader { + var val string + if securityScheme.GetType() == HttpType && securityScheme.GetScheme() == BearerScheme { + val = fmt.Sprintf("%s %s", BearerPrefix, attackValue) + header.Set(AuthorizationHeader, val) + } else { + val = fmt.Sprintf("%s", attackValue) + header.Set(securityScheme.GetName(), val) + } + } + + return header +} + +func (securityScheme *SecurityScheme) GetCookies() []*http.Cookie { + if securityScheme.GetIn() == nil || *securityScheme.GetIn() != InCookie { + return []*http.Cookie{} } - return uniqueName + cookies := []*http.Cookie{} + // TODO + return cookies } diff --git a/internal/auth/security_scheme_test.go b/internal/auth/security_scheme_test.go index 0861e660..f825633c 100644 --- a/internal/auth/security_scheme_test.go +++ b/internal/auth/security_scheme_test.go @@ -1,50 +1,458 @@ package auth_test import ( + "fmt" + "net/http" "testing" "github.com/cerberauth/vulnapi/internal/auth" + "github.com/cerberauth/vulnapi/jwt" "github.com/stretchr/testify/assert" ) -func TestGetSecuritySchemeUniqueName(t *testing.T) { - noAuthSecurityScheme := auth.NewNoAuthSecurityScheme() - bearerSecurityScheme := auth.NewAuthorizationBearerSecurityScheme("name", nil) - jwtBearerSecurityScheme, _ := auth.NewAuthorizationJWTBearerSecurityScheme("name", nil) - oauthSecurityScheme := auth.NewOAuthSecurityScheme("name", nil, nil) +func TestNewSecurityScheme(t *testing.T) { + inHeader := auth.InHeader + jwtTokenFormat := auth.JWTTokenFormat tests := []struct { - name string - securityScheme auth.SecurityScheme - expected string + name string + schemeName string + config interface{} + t auth.Type + scheme auth.SchemeName + in *auth.SchemeIn + tokenFormat *auth.TokenFormat + expectError bool }{ { - name: "no auth security scheme", - securityScheme: noAuthSecurityScheme, - expected: "none-None", + name: "Valid API Key in Header", + schemeName: "apiKey", + config: nil, + t: auth.ApiKey, + scheme: auth.PrivateToken, + in: &inHeader, + tokenFormat: nil, + expectError: false, }, { - name: "bearer security scheme", - securityScheme: bearerSecurityScheme, - expected: "http-Bearer-header", + name: "Missing name with in", + schemeName: "", + config: nil, + t: auth.ApiKey, + scheme: auth.PrivateToken, + in: &inHeader, + tokenFormat: nil, + expectError: true, }, { - name: "jwt bearer security scheme", - securityScheme: jwtBearerSecurityScheme, - expected: "http-Bearer-header", + name: "Missing in for API Key", + schemeName: "apiKey", + config: nil, + t: auth.ApiKey, + scheme: auth.PrivateToken, + in: nil, + tokenFormat: nil, + expectError: true, }, { - name: "oauth security scheme", - securityScheme: oauthSecurityScheme, - expected: "oauth2-Bearer-header", + name: "Valid HTTP Bearer", + schemeName: "bearer", + config: nil, + t: auth.HttpType, + scheme: auth.BearerScheme, + in: &inHeader, + tokenFormat: &jwtTokenFormat, + expectError: false, }, } - assert.Equal(t, "", auth.GetSecuritySchemeUniqueName(nil)) for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result := auth.GetSecuritySchemeUniqueName(tt.securityScheme) - assert.Equal(t, tt.expected, result) + securityScheme, err := auth.NewSecurityScheme(tt.schemeName, tt.config, tt.t, tt.scheme, tt.in, tt.tokenFormat) + if tt.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.schemeName, securityScheme.GetName()) + assert.Equal(t, tt.config, securityScheme.GetConfig()) + assert.Equal(t, tt.t, securityScheme.GetType()) + assert.Equal(t, tt.scheme, securityScheme.GetScheme()) + assert.Equal(t, tt.in, securityScheme.GetIn()) + assert.Equal(t, tt.tokenFormat, securityScheme.GetTokenFormat()) + } + }) + } +} + +func TestSetValidValue(t *testing.T) { + inHeader := auth.InHeader + inQuery := auth.InQuery + inCookie := auth.InCookie + jwtTokenFormat := auth.JWTTokenFormat + + tests := []struct { + name string + schemeType auth.Type + schemeName auth.SchemeName + in *auth.SchemeIn + tokenFormat *auth.TokenFormat + value interface{} + expectError bool + expectedMessage string + }{ + { + name: "Valid API Key in Header", + schemeType: auth.ApiKey, + schemeName: auth.PrivateToken, + in: &inHeader, + tokenFormat: nil, + value: "valid-api-key", + expectError: false, + expectedMessage: "", + }, + { + name: "Invalid API Key in Header", + schemeType: auth.ApiKey, + schemeName: auth.PrivateToken, + in: &inHeader, + tokenFormat: nil, + value: &http.Header{"token": []string{"invalid-api-key"}}, + expectError: true, + expectedMessage: "invalid value for api key security scheme", + }, + { + name: "Valid API Key in Query", + schemeType: auth.ApiKey, + schemeName: auth.PrivateToken, + in: &inQuery, + tokenFormat: nil, + value: "valid-api-key", + expectError: false, + expectedMessage: "", + }, + { + name: "Valid API Key in Cookie", + schemeType: auth.ApiKey, + schemeName: auth.PrivateToken, + in: &inCookie, + tokenFormat: nil, + value: http.Cookie{Name: "token", Value: "valid-api-key"}, + expectError: false, + expectedMessage: "", + }, + { + name: "Valid HTTP Bearer with JWT", + schemeType: auth.HttpType, + schemeName: auth.BearerScheme, + in: &inHeader, + tokenFormat: &jwtTokenFormat, + value: jwt.FakeJWT, + expectError: false, + expectedMessage: "", + }, + { + name: "Invalid HTTP Bearer with JWT", + schemeType: auth.HttpType, + schemeName: auth.BearerScheme, + in: &inHeader, + tokenFormat: &jwtTokenFormat, + value: "invalid-jwt", + expectError: true, + expectedMessage: "token is malformed: token contains an invalid number of segments", + }, + { + name: "Valid OAuth2", + schemeType: auth.OAuth2, + schemeName: auth.PrivateToken, + in: nil, + tokenFormat: nil, + value: &auth.OAuthValue{}, + expectError: false, + expectedMessage: "", + }, + { + name: "Invalid OAuth2", + schemeType: auth.OAuth2, + schemeName: auth.PrivateToken, + in: nil, + tokenFormat: nil, + value: "invalid-oauth2", + expectError: true, + expectedMessage: "invalid value for oauth2 security scheme", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + securityScheme, err := auth.NewSecurityScheme("testScheme", nil, tt.schemeType, tt.schemeName, tt.in, tt.tokenFormat) + assert.NoError(t, err) + + err = securityScheme.SetValidValue(tt.value) + if tt.expectError { + assert.Error(t, err) + assert.Equal(t, tt.expectedMessage, err.Error()) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.value, securityScheme.GetValidValue()) + } + }) + } +} + +func TestSetTokenFormat(t *testing.T) { + inHeader := auth.InHeader + jwtTokenFormat := auth.JWTTokenFormat + + tests := []struct { + name string + initialValue interface{} + tokenFormat auth.TokenFormat + expectError bool + expectedMessage string + }{ + { + name: "Valid JWT Token Format", + initialValue: jwt.FakeJWT, + tokenFormat: jwtTokenFormat, + expectError: false, + expectedMessage: "", + }, + { + name: "Invalid JWT Token Format", + initialValue: "invalid-token", + tokenFormat: jwtTokenFormat, + expectError: true, + expectedMessage: "token format should be jwt", + }, + { + name: "Non-JWT Token Format", + initialValue: "some-value", + tokenFormat: auth.TokenFormat("non-jwt"), + expectError: false, + expectedMessage: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + securityScheme, err := auth.NewSecurityScheme("testScheme", nil, auth.HttpType, auth.BearerScheme, &inHeader, nil) + assert.NoError(t, err) + + err = securityScheme.SetValidValue(tt.initialValue) + assert.NoError(t, err) + + err = securityScheme.SetTokenFormat(tt.tokenFormat) + if tt.expectError { + assert.Error(t, err) + assert.Equal(t, tt.expectedMessage, err.Error()) + } else { + assert.NoError(t, err) + assert.Equal(t, &tt.tokenFormat, securityScheme.GetTokenFormat()) + } + }) + } +} + +func TestSetAttackValue(t *testing.T) { + inHeader := auth.InHeader + inQuery := auth.InQuery + inCookie := auth.InCookie + jwtTokenFormat := auth.JWTTokenFormat + + tests := []struct { + name string + schemeType auth.Type + schemeName auth.SchemeName + in *auth.SchemeIn + tokenFormat *auth.TokenFormat + value interface{} + expectError bool + expectedMessage string + }{ + { + name: "Valid API Key in Header", + schemeType: auth.ApiKey, + schemeName: auth.PrivateToken, + in: &inHeader, + tokenFormat: nil, + value: "valid-api-key", + expectError: false, + expectedMessage: "", + }, + { + name: "Invalid API Key in Header", + schemeType: auth.ApiKey, + schemeName: auth.PrivateToken, + in: &inHeader, + tokenFormat: nil, + value: &http.Header{"token": []string{"invalid-api-key"}}, + expectError: true, + expectedMessage: "invalid value for api key security scheme", + }, + { + name: "Valid API Key in Query", + schemeType: auth.ApiKey, + schemeName: auth.PrivateToken, + in: &inQuery, + tokenFormat: nil, + value: "valid-api-key", + expectError: false, + expectedMessage: "", + }, + { + name: "Valid API Key in Cookie", + schemeType: auth.ApiKey, + schemeName: auth.PrivateToken, + in: &inCookie, + tokenFormat: nil, + value: http.Cookie{Name: "token", Value: "valid-api-key"}, + expectError: false, + expectedMessage: "", + }, + { + name: "Valid HTTP Bearer with JWT", + schemeType: auth.HttpType, + schemeName: auth.BearerScheme, + in: &inHeader, + tokenFormat: &jwtTokenFormat, + value: jwt.FakeJWT, + expectError: false, + expectedMessage: "", + }, + { + name: "Invalid HTTP Bearer with JWT", + schemeType: auth.HttpType, + schemeName: auth.BearerScheme, + in: &inHeader, + tokenFormat: &jwtTokenFormat, + value: "invalid-jwt", + expectError: true, + expectedMessage: "token is malformed: token contains an invalid number of segments", + }, + { + name: "Valid OAuth2", + schemeType: auth.OAuth2, + schemeName: auth.PrivateToken, + in: nil, + tokenFormat: nil, + value: &auth.OAuthValue{}, + expectError: false, + expectedMessage: "", + }, + { + name: "Invalid OAuth2", + schemeType: auth.OAuth2, + schemeName: auth.PrivateToken, + in: nil, + tokenFormat: nil, + value: "invalid-oauth2", + expectError: true, + expectedMessage: "invalid value for oauth2 security scheme", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + securityScheme, err := auth.NewSecurityScheme("testScheme", nil, tt.schemeType, tt.schemeName, tt.in, tt.tokenFormat) + assert.NoError(t, err) + + err = securityScheme.SetAttackValue(tt.value) + if tt.expectError { + assert.Error(t, err) + assert.Equal(t, tt.expectedMessage, err.Error()) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.value, securityScheme.GetAttackValue()) + } + }) + } +} + +func TestGetHeaders(t *testing.T) { + inHeader := auth.InHeader + jwtTokenFormat := auth.JWTTokenFormat + + tests := []struct { + name string + schemeName string + schemeType auth.Type + scheme auth.SchemeName + in *auth.SchemeIn + tokenFormat *auth.TokenFormat + validValue interface{} + attackValue interface{} + expectedHeaders http.Header + }{ + { + name: "API Key in Header with Valid Value", + schemeName: "X-Api-Key", + schemeType: auth.ApiKey, + scheme: auth.PrivateToken, + in: &inHeader, + tokenFormat: nil, + validValue: "valid-api-key", + attackValue: nil, + expectedHeaders: http.Header{"X-Api-Key": []string{"valid-api-key"}}, + }, + { + name: "API Key in Header with Attack Value", + schemeName: "X-Api-Key", + schemeType: auth.ApiKey, + scheme: auth.PrivateToken, + in: &inHeader, + tokenFormat: nil, + validValue: "valid-api-key", + attackValue: "attack-api-key", + expectedHeaders: http.Header{"X-Api-Key": []string{"attack-api-key"}}, + }, + { + name: "HTTP Bearer with JWT", + schemeName: "Bearer", + schemeType: auth.HttpType, + scheme: auth.BearerScheme, + in: &inHeader, + tokenFormat: &jwtTokenFormat, + validValue: jwt.FakeJWT, + attackValue: nil, + expectedHeaders: http.Header{"Authorization": []string{fmt.Sprintf("%s %s", auth.BearerPrefix, jwt.FakeJWT)}}, + }, + { + name: "HTTP Bearer with Attack JWT", + schemeName: "Bearer", + schemeType: auth.HttpType, + scheme: auth.BearerScheme, + in: &inHeader, + tokenFormat: &jwtTokenFormat, + validValue: jwt.FakeJWT, + attackValue: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.e30.", + expectedHeaders: http.Header{"Authorization": []string{fmt.Sprintf("%s %s", auth.BearerPrefix, "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.e30.")}}, + }, + { + name: "No Valid or Attack Value", + schemeName: "X-Api-Key", + schemeType: auth.ApiKey, + scheme: auth.PrivateToken, + in: &inHeader, + tokenFormat: nil, + validValue: nil, + attackValue: nil, + expectedHeaders: http.Header{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + securityScheme, err := auth.NewSecurityScheme(tt.schemeName, nil, tt.schemeType, tt.scheme, tt.in, tt.tokenFormat) + assert.NoError(t, err) + + err = securityScheme.SetValidValue(tt.validValue) + assert.NoError(t, err) + + err = securityScheme.SetAttackValue(tt.attackValue) + assert.NoError(t, err) + + headers := securityScheme.GetHeaders() + assert.Equal(t, tt.expectedHeaders, headers) }) } } diff --git a/internal/auth/type.go b/internal/auth/type.go index 5c41fa3c..20c52dbd 100644 --- a/internal/auth/type.go +++ b/internal/auth/type.go @@ -7,5 +7,6 @@ const ( OAuth2 Type = "oauth2" OpenIdConnect Type = "openIdConnect" ApiKey Type = "apiKey" + MutualTLS Type = "mutualTLS" None Type = "none" ) diff --git a/internal/auth/uniq_name.go b/internal/auth/uniq_name.go new file mode 100644 index 00000000..fee611a3 --- /dev/null +++ b/internal/auth/uniq_name.go @@ -0,0 +1,14 @@ +package auth + +func GetSecuritySchemeUniqueName(securityScheme *SecurityScheme) string { + if securityScheme == nil { + return "" + } + + uniqueName := string(securityScheme.GetType()) + "-" + string(securityScheme.GetScheme()) + if securityScheme.GetIn() != nil { + uniqueName += "-" + string(*securityScheme.GetIn()) + } + + return uniqueName +} diff --git a/internal/auth/uniq_name_test.go b/internal/auth/uniq_name_test.go new file mode 100644 index 00000000..9c4e2037 --- /dev/null +++ b/internal/auth/uniq_name_test.go @@ -0,0 +1,40 @@ +package auth_test + +import ( + "testing" + + "github.com/cerberauth/vulnapi/internal/auth" + "github.com/stretchr/testify/assert" +) + +func TestGetSecuritySchemeUniqueName(t *testing.T) { + tests := []struct { + name string + securityScheme *auth.SecurityScheme + expected string + }{ + { + name: "no auth security scheme", + securityScheme: auth.MustNewNoAuthSecurityScheme(), + expected: "none-None", + }, + { + name: "bearer security scheme", + securityScheme: auth.MustNewAuthorizationBearerSecurityScheme("name", nil), + expected: "http-Bearer-header", + }, + { + name: "oauth security scheme", + securityScheme: auth.MustNewOAuthSecurityScheme("name", nil, &auth.OAuthValue{}, nil), + expected: "oauth2-OAuth-header", + }, + } + + assert.Equal(t, "", auth.GetSecuritySchemeUniqueName(nil)) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := auth.GetSecuritySchemeUniqueName(tt.securityScheme) + assert.Equal(t, tt.expected, result) + }) + } +} diff --git a/internal/operation/operation.go b/internal/operation/operation.go index 8bf9e8ba..0bdd8e07 100644 --- a/internal/operation/operation.go +++ b/internal/operation/operation.go @@ -42,12 +42,12 @@ type Operation struct { OpenAPIDocPath *string `json:"-" yaml:"-"` ID string `json:"id" yaml:"id"` - Method string `json:"method" yaml:"method"` - URL url.URL `json:"url" yaml:"url"` - Body []byte `json:"body,omitempty" yaml:"body,omitempty"` - Cookies []*http.Cookie `json:"cookies,omitempty" yaml:"cookies,omitempty"` - Header http.Header `json:"header,omitempty" yaml:"header,omitempty"` - SecuritySchemes []auth.SecurityScheme `json:"securitySchemes" yaml:"securitySchemes"` + Method string `json:"method" yaml:"method"` + URL url.URL `json:"url" yaml:"url"` + Body []byte `json:"body,omitempty" yaml:"body,omitempty"` + Cookies []*http.Cookie `json:"cookies,omitempty" yaml:"cookies,omitempty"` + Header http.Header `json:"header,omitempty" yaml:"header,omitempty"` + SecuritySchemes []*auth.SecurityScheme `json:"securitySchemes" yaml:"securitySchemes"` } func getBody(body io.Reader) ([]byte, error) { @@ -88,7 +88,7 @@ func NewOperation(method string, operationUrl string, body io.Reader, client *re Body: bodyBuffer, Cookies: []*http.Cookie{}, Header: http.Header{}, - SecuritySchemes: []auth.SecurityScheme{auth.NewNoAuthSecurityScheme()}, + SecuritySchemes: []*auth.SecurityScheme{auth.MustNewNoAuthSecurityScheme()}, }, nil } @@ -126,7 +126,7 @@ func NewOperationFromRequest(r *request.Request) (*Operation, error) { Cookies: r.GetCookies(), Body: r.GetBody(), - SecuritySchemes: []auth.SecurityScheme{auth.NewNoAuthSecurityScheme()}, + SecuritySchemes: []*auth.SecurityScheme{auth.MustNewNoAuthSecurityScheme()}, }, nil } @@ -162,21 +162,21 @@ func (operation *Operation) NewRequest() (*request.Request, error) { return req, nil } -func (operation *Operation) GetSecuritySchemes() []auth.SecurityScheme { +func (operation *Operation) GetSecuritySchemes() []*auth.SecurityScheme { if operation.SecuritySchemes == nil { - return []auth.SecurityScheme{auth.NewNoAuthSecurityScheme()} + return []*auth.SecurityScheme{auth.MustNewNoAuthSecurityScheme()} } return operation.SecuritySchemes } -func (operation *Operation) GetSecurityScheme() auth.SecurityScheme { +func (operation *Operation) GetSecurityScheme() *auth.SecurityScheme { if operation.SecuritySchemes == nil { - return auth.NewNoAuthSecurityScheme() + return auth.MustNewNoAuthSecurityScheme() } return operation.SecuritySchemes[0] } -func (operation *Operation) SetSecuritySchemes(securitySchemes []auth.SecurityScheme) *Operation { +func (operation *Operation) SetSecuritySchemes(securitySchemes []*auth.SecurityScheme) *Operation { operation.SecuritySchemes = securitySchemes return operation } @@ -204,9 +204,9 @@ func (operation *Operation) GetID() string { } func (o *Operation) Clone() (*Operation, error) { - var clonedSecuritySchemes []auth.SecurityScheme + var clonedSecuritySchemes []*auth.SecurityScheme if o.SecuritySchemes != nil { - clonedSecuritySchemes = make([]auth.SecurityScheme, len(o.SecuritySchemes)) + clonedSecuritySchemes = make([]*auth.SecurityScheme, len(o.SecuritySchemes)) copy(clonedSecuritySchemes, o.SecuritySchemes) } diff --git a/internal/operation/operation_test.go b/internal/operation/operation_test.go index 34313011..a8bf3d2f 100644 --- a/internal/operation/operation_test.go +++ b/internal/operation/operation_test.go @@ -140,7 +140,7 @@ func TestNewOperationFromRequest_WithBody(t *testing.T) { func TestOperation_GetSecurityScheme(t *testing.T) { t.Run("NoSecuritySchemes", func(t *testing.T) { operation := &operation.Operation{} - expectedScheme := auth.NewNoAuthSecurityScheme() + expectedScheme := auth.MustNewNoAuthSecurityScheme() scheme := operation.GetSecurityScheme() @@ -148,9 +148,9 @@ func TestOperation_GetSecurityScheme(t *testing.T) { }) t.Run("WithSecuritySchemes", func(t *testing.T) { - expectedScheme := auth.NewNoAuthSecurityScheme() + expectedScheme := auth.MustNewNoAuthSecurityScheme() operation := &operation.Operation{ - SecuritySchemes: []auth.SecurityScheme{expectedScheme}, + SecuritySchemes: []*auth.SecurityScheme{expectedScheme}, } scheme := operation.GetSecurityScheme() @@ -160,7 +160,7 @@ func TestOperation_GetSecurityScheme(t *testing.T) { } func TestOperationCloneWithSecuritySchemes(t *testing.T) { - securitySchemes := []auth.SecurityScheme{auth.NewNoAuthSecurityScheme()} + securitySchemes := []*auth.SecurityScheme{auth.MustNewNoAuthSecurityScheme()} operation := operation.MustNewOperation(http.MethodGet, "http://example.com", nil, nil) operation.SetSecuritySchemes(securitySchemes) diff --git a/internal/request/request.go b/internal/request/request.go index a9f77f3a..4d272a5f 100644 --- a/internal/request/request.go +++ b/internal/request/request.go @@ -74,13 +74,13 @@ func (r *Request) WithCookies(cookies []*http.Cookie) *Request { return r } -func (r *Request) WithSecurityScheme(securityScheme auth.SecurityScheme) *Request { - if securityScheme.GetCookies() != nil { - r.WithCookies(securityScheme.GetCookies()) +func (r *Request) WithSecurityScheme(securityScheme *auth.SecurityScheme) *Request { + if cookies := securityScheme.GetCookies(); cookies != nil { + r.WithCookies(cookies) } - if securityScheme.GetHeaders() != nil { - r.WithHeader(securityScheme.GetHeaders()) + if headers := securityScheme.GetHeaders(); headers != nil { + r.WithHeader(headers) } return r diff --git a/internal/request/request_test.go b/internal/request/request_test.go index fa1eb141..647943e3 100644 --- a/internal/request/request_test.go +++ b/internal/request/request_test.go @@ -59,7 +59,7 @@ func TestWithSecurityScheme(t *testing.T) { method := http.MethodGet url := "http://localhost:8080/" token := "token" - securityScheme := auth.SecurityScheme(auth.NewAuthorizationBearerSecurityScheme("token", &token)) + securityScheme := auth.MustNewAuthorizationBearerSecurityScheme("token", &token) request, err := request.NewRequest(method, url, nil, nil) request = request.WithSecurityScheme(securityScheme) @@ -346,7 +346,7 @@ func TestDoWithSecuritySchemeHeaders(t *testing.T) { method := http.MethodGet url := "http://localhost:8080/" token := "token" - securityScheme := auth.SecurityScheme(auth.NewAuthorizationBearerSecurityScheme("token", &token)) + securityScheme := auth.MustNewAuthorizationBearerSecurityScheme("token", &token) request, _ := request.NewRequest(method, url, nil, client) request.WithSecurityScheme(securityScheme) @@ -376,7 +376,7 @@ func TestDoWithHeadersSecuritySchemeHeaders(t *testing.T) { "Authorization": []string{"Bearer othertoken"}, } token := "token" - securityScheme := auth.SecurityScheme(auth.NewAuthorizationBearerSecurityScheme("token", &token)) + securityScheme := auth.MustNewAuthorizationBearerSecurityScheme("token", &token) request, _ := request.NewRequest(method, url, nil, client) request = request.WithHeader(header) request = request.WithSecurityScheme(securityScheme) @@ -408,7 +408,7 @@ func TestDoWithCookiesSecuritySchemeHeaders(t *testing.T) { Value: "value1", }} token := "token" - securityScheme := auth.SecurityScheme(auth.NewAuthorizationBearerSecurityScheme("token", &token)) + securityScheme := auth.MustNewAuthorizationBearerSecurityScheme("token", &token) request, _ := request.NewRequest(method, url, nil, client) request = request.WithCookies(cookies) request = request.WithSecurityScheme(securityScheme) diff --git a/internal/scan/scan_url.go b/internal/scan/scan_url.go index 7409f447..fa764c71 100644 --- a/internal/scan/scan_url.go +++ b/internal/scan/scan_url.go @@ -21,7 +21,7 @@ func ScanURL(operation *operation.Operation, securityScheme *auth.SecurityScheme } if securityScheme != nil { - req.WithSecurityScheme(*securityScheme) + req.WithSecurityScheme(securityScheme) } else { req.WithSecurityScheme(operation.GetSecurityScheme()) } diff --git a/jwt/jwt.go b/jwt/jwt.go new file mode 100644 index 00000000..43f323a2 --- /dev/null +++ b/jwt/jwt.go @@ -0,0 +1,23 @@ +package jwt + +import ( + "errors" + "regexp" + + "github.com/golang-jwt/jwt/v5" + "go.opentelemetry.io/otel" +) + +var tracer = otel.Tracer("jwt") + +var jwtRegexp = `^[A-Za-z0-9-_=]+\.[A-Za-z0-9-_=]+\.?[A-Za-z0-9-_.+/=]*$` + +func IsJWT(token string) bool { + matched, err := regexp.MatchString(jwtRegexp, token) + if err != nil || !matched { + return false + } + + _, _, err = new(jwt.Parser).ParseUnverified(token, jwt.MapClaims{}) + return err == nil && !errors.Is(err, jwt.ErrTokenUnverifiable) && !errors.Is(err, jwt.ErrTokenSignatureInvalid) +} diff --git a/jwt/jwt_test.go b/jwt/jwt_test.go new file mode 100644 index 00000000..2a33b89a --- /dev/null +++ b/jwt/jwt_test.go @@ -0,0 +1,26 @@ +package jwt_test + +import ( + "testing" + + "github.com/cerberauth/vulnapi/jwt" +) + +func TestIsJWT(t *testing.T) { + tests := []struct { + token string + expected bool + }{ + {"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIn0.Gfx6VO9tcxwk6xqx9yYzSfebfeakZp5JYIgP_edcw_A", true}, + {"invalid.jwt.token", false}, + {"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIn0.", true}, + {"", false}, + } + + for _, test := range tests { + result := jwt.IsJWT(test.token) + if result != test.expected { + t.Errorf("IsJWT(%q) = %v; want %v", test.token, result, test.expected) + } + } +} diff --git a/jwt/jwt_writer.go b/jwt/jwt_writer.go index f1075860..cc180027 100644 --- a/jwt/jwt_writer.go +++ b/jwt/jwt_writer.go @@ -1,6 +1,7 @@ package jwt import ( + "context" "errors" "time" @@ -12,8 +13,12 @@ type JWTWriter struct { } func NewJWTWriter(token string) (*JWTWriter, error) { + _, span := tracer.Start(context.TODO(), "NewJWTWriter") + defer span.End() + tokenParsed, _, err := new(jwt.Parser).ParseUnverified(token, jwt.MapClaims{}) - if err != nil { + if err != nil && !errors.Is(err, jwt.ErrTokenUnverifiable) && !errors.Is(err, jwt.ErrTokenSignatureInvalid) { + span.RecordError(err) return nil, err } diff --git a/openapi/openapi.go b/openapi/openapi.go index 55a32281..95fb9bfb 100644 --- a/openapi/openapi.go +++ b/openapi/openapi.go @@ -4,8 +4,11 @@ import ( "net/url" "github.com/getkin/kin-openapi/openapi3" + "go.opentelemetry.io/otel" ) +var tracer = otel.Tracer("openapi") + type OpenAPI struct { baseUrl *url.URL diff --git a/openapi/operation.go b/openapi/operation.go index 1dd3b04f..b61675b7 100644 --- a/openapi/operation.go +++ b/openapi/operation.go @@ -12,8 +12,8 @@ import ( stduritemplate "github.com/std-uritemplate/std-uritemplate/go" ) -func getOperationSecuritySchemes(securityRequirements *openapi3.SecurityRequirements, securitySchemes map[string]auth.SecurityScheme) []auth.SecurityScheme { - operationsSecuritySchemes := []auth.SecurityScheme{} +func getOperationSecuritySchemes(securityRequirements *openapi3.SecurityRequirements, securitySchemes map[string]*auth.SecurityScheme) []*auth.SecurityScheme { + operationsSecuritySchemes := []*auth.SecurityScheme{} for _, security := range *securityRequirements { if len(security) == 0 { continue diff --git a/openapi/param_test.go b/openapi/param_test.go index c7109427..b0ef7a5c 100644 --- a/openapi/param_test.go +++ b/openapi/param_test.go @@ -4,19 +4,18 @@ import ( "context" "testing" - "github.com/cerberauth/vulnapi/internal/auth" "github.com/cerberauth/vulnapi/openapi" "github.com/stretchr/testify/assert" ) func TestGetSchemaValue_WhenNoParameters(t *testing.T) { - openapi, _ := openapi.LoadFromData( + openapiContract, _ := openapi.LoadFromData( context.Background(), []byte(`{openapi: 3.0.2, servers: [{url: 'http://localhost:8080'}], paths: {/: {get: {parameters: [], responses: {'204': {description: successful operation}}}}}}`), ) - securitySchemesMap, _ := openapi.SecuritySchemeMap(auth.NewEmptySecuritySchemeValues()) - operations, err := openapi.Operations(nil, securitySchemesMap) + securitySchemesMap, _ := openapiContract.SecuritySchemeMap(openapi.NewEmptySecuritySchemeValues()) + operations, err := openapiContract.Operations(nil, securitySchemesMap) assert.NoError(t, err) assert.Len(t, operations, 1) @@ -26,13 +25,13 @@ func TestGetSchemaValue_WhenNoParameters(t *testing.T) { func TestGetSchemaValue_WhenHeaderParametersWithExample(t *testing.T) { expected := "example" - openapi, _ := openapi.LoadFromData( + openapiContract, _ := openapi.LoadFromData( context.Background(), []byte(`{openapi: 3.0.2, servers: [{url: 'http://localhost:8080'}], paths: {/: {get: {parameters: [{name: param, in: header, required: true, schema: {type: string, example: example}}], responses: {'204': {description: successful operation}}}}}}`), ) - securitySchemesMap, _ := openapi.SecuritySchemeMap(auth.NewEmptySecuritySchemeValues()) - operations, err := openapi.Operations(nil, securitySchemesMap) + securitySchemesMap, _ := openapiContract.SecuritySchemeMap(openapi.NewEmptySecuritySchemeValues()) + operations, err := openapiContract.Operations(nil, securitySchemesMap) assert.NoError(t, err) assert.Len(t, operations, 1) @@ -42,13 +41,13 @@ func TestGetSchemaValue_WhenHeaderParametersWithExample(t *testing.T) { } func TestGetSchemaValue_WhenHeaderParametersWithoutExample(t *testing.T) { - openapi, _ := openapi.LoadFromData( + openapiContract, _ := openapi.LoadFromData( context.Background(), []byte(`{openapi: 3.0.2, servers: [{url: 'http://localhost:8080'}], paths: {/: {get: {parameters: [{name: param, in: header, required: true, schema: {type: string}}], responses: {'204': {description: successful operation}}}}}}`), ) - securitySchemesMap, _ := openapi.SecuritySchemeMap(auth.NewEmptySecuritySchemeValues()) - operations, err := openapi.Operations(nil, securitySchemesMap) + securitySchemesMap, _ := openapiContract.SecuritySchemeMap(openapi.NewEmptySecuritySchemeValues()) + operations, err := openapiContract.Operations(nil, securitySchemesMap) assert.NoError(t, err) assert.Len(t, operations, 1) @@ -59,13 +58,13 @@ func TestGetSchemaValue_WhenHeaderParametersWithoutExample(t *testing.T) { } func TestGetSchemaValue_WhenHeaderParametersNotRequired(t *testing.T) { - openapi, _ := openapi.LoadFromData( + openapiContract, _ := openapi.LoadFromData( context.Background(), []byte(`{openapi: 3.0.2, servers: [{url: 'http://localhost:8080'}], paths: {/: {get: {parameters: [{name: param, in: header, schema: {type: string, example: example}}], responses: {'204': {description: successful operation}}}}}}`), ) - securitySchemesMap, _ := openapi.SecuritySchemeMap(auth.NewEmptySecuritySchemeValues()) - operations, err := openapi.Operations(nil, securitySchemesMap) + securitySchemesMap, _ := openapiContract.SecuritySchemeMap(openapi.NewEmptySecuritySchemeValues()) + operations, err := openapiContract.Operations(nil, securitySchemesMap) assert.NoError(t, err) assert.Len(t, operations, 1) @@ -75,13 +74,13 @@ func TestGetSchemaValue_WhenHeaderParametersNotRequired(t *testing.T) { func TestGetSchemaValue_WhenCookieParametersWithExample(t *testing.T) { expected := "example" - openapi, _ := openapi.LoadFromData( + openapiContract, _ := openapi.LoadFromData( context.Background(), []byte(`{openapi: 3.0.2, servers: [{url: 'http://localhost:8080'}], paths: {/: {get: {parameters: [{name: param, in: cookie, required: true, schema: {type: string, example: example}}], responses: {'204': {description: successful operation}}}}}}`), ) - securitySchemesMap, _ := openapi.SecuritySchemeMap(auth.NewEmptySecuritySchemeValues()) - operations, err := openapi.Operations(nil, securitySchemesMap) + securitySchemesMap, _ := openapiContract.SecuritySchemeMap(openapi.NewEmptySecuritySchemeValues()) + operations, err := openapiContract.Operations(nil, securitySchemesMap) assert.NoError(t, err) assert.Len(t, operations, 1) @@ -91,13 +90,13 @@ func TestGetSchemaValue_WhenCookieParametersWithExample(t *testing.T) { } func TestGetSchemaValue_WhenCookieParametersWithoutExample(t *testing.T) { - openapi, _ := openapi.LoadFromData( + openapiContract, _ := openapi.LoadFromData( context.Background(), []byte(`{openapi: 3.0.2, servers: [{url: 'http://localhost:8080'}], paths: {/: {get: {parameters: [{name: param, in: cookie, required: true, schema: {type: string}}], responses: {'204': {description: successful operation}}}}}}`), ) - securitySchemesMap, _ := openapi.SecuritySchemeMap(auth.NewEmptySecuritySchemeValues()) - operations, err := openapi.Operations(nil, securitySchemesMap) + securitySchemesMap, _ := openapiContract.SecuritySchemeMap(openapi.NewEmptySecuritySchemeValues()) + operations, err := openapiContract.Operations(nil, securitySchemesMap) assert.NoError(t, err) assert.Len(t, operations, 1) @@ -108,13 +107,13 @@ func TestGetSchemaValue_WhenCookieParametersWithoutExample(t *testing.T) { } func TestGetSchemaValue_WhenCookieParametersNotRequired(t *testing.T) { - openapi, _ := openapi.LoadFromData( + openapiContract, _ := openapi.LoadFromData( context.Background(), []byte(`{openapi: 3.0.2, servers: [{url: 'http://localhost:8080'}], paths: {/: {get: {parameters: [{name: param, in: cookie, schema: {type: string, example: example}}], responses: {'204': {description: successful operation}}}}}}`), ) - securitySchemesMap, _ := openapi.SecuritySchemeMap(auth.NewEmptySecuritySchemeValues()) - operations, err := openapi.Operations(nil, securitySchemesMap) + securitySchemesMap, _ := openapiContract.SecuritySchemeMap(openapi.NewEmptySecuritySchemeValues()) + operations, err := openapiContract.Operations(nil, securitySchemesMap) assert.NoError(t, err) assert.Len(t, operations, 1) @@ -124,13 +123,13 @@ func TestGetSchemaValue_WhenCookieParametersNotRequired(t *testing.T) { func TestGetSchemaValue_WhenPathParametersWithExample(t *testing.T) { expected := "/example" - openapi, _ := openapi.LoadFromData( + openapiContract, _ := openapi.LoadFromData( context.Background(), []byte(`{openapi: 3.0.2, servers: [{url: 'http://localhost:8080'}], paths: {'/{param}': {get: {parameters: [{name: param, in: path, required: true, schema: {type: string, example: example}}], responses: {'204': {description: successful operation}}}}}}`), ) - securitySchemesMap, _ := openapi.SecuritySchemeMap(auth.NewEmptySecuritySchemeValues()) - operations, err := openapi.Operations(nil, securitySchemesMap) + securitySchemesMap, _ := openapiContract.SecuritySchemeMap(openapi.NewEmptySecuritySchemeValues()) + operations, err := openapiContract.Operations(nil, securitySchemesMap) assert.NoError(t, err) assert.Len(t, operations, 1) @@ -140,13 +139,13 @@ func TestGetSchemaValue_WhenPathParametersWithExample(t *testing.T) { } func TestGetSchemaValue_WhenPathParametersWithoutExample(t *testing.T) { - openapi, _ := openapi.LoadFromData( + openapiContract, _ := openapi.LoadFromData( context.Background(), []byte(`{openapi: 3.0.2, servers: [{url: 'http://localhost:8080'}], paths: {'/{param}': {get: {parameters: [{name: param, in: path, required: true, schema: {type: string}}], responses: {'204': {description: successful operation}}}}}}`), ) - securitySchemesMap, _ := openapi.SecuritySchemeMap(auth.NewEmptySecuritySchemeValues()) - operations, err := openapi.Operations(nil, securitySchemesMap) + securitySchemesMap, _ := openapiContract.SecuritySchemeMap(openapi.NewEmptySecuritySchemeValues()) + operations, err := openapiContract.Operations(nil, securitySchemesMap) assert.NoError(t, err) assert.Len(t, operations, 1) @@ -157,13 +156,13 @@ func TestGetSchemaValue_WhenPathParametersWithoutExample(t *testing.T) { func TestGetSchemaValue_WhenRequestBodyParametersWithExample(t *testing.T) { expected := []byte("\"example\"") - openapi, _ := openapi.LoadFromData( + openapiContract, _ := openapi.LoadFromData( context.Background(), []byte(`{openapi: 3.0.2, servers: [{url: 'http://localhost:8080'}], paths: {/: {post: {requestBody: {content: {'application/json': {schema: {type: string, example: example}}}}, responses: {'204': {description: successful operation}}}}}}`), ) - securitySchemesMap, _ := openapi.SecuritySchemeMap(auth.NewEmptySecuritySchemeValues()) - operations, err := openapi.Operations(nil, securitySchemesMap) + securitySchemesMap, _ := openapiContract.SecuritySchemeMap(openapi.NewEmptySecuritySchemeValues()) + operations, err := openapiContract.Operations(nil, securitySchemesMap) assert.NoError(t, err) assert.Len(t, operations, 1) @@ -175,13 +174,13 @@ func TestGetSchemaValue_WhenRequestBodyParametersWithExample(t *testing.T) { } func TestGetSchemaValue_WhenRequestBodyParametersWithoutExample(t *testing.T) { - openapi, _ := openapi.LoadFromData( + openapiContract, _ := openapi.LoadFromData( context.Background(), []byte(`{openapi: 3.0.2, servers: [{url: 'http://localhost:8080'}], paths: {/: {post: {requestBody: {content: {'application/json': {schema: {type: string}}}}, responses: {'204': {description: successful operation}}}}}}`), ) - securitySchemesMap, _ := openapi.SecuritySchemeMap(auth.NewEmptySecuritySchemeValues()) - operations, err := openapi.Operations(nil, securitySchemesMap) + securitySchemesMap, _ := openapiContract.SecuritySchemeMap(openapi.NewEmptySecuritySchemeValues()) + operations, err := openapiContract.Operations(nil, securitySchemesMap) assert.NoError(t, err) assert.Len(t, operations, 1) @@ -192,13 +191,13 @@ func TestGetSchemaValue_WhenRequestBodyParametersWithoutExample(t *testing.T) { } func TestGetSchemaValue_WhenRequestBodyParametersNotRequired(t *testing.T) { - openapi, _ := openapi.LoadFromData( + openapiContract, _ := openapi.LoadFromData( context.Background(), []byte(`{openapi: 3.0.2, servers: [{url: 'http://localhost:8080'}], paths: {/: {post: {requestBody: {content: {'application/json': {schema: {type: string, example: example}}}}, responses: {'204': {description: successful operation}}}}}}`), ) - securitySchemesMap, _ := openapi.SecuritySchemeMap(auth.NewEmptySecuritySchemeValues()) - operations, err := openapi.Operations(nil, securitySchemesMap) + securitySchemesMap, _ := openapiContract.SecuritySchemeMap(openapi.NewEmptySecuritySchemeValues()) + operations, err := openapiContract.Operations(nil, securitySchemesMap) assert.NoError(t, err) assert.Len(t, operations, 1) @@ -210,13 +209,13 @@ func TestGetSchemaValue_WhenRequestBodyParametersNotRequired(t *testing.T) { func TestGetSchemaValue_WhenRequestBodyParametersWithArrayExample(t *testing.T) { expected := []byte("[\"example\"]") - openapi, _ := openapi.LoadFromData( + openapiContract, _ := openapi.LoadFromData( context.Background(), []byte(`{openapi: 3.0.2, servers: [{url: 'http://localhost:8080'}], paths: {/: {post: {requestBody: {content: {'application/json': {schema: {type: array, items: {type: string, example: example}}}}}, responses: {'204': {description: successful operation}}}}}}`), ) - securitySchemesMap, _ := openapi.SecuritySchemeMap(auth.NewEmptySecuritySchemeValues()) - operations, err := openapi.Operations(nil, securitySchemesMap) + securitySchemesMap, _ := openapiContract.SecuritySchemeMap(openapi.NewEmptySecuritySchemeValues()) + operations, err := openapiContract.Operations(nil, securitySchemesMap) assert.NoError(t, err) assert.Len(t, operations, 1) @@ -229,13 +228,13 @@ func TestGetSchemaValue_WhenRequestBodyParametersWithArrayExample(t *testing.T) func TestGetSchemaValue_WhenRequestBodyParametersWithObjectExample(t *testing.T) { expected := []byte("{\"name\":\"example\"}") - openapi, operr := openapi.LoadFromData( + openapiContract, operr := openapi.LoadFromData( context.Background(), []byte(`{openapi: 3.0.2, servers: [{url: 'http://localhost:8080'}], paths: {/: {post: {requestBody: {content: {'application/json': {schema: {type: object, properties: {name: {type: string, example: example}}}}}}, responses: {'204': {description: successful operation}}}}}}`), ) - securitySchemesMap, _ := openapi.SecuritySchemeMap(auth.NewEmptySecuritySchemeValues()) - operations, err := openapi.Operations(nil, securitySchemesMap) + securitySchemesMap, _ := openapiContract.SecuritySchemeMap(openapi.NewEmptySecuritySchemeValues()) + operations, err := openapiContract.Operations(nil, securitySchemesMap) assert.NoError(t, operr) assert.NoError(t, err) @@ -249,13 +248,13 @@ func TestGetSchemaValue_WhenRequestBodyParametersWithObjectExample(t *testing.T) func TestGetSchemaValue_WhenRequestBodyParametersWithObjectExampleAndArrayExample(t *testing.T) { expected := []byte("{\"name\":[\"example\"]}") - openapi, operr := openapi.LoadFromData( + openapiContract, operr := openapi.LoadFromData( context.Background(), []byte(`{openapi: 3.0.2, servers: [{url: 'http://localhost:8080'}], paths: {/: {post: {requestBody: {content: {'application/json': {schema: {type: object, properties: {name: {type: array, items: {type: string, example: example}}}}}}}, responses: {'204': {description: successful operation}}}}}}`), ) - securitySchemesMap, _ := openapi.SecuritySchemeMap(auth.NewEmptySecuritySchemeValues()) - operations, err := openapi.Operations(nil, securitySchemesMap) + securitySchemesMap, _ := openapiContract.SecuritySchemeMap(openapi.NewEmptySecuritySchemeValues()) + operations, err := openapiContract.Operations(nil, securitySchemesMap) assert.NoError(t, operr) assert.NoError(t, err) diff --git a/openapi/security_scheme.go b/openapi/security_scheme.go index 2dcc29be..6cd8b79c 100644 --- a/openapi/security_scheme.go +++ b/openapi/security_scheme.go @@ -1,11 +1,13 @@ package openapi import ( + "context" "fmt" "strings" "github.com/cerberauth/vulnapi/internal/auth" "github.com/getkin/kin-openapi/openapi3" + "go.opentelemetry.io/otel/codes" ) const ( @@ -36,31 +38,37 @@ func NewErrUnsupportedSecuritySchemeType(schemeType string) error { return fmt.Errorf("unsupported security scheme type: %s", schemeType) } -func mapHTTPSchemeType(name string, scheme *openapi3.SecuritySchemeRef, securitySchemeValue *string) (auth.SecurityScheme, error) { +func mapHTTPSchemeType(name string, scheme *openapi3.SecuritySchemeRef, securitySchemeValue *string) (*auth.SecurityScheme, error) { schemeScheme := strings.ToLower(scheme.Value.Scheme) switch schemeScheme { case BearerScheme: - bearerFormat := strings.ToLower(scheme.Value.BearerFormat) - if bearerFormat == "" { - return auth.NewAuthorizationBearerSecurityScheme(name, securitySchemeValue), nil + securityScheme, err := auth.NewAuthorizationBearerSecurityScheme(name, securitySchemeValue) + if err != nil { + return nil, err } + bearerFormat := strings.ToLower(scheme.Value.BearerFormat) switch bearerFormat { + case "": + return securityScheme, nil case "jwt": - return auth.NewAuthorizationJWTBearerSecurityScheme(name, securitySchemeValue) + err := securityScheme.SetTokenFormat(auth.JWTTokenFormat) + if err != nil { + return nil, err + } + return securityScheme, nil default: return nil, NewErrUnsupportedBearerFormat(bearerFormat) } - default: return nil, NewErrUnsupportedScheme(schemeScheme) } } -func mapOAuth2SchemeType(name string, scheme *openapi3.SecuritySchemeRef, securitySchemeValue *string) (auth.SecurityScheme, error) { +func mapOAuth2SchemeType(name string, scheme *openapi3.SecuritySchemeRef, securitySchemeValue *auth.OAuthValue) (*auth.SecurityScheme, error) { if scheme.Value.Flows == nil { - return auth.NewOAuthSecurityScheme(name, securitySchemeValue, nil), nil + return auth.NewOAuthSecurityScheme(name, nil, securitySchemeValue, nil) } var cfg *auth.OAuthConfig @@ -82,10 +90,13 @@ func mapOAuth2SchemeType(name string, scheme *openapi3.SecuritySchemeRef, securi } } - return auth.NewOAuthSecurityScheme(name, securitySchemeValue, cfg), nil + return auth.NewOAuthSecurityScheme(name, nil, securitySchemeValue, cfg) } -func (openapi *OpenAPI) SecuritySchemeMap(values *auth.SecuritySchemeValues) (auth.SecuritySchemesMap, error) { +func (openapi *OpenAPI) SecuritySchemeMap(values *SecuritySchemeValues) (auth.SecuritySchemesMap, error) { + _, span := tracer.Start(context.Background(), "SecuritySchemeMap") + defer span.End() + var err error var securitySchemeValue interface{} @@ -93,7 +104,7 @@ func (openapi *OpenAPI) SecuritySchemeMap(values *auth.SecuritySchemeValues) (au return nil, nil } - securitySchemes := map[string]auth.SecurityScheme{} + securitySchemes := map[string]*auth.SecurityScheme{} for name, scheme := range openapi.Doc.Components.SecuritySchemes { securitySchemeValue = values.Get(name) @@ -107,12 +118,18 @@ func (openapi *OpenAPI) SecuritySchemeMap(values *auth.SecuritySchemeValues) (au case HttpSchemeType: securitySchemes[name], err = mapHTTPSchemeType(name, scheme, value) case OAuth2SchemeType, OpenIdConnectSchemeType: - securitySchemes[name], err = mapOAuth2SchemeType(name, scheme, value) + var oauthValue *auth.OAuthValue + if value != nil { + oauthValue = auth.NewOAuthValue(*value, nil, nil, nil) + } + securitySchemes[name], err = mapOAuth2SchemeType(name, scheme, oauthValue) default: err = NewErrUnsupportedSecuritySchemeType(schemeType) } if err != nil { + span.RecordError(err) + span.SetStatus(codes.Error, err.Error()) return nil, err } } diff --git a/openapi/security_scheme_test.go b/openapi/security_scheme_test.go index cb89a918..071383cb 100644 --- a/openapi/security_scheme_test.go +++ b/openapi/security_scheme_test.go @@ -11,12 +11,12 @@ import ( ) func TestSecuritySchemeMap_WithoutSecurityComponents(t *testing.T) { - openapi, _ := openapi.LoadFromData( + openapiContract, _ := openapi.LoadFromData( context.Background(), []byte(`{openapi: 3.0.2, servers: [{url: 'http://localhost:8080'}], paths: {/: {get: {parameters: [], responses: {'204': {description: successful operation}}}}}}`), ) - result, err := openapi.SecuritySchemeMap(auth.NewEmptySecuritySchemeValues()) + result, err := openapiContract.SecuritySchemeMap(openapi.NewEmptySecuritySchemeValues()) assert.NoError(t, err) assert.Nil(t, result) @@ -24,12 +24,12 @@ func TestSecuritySchemeMap_WithoutSecurityComponents(t *testing.T) { func TestSecuritySchemeMap_WithUnknownSchemeType(t *testing.T) { expectedErr := openapi.NewErrUnsupportedSecuritySchemeType("other") - openapi, _ := openapi.LoadFromData( + openapiContract, _ := openapi.LoadFromData( context.Background(), []byte(`{openapi: 3.0.2, servers: [{url: 'http://localhost:8080'}], paths: {/: {get: {parameters: [], responses: {'204': {description: successful operation}}, security: [{bearer_auth: []}]}}}, components: {securitySchemes: {bearer_auth: {type: other}}}}`), ) - result, err := openapi.SecuritySchemeMap(auth.NewEmptySecuritySchemeValues()) + result, err := openapiContract.SecuritySchemeMap(openapi.NewEmptySecuritySchemeValues()) assert.Error(t, err) assert.Equal(t, expectedErr, err) @@ -38,12 +38,12 @@ func TestSecuritySchemeMap_WithUnknownSchemeType(t *testing.T) { func TestSecuritySchemeMap_WithUnknownScheme(t *testing.T) { expectedErr := openapi.NewErrUnsupportedScheme("other") - openapi, _ := openapi.LoadFromData( + openapiContract, _ := openapi.LoadFromData( context.Background(), []byte(`{openapi: 3.0.2, servers: [{url: 'http://localhost:8080'}], paths: {/: {get: {parameters: [], responses: {'204': {description: successful operation}}, security: [{bearer_auth: []}]}}}, components: {securitySchemes: {bearer_auth: {type: http, scheme: other}}}}`), ) - result, err := openapi.SecuritySchemeMap(auth.NewEmptySecuritySchemeValues()) + result, err := openapiContract.SecuritySchemeMap(openapi.NewEmptySecuritySchemeValues()) assert.Error(t, err) assert.Equal(t, expectedErr, err) @@ -52,12 +52,12 @@ func TestSecuritySchemeMap_WithUnknownScheme(t *testing.T) { func TestSecuritySchemeMap_WithUnknownBearerFormat(t *testing.T) { expectedErr := openapi.NewErrUnsupportedBearerFormat("other") - openapi, _ := openapi.LoadFromData( + openapiContract, _ := openapi.LoadFromData( context.Background(), []byte(`{openapi: 3.0.2, servers: [{url: 'http://localhost:8080'}], paths: {/: {get: {parameters: [], responses: {'204': {description: successful operation}}, security: [{bearer_auth: []}]}}}, components: {securitySchemes: {bearer_auth: {type: http, scheme: bearer, bearerFormat: other}}}}`), ) - result, err := openapi.SecuritySchemeMap(auth.NewEmptySecuritySchemeValues()) + result, err := openapiContract.SecuritySchemeMap(openapi.NewEmptySecuritySchemeValues()) assert.Error(t, err) assert.Equal(t, expectedErr, err) @@ -65,125 +65,144 @@ func TestSecuritySchemeMap_WithUnknownBearerFormat(t *testing.T) { } func TestSecuritySchemeMap_WithHTTPJWTBearer(t *testing.T) { - openapi, _ := openapi.LoadFromData( + openapiContract, _ := openapi.LoadFromData( context.Background(), []byte(`{openapi: 3.0.2, servers: [{url: 'http://localhost:8080'}], paths: {/: {get: {parameters: [], responses: {'204': {description: successful operation}}, security: [{bearer_auth: []}]}}}, components: {securitySchemes: {bearer_auth: {type: http, scheme: bearer, bearerFormat: JWT}}}}`), ) - result, err := openapi.SecuritySchemeMap(auth.NewEmptySecuritySchemeValues()) + result, err := openapiContract.SecuritySchemeMap(openapi.NewEmptySecuritySchemeValues()) assert.NoError(t, err) assert.NotNil(t, result) - assert.IsType(t, &auth.JWTBearerSecurityScheme{}, result["bearer_auth"]) + assert.Equal(t, auth.HttpType, result["bearer_auth"].GetType()) + assert.Equal(t, auth.BearerScheme, result["bearer_auth"].GetScheme()) + assert.Equal(t, auth.JWTTokenFormat, *result["bearer_auth"].GetTokenFormat()) } func TestSecuritySchemeMap_WithHTTPBearer(t *testing.T) { - openapi, _ := openapi.LoadFromData( + openapiContract, _ := openapi.LoadFromData( context.Background(), []byte(`{openapi: 3.0.2, servers: [{url: 'http://localhost:8080'}], paths: {/: {get: {parameters: [], responses: {'204': {description: successful operation}}, security: [{bearer_auth: []}]}}}, components: {securitySchemes: {bearer_auth: {type: http, scheme: bearer}}}}`), ) - result, err := openapi.SecuritySchemeMap(auth.NewEmptySecuritySchemeValues()) + result, err := openapiContract.SecuritySchemeMap(openapi.NewEmptySecuritySchemeValues()) assert.NoError(t, err) assert.NotNil(t, result) - assert.IsType(t, &auth.BearerSecurityScheme{}, result["bearer_auth"]) + assert.Equal(t, auth.HttpType, result["bearer_auth"].GetType()) + assert.Equal(t, auth.BearerScheme, result["bearer_auth"].GetScheme()) } func TestSecuritySchemeMap_WithoutHTTPJWTBearerAndDefaultValue(t *testing.T) { - openapi, _ := openapi.LoadFromData( + openapiContract, _ := openapi.LoadFromData( context.Background(), []byte(`{openapi: 3.0.2, servers: [{url: 'http://localhost:8080'}], paths: {/: {get: {parameters: [], responses: {'204': {description: successful operation}}, security: [{bearer_auth: []}]}}}, components: {securitySchemes: {bearer_auth: {type: http, scheme: bearer, bearerFormat: JWT}}}}`), ) token := jwt.FakeJWT - result, err := openapi.SecuritySchemeMap(auth.NewEmptySecuritySchemeValues().WithDefault(&token)) + result, err := openapiContract.SecuritySchemeMap(openapi.NewEmptySecuritySchemeValues().WithDefault(&token)) assert.NoError(t, err) assert.NotNil(t, result) - assert.IsType(t, &auth.JWTBearerSecurityScheme{}, result["bearer_auth"]) + assert.Equal(t, auth.HttpType, result["bearer_auth"].GetType()) + assert.Equal(t, auth.BearerScheme, result["bearer_auth"].GetScheme()) + assert.Equal(t, auth.JWTTokenFormat, *result["bearer_auth"].GetTokenFormat()) } func TestSecuritySchemeMap_WithInvalidValueType(t *testing.T) { - openapi, _ := openapi.LoadFromData( + openapiContract, _ := openapi.LoadFromData( context.Background(), []byte(`{openapi: 3.0.2, servers: [{url: 'http://localhost:8080'}], paths: {/: {get: {parameters: [], responses: {'204': {description: successful operation}}, security: [{bearer_auth: []}]}}}, components: {securitySchemes: {bearer_auth: {type: http, scheme: bearer, bearerFormat: JWT}}}}`), ) - result, err := openapi.SecuritySchemeMap(auth.NewEmptySecuritySchemeValues().WithDefault("")) + result, err := openapiContract.SecuritySchemeMap(openapi.NewEmptySecuritySchemeValues().WithDefault("")) assert.NoError(t, err) assert.NotNil(t, result) - assert.IsType(t, &auth.JWTBearerSecurityScheme{}, result["bearer_auth"]) + assert.Equal(t, auth.HttpType, result["bearer_auth"].GetType()) + assert.Equal(t, auth.BearerScheme, result["bearer_auth"].GetScheme()) + assert.Equal(t, auth.JWTTokenFormat, *result["bearer_auth"].GetTokenFormat()) } func TestSecuritySchemeMap_WithOAuth(t *testing.T) { - openapi, _ := openapi.LoadFromData( + openapiContract, _ := openapi.LoadFromData( context.Background(), []byte(`{openapi: 3.0.2, servers: [{url: 'http://localhost:8080'}], paths: {/: {get: {parameters: [], responses: {'204': {description: successful operation}}, security: [{oauth_auth: []}]}}}, components: {securitySchemes: {oauth_auth: {type: oauth2}}}}`), ) - result, err := openapi.SecuritySchemeMap(auth.NewEmptySecuritySchemeValues()) + result, err := openapiContract.SecuritySchemeMap(openapi.NewEmptySecuritySchemeValues()) assert.NoError(t, err) assert.NotNil(t, result) - assert.IsType(t, &auth.OAuthSecurityScheme{}, result["oauth_auth"]) + assert.Equal(t, auth.OAuth2, result["oauth_auth"].GetType()) + assert.Equal(t, auth.OAuthScheme, result["oauth_auth"].GetScheme()) } func TestSecuritySchemeMap_WithOAuthAndAuthorizationCodeFlow(t *testing.T) { - openapi, _ := openapi.LoadFromData( + openapiContract, _ := openapi.LoadFromData( context.Background(), []byte(`{openapi: 3.0.2, servers: [{url: 'http://localhost:8080'}], paths: {/: {get: {parameters: [], responses: {'204': {description: successful operation}}, security: [{oauth_auth: []}]}}}, components: {securitySchemes: {oauth_auth: {type: oauth2, flows: {authorizationCode: {tokenUrl: 'http://localhost:8080/token', refreshUrl: 'http://localhost:8080/refresh'}}}}}}`), ) - result, err := openapi.SecuritySchemeMap(auth.NewEmptySecuritySchemeValues()) + result, err := openapiContract.SecuritySchemeMap(openapi.NewEmptySecuritySchemeValues()) assert.NoError(t, err) assert.NotNil(t, result) - assert.IsType(t, &auth.OAuthSecurityScheme{}, result["oauth_auth"]) - assert.Equal(t, "http://localhost:8080/token", result["oauth_auth"].(*auth.OAuthSecurityScheme).Config.TokenURL) - assert.Equal(t, "http://localhost:8080/refresh", result["oauth_auth"].(*auth.OAuthSecurityScheme).Config.RefreshURL) + assert.Equal(t, auth.OAuth2, result["oauth_auth"].GetType()) + assert.Equal(t, auth.OAuthScheme, result["oauth_auth"].GetScheme()) + assert.Equal(t, &auth.OAuthConfig{ + TokenURL: "http://localhost:8080/token", + RefreshURL: "http://localhost:8080/refresh", + }, result["oauth_auth"].GetConfig()) } func TestSecuritySchemeMap_WithOAuthAndImplicitFlow(t *testing.T) { - openapi, _ := openapi.LoadFromData( + openapiContract, _ := openapi.LoadFromData( context.Background(), []byte(`{openapi: 3.0.2, servers: [{url: 'http://localhost:8080'}], paths: {/: {get: {parameters: [], responses: {'204': {description: successful operation}}, security: [{oauth_auth: []}]}}}, components: {securitySchemes: {oauth_auth: {type: oauth2, flows: {implicit: {tokenUrl: 'http://localhost:8080/token', refreshUrl: 'http://localhost:8080/refresh'}}}}}}`), ) - result, err := openapi.SecuritySchemeMap(auth.NewEmptySecuritySchemeValues()) + result, err := openapiContract.SecuritySchemeMap(openapi.NewEmptySecuritySchemeValues()) assert.NoError(t, err) assert.NotNil(t, result) - assert.IsType(t, &auth.OAuthSecurityScheme{}, result["oauth_auth"]) - assert.Equal(t, "http://localhost:8080/token", result["oauth_auth"].(*auth.OAuthSecurityScheme).Config.TokenURL) - assert.Equal(t, "http://localhost:8080/refresh", result["oauth_auth"].(*auth.OAuthSecurityScheme).Config.RefreshURL) + assert.Equal(t, auth.OAuth2, result["oauth_auth"].GetType()) + assert.Equal(t, auth.OAuthScheme, result["oauth_auth"].GetScheme()) + assert.Equal(t, &auth.OAuthConfig{ + TokenURL: "http://localhost:8080/token", + RefreshURL: "http://localhost:8080/refresh", + }, result["oauth_auth"].GetConfig()) } func TestSecuritySchemeMap_WithOAuthAndClientCredentialsFlow(t *testing.T) { - openapi, _ := openapi.LoadFromData( + openapiContract, _ := openapi.LoadFromData( context.Background(), []byte(`{openapi: 3.0.2, servers: [{url: 'http://localhost:8080'}], paths: {/: {get: {parameters: [], responses: {'204': {description: successful operation}}, security: [{oauth_auth: []}]}}}, components: {securitySchemes: {oauth_auth: {type: oauth2, flows: {clientCredentials: {tokenUrl: 'http://localhost:8080/token', refreshUrl: 'http://localhost:8080/refresh'}}}}}}`), ) - result, err := openapi.SecuritySchemeMap(auth.NewEmptySecuritySchemeValues()) + result, err := openapiContract.SecuritySchemeMap(openapi.NewEmptySecuritySchemeValues()) assert.NoError(t, err) assert.NotNil(t, result) - assert.IsType(t, &auth.OAuthSecurityScheme{}, result["oauth_auth"]) - assert.Equal(t, "http://localhost:8080/token", result["oauth_auth"].(*auth.OAuthSecurityScheme).Config.TokenURL) - assert.Equal(t, "http://localhost:8080/refresh", result["oauth_auth"].(*auth.OAuthSecurityScheme).Config.RefreshURL) + assert.Equal(t, auth.OAuth2, result["oauth_auth"].GetType()) + assert.Equal(t, auth.OAuthScheme, result["oauth_auth"].GetScheme()) + assert.Equal(t, &auth.OAuthConfig{ + TokenURL: "http://localhost:8080/token", + RefreshURL: "http://localhost:8080/refresh", + }, result["oauth_auth"].GetConfig()) } func TestSecuritySchemeMap_WithOpenIDConnect(t *testing.T) { - openapi, _ := openapi.LoadFromData( + openapiContract, _ := openapi.LoadFromData( context.Background(), []byte(`{openapi: 3.0.2, servers: [{url: 'http://localhost:8080'}], paths: {/: {get: {parameters: [], responses: {'204': {description: successful operation}}, security: [{oidc_auth: []}]}}}, components: {securitySchemes: {oidc_auth: {type: openIdConnect}}}}`), ) - result, err := openapi.SecuritySchemeMap(auth.NewEmptySecuritySchemeValues()) + result, err := openapiContract.SecuritySchemeMap(openapi.NewEmptySecuritySchemeValues()) assert.NoError(t, err) assert.NotNil(t, result) - assert.IsType(t, &auth.OAuthSecurityScheme{}, result["oidc_auth"]) + assert.Equal(t, auth.OAuth2, result["oidc_auth"].GetType()) + assert.Equal(t, auth.OAuthScheme, result["oidc_auth"].GetScheme()) + assert.Nil(t, result["oidc_auth"].GetConfig()) } diff --git a/internal/auth/security_scheme_values.go b/openapi/security_scheme_values.go similarity index 98% rename from internal/auth/security_scheme_values.go rename to openapi/security_scheme_values.go index b9911f0e..f8dd2f44 100644 --- a/internal/auth/security_scheme_values.go +++ b/openapi/security_scheme_values.go @@ -1,4 +1,4 @@ -package auth +package openapi type SecuritySchemeValues struct { Default interface{} diff --git a/internal/auth/security_scheme_values_test.go b/openapi/security_scheme_values_test.go similarity index 72% rename from internal/auth/security_scheme_values_test.go rename to openapi/security_scheme_values_test.go index 8ce4cece..9a3ba772 100644 --- a/internal/auth/security_scheme_values_test.go +++ b/openapi/security_scheme_values_test.go @@ -1,9 +1,9 @@ -package auth_test +package openapi_test import ( "testing" - "github.com/cerberauth/vulnapi/internal/auth" + "github.com/cerberauth/vulnapi/openapi" "github.com/stretchr/testify/assert" ) @@ -11,7 +11,7 @@ func TestNewSecuritySchemeValues(t *testing.T) { values := map[string]interface{}{ "key": "value", } - securitySchemeValues := auth.NewSecuritySchemeValues(values) + securitySchemeValues := openapi.NewSecuritySchemeValues(values) assert.Nil(t, securitySchemeValues.Default) assert.NotNil(t, securitySchemeValues.Values) @@ -19,7 +19,7 @@ func TestNewSecuritySchemeValues(t *testing.T) { } func TestNewEmptySecuritySchemeValues(t *testing.T) { - securitySchemeValues := auth.NewEmptySecuritySchemeValues() + securitySchemeValues := openapi.NewEmptySecuritySchemeValues() assert.Nil(t, securitySchemeValues.Default) assert.NotNil(t, securitySchemeValues.Values) @@ -27,21 +27,21 @@ func TestNewEmptySecuritySchemeValues(t *testing.T) { } func TestSecuritySchemeValues_WithDefault(t *testing.T) { - securitySchemeValues := auth.NewEmptySecuritySchemeValues() + securitySchemeValues := openapi.NewEmptySecuritySchemeValues() securitySchemeValues.WithDefault("default") assert.Equal(t, "default", securitySchemeValues.Default) } func TestSecuritySchemeValues_GetDefault(t *testing.T) { - securitySchemeValues := auth.NewEmptySecuritySchemeValues() + securitySchemeValues := openapi.NewEmptySecuritySchemeValues() securitySchemeValues.WithDefault("default") assert.Equal(t, "default", securitySchemeValues.GetDefault()) } func TestSecuritySchemeValues_Get(t *testing.T) { - securitySchemeValues := auth.NewEmptySecuritySchemeValues() + securitySchemeValues := openapi.NewEmptySecuritySchemeValues() securitySchemeValues.WithDefault("default") securitySchemeValues.Set("key", "value") @@ -49,14 +49,14 @@ func TestSecuritySchemeValues_Get(t *testing.T) { } func TestSecuritySchemeValues_Get_WhenNotExist(t *testing.T) { - securitySchemeValues := auth.NewEmptySecuritySchemeValues() + securitySchemeValues := openapi.NewEmptySecuritySchemeValues() securitySchemeValues.WithDefault("default") assert.Equal(t, "default", securitySchemeValues.Get("key")) } func TestSecuritySchemeValues_Set(t *testing.T) { - securitySchemeValues := auth.NewEmptySecuritySchemeValues() + securitySchemeValues := openapi.NewEmptySecuritySchemeValues() securitySchemeValues.Set("key", "value") assert.Equal(t, "value", securitySchemeValues.Get("key")) diff --git a/report.json b/report.json new file mode 100644 index 00000000..22c2e761 --- /dev/null +++ b/report.json @@ -0,0 +1 @@ +{"$schema":"https://schemas.cerberauth.com/vulnapi/draft/2024-10/report.schema.json","options":{},"curl":{"method":"GET","url":"http://localhost:8080","data":"","headers":{"Authorization":["Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c"]},"securitySchemes":[{"type":"http","scheme":"Bearer","in":"header","token_format":"jwt","name":"default"}],"issues":[{"id":"broken_authentication.alg_none","name":"JWT Algorithm None is accepted","url":"https://vulnapi.cerberauth.com/docs/vulnerabilities/broken-authentication/jwt-alg-none?utm_source=vulnapi","cvss":{"version":4,"vector":"CVSS:4.0/AV:N/AC:L/AT:N/PR:N/UI:N/VC:H/VI:H/VA:N/SC:N/SI:N/SA:N","score":9.3},"classifications":{"owasp":"API2:2023 Broken Authentication","cwe":"CWE-345: Insufficient Verification of Data Authenticity"},"status":"failed"}]},"reports":[{"id":"jwt.alg_none","name":"JWT None Algorithm","startTime":"2024-11-17T18:46:14.75459302+01:00","endTime":"2024-11-17T18:46:15.044591343+01:00","operation":{"id":"getRoot"},"data":{"alg":"none"},"scans":[{"request":{"method":"GET","url":"http://localhost:8080","headers":{"Authorization":["Bearer eyJhbGciOiJub25lIiwidHlwIjoiSldUIn0.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ."],"User-Agent":["vulnapi"]}},"response":{"statusCode":204,"body":"","headers":{"Date":["Sun, 17 Nov 2024 17:46:15 GMT"]}}}],"issues":[{"id":"broken_authentication.alg_none","name":"JWT Algorithm None is accepted","url":"https://vulnapi.cerberauth.com/docs/vulnerabilities/broken-authentication/jwt-alg-none?utm_source=vulnapi","cvss":{"version":4,"vector":"CVSS:4.0/AV:N/AC:L/AT:N/PR:N/UI:N/VC:H/VI:H/VA:N/SC:N/SI:N/SA:N","score":9.3},"classifications":{"owasp":"API2:2023 Broken Authentication","cwe":"CWE-345: Insufficient Verification of Data Authenticity"},"status":"failed"}]}]} \ No newline at end of file diff --git a/report/curl_report.go b/report/curl_report.go index c7f0068a..72a4eedf 100644 --- a/report/curl_report.go +++ b/report/curl_report.go @@ -18,10 +18,10 @@ type CurlReport struct { Issues []*IssueReport `json:"issues" yaml:"issues"` } -func NewCurlReport(method string, url string, data interface{}, header http.Header, cookies []*http.Cookie, securitySchemes []auth.SecurityScheme) *CurlReport { +func NewCurlReport(method string, url string, data interface{}, header http.Header, cookies []*http.Cookie, securitySchemes []*auth.SecurityScheme) *CurlReport { reportSecuritySchemes := []OperationSecurityScheme{} - for _, ss := range securitySchemes { - reportSecuritySchemes = append(reportSecuritySchemes, NewOperationSecurityScheme(ss)) + for _, securityScheme := range securitySchemes { + reportSecuritySchemes = append(reportSecuritySchemes, NewOperationSecurityScheme(securityScheme)) } return &CurlReport{ diff --git a/report/curl_report_test.go b/report/curl_report_test.go index 4cad15f3..407b7465 100644 --- a/report/curl_report_test.go +++ b/report/curl_report_test.go @@ -18,8 +18,8 @@ func TestNewCurlReport(t *testing.T) { header := http.Header{"Content-Type": []string{"application/json"}} cookies := []*http.Cookie{{Name: "session_id", Value: "abc123"}} value := jwt.FakeJWT - securityScheme, _ := auth.NewAuthorizationJWTBearerSecurityScheme("token", &value) - securitySchemes := []auth.SecurityScheme{securityScheme} + securityScheme := auth.MustNewAuthorizationBearerSecurityScheme("token", &value) + securitySchemes := []*auth.SecurityScheme{securityScheme} curlReport := report.NewCurlReport(method, url, data, header, cookies, securitySchemes) @@ -41,7 +41,7 @@ func Test_CurlReport_AddReport(t *testing.T) { data := map[string]interface{}{"key": "value"} header := http.Header{"Content-Type": []string{"application/json"}} cookies := []*http.Cookie{{Name: "session_id", Value: "abc123"}} - securitySchemes := []auth.SecurityScheme{} + securitySchemes := []*auth.SecurityScheme{} curlReport := report.NewCurlReport(method, url, data, header, cookies, securitySchemes) @@ -70,7 +70,7 @@ func TestAddReport_WhenScanReportHasNoFailedIssueReport(t *testing.T) { data := map[string]interface{}{"key": "value"} header := http.Header{"Content-Type": []string{"application/json"}} cookies := []*http.Cookie{{Name: "session_id", Value: "abc123"}} - securitySchemes := []auth.SecurityScheme{} + securitySchemes := []*auth.SecurityScheme{} curlReport := report.NewCurlReport(method, url, data, header, cookies, securitySchemes) diff --git a/report/graphql_report.go b/report/graphql_report.go index d1f46ee3..266bda15 100644 --- a/report/graphql_report.go +++ b/report/graphql_report.go @@ -26,7 +26,7 @@ type GraphQLReport struct { Mutations GraphQLOperationsMethods `json:"mutations" yaml:"mutations"` } -func NewGraphQLReport(url string, securitySchemes []auth.SecurityScheme) *GraphQLReport { +func NewGraphQLReport(url string, securitySchemes []*auth.SecurityScheme) *GraphQLReport { queries := GraphQLOperationsMethods{} mutations := GraphQLOperationsMethods{} diff --git a/report/issue_report.go b/report/issue_report.go index 1499c126..8ef51a1c 100644 --- a/report/issue_report.go +++ b/report/issue_report.go @@ -35,7 +35,7 @@ type IssueReport struct { Status IssueReportStatus `json:"status" yaml:"status"` Operation *operation.Operation `json:"-" yaml:"-"` - SecurityScheme auth.SecurityScheme `json:"-" yaml:"-"` + SecurityScheme *auth.SecurityScheme `json:"-" yaml:"-"` } func NewIssueReport(issue Issue) *IssueReport { @@ -50,8 +50,8 @@ func (vr *IssueReport) WithOperation(operation *operation.Operation) *IssueRepor return vr } -func (vr *IssueReport) WithSecurityScheme(ss auth.SecurityScheme) *IssueReport { - vr.SecurityScheme = ss +func (vr *IssueReport) WithSecurityScheme(securityScheme *auth.SecurityScheme) *IssueReport { + vr.SecurityScheme = securityScheme return vr } diff --git a/report/issue_report_test.go b/report/issue_report_test.go index 9fc08a6d..803d0e20 100644 --- a/report/issue_report_test.go +++ b/report/issue_report_test.go @@ -53,7 +53,7 @@ func TestIssueReport_WithSecurityScheme(t *testing.T) { } vr := report.NewIssueReport(issue) value := jwt.FakeJWT - securityScheme, _ := auth.NewAuthorizationJWTBearerSecurityScheme("token", &value) + securityScheme := auth.MustNewAuthorizationBearerSecurityScheme("token", &value) vr.WithSecurityScheme(securityScheme) assert.Equal(t, jwt.FakeJWT, vr.SecurityScheme.GetValidValue()) } diff --git a/report/openapi_report_test.go b/report/openapi_report_test.go index 84fcd578..7fdfe2c7 100644 --- a/report/openapi_report_test.go +++ b/report/openapi_report_test.go @@ -5,7 +5,6 @@ import ( "net/http" "testing" - "github.com/cerberauth/vulnapi/internal/auth" "github.com/cerberauth/vulnapi/openapi" "github.com/cerberauth/vulnapi/report" "github.com/stretchr/testify/assert" @@ -13,7 +12,7 @@ import ( func TestNewOpenAPIReportOperation(t *testing.T) { doc, _ := openapi.LoadOpenAPI(context.Background(), "../test/stub/simple_http_bearer.openapi.json") - securitySchemesMap, _ := doc.SecuritySchemeMap(&auth.SecuritySchemeValues{}) + securitySchemesMap, _ := doc.SecuritySchemeMap(openapi.NewEmptySecuritySchemeValues()) operations, _ := doc.Operations(nil, securitySchemesMap) securitySchemes := operations[0].GetSecuritySchemes() @@ -29,7 +28,7 @@ func TestNewOpenAPIReportOperation(t *testing.T) { func TestNewOpenAPIReport(t *testing.T) { doc, _ := openapi.LoadOpenAPI(context.Background(), "../test/stub/simple_http_bearer.openapi.json") - securitySchemesMap, _ := doc.SecuritySchemeMap(&auth.SecuritySchemeValues{}) + securitySchemesMap, _ := doc.SecuritySchemeMap(openapi.NewEmptySecuritySchemeValues()) operations, _ := doc.Operations(nil, securitySchemesMap) r := report.NewOpenAPIReport(doc.Doc, operations) @@ -43,7 +42,7 @@ func TestNewOpenAPIReport(t *testing.T) { func Test_OpenAPIReport_AddReport(t *testing.T) { doc, _ := openapi.LoadOpenAPI(context.Background(), "../test/stub/simple_http_bearer.openapi.json") - securitySchemesMap, _ := doc.SecuritySchemeMap(&auth.SecuritySchemeValues{}) + securitySchemesMap, _ := doc.SecuritySchemeMap(openapi.NewEmptySecuritySchemeValues()) operations, _ := doc.Operations(nil, securitySchemesMap) r := report.NewOpenAPIReport(doc.Doc, operations) @@ -66,7 +65,7 @@ func Test_OpenAPIReport_AddReport(t *testing.T) { func Test_OpenAPIReport_AddReport_NoFailedIssue(t *testing.T) { doc, _ := openapi.LoadOpenAPI(context.Background(), "../test/stub/simple_http_bearer.openapi.json") - securitySchemesMap, _ := doc.SecuritySchemeMap(&auth.SecuritySchemeValues{}) + securitySchemesMap, _ := doc.SecuritySchemeMap(openapi.NewEmptySecuritySchemeValues()) operations, _ := doc.Operations(nil, securitySchemesMap) r := report.NewOpenAPIReport(doc.Doc, operations) diff --git a/report/report.go b/report/report.go index 871f7324..4c39c807 100644 --- a/report/report.go +++ b/report/report.go @@ -11,18 +11,22 @@ import ( ) type OperationSecurityScheme struct { - Type auth.Type `json:"type" yaml:"type"` - Scheme auth.SchemeName `json:"scheme" yaml:"scheme"` - In *auth.SchemeIn `json:"in,omitempty" yaml:"in,omitempty"` - Name string `json:"name" yaml:"name"` + Type auth.Type `json:"type" yaml:"type"` + Scheme auth.SchemeName `json:"scheme" yaml:"scheme"` + In *auth.SchemeIn `json:"in" yaml:"in"` + TokenFormat *auth.TokenFormat `json:"token_format" yaml:"token_format"` + + Name string `json:"name" yaml:"name"` } -func NewOperationSecurityScheme(ss auth.SecurityScheme) OperationSecurityScheme { +func NewOperationSecurityScheme(securityScheme *auth.SecurityScheme) OperationSecurityScheme { return OperationSecurityScheme{ - Type: ss.GetType(), - Scheme: ss.GetScheme(), - In: ss.GetIn(), - Name: ss.GetName(), + Type: securityScheme.GetType(), + Scheme: securityScheme.GetScheme(), + In: securityScheme.GetIn(), + TokenFormat: securityScheme.GetTokenFormat(), + + Name: securityScheme.GetName(), } } diff --git a/report/report_test.go b/report/report_test.go index f9547c18..8f639492 100644 --- a/report/report_test.go +++ b/report/report_test.go @@ -16,35 +16,40 @@ import ( func TestNewOperationSecurityScheme(t *testing.T) { inHeader := auth.InHeader + value := "test" + noneTokenFormat := auth.NoneTokenFormat + tests := []struct { - name string - ss auth.SecurityScheme - want report.OperationSecurityScheme + name string + securityScheme *auth.SecurityScheme + want report.OperationSecurityScheme }{ { - name: "No Auth", - ss: auth.NewNoAuthSecurityScheme(), + name: "No Auth", + securityScheme: auth.MustNewNoAuthSecurityScheme(), want: report.OperationSecurityScheme{ Type: auth.None, Scheme: auth.NoneScheme, In: nil, + Name: "no_auth", }, }, { - name: "Bearer Token", - ss: auth.NewAuthorizationBearerSecurityScheme("test", nil), + name: "Bearer Token", + securityScheme: auth.MustNewAuthorizationBearerSecurityScheme("test", &value), want: report.OperationSecurityScheme{ - Type: auth.HttpType, - Scheme: auth.BearerScheme, - In: &inHeader, - Name: "test", + Type: auth.HttpType, + Scheme: auth.BearerScheme, + In: &inHeader, + TokenFormat: &noneTokenFormat, + Name: "test", }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := report.NewOperationSecurityScheme(tt.ss) + got := report.NewOperationSecurityScheme(tt.securityScheme) assert.Equal(t, tt.want, got) }) } diff --git a/report/reporter.go b/report/reporter.go index 72905212..4e248855 100644 --- a/report/reporter.go +++ b/report/reporter.go @@ -29,7 +29,7 @@ func NewReporter() *Reporter { } } -func NewReporterWithCurl(method string, url string, data interface{}, header http.Header, cookies []*http.Cookie, securitySchemes []auth.SecurityScheme) *Reporter { +func NewReporterWithCurl(method string, url string, data interface{}, header http.Header, cookies []*http.Cookie, securitySchemes []*auth.SecurityScheme) *Reporter { return &Reporter{ Schema: reporterSchema, @@ -49,7 +49,7 @@ func NewReporterWithOpenAPIDoc(openapi *openapi3.T, operations operation.Operati } } -func NewReporterWithGraphQL(url string, securitySchemes []auth.SecurityScheme) *Reporter { +func NewReporterWithGraphQL(url string, securitySchemes []*auth.SecurityScheme) *Reporter { return &Reporter{ Schema: reporterSchema, diff --git a/report/reporter_test.go b/report/reporter_test.go index 75c6c919..093971cd 100644 --- a/report/reporter_test.go +++ b/report/reporter_test.go @@ -7,6 +7,7 @@ import ( "github.com/cerberauth/vulnapi/internal/auth" "github.com/cerberauth/vulnapi/internal/operation" + "github.com/cerberauth/vulnapi/openapi" openapilib "github.com/cerberauth/vulnapi/openapi" "github.com/cerberauth/vulnapi/report" "github.com/stretchr/testify/assert" @@ -19,15 +20,16 @@ func TestNewReporterWithCurl(t *testing.T) { header := http.Header{"Content-Type": []string{"application/json"}} cookies := []*http.Cookie{{Name: "session_id", Value: "abc123"}} token := "abc123" - securityScheme := auth.SecurityScheme(auth.NewAuthorizationBearerSecurityScheme("token", &token)) - securitySchemes := []auth.SecurityScheme{securityScheme} + securityScheme := auth.MustNewAuthorizationBearerSecurityScheme("token", &token) + securitySchemes := []*auth.SecurityScheme{securityScheme} reportSecuritySchemes := []report.OperationSecurityScheme{ { - Type: securityScheme.GetType(), - Scheme: securityScheme.GetScheme(), - In: securityScheme.GetIn(), - Name: securityScheme.GetName(), + Type: securityScheme.GetType(), + Scheme: securityScheme.GetScheme(), + In: securityScheme.GetIn(), + TokenFormat: securityScheme.GetTokenFormat(), + Name: securityScheme.GetName(), }, } @@ -51,8 +53,8 @@ func TestNewReporterWithCurl_AddReport(t *testing.T) { header := http.Header{"Content-Type": []string{"application/json"}} cookies := []*http.Cookie{{Name: "session_id", Value: "abc123"}} token := "abc123" - securityScheme := auth.SecurityScheme(auth.NewAuthorizationBearerSecurityScheme("token", &token)) - securitySchemes := []auth.SecurityScheme{securityScheme} + securityScheme := auth.MustNewAuthorizationBearerSecurityScheme("token", &token) + securitySchemes := []*auth.SecurityScheme{securityScheme} reporter := report.NewReporterWithCurl(method, url, data, header, cookies, securitySchemes) @@ -76,11 +78,11 @@ func TestNewReporterWithCurl_AddReport(t *testing.T) { } func TestNewReporterWithOpenAPIDoc(t *testing.T) { - openapi, _ := openapilib.LoadOpenAPI(context.Background(), "../test/stub/simple_http_bearer.openapi.json") - securitySchemesMap, _ := openapi.SecuritySchemeMap(auth.NewEmptySecuritySchemeValues()) - operations, _ := openapi.Operations(nil, securitySchemesMap) + openapiContract, _ := openapilib.LoadOpenAPI(context.Background(), "../test/stub/simple_http_bearer.openapi.json") + securitySchemesMap, _ := openapiContract.SecuritySchemeMap(openapi.NewEmptySecuritySchemeValues()) + operations, _ := openapiContract.Operations(nil, securitySchemesMap) - reporter := report.NewReporterWithOpenAPIDoc(openapi.Doc, operations) + reporter := report.NewReporterWithOpenAPIDoc(openapiContract.Doc, operations) assert.NotNil(t, reporter) assert.NotNil(t, reporter.OpenAPI) @@ -88,10 +90,10 @@ func TestNewReporterWithOpenAPIDoc(t *testing.T) { } func TestReporterWithOpenAPIDoc_AddReport(t *testing.T) { - openapi, _ := openapilib.LoadOpenAPI(context.Background(), "../test/stub/simple_http_bearer.openapi.json") - securitySchemesMap, _ := openapi.SecuritySchemeMap(auth.NewEmptySecuritySchemeValues()) - operations, _ := openapi.Operations(nil, securitySchemesMap) - reporter := report.NewReporterWithOpenAPIDoc(openapi.Doc, operations) + openapiContract, _ := openapilib.LoadOpenAPI(context.Background(), "../test/stub/simple_http_bearer.openapi.json") + securitySchemesMap, _ := openapiContract.SecuritySchemeMap(openapi.NewEmptySecuritySchemeValues()) + operations, _ := openapiContract.Operations(nil, securitySchemesMap) + reporter := report.NewReporterWithOpenAPIDoc(openapiContract.Doc, operations) issue := report.Issue{ ID: "id", diff --git a/scan/broken_authentication/authentication_bypass/authentication_bypass.go b/scan/broken_authentication/authentication_bypass/authentication_bypass.go index 63735caf..4ce5effe 100644 --- a/scan/broken_authentication/authentication_bypass/authentication_bypass.go +++ b/scan/broken_authentication/authentication_bypass/authentication_bypass.go @@ -27,16 +27,16 @@ var issue = report.Issue{ }, } -func ScanHandler(op *operation.Operation, securityScheme auth.SecurityScheme) (*report.ScanReport, error) { +func ScanHandler(op *operation.Operation, securityScheme *auth.SecurityScheme) (*report.ScanReport, error) { vulnReport := report.NewIssueReport(issue).WithOperation(op).WithSecurityScheme(securityScheme) r := report.NewScanReport(AcceptsUnauthenticatedOperationScanID, AcceptsUnauthenticatedOperationScanName, op) - if _, ok := securityScheme.(*auth.NoAuthSecurityScheme); ok { + if securityScheme.GetType() == auth.None { return r.AddIssueReport(vulnReport.Skip()).End(), nil } - noAuthSecurityScheme := auth.SecurityScheme(auth.NewNoAuthSecurityScheme()) - vsa, err := scan.ScanURL(op, &noAuthSecurityScheme) + noAuthSecurityScheme := auth.MustNewNoAuthSecurityScheme() + vsa, err := scan.ScanURL(op, noAuthSecurityScheme) if err != nil { return r, err } diff --git a/scan/broken_authentication/authentication_bypass/authentication_bypass_test.go b/scan/broken_authentication/authentication_bypass/authentication_bypass_test.go index 5890a745..d4a281ce 100644 --- a/scan/broken_authentication/authentication_bypass/authentication_bypass_test.go +++ b/scan/broken_authentication/authentication_bypass/authentication_bypass_test.go @@ -14,7 +14,7 @@ import ( ) func TestAuthenticationByPassScanHandler_Skipped_WhenNoAuthSecurityScheme(t *testing.T) { - securityScheme := auth.NewNoAuthSecurityScheme() + securityScheme := auth.MustNewNoAuthSecurityScheme() operation := operation.MustNewOperation(http.MethodGet, "http://localhost:8080/", nil, nil) report, err := authenticationbypass.ScanHandler(operation, securityScheme) @@ -29,7 +29,7 @@ func TestAuthenticationByPassScanHandler_Failed_WhenAuthIsByPassed(t *testing.T) defer httpmock.DeactivateAndReset() token := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c" - securityScheme := auth.NewAuthorizationBearerSecurityScheme("default", &token) + securityScheme := auth.MustNewAuthorizationBearerSecurityScheme("default", &token) operation := operation.MustNewOperation(http.MethodGet, "http://localhost:8080/", nil, client) httpmock.RegisterResponder(operation.Method, operation.URL.String(), httpmock.NewBytesResponder(http.StatusNoContent, nil)) @@ -45,7 +45,7 @@ func TestAuthenticationByPassScanHandler_Passed_WhenAuthIsNotByPassed(t *testing defer httpmock.DeactivateAndReset() token := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c" - securityScheme := auth.NewAuthorizationBearerSecurityScheme("default", &token) + securityScheme := auth.MustNewAuthorizationBearerSecurityScheme("default", &token) operation := operation.MustNewOperation(http.MethodGet, "http://localhost:8080/", nil, client) httpmock.RegisterResponder(operation.Method, operation.URL.String(), httpmock.NewBytesResponder(http.StatusUnauthorized, nil)) diff --git a/scan/broken_authentication/jwt/alg_none/alg_none.go b/scan/broken_authentication/jwt/alg_none/alg_none.go index 0a4e18ee..53bcd55b 100644 --- a/scan/broken_authentication/jwt/alg_none/alg_none.go +++ b/scan/broken_authentication/jwt/alg_none/alg_none.go @@ -37,16 +37,8 @@ var issue = report.Issue{ }, } -func ShouldBeScanned(securitySheme auth.SecurityScheme) bool { - if securitySheme == nil { - return false - } - - if _, ok := securitySheme.(*auth.JWTBearerSecurityScheme); !ok { - return false - } - - return true +func ShouldBeScanned(securityScheme *auth.SecurityScheme) bool { + return securityScheme != nil && securityScheme.GetType() != auth.None && (securityScheme.GetTokenFormat() == nil || *securityScheme.GetTokenFormat() == auth.JWTTokenFormat) } var algs = []string{ @@ -56,7 +48,7 @@ var algs = []string{ "nOnE", } -func ScanHandler(op *operation.Operation, securityScheme auth.SecurityScheme) (*report.ScanReport, error) { +func ScanHandler(op *operation.Operation, securityScheme *auth.SecurityScheme) (*report.ScanReport, error) { issueReport := report.NewIssueReport(issue).WithOperation(op).WithSecurityScheme(securityScheme) r := report.NewScanReport(AlgNoneJwtScanID, AlgNoneJwtScanName, op) @@ -66,18 +58,24 @@ func ScanHandler(op *operation.Operation, securityScheme auth.SecurityScheme) (* return r, nil } - var valueWriter *jwt.JWTWriter + var token string if securityScheme.HasValidValue() { - valueWriter = securityScheme.GetValidValueWriter().(*jwt.JWTWriter) - if valueWriter.GetToken().Method.Alg() == jwtlib.SigningMethodNone.Alg() { - return r, nil - } - - valueWriter = jwt.NewJWTWriterWithValidClaims(valueWriter) + token = securityScheme.GetToken() } else { - valueWriter, _ = jwt.NewJWTWriter(jwt.FakeJWT) + token = jwt.FakeJWT + } + + valueWriter, err := jwt.NewJWTWriter(token) + if err != nil { + return r, err } + if valueWriter.GetToken().Method.Alg() == jwtlib.SigningMethodNone.Alg() { + r.AddIssueReport(issueReport.Fail()).End() + return r, nil + } + valueWriter = jwt.NewJWTWriterWithValidClaims(valueWriter) + method := &signingMethodNone{} for _, alg := range algs { method.SetAlg(alg) @@ -94,19 +92,21 @@ func ScanHandler(op *operation.Operation, securityScheme auth.SecurityScheme) (* } } - r.End() - r.AddIssueReport(issueReport) - + r.AddIssueReport(issueReport).End() return r, nil } -func scanWithAlg(method jwtlib.SigningMethod, valueWriter *jwt.JWTWriter, securityScheme auth.SecurityScheme, op *operation.Operation) (*scan.IssueScanAttempt, error) { +func scanWithAlg(method jwtlib.SigningMethod, valueWriter *jwt.JWTWriter, securityScheme *auth.SecurityScheme, op *operation.Operation) (*scan.IssueScanAttempt, error) { newToken, err := valueWriter.SignWithMethodAndKey(method, jwtlib.UnsafeAllowNoneSignatureType) if err != nil { return nil, err } - securityScheme.SetAttackValue(newToken) - vsa, err := scan.ScanURL(op, &securityScheme) + + if err = securityScheme.SetAttackValue(newToken); err != nil { + return nil, err + } + + vsa, err := scan.ScanURL(op, securityScheme) if err != nil { return nil, err } diff --git a/scan/broken_authentication/jwt/alg_none/alg_none_test.go b/scan/broken_authentication/jwt/alg_none/alg_none_test.go index 51ffcb39..006d19ad 100644 --- a/scan/broken_authentication/jwt/alg_none/alg_none_test.go +++ b/scan/broken_authentication/jwt/alg_none/alg_none_test.go @@ -14,7 +14,11 @@ import ( ) func TestAlgNoneJwtScanHandler_WithoutSecurityScheme(t *testing.T) { - securityScheme := auth.NewNoAuthSecurityScheme() + client := request.GetDefaultClient() + httpmock.ActivateNonDefault(client.Client) + defer httpmock.DeactivateAndReset() + + securityScheme := auth.MustNewNoAuthSecurityScheme() operation := operation.MustNewOperation(http.MethodGet, "http://localhost:8080/", nil, nil) report, err := algnone.ScanHandler(operation, securityScheme) @@ -29,7 +33,7 @@ func TestAlgNoneJwtScanHandler_Passed_WhenNoJWTAndUnauthorizedResponse(t *testin httpmock.ActivateNonDefault(client.Client) defer httpmock.DeactivateAndReset() - securityScheme, _ := auth.NewAuthorizationJWTBearerSecurityScheme("token", nil) + securityScheme := auth.MustNewAuthorizationBearerSecurityScheme("token", nil) operation := operation.MustNewOperation(http.MethodGet, "http://localhost:8080/", nil, client) httpmock.RegisterResponder(operation.Method, operation.URL.String(), httpmock.NewBytesResponder(http.StatusUnauthorized, nil)) @@ -47,7 +51,7 @@ func TestAlgNoneJwtScanHandler_Passed_WhenUnauthorizedResponse(t *testing.T) { defer httpmock.DeactivateAndReset() token := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c" - securityScheme, _ := auth.NewAuthorizationJWTBearerSecurityScheme("token", &token) + securityScheme := auth.MustNewAuthorizationBearerSecurityScheme("token", &token) operation := operation.MustNewOperation(http.MethodGet, "http://localhost:8080/", nil, client) httpmock.RegisterResponder(operation.Method, operation.URL.String(), httpmock.NewBytesResponder(http.StatusUnauthorized, nil)) @@ -58,13 +62,25 @@ func TestAlgNoneJwtScanHandler_Passed_WhenUnauthorizedResponse(t *testing.T) { assert.True(t, report.Issues[0].HasPassed()) } +func TestAlgNoneJwtScanHandler_Failed_WhenValidValueUseNoneAlg(t *testing.T) { + token := "eyJhbGciOiJub25lIiwidHlwIjoiSldUIn0.eyJleHAiOjE3MzE4NjkwNTEsImlhdCI6MTczMTg2NTQ1MSwibmFtZSI6IkpvaG4gRG9lIiwic3ViIjoiMmNiMzA3YmEtYmI0Ni00MTk0LTg1NGYtNDc3NDA0NmQ5YzliIn0." + securityScheme := auth.MustNewAuthorizationBearerSecurityScheme("token", &token) + operation := operation.MustNewOperation(http.MethodGet, "http://localhost:8080/", nil, nil) + + report, err := algnone.ScanHandler(operation, securityScheme) + + require.NoError(t, err) + assert.Equal(t, 0, len(report.GetScanAttempts())) + assert.True(t, report.Issues[0].HasFailed()) +} + func TestAlgNoneJwtScanHandler_Failed_WhenOKResponse(t *testing.T) { client := request.GetDefaultClient() httpmock.ActivateNonDefault(client.Client) defer httpmock.DeactivateAndReset() token := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c" - securityScheme, _ := auth.NewAuthorizationJWTBearerSecurityScheme("token", &token) + securityScheme := auth.MustNewAuthorizationBearerSecurityScheme("token", &token) operation := operation.MustNewOperation(http.MethodGet, "http://localhost:8080/", nil, client) httpmock.RegisterResponder(operation.Method, operation.URL.String(), httpmock.NewBytesResponder(http.StatusOK, nil)) @@ -84,7 +100,7 @@ func TestAlgNoneJwtScanHandler_Failed_WhenOKResponseAndAlgNone(t *testing.T) { defer httpmock.DeactivateAndReset() token := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c" - securityScheme, _ := auth.NewAuthorizationJWTBearerSecurityScheme("token", &token) + securityScheme := auth.MustNewAuthorizationBearerSecurityScheme("token", &token) operation := operation.MustNewOperation(http.MethodGet, "http://localhost:8080/", nil, client) httpmock.RegisterResponder(operation.Method, operation.URL.String(), func(req *http.Request) (*http.Response, error) { switch req.Header.Get("Authorization") { diff --git a/scan/broken_authentication/jwt/blank_secret/blank_secret.go b/scan/broken_authentication/jwt/blank_secret/blank_secret.go index d6344a9a..3c5d53ef 100644 --- a/scan/broken_authentication/jwt/blank_secret/blank_secret.go +++ b/scan/broken_authentication/jwt/blank_secret/blank_secret.go @@ -30,19 +30,11 @@ var issue = report.Issue{ }, } -func ShouldBeScanned(securitySheme auth.SecurityScheme) bool { - if securitySheme == nil { - return false - } - - if _, ok := securitySheme.(*auth.JWTBearerSecurityScheme); !ok { - return false - } - - return true +func ShouldBeScanned(securityScheme *auth.SecurityScheme) bool { + return securityScheme != nil && securityScheme.GetType() != auth.None && (securityScheme.GetTokenFormat() == nil || *securityScheme.GetTokenFormat() == auth.JWTTokenFormat) } -func ScanHandler(op *operation.Operation, securityScheme auth.SecurityScheme) (*report.ScanReport, error) { +func ScanHandler(op *operation.Operation, securityScheme *auth.SecurityScheme) (*report.ScanReport, error) { vulnReport := report.NewIssueReport(issue).WithOperation(op).WithSecurityScheme(securityScheme) r := report.NewScanReport(BlankSecretVulnerabilityScanID, BlankSecretVulnerabilityScanName, op) @@ -50,11 +42,16 @@ func ScanHandler(op *operation.Operation, securityScheme auth.SecurityScheme) (* return r.AddIssueReport(vulnReport.Skip()).End(), nil } - var valueWriter *jwt.JWTWriter + var token string if securityScheme.HasValidValue() { - valueWriter = jwt.NewJWTWriterWithValidClaims(securityScheme.GetValidValueWriter().(*jwt.JWTWriter)) + token = securityScheme.GetToken() } else { - valueWriter, _ = jwt.NewJWTWriter(jwt.FakeJWT) + token = jwt.FakeJWT + } + + valueWriter, err := jwt.NewJWTWriter(token) + if err != nil { + return r, err } newToken, err := valueWriter.SignWithKey([]byte("")) @@ -62,7 +59,7 @@ func ScanHandler(op *operation.Operation, securityScheme auth.SecurityScheme) (* return r, err } securityScheme.SetAttackValue(newToken) - vsa, err := scan.ScanURL(op, &securityScheme) + vsa, err := scan.ScanURL(op, securityScheme) if err != nil { return r, err } diff --git a/scan/broken_authentication/jwt/blank_secret/blank_secret_test.go b/scan/broken_authentication/jwt/blank_secret/blank_secret_test.go index e4aa0dd8..d102623b 100644 --- a/scan/broken_authentication/jwt/blank_secret/blank_secret_test.go +++ b/scan/broken_authentication/jwt/blank_secret/blank_secret_test.go @@ -14,7 +14,7 @@ import ( ) func TestBlankSecretScanHandler_WithoutSecurityScheme(t *testing.T) { - securityScheme := auth.NewNoAuthSecurityScheme() + securityScheme := auth.MustNewNoAuthSecurityScheme() operation := operation.MustNewOperation(http.MethodGet, "http://localhost:8080/", nil, nil) report, err := blanksecret.ScanHandler(operation, securityScheme) @@ -28,7 +28,7 @@ func TestBlankSecretScanHandler_Passed_WhenNoJWTAndUnauthorizedResponse(t *testi httpmock.ActivateNonDefault(client.Client) defer httpmock.DeactivateAndReset() - securityScheme, _ := auth.NewAuthorizationJWTBearerSecurityScheme("token", nil) + securityScheme := auth.MustNewAuthorizationBearerSecurityScheme("token", nil) operation := operation.MustNewOperation(http.MethodGet, "http://localhost:8080/", nil, client) httpmock.RegisterResponder(operation.Method, operation.URL.String(), httpmock.NewBytesResponder(http.StatusUnauthorized, nil)) @@ -44,7 +44,7 @@ func TestBlankSecretScanHandler_Passed_WhenNoJWTAndOKResponse(t *testing.T) { httpmock.ActivateNonDefault(client.Client) defer httpmock.DeactivateAndReset() - securityScheme, _ := auth.NewAuthorizationJWTBearerSecurityScheme("token", nil) + securityScheme := auth.MustNewAuthorizationBearerSecurityScheme("token", nil) operation := operation.MustNewOperation(http.MethodGet, "http://localhost:8080/", nil, client) httpmock.RegisterResponder(operation.Method, operation.URL.String(), httpmock.NewBytesResponder(http.StatusOK, nil)) @@ -61,7 +61,7 @@ func TestBlankSecretScanHandler_Passed_WhenUnauthorizedResponse(t *testing.T) { defer httpmock.DeactivateAndReset() token := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c" - securityScheme, _ := auth.NewAuthorizationJWTBearerSecurityScheme("token", &token) + securityScheme := auth.MustNewAuthorizationBearerSecurityScheme("token", &token) operation := operation.MustNewOperation(http.MethodGet, "http://localhost:8080/", nil, client) httpmock.RegisterResponder(operation.Method, operation.URL.String(), httpmock.NewBytesResponder(http.StatusUnauthorized, nil)) @@ -78,7 +78,7 @@ func TestBlankSecretScanHandler_Failed_WhenOKResponse(t *testing.T) { defer httpmock.DeactivateAndReset() token := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c" - securityScheme, _ := auth.NewAuthorizationJWTBearerSecurityScheme("token", &token) + securityScheme := auth.MustNewAuthorizationBearerSecurityScheme("token", &token) operation := operation.MustNewOperation(http.MethodGet, "http://localhost:8080/", nil, client) httpmock.RegisterResponder(operation.Method, operation.URL.String(), httpmock.NewBytesResponder(http.StatusOK, nil)) diff --git a/scan/broken_authentication/jwt/not_verified/not_verified.go b/scan/broken_authentication/jwt/not_verified/not_verified.go index 55ca8197..4c247566 100644 --- a/scan/broken_authentication/jwt/not_verified/not_verified.go +++ b/scan/broken_authentication/jwt/not_verified/not_verified.go @@ -29,23 +29,11 @@ var issue = report.Issue{ }, } -func ShouldBeScanned(securitySheme auth.SecurityScheme) bool { - if securitySheme == nil { - return false - } - - if _, ok := securitySheme.(*auth.JWTBearerSecurityScheme); !ok { - return false - } - - if !securitySheme.HasValidValue() { - return false - } - - return true +func ShouldBeScanned(securityScheme *auth.SecurityScheme) bool { + return securityScheme != nil && securityScheme.GetType() != auth.None && securityScheme.GetTokenFormat() != nil && *securityScheme.GetTokenFormat() == auth.JWTTokenFormat } -func ScanHandler(op *operation.Operation, securityScheme auth.SecurityScheme) (*report.ScanReport, error) { +func ScanHandler(op *operation.Operation, securityScheme *auth.SecurityScheme) (*report.ScanReport, error) { vulnReport := report.NewIssueReport(issue).WithOperation(op).WithSecurityScheme(securityScheme) r := report.NewScanReport(NotVerifiedJwtScanID, NotVerifiedJwtScanName, op) @@ -54,11 +42,16 @@ func ScanHandler(op *operation.Operation, securityScheme auth.SecurityScheme) (* return r, nil } - var valueWriter *jwt.JWTWriter + var token string if securityScheme.HasValidValue() { - valueWriter = jwt.NewJWTWriterWithValidClaims(securityScheme.GetValidValueWriter().(*jwt.JWTWriter)) + token = securityScheme.GetToken() } else { - valueWriter, _ = jwt.NewJWTWriter(jwt.FakeJWT) + token = jwt.FakeJWT + } + + valueWriter, err := jwt.NewJWTWriter(token) + if err != nil { + return r, err } newToken, err := valueWriter.SignWithMethodAndRandomKey(valueWriter.GetToken().Method) @@ -67,7 +60,7 @@ func ScanHandler(op *operation.Operation, securityScheme auth.SecurityScheme) (* } securityScheme.SetAttackValue(securityScheme.GetValidValue()) - attemptOne, err := scan.ScanURL(op, &securityScheme) + attemptOne, err := scan.ScanURL(op, securityScheme) if err != nil { return r, err } @@ -79,7 +72,7 @@ func ScanHandler(op *operation.Operation, securityScheme auth.SecurityScheme) (* } securityScheme.SetAttackValue(newToken) - attemptTwo, err := scan.ScanURL(op, &securityScheme) + attemptTwo, err := scan.ScanURL(op, securityScheme) if err != nil { return r, err } diff --git a/scan/broken_authentication/jwt/not_verified/not_verified_test.go b/scan/broken_authentication/jwt/not_verified/not_verified_test.go index 6eadab10..6f94d684 100644 --- a/scan/broken_authentication/jwt/not_verified/not_verified_test.go +++ b/scan/broken_authentication/jwt/not_verified/not_verified_test.go @@ -14,7 +14,7 @@ import ( ) func TestNotVerifiedScanHandler_WithoutSecurityScheme(t *testing.T) { - securityScheme := auth.NewNoAuthSecurityScheme() + securityScheme := auth.MustNewNoAuthSecurityScheme() operation := operation.MustNewOperation(http.MethodGet, "http://localhost:8080/", nil, nil) report, err := notverified.ScanHandler(operation, securityScheme) @@ -23,8 +23,8 @@ func TestNotVerifiedScanHandler_WithoutSecurityScheme(t *testing.T) { assert.True(t, report.Issues[0].HasBeenSkipped()) } -func TestNotVerifiedScanHandler_Passed_WhenNoJWTAndUnauthorizedResponse(t *testing.T) { - securityScheme, _ := auth.NewAuthorizationJWTBearerSecurityScheme("token", nil) +func TestNotVerifiedScanHandler_Skipped_WhenNoJWTAndUnauthorizedResponse(t *testing.T) { + securityScheme := auth.MustNewAuthorizationBearerSecurityScheme("token", nil) operation := operation.MustNewOperation(http.MethodGet, "http://localhost:8080/", nil, nil) report, err := notverified.ScanHandler(operation, securityScheme) @@ -39,7 +39,7 @@ func TestNotVerifiedScanHandler_Failed_WhenUnauthorizedThenOK(t *testing.T) { defer httpmock.DeactivateAndReset() token := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c" - securityScheme, _ := auth.NewAuthorizationJWTBearerSecurityScheme("token", &token) + securityScheme := auth.MustNewAuthorizationBearerSecurityScheme("token", &token) operation := operation.MustNewOperation(http.MethodGet, "http://localhost:8080/", nil, client) httpmock.RegisterResponder(operation.Method, operation.URL.String(), httpmock.ResponderFromMultipleResponses( @@ -61,7 +61,7 @@ func TestNotVerifiedScanHandler_Skipped_WhenOKFirstRequest(t *testing.T) { defer httpmock.DeactivateAndReset() token := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c" - securityScheme, _ := auth.NewAuthorizationJWTBearerSecurityScheme("token", &token) + securityScheme := auth.MustNewAuthorizationBearerSecurityScheme("token", &token) operation := operation.MustNewOperation(http.MethodGet, "http://localhost:8080/", nil, client) httpmock.RegisterResponder(operation.Method, operation.URL.String(), httpmock.ResponderFromMultipleResponses( @@ -83,7 +83,7 @@ func TestNotVerifiedScanHandler_Failed_WhenUnauthorizedThenUnauthorized(t *testi defer httpmock.DeactivateAndReset() token := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c" - securityScheme, _ := auth.NewAuthorizationJWTBearerSecurityScheme("token", &token) + securityScheme := auth.MustNewAuthorizationBearerSecurityScheme("token", &token) operation := operation.MustNewOperation(http.MethodGet, "http://localhost:8080/", nil, client) httpmock.RegisterResponder(operation.Method, operation.URL.String(), httpmock.ResponderFromMultipleResponses( diff --git a/scan/broken_authentication/jwt/null_signature/null_signature.go b/scan/broken_authentication/jwt/null_signature/null_signature.go index 56a01531..0d653e96 100644 --- a/scan/broken_authentication/jwt/null_signature/null_signature.go +++ b/scan/broken_authentication/jwt/null_signature/null_signature.go @@ -30,19 +30,11 @@ var issue = report.Issue{ }, } -func ShouldBeScanned(securitySheme auth.SecurityScheme) bool { - if securitySheme == nil { - return false - } - - if _, ok := securitySheme.(*auth.JWTBearerSecurityScheme); !ok { - return false - } - - return true +func ShouldBeScanned(securityScheme *auth.SecurityScheme) bool { + return securityScheme != nil && securityScheme.GetType() != auth.None && (securityScheme.GetTokenFormat() == nil || *securityScheme.GetTokenFormat() == auth.JWTTokenFormat) } -func ScanHandler(op *operation.Operation, securityScheme auth.SecurityScheme) (*report.ScanReport, error) { +func ScanHandler(op *operation.Operation, securityScheme *auth.SecurityScheme) (*report.ScanReport, error) { vulnReport := report.NewIssueReport(issue).WithOperation(op).WithSecurityScheme(securityScheme) r := report.NewScanReport(NullSignatureScanID, NullSignatureScanName, op) @@ -51,11 +43,16 @@ func ScanHandler(op *operation.Operation, securityScheme auth.SecurityScheme) (* return r, nil } - var valueWriter *jwt.JWTWriter + var token string if securityScheme.HasValidValue() { - valueWriter = jwt.NewJWTWriterWithValidClaims(securityScheme.GetValidValueWriter().(*jwt.JWTWriter)) + token = securityScheme.GetToken() } else { - valueWriter, _ = jwt.NewJWTWriter(jwt.FakeJWT) + token = jwt.FakeJWT + } + + valueWriter, err := jwt.NewJWTWriter(token) + if err != nil { + return r, err } newToken, err := valueWriter.WithoutSignature() @@ -63,7 +60,7 @@ func ScanHandler(op *operation.Operation, securityScheme auth.SecurityScheme) (* return r, err } securityScheme.SetAttackValue(newToken) - vsa, err := scan.ScanURL(op, &securityScheme) + vsa, err := scan.ScanURL(op, securityScheme) if err != nil { return r, err } diff --git a/scan/broken_authentication/jwt/null_signature/null_signature_test.go b/scan/broken_authentication/jwt/null_signature/null_signature_test.go index 7425bd40..94db2f34 100644 --- a/scan/broken_authentication/jwt/null_signature/null_signature_test.go +++ b/scan/broken_authentication/jwt/null_signature/null_signature_test.go @@ -14,7 +14,7 @@ import ( ) func TestNullSignatureScanHandler_WithoutSecurityScheme(t *testing.T) { - securityScheme := auth.NewNoAuthSecurityScheme() + securityScheme := auth.MustNewNoAuthSecurityScheme() operation := operation.MustNewOperation(http.MethodGet, "http://localhost:8080/", nil, nil) report, err := nullsignature.ScanHandler(operation, securityScheme) @@ -28,7 +28,7 @@ func TestNullSignatureScanHandler_Passed_WhenNoJWTAndUnauthorizedResponse(t *tes httpmock.ActivateNonDefault(client.Client) defer httpmock.DeactivateAndReset() - securityScheme, _ := auth.NewAuthorizationJWTBearerSecurityScheme("token", nil) + securityScheme := auth.MustNewAuthorizationBearerSecurityScheme("token", nil) operation := operation.MustNewOperation(http.MethodGet, "http://localhost:8080/", nil, client) httpmock.RegisterResponder(operation.Method, operation.URL.String(), httpmock.NewBytesResponder(http.StatusUnauthorized, nil)) @@ -45,7 +45,7 @@ func TestNullSignatureScanHandler_Passed_WhenUnauthorizedResponse(t *testing.T) defer httpmock.DeactivateAndReset() token := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c" - securityScheme, _ := auth.NewAuthorizationJWTBearerSecurityScheme("token", &token) + securityScheme := auth.MustNewAuthorizationBearerSecurityScheme("token", &token) operation := operation.MustNewOperation(http.MethodGet, "http://localhost:8080/", nil, client) httpmock.RegisterResponder(operation.Method, operation.URL.String(), httpmock.NewBytesResponder(http.StatusUnauthorized, nil)) @@ -62,7 +62,7 @@ func TestNullSignatureScanHandler_Failed_WhenOKResponse(t *testing.T) { defer httpmock.DeactivateAndReset() token := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c" - securityScheme, _ := auth.NewAuthorizationJWTBearerSecurityScheme("token", &token) + securityScheme := auth.MustNewAuthorizationBearerSecurityScheme("token", &token) operation := operation.MustNewOperation(http.MethodGet, "http://localhost:8080/", nil, client) httpmock.RegisterResponder(operation.Method, operation.URL.String(), httpmock.NewBytesResponder(http.StatusOK, nil)) diff --git a/scan/broken_authentication/jwt/weak_secret/weak_secret.go b/scan/broken_authentication/jwt/weak_secret/weak_secret.go index feaddf41..1b36a6c7 100644 --- a/scan/broken_authentication/jwt/weak_secret/weak_secret.go +++ b/scan/broken_authentication/jwt/weak_secret/weak_secret.go @@ -35,20 +35,16 @@ var issue = report.Issue{ }, } -func ShouldBeScanned(securitySheme auth.SecurityScheme) bool { - if securitySheme == nil { +func ShouldBeScanned(securityScheme *auth.SecurityScheme) bool { + if !(securityScheme != nil && securityScheme.GetType() != auth.None && (securityScheme.GetTokenFormat() == nil || *securityScheme.GetTokenFormat() == auth.JWTTokenFormat)) { return false } - if _, ok := securitySheme.(*auth.JWTBearerSecurityScheme); !ok { + valueWriter, err := jwt.NewJWTWriter(securityScheme.GetToken()) + if err != nil { return false } - if !securitySheme.HasValidValue() { - return false - } - - valueWriter := securitySheme.GetValidValueWriter().(*jwt.JWTWriter) return valueWriter.IsHMACAlg() } @@ -56,7 +52,7 @@ var defaultJwtSecretDictionary = []string{"secret", "password", "123456", "chang const jwtSecretDictionarySeclistUrl = "https://raw.githubusercontent.com/danielmiessler/SecLists/master/Passwords/scraped-JWT-secrets.txt" -func ScanHandler(op *operation.Operation, securityScheme auth.SecurityScheme) (*report.ScanReport, error) { +func ScanHandler(op *operation.Operation, securityScheme *auth.SecurityScheme) (*report.ScanReport, error) { vulnReport := report.NewIssueReport(issue).WithOperation(op).WithSecurityScheme(securityScheme) r := report.NewScanReport(WeakSecretVulnerabilityScanID, WeakSecretVulnerabilityScanName, op) @@ -70,8 +66,12 @@ func ScanHandler(op *operation.Operation, securityScheme auth.SecurityScheme) (* jwtSecretDictionary = secretDictionnaryFromSeclist.Items } + valueWriter, err := jwt.NewJWTWriter(securityScheme.GetToken()) + if err != nil { + return r, err + } + secretFound := false - valueWriter := securityScheme.GetValidValueWriter().(*jwt.JWTWriter) currentToken := valueWriter.GetToken().Raw for _, secret := range jwtSecretDictionary { if secret == "" { @@ -93,7 +93,7 @@ func ScanHandler(op *operation.Operation, securityScheme auth.SecurityScheme) (* } securityScheme.SetAttackValue(newValidToken) - vsa, err := scan.ScanURL(op, &securityScheme) + vsa, err := scan.ScanURL(op, securityScheme) if err != nil { return r, err } diff --git a/scan/broken_authentication/jwt/weak_secret/weak_secret_test.go b/scan/broken_authentication/jwt/weak_secret/weak_secret_test.go index 1e95f44f..1d3d8efc 100644 --- a/scan/broken_authentication/jwt/weak_secret/weak_secret_test.go +++ b/scan/broken_authentication/jwt/weak_secret/weak_secret_test.go @@ -14,7 +14,7 @@ import ( ) func TestWeakHMACSecretScanHandler_WithoutSecurityScheme(t *testing.T) { - securityScheme := auth.NewNoAuthSecurityScheme() + securityScheme := auth.MustNewNoAuthSecurityScheme() operation := operation.MustNewOperation(http.MethodGet, "http://localhost:8080/", nil, nil) report, err := weaksecret.ScanHandler(operation, securityScheme) @@ -25,7 +25,7 @@ func TestWeakHMACSecretScanHandler_WithoutSecurityScheme(t *testing.T) { func TestWeakHMACSecretScanHandler_WithJWTUsingOtherAlg(t *testing.T) { token := "eyJhbGciOiJFZERTQSIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJhYmMxMjMifQ.vLBmArLmAKEshqJa3px6qYfrkAfiwBrKPs5dCMxqj9bdiEKR5W4o0Srxt6VHZKzsxIGMTTsqpW21lKnYsLw5DA" - securityScheme, _ := auth.NewAuthorizationJWTBearerSecurityScheme("token", &token) + securityScheme := auth.MustNewAuthorizationBearerSecurityScheme("token", &token) operation := operation.MustNewOperation(http.MethodGet, "http://localhost:8080/", nil, nil) report, err := weaksecret.ScanHandler(operation, securityScheme) @@ -35,7 +35,7 @@ func TestWeakHMACSecretScanHandler_WithJWTUsingOtherAlg(t *testing.T) { } func TestWeakHMACSecretScanHandler_WithoutJWT(t *testing.T) { - securityScheme, _ := auth.NewAuthorizationJWTBearerSecurityScheme("token", nil) + securityScheme := auth.MustNewAuthorizationBearerSecurityScheme("token", nil) operation := operation.MustNewOperation(http.MethodGet, "http://localhost:8080/", nil, nil) report, err := weaksecret.ScanHandler(operation, securityScheme) @@ -51,7 +51,7 @@ func TestWeakHMACSecretScanHandler_Failed_WithWeakJWT(t *testing.T) { secret := "secret" token := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.e30.t-IDcSemACt8x4iTMCda8Yhe3iZaWbvV5XKSTbuAn0M" - securityScheme, _ := auth.NewAuthorizationJWTBearerSecurityScheme("token", &token) + securityScheme := auth.MustNewAuthorizationBearerSecurityScheme("token", &token) operation := operation.MustNewOperation(http.MethodGet, "http://localhost:8080/", nil, client) httpmock.RegisterResponder(operation.Method, operation.URL.String(), httpmock.NewBytesResponder(http.StatusOK, nil)) @@ -71,7 +71,7 @@ func TestWeakHMACSecretScanHandler_Failed_WithExpiredJWTSignedWithWeakSecret(t * secret := "secret" token := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjE1MTYyMzkwMjJ9.7BbIenT4-HobiMHaMUQdNcJ6lD_QQkKnImP9IprJFvU" - securityScheme, _ := auth.NewAuthorizationJWTBearerSecurityScheme("token", &token) + securityScheme := auth.MustNewAuthorizationBearerSecurityScheme("token", &token) operation := operation.MustNewOperation(http.MethodGet, "http://localhost:8080/", nil, client) httpmock.RegisterResponder(operation.Method, operation.URL.String(), httpmock.NewBytesResponder(http.StatusOK, nil)) @@ -86,7 +86,7 @@ func TestWeakHMACSecretScanHandler_Failed_WithExpiredJWTSignedWithWeakSecret(t * func TestWeakHMACSecretScanHandler_Passed_WithStrongerJWT(t *testing.T) { token := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.e30.MWUarT7Q4e5DqnZbdr7VKw3rx9VW-CrvoVkfpllS4CY" - securityScheme, _ := auth.NewAuthorizationJWTBearerSecurityScheme("token", &token) + securityScheme := auth.MustNewAuthorizationBearerSecurityScheme("token", &token) operation := operation.MustNewOperation(http.MethodGet, "http://localhost:8080/", nil, nil) httpmock.RegisterResponder(operation.Method, operation.URL.String(), httpmock.NewBytesResponder(http.StatusUnauthorized, nil)) @@ -105,7 +105,7 @@ func TestWeakHMACSecretScanHandler_Failed_WithUnorderedClaims(t *testing.T) { secret := "secret" token := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyLCJuYmYiOjIwMTYyMzkwMjJ9.ymnE0GznV0dMkjANTQl8IqBSlTi9RFWfBeT42jBNrU4" - securityScheme, _ := auth.NewAuthorizationJWTBearerSecurityScheme("token", &token) + securityScheme := auth.MustNewAuthorizationBearerSecurityScheme("token", &token) operation := operation.MustNewOperation(http.MethodGet, "http://localhost:8080/", nil, client) httpmock.RegisterResponder(operation.Method, operation.URL.String(), httpmock.NewBytesResponder(http.StatusOK, nil)) diff --git a/scan/discover/accept_unauthenticated/accept_unauthenticated_operation.go b/scan/discover/accept_unauthenticated/accept_unauthenticated_operation.go index 7dca4144..57bd9d2a 100644 --- a/scan/discover/accept_unauthenticated/accept_unauthenticated_operation.go +++ b/scan/discover/accept_unauthenticated/accept_unauthenticated_operation.go @@ -26,12 +26,11 @@ var issue = report.Issue{ }, } -func ScanHandler(op *operation.Operation, securityScheme auth.SecurityScheme) (*report.ScanReport, error) { +func ScanHandler(op *operation.Operation, securityScheme *auth.SecurityScheme) (*report.ScanReport, error) { vulnReport := report.NewIssueReport(issue).WithOperation(op).WithSecurityScheme(securityScheme) r := report.NewScanReport(NoAuthOperationScanID, NoAuthOperationScanName, op) - _, ok := securityScheme.(*auth.NoAuthSecurityScheme) - r.AddIssueReport(vulnReport.WithBooleanStatus(!ok)).End() + r.AddIssueReport(vulnReport.WithBooleanStatus(securityScheme.GetType() != auth.None)).End() r.End() return r, nil diff --git a/scan/discover/accept_unauthenticated/accept_unauthenticated_operation_test.go b/scan/discover/accept_unauthenticated/accept_unauthenticated_operation_test.go index 7d3d2828..84b03a4b 100644 --- a/scan/discover/accept_unauthenticated/accept_unauthenticated_operation_test.go +++ b/scan/discover/accept_unauthenticated/accept_unauthenticated_operation_test.go @@ -12,7 +12,7 @@ import ( ) func TestAcceptUnauthenticatedScanHandler_Failed_WhenNoAuthSecurityScheme(t *testing.T) { - securityScheme := auth.NewNoAuthSecurityScheme() + securityScheme := auth.MustNewNoAuthSecurityScheme() op := operation.MustNewOperation(http.MethodGet, "http://localhost:8080/", nil, nil) report, err := acceptunauthenticated.ScanHandler(op, securityScheme) @@ -23,7 +23,7 @@ func TestAcceptUnauthenticatedScanHandler_Failed_WhenNoAuthSecurityScheme(t *tes func TestCheckNoAuthOperationScanHandler_Passed_WhenAuthConfigured(t *testing.T) { token := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c" - securityScheme := auth.NewAuthorizationBearerSecurityScheme("default", &token) + securityScheme := auth.MustNewAuthorizationBearerSecurityScheme("default", &token) op := operation.MustNewOperation(http.MethodGet, "http://localhost:8080/", nil, nil) report, err := acceptunauthenticated.ScanHandler(op, securityScheme) diff --git a/scan/discover/discoverable_graphql/discoverable_graphql.go b/scan/discover/discoverable_graphql/discoverable_graphql.go index daeee4e5..7eeedcdc 100644 --- a/scan/discover/discoverable_graphql/discoverable_graphql.go +++ b/scan/discover/discoverable_graphql/discoverable_graphql.go @@ -39,7 +39,7 @@ var potentialGraphQLEndpoints = []string{ "/v1/graphiql", } -func ScanHandler(op *operation.Operation, securityScheme auth.SecurityScheme) (*report.ScanReport, error) { +func ScanHandler(op *operation.Operation, securityScheme *auth.SecurityScheme) (*report.ScanReport, error) { vulnReport := report.NewIssueReport(issue).WithOperation(op).WithSecurityScheme(securityScheme) r := report.NewScanReport(DiscoverableGraphQLPathScanID, DiscoverableGraphQLPathScanName, op) handler := discover.CreateURLScanHandler("GraphQL", graphqlSeclistUrl, potentialGraphQLEndpoints, r, vulnReport) diff --git a/scan/discover/discoverable_graphql/discoverable_graphql_test.go b/scan/discover/discoverable_graphql/discoverable_graphql_test.go index a1c6694f..2534508c 100644 --- a/scan/discover/discoverable_graphql/discoverable_graphql_test.go +++ b/scan/discover/discoverable_graphql/discoverable_graphql_test.go @@ -24,7 +24,7 @@ func TestDiscoverableScanner_Passed_WhenNoDiscoverableGraphqlPathFound(t *testin httpmock.RegisterResponder(op.Method, op.URL.String(), httpmock.NewBytesResponder(http.StatusNoContent, nil)) httpmock.RegisterNoResponder(httpmock.NewBytesResponder(http.StatusNotFound, nil)) - report, err := discoverablegraphql.ScanHandler(op, auth.NewNoAuthSecurityScheme()) + report, err := discoverablegraphql.ScanHandler(op, auth.MustNewNoAuthSecurityScheme()) require.NoError(t, err) assert.Greater(t, httpmock.GetTotalCallCount(), 7) @@ -42,7 +42,7 @@ func TestDiscoverableScanner_Failed_WhenOneGraphQLPathFound(t *testing.T) { httpmock.RegisterResponder(operation.Method, operation.URL.String(), httpmock.NewBytesResponder(http.StatusOK, nil)) httpmock.RegisterNoResponder(httpmock.NewBytesResponder(http.StatusNotFound, nil)) - report, err := discoverablegraphql.ScanHandler(operation, auth.NewNoAuthSecurityScheme()) + report, err := discoverablegraphql.ScanHandler(operation, auth.MustNewNoAuthSecurityScheme()) require.NoError(t, err) assert.Greater(t, httpmock.GetTotalCallCount(), 0) diff --git a/scan/discover/discoverable_openapi/discoverable_openapi.go b/scan/discover/discoverable_openapi/discoverable_openapi.go index 663dddea..d1c829fa 100644 --- a/scan/discover/discoverable_openapi/discoverable_openapi.go +++ b/scan/discover/discoverable_openapi/discoverable_openapi.go @@ -38,7 +38,7 @@ var potentialOpenAPIPaths = []string{ "/.well-known/openapi.yml", } -func ScanHandler(op *operation.Operation, securityScheme auth.SecurityScheme) (*report.ScanReport, error) { +func ScanHandler(op *operation.Operation, securityScheme *auth.SecurityScheme) (*report.ScanReport, error) { vulnReport := report.NewIssueReport(issue).WithOperation(op).WithSecurityScheme(securityScheme) r := report.NewScanReport(DiscoverableOpenAPIScanID, DiscoverableOpenAPIScanName, op) handler := discover.CreateURLScanHandler("OpenAPI", openapiSeclistUrl, potentialOpenAPIPaths, r, vulnReport) diff --git a/scan/discover/discoverable_openapi/discoverable_openapi_test.go b/scan/discover/discoverable_openapi/discoverable_openapi_test.go index 56bb5189..45198963 100644 --- a/scan/discover/discoverable_openapi/discoverable_openapi_test.go +++ b/scan/discover/discoverable_openapi/discoverable_openapi_test.go @@ -24,7 +24,7 @@ func TestDiscoverableScanner_Passed_WhenNoDiscoverableGraphqlPathFound(t *testin httpmock.RegisterResponder(op.Method, op.URL.String(), httpmock.NewBytesResponder(http.StatusNoContent, nil).HeaderAdd(http.Header{"Server": []string{"Apache/2.4.29 (Ubuntu)"}})) httpmock.RegisterNoResponder(httpmock.NewBytesResponder(http.StatusNotFound, nil)) - report, err := discoverableopenapi.ScanHandler(op, auth.NewNoAuthSecurityScheme()) + report, err := discoverableopenapi.ScanHandler(op, auth.MustNewNoAuthSecurityScheme()) require.NoError(t, err) assert.Greater(t, httpmock.GetTotalCallCount(), 10) @@ -42,7 +42,7 @@ func TestDiscoverableScanner_Failed_WhenOneOpenAPIFound(t *testing.T) { httpmock.RegisterResponder(operation.Method, operation.URL.String(), httpmock.NewBytesResponder(http.StatusOK, nil)) httpmock.RegisterNoResponder(httpmock.NewBytesResponder(http.StatusNotFound, nil)) - report, err := discoverableopenapi.ScanHandler(operation, auth.NewNoAuthSecurityScheme()) + report, err := discoverableopenapi.ScanHandler(operation, auth.MustNewNoAuthSecurityScheme()) require.NoError(t, err) assert.Greater(t, httpmock.GetTotalCallCount(), 0) diff --git a/scan/discover/fingerprint/fingerprint.go b/scan/discover/fingerprint/fingerprint.go index 4dc6b076..37525e96 100644 --- a/scan/discover/fingerprint/fingerprint.go +++ b/scan/discover/fingerprint/fingerprint.go @@ -58,11 +58,11 @@ func appendIfMissing(slice []FingerPrintApp, app FingerPrintApp) []FingerPrintAp return append(slice, app) } -func ScanHandler(op *operation.Operation, securityScheme auth.SecurityScheme) (*report.ScanReport, error) { +func ScanHandler(op *operation.Operation, securityScheme *auth.SecurityScheme) (*report.ScanReport, error) { vulnReport := report.NewIssueReport(issue).WithOperation(op).WithSecurityScheme(securityScheme) r := report.NewScanReport(DiscoverFingerPrintScanID, DiscoverFingerPrintScanName, op) - attempt, err := scan.ScanURL(op, &securityScheme) + attempt, err := scan.ScanURL(op, securityScheme) r.AddScanAttempt(attempt) if err != nil { return r.AddIssueReport(vulnReport.Skip()).End(), err diff --git a/scan/discover/fingerprint/fingerprint_test.go b/scan/discover/fingerprint/fingerprint_test.go index 19523337..a4a71a36 100644 --- a/scan/discover/fingerprint/fingerprint_test.go +++ b/scan/discover/fingerprint/fingerprint_test.go @@ -19,7 +19,7 @@ func TestCheckSignatureHeader_Failed_WithServerSignatureHeader(t *testing.T) { defer httpmock.DeactivateAndReset() token := "token" - securityScheme := auth.NewAuthorizationBearerSecurityScheme("default", &token) + securityScheme := auth.MustNewAuthorizationBearerSecurityScheme("default", &token) op := operation.MustNewOperation(http.MethodGet, "http://localhost:8080/", nil, client) httpmock.RegisterResponder(op.Method, op.URL.String(), httpmock.NewBytesResponder(http.StatusOK, nil).HeaderAdd(http.Header{"Server": []string{"Apache/2.4.29"}})) @@ -40,7 +40,7 @@ func TestCheckSignatureHeader_Failed_WithOSSignatureHeader(t *testing.T) { defer httpmock.DeactivateAndReset() token := "token" - securityScheme := auth.NewAuthorizationBearerSecurityScheme("default", &token) + securityScheme := auth.MustNewAuthorizationBearerSecurityScheme("default", &token) operation := operation.MustNewOperation(http.MethodGet, "http://localhost:8080/", nil, client) httpmock.RegisterResponder(operation.Method, operation.URL.String(), httpmock.NewBytesResponder(http.StatusOK, nil).HeaderAdd(http.Header{"Server": []string{"Ubuntu"}})) @@ -61,7 +61,7 @@ func TestCheckSignatureHeader_Failed_WithHostingSignatureHeader(t *testing.T) { defer httpmock.DeactivateAndReset() token := "token" - securityScheme := auth.NewAuthorizationBearerSecurityScheme("default", &token) + securityScheme := auth.MustNewAuthorizationBearerSecurityScheme("default", &token) operation := operation.MustNewOperation(http.MethodGet, "http://localhost:8080/", nil, client) httpmock.RegisterResponder(operation.Method, operation.URL.String(), httpmock.NewBytesResponder(http.StatusOK, nil).HeaderAdd(http.Header{"platform": []string{"hostinger"}})) @@ -82,7 +82,7 @@ func TestCheckSignatureHeader_Failed_WithAuthenticationSignatureHeader(t *testin defer httpmock.DeactivateAndReset() token := "token" - securityScheme := auth.NewAuthorizationBearerSecurityScheme("default", &token) + securityScheme := auth.MustNewAuthorizationBearerSecurityScheme("default", &token) operation := operation.MustNewOperation(http.MethodGet, "http://localhost:8080/", nil, client) httpmock.RegisterResponder(operation.Method, operation.URL.String(), httpmock.NewBytesResponder(http.StatusOK, nil).HeaderAdd(http.Header{"x-auth0-requestid": []string{"id"}})) @@ -103,7 +103,7 @@ func TestCheckSignatureHeader_Failed_WithCDNSignatureHeader(t *testing.T) { defer httpmock.DeactivateAndReset() token := "token" - securityScheme := auth.NewAuthorizationBearerSecurityScheme("default", &token) + securityScheme := auth.MustNewAuthorizationBearerSecurityScheme("default", &token) operation := operation.MustNewOperation(http.MethodGet, "http://localhost:8080/", nil, client) httpmock.RegisterResponder(operation.Method, operation.URL.String(), httpmock.NewBytesResponder(http.StatusOK, nil).HeaderAdd(http.Header{"cf-cache-status": []string{"HIT"}})) @@ -124,7 +124,7 @@ func TestCheckSignatureHeader_Failed_WithLanguageSignatureHeader(t *testing.T) { defer httpmock.DeactivateAndReset() token := "token" - securityScheme := auth.NewAuthorizationBearerSecurityScheme("default", &token) + securityScheme := auth.MustNewAuthorizationBearerSecurityScheme("default", &token) operation := operation.MustNewOperation(http.MethodGet, "http://localhost:8080/", nil, client) httpmock.RegisterResponder(operation.Method, operation.URL.String(), httpmock.NewBytesResponder(http.StatusOK, nil).HeaderAdd(http.Header{"x-powered-by": []string{"PHP 7.4.3"}})) @@ -145,7 +145,7 @@ func TestCheckSignatureHeader_Failed_WithFrameworkSignatureHeader(t *testing.T) defer httpmock.DeactivateAndReset() token := "token" - securityScheme := auth.NewAuthorizationBearerSecurityScheme("default", &token) + securityScheme := auth.MustNewAuthorizationBearerSecurityScheme("default", &token) operation := operation.MustNewOperation(http.MethodGet, "http://localhost:8080/", nil, client) httpmock.RegisterResponder(operation.Method, operation.URL.String(), httpmock.NewBytesResponder(http.StatusOK, nil).HeaderAdd(http.Header{"x-powered-by": []string{"express"}})) @@ -167,7 +167,7 @@ func TestCheckSignatureHeader_Passed_WithoutDuplicate(t *testing.T) { httpmock.ActivateNonDefault(client.Client) defer httpmock.DeactivateAndReset() - securityScheme := auth.NewNoAuthSecurityScheme() + securityScheme := auth.MustNewNoAuthSecurityScheme() operation := operation.MustNewOperation(http.MethodGet, "http://localhost:8080/", nil, client) httpmock.RegisterResponder(operation.Method, operation.URL.String(), httpmock.NewBytesResponder(http.StatusOK, nil).HeaderAdd(http.Header{"x-powered-by": []string{"next.js"}})) @@ -187,7 +187,7 @@ func TestCheckSignatureHeader_Passed_WithoutSignatureHeader(t *testing.T) { defer httpmock.DeactivateAndReset() token := "token" - securityScheme := auth.NewAuthorizationBearerSecurityScheme("default", &token) + securityScheme := auth.MustNewAuthorizationBearerSecurityScheme("default", &token) operation := operation.MustNewOperation(http.MethodGet, "http://localhost:8080/", nil, client) httpmock.RegisterResponder(operation.Method, operation.URL.String(), httpmock.NewBytesResponder(http.StatusOK, nil)) diff --git a/scan/discover/utils.go b/scan/discover/utils.go index 0bffa4bb..4e416460 100644 --- a/scan/discover/utils.go +++ b/scan/discover/utils.go @@ -22,8 +22,8 @@ func ExtractBaseURL(inputURL *url.URL) *url.URL { } } -func ScanURLs(scanUrls []string, op *operation.Operation, securityScheme auth.SecurityScheme, r *report.ScanReport, vulnReport *report.IssueReport) (*report.ScanReport, error) { - securitySchemes := []auth.SecurityScheme{securityScheme} +func ScanURLs(scanUrls []string, op *operation.Operation, securityScheme *auth.SecurityScheme, r *report.ScanReport, vulnReport *report.IssueReport) (*report.ScanReport, error) { + securitySchemes := []*auth.SecurityScheme{securityScheme} base := ExtractBaseURL(&op.URL) for _, path := range scanUrls { @@ -33,7 +33,7 @@ func ScanURLs(scanUrls []string, op *operation.Operation, securityScheme auth.Se return r, err } - attempt, err := scan.ScanURL(newOperation, &securityScheme) + attempt, err := scan.ScanURL(newOperation, securityScheme) if err != nil { return r, err } @@ -51,13 +51,13 @@ func ScanURLs(scanUrls []string, op *operation.Operation, securityScheme auth.Se return r, nil } -func CreateURLScanHandler(name string, seclistUrl string, defaultUrls []string, r *report.ScanReport, vulnReport *report.IssueReport) func(operation *operation.Operation, securityScheme auth.SecurityScheme) (*report.ScanReport, error) { +func CreateURLScanHandler(name string, seclistUrl string, defaultUrls []string, r *report.ScanReport, vulnReport *report.IssueReport) func(operation *operation.Operation, securityScheme *auth.SecurityScheme) (*report.ScanReport, error) { scanUrls := defaultUrls if urlsFromSeclist, err := seclist.NewSecListFromURL(name, seclistUrl); err == nil && urlsFromSeclist != nil { scanUrls = urlsFromSeclist.Items } - return func(op *operation.Operation, securityScheme auth.SecurityScheme) (*report.ScanReport, error) { + return func(op *operation.Operation, securityScheme *auth.SecurityScheme) (*report.ScanReport, error) { return ScanURLs(scanUrls, op, securityScheme, r, vulnReport) } } diff --git a/scan/discover/utils_test.go b/scan/discover/utils_test.go index 76827e53..acb88d2a 100644 --- a/scan/discover/utils_test.go +++ b/scan/discover/utils_test.go @@ -49,7 +49,7 @@ func TestCreateURLScanHandler_WithTimeout(t *testing.T) { seclistUrl := "http://localhost:8080/seclist" defaultUrls := []string{"/path1", "/path2"} - securitySchemes := []auth.SecurityScheme{auth.NewNoAuthSecurityScheme()} + securitySchemes := []*auth.SecurityScheme{auth.MustNewNoAuthSecurityScheme()} operation := operation.MustNewOperation(http.MethodGet, "http://localhost:8080", nil, client) operation.SetSecuritySchemes(securitySchemes) r := report.NewScanReport("test", "test", operation) @@ -71,7 +71,7 @@ func TestCreateURLScanHandler_Passed_WhenNotFoundURLs(t *testing.T) { seclistUrl := "http://localhost:8080/seclist" defaultUrls := []string{"/path1", "/path2"} - securitySchemes := []auth.SecurityScheme{auth.NewNoAuthSecurityScheme()} + securitySchemes := []*auth.SecurityScheme{auth.MustNewNoAuthSecurityScheme()} operation := operation.MustNewOperation(http.MethodGet, "http://localhost:8080", nil, client) operation.SetSecuritySchemes(securitySchemes) r := report.NewScanReport("test", "test", operation) @@ -96,7 +96,7 @@ func TestCreateURLScanHandler_Failed_WhenFoundExposedURLs(t *testing.T) { seclistUrl := "http://localhost:8080/seclist" defaultUrls := []string{"/path1", "/path2"} - securitySchemes := []auth.SecurityScheme{auth.NewNoAuthSecurityScheme()} + securitySchemes := []*auth.SecurityScheme{auth.MustNewNoAuthSecurityScheme()} operation := operation.MustNewOperation(http.MethodGet, "http://localhost:8080", nil, client) operation.SetSecuritySchemes(securitySchemes) r := report.NewScanReport("test", "test", operation) diff --git a/scan/graphql/introspection_enabled/introspection_enabled.go b/scan/graphql/introspection_enabled/introspection_enabled.go index 42a82f7e..6487a38b 100644 --- a/scan/graphql/introspection_enabled/introspection_enabled.go +++ b/scan/graphql/introspection_enabled/introspection_enabled.go @@ -60,8 +60,8 @@ func newGetGraphqlIntrospectionRequest(client *request.Client, endpoint url.URL) return req, nil } -func ScanHandler(op *operation.Operation, securityScheme auth.SecurityScheme) (*report.ScanReport, error) { - securitySchemes := []auth.SecurityScheme{securityScheme} +func ScanHandler(op *operation.Operation, securityScheme *auth.SecurityScheme) (*report.ScanReport, error) { + securitySchemes := []*auth.SecurityScheme{securityScheme} vulnReport := report.NewIssueReport(issue).WithOperation(op).WithSecurityScheme(securityScheme) r := report.NewScanReport(GraphqlIntrospectionScanID, GraphqlIntrospectionScanName, op) @@ -75,7 +75,7 @@ func ScanHandler(op *operation.Operation, securityScheme auth.SecurityScheme) (* } newOperation.SetSecuritySchemes(securitySchemes) - attempt, err := scan.ScanURL(newOperation, &securityScheme) + attempt, err := scan.ScanURL(newOperation, securityScheme) if err != nil { return r, err } @@ -96,7 +96,7 @@ func ScanHandler(op *operation.Operation, securityScheme auth.SecurityScheme) (* } newOperation.SetSecuritySchemes(securitySchemes) - attempt, err = scan.ScanURL(newOperation, &securityScheme) + attempt, err = scan.ScanURL(newOperation, securityScheme) if err != nil { return r, err } diff --git a/scan/graphql/introspection_enabled/introspection_enabled_test.go b/scan/graphql/introspection_enabled/introspection_enabled_test.go index ad8cc11d..3a6cc643 100644 --- a/scan/graphql/introspection_enabled/introspection_enabled_test.go +++ b/scan/graphql/introspection_enabled/introspection_enabled_test.go @@ -23,7 +23,7 @@ func TestGraphqlIntrospectionScanHandler_Failed_WhenRespondHTTPStatusIsOK(t *tes httpmock.RegisterResponder(http.MethodPost, operation.URL.String(), httpmock.NewBytesResponder(http.StatusOK, resBody)) httpmock.RegisterResponder(http.MethodGet, operation.URL.String(), httpmock.NewBytesResponder(http.StatusOK, resBody)) - report, err := introspectionenabled.ScanHandler(operation, auth.NewNoAuthSecurityScheme()) + report, err := introspectionenabled.ScanHandler(operation, auth.MustNewNoAuthSecurityScheme()) require.NoError(t, err) assert.Equal(t, 1, httpmock.GetTotalCallCount()) @@ -40,7 +40,7 @@ func TestGraphqlIntrospectionScanHandler_Failed_WhenRespond_GETMethodOnly_HTTPSt httpmock.RegisterResponder(http.MethodPost, operation.URL.String(), httpmock.NewBytesResponder(http.StatusBadRequest, nil)) httpmock.RegisterResponder(http.MethodGet, operation.URL.String(), httpmock.NewBytesResponder(http.StatusOK, resBody)) - report, err := introspectionenabled.ScanHandler(operation, auth.NewNoAuthSecurityScheme()) + report, err := introspectionenabled.ScanHandler(operation, auth.MustNewNoAuthSecurityScheme()) require.NoError(t, err) assert.Equal(t, 2, httpmock.GetTotalCallCount()) @@ -56,7 +56,7 @@ func TestGraphqlIntrospectionScanHandler_Passed_WhenBadRequestStatus(t *testing. httpmock.RegisterResponder(http.MethodPost, operation.URL.String(), httpmock.NewBytesResponder(http.StatusBadRequest, nil)) httpmock.RegisterResponder(http.MethodGet, operation.URL.String(), httpmock.NewBytesResponder(http.StatusBadRequest, nil)) - report, err := introspectionenabled.ScanHandler(operation, auth.NewNoAuthSecurityScheme()) + report, err := introspectionenabled.ScanHandler(operation, auth.MustNewNoAuthSecurityScheme()) require.NoError(t, err) assert.Equal(t, 2, httpmock.GetTotalCallCount()) @@ -72,7 +72,7 @@ func TestGraphqlIntrospectionScanHandler_Passed_WhenOKStatusButNoQuery(t *testin httpmock.RegisterResponder(http.MethodPost, operation.URL.String(), httpmock.NewBytesResponder(http.StatusOK, nil)) httpmock.RegisterResponder(http.MethodGet, operation.URL.String(), httpmock.NewBytesResponder(http.StatusOK, nil)) - report, err := introspectionenabled.ScanHandler(operation, auth.NewNoAuthSecurityScheme()) + report, err := introspectionenabled.ScanHandler(operation, auth.MustNewNoAuthSecurityScheme()) require.NoError(t, err) assert.Equal(t, 2, httpmock.GetTotalCallCount()) diff --git a/scan/misconfiguration/http_cookies/http_cookies.go b/scan/misconfiguration/http_cookies/http_cookies.go index 408799eb..0a5a64ab 100644 --- a/scan/misconfiguration/http_cookies/http_cookies.go +++ b/scan/misconfiguration/http_cookies/http_cookies.go @@ -103,14 +103,14 @@ var withoutExpiresIssue = report.Issue{ }, } -func ScanHandler(op *operation.Operation, securityScheme auth.SecurityScheme) (*report.ScanReport, error) { +func ScanHandler(op *operation.Operation, securityScheme *auth.SecurityScheme) (*report.ScanReport, error) { httpOnlyVulnReport := report.NewIssueReport(httpNotHttpOnlyIssue).WithOperation(op).WithSecurityScheme(securityScheme) notSecureVulnReport := report.NewIssueReport(notSecureIssue).WithOperation(op).WithSecurityScheme(securityScheme) sameSiteNoneVulnReport := report.NewIssueReport(sameSiteNoneIssue).WithOperation(op).WithSecurityScheme(securityScheme) withoutSameSiteVulnReport := report.NewIssueReport(withoutSameSiteIssue).WithOperation(op).WithSecurityScheme(securityScheme) withoutExpiresVulnReport := report.NewIssueReport(withoutExpiresIssue).WithOperation(op).WithSecurityScheme(securityScheme) - attempt, err := scan.ScanURL(op, &securityScheme) + attempt, err := scan.ScanURL(op, securityScheme) r := report.NewScanReport(HTTPCookiesScanID, HTTPCookiesScanName, op) if err != nil { return r, err diff --git a/scan/misconfiguration/http_cookies/http_cookies_test.go b/scan/misconfiguration/http_cookies/http_cookies_test.go index a74a0d2b..34a2ed15 100644 --- a/scan/misconfiguration/http_cookies/http_cookies_test.go +++ b/scan/misconfiguration/http_cookies/http_cookies_test.go @@ -19,7 +19,7 @@ func TestHTTPCookiesScanHandler_Skipped_WhenNoCookies(t *testing.T) { httpmock.ActivateNonDefault(client.Client) defer httpmock.DeactivateAndReset() - securityScheme := auth.NewNoAuthSecurityScheme() + securityScheme := auth.MustNewNoAuthSecurityScheme() op := operation.MustNewOperation(http.MethodGet, "http://localhost:8080/", nil, client) httpmock.RegisterResponder(op.Method, op.URL.String(), httpmock.NewBytesResponder(http.StatusUnauthorized, nil)) @@ -37,7 +37,7 @@ func TestHTTPCookiesScanHandler_Passed_WhenNoUnsecurePractices(t *testing.T) { httpmock.ActivateNonDefault(client.Client) defer httpmock.DeactivateAndReset() - securityScheme := auth.NewNoAuthSecurityScheme() + securityScheme := auth.MustNewNoAuthSecurityScheme() operation := operation.MustNewOperation(http.MethodGet, "http://localhost:8080/", nil, client) resp := httpmock.NewStringResponse(http.StatusOK, "OK") cookie := &http.Cookie{ @@ -66,7 +66,7 @@ func TestHTTPCookiesScanHandler_Failed_WhenNotHttpOnly(t *testing.T) { httpmock.ActivateNonDefault(client.Client) defer httpmock.DeactivateAndReset() - securityScheme := auth.NewNoAuthSecurityScheme() + securityScheme := auth.MustNewNoAuthSecurityScheme() operation := operation.MustNewOperation(http.MethodGet, "http://localhost:8080/", nil, client) resp := httpmock.NewStringResponse(http.StatusOK, "OK") cookie := &http.Cookie{ @@ -95,7 +95,7 @@ func TestHTTPCookiesScanHandlerFailed_WhenNotSecure(t *testing.T) { httpmock.ActivateNonDefault(client.Client) defer httpmock.DeactivateAndReset() - securityScheme := auth.NewNoAuthSecurityScheme() + securityScheme := auth.MustNewNoAuthSecurityScheme() operation := operation.MustNewOperation(http.MethodGet, "http://localhost:8080/", nil, client) resp := httpmock.NewStringResponse(http.StatusOK, "OK") cookie := &http.Cookie{ @@ -124,7 +124,7 @@ func TestHTTPCookiesScanHandler_Failed_WhenSameSiteNone(t *testing.T) { httpmock.ActivateNonDefault(client.Client) defer httpmock.DeactivateAndReset() - securityScheme := auth.NewNoAuthSecurityScheme() + securityScheme := auth.MustNewNoAuthSecurityScheme() operation := operation.MustNewOperation(http.MethodGet, "http://localhost:8080/", nil, client) resp := httpmock.NewStringResponse(http.StatusOK, "OK") cookie := &http.Cookie{ @@ -153,7 +153,7 @@ func TestHTTPCookiesScanHandler_Failed_WhithoutSameSite(t *testing.T) { httpmock.ActivateNonDefault(client.Client) defer httpmock.DeactivateAndReset() - securityScheme := auth.NewNoAuthSecurityScheme() + securityScheme := auth.MustNewNoAuthSecurityScheme() operation := operation.MustNewOperation(http.MethodGet, "http://localhost:8080/", nil, client) resp := httpmock.NewStringResponse(http.StatusOK, "OK") cookie := &http.Cookie{ @@ -181,7 +181,7 @@ func TestHTTPCookiesScanHandler_Failed_WhenExpiresNotSet(t *testing.T) { httpmock.ActivateNonDefault(client.Client) defer httpmock.DeactivateAndReset() - securityScheme := auth.NewNoAuthSecurityScheme() + securityScheme := auth.MustNewNoAuthSecurityScheme() operation := operation.MustNewOperation(http.MethodGet, "http://localhost:8080/", nil, client) resp := httpmock.NewStringResponse(http.StatusOK, "OK") cookie := &http.Cookie{ diff --git a/scan/misconfiguration/http_headers/http_headers.go b/scan/misconfiguration/http_headers/http_headers.go index 31c4882d..7410d5ba 100644 --- a/scan/misconfiguration/http_headers/http_headers.go +++ b/scan/misconfiguration/http_headers/http_headers.go @@ -156,7 +156,7 @@ func CheckCSPFrameAncestors(cspHeader string) bool { return false } -func ScanHandler(op *operation.Operation, securityScheme auth.SecurityScheme) (*report.ScanReport, error) { +func ScanHandler(op *operation.Operation, securityScheme *auth.SecurityScheme) (*report.ScanReport, error) { contentOptionsMissing := report.NewIssueReport(contentOptionsMissingIssue).WithOperation(op).WithSecurityScheme(securityScheme) corsMissing := report.NewIssueReport(corsMissingIssue).WithOperation(op).WithSecurityScheme(securityScheme) corsWildcard := report.NewIssueReport(corsWildcardIssue).WithOperation(op).WithSecurityScheme(securityScheme) @@ -165,7 +165,7 @@ func ScanHandler(op *operation.Operation, securityScheme auth.SecurityScheme) (* frameOptionsMissing := report.NewIssueReport(frameOptionsMissingIssue).WithOperation(op).WithSecurityScheme(securityScheme) hstsMissing := report.NewIssueReport(hstsMissingIssue).WithOperation(op).WithSecurityScheme(securityScheme) - attempt, err := scan.ScanURL(op, &securityScheme) + attempt, err := scan.ScanURL(op, securityScheme) r := report.NewScanReport(HTTPHeadersScanID, HTTPHeadersScanName, op) if err != nil { return r, err diff --git a/scan/misconfiguration/http_headers/http_headers_test.go b/scan/misconfiguration/http_headers/http_headers_test.go index 15b494af..b7271dd7 100644 --- a/scan/misconfiguration/http_headers/http_headers_test.go +++ b/scan/misconfiguration/http_headers/http_headers_test.go @@ -30,7 +30,7 @@ func TestHTTPHeadersScanHandler_Passed(t *testing.T) { defer httpmock.DeactivateAndReset() token := "token" - securityScheme := auth.NewAuthorizationBearerSecurityScheme("default", &token) + securityScheme := auth.MustNewAuthorizationBearerSecurityScheme("default", &token) operation := operation.MustNewOperation(http.MethodGet, "http://localhost:8080/", nil, client) httpmock.RegisterResponder(operation.Method, operation.URL.String(), httpmock.NewBytesResponder(http.StatusNoContent, nil).HeaderAdd(getValidHTTPHeaders(operation))) @@ -54,7 +54,7 @@ func TestHTTPHeadersBestPracticesWithoutCSPScanHandler(t *testing.T) { defer httpmock.DeactivateAndReset() token := "token" - securityScheme := auth.NewAuthorizationBearerSecurityScheme("default", &token) + securityScheme := auth.MustNewAuthorizationBearerSecurityScheme("default", &token) operation := operation.MustNewOperation(http.MethodGet, "http://localhost:8080/", nil, client) header := getValidHTTPHeaders(operation) header.Del(httpheaders.CSPHTTPHeader) @@ -73,7 +73,7 @@ func TestHTTPHeadersBestPracticesWithoutFrameAncestorsCSPDirectiveScanHandler(t defer httpmock.DeactivateAndReset() token := "token" - securityScheme := auth.NewAuthorizationBearerSecurityScheme("default", &token) + securityScheme := auth.MustNewAuthorizationBearerSecurityScheme("default", &token) operation := operation.MustNewOperation(http.MethodGet, "http://localhost:8080/", nil, client) header := getValidHTTPHeaders(operation) header.Set(httpheaders.CSPHTTPHeader, "default-src 'self' http://example.com; connect-src 'none'") @@ -92,7 +92,7 @@ func TestHTTPHeadersBestPracticesWithNotNoneFrameAncestorsCSPDirectiveScanHandle defer httpmock.DeactivateAndReset() token := "token" - securityScheme := auth.NewAuthorizationBearerSecurityScheme("default", &token) + securityScheme := auth.MustNewAuthorizationBearerSecurityScheme("default", &token) operation := operation.MustNewOperation(http.MethodGet, "http://localhost:8080/", nil, client) header := getValidHTTPHeaders(operation) header.Set(httpheaders.CSPHTTPHeader, "default-src 'self' http://example.com; connect-src 'none'; frame-ancestors 'http://example.com'") @@ -111,7 +111,7 @@ func TestHTTPHeadersBestPracticesWithoutCORSScanHandler(t *testing.T) { defer httpmock.DeactivateAndReset() token := "token" - securityScheme := auth.NewAuthorizationBearerSecurityScheme("default", &token) + securityScheme := auth.MustNewAuthorizationBearerSecurityScheme("default", &token) operation := operation.MustNewOperation(http.MethodGet, "http://localhost:8080/", nil, client) header := getValidHTTPHeaders(operation) header.Del(httpheaders.CORSOriginHTTPHeader) @@ -130,7 +130,7 @@ func TestHTTPHeadersBestPracticesWithPermissiveCORSScanHandler(t *testing.T) { defer httpmock.DeactivateAndReset() token := "token" - securityScheme := auth.NewAuthorizationBearerSecurityScheme("default", &token) + securityScheme := auth.MustNewAuthorizationBearerSecurityScheme("default", &token) operation := operation.MustNewOperation(http.MethodGet, "http://localhost:8080/", nil, client) header := getValidHTTPHeaders(operation) header.Set(httpheaders.CORSOriginHTTPHeader, "*") @@ -149,7 +149,7 @@ func TestHTTPHeadersBestPracticesWithoutHSTSScanHandler(t *testing.T) { defer httpmock.DeactivateAndReset() token := "token" - securityScheme := auth.NewAuthorizationBearerSecurityScheme("default", &token) + securityScheme := auth.MustNewAuthorizationBearerSecurityScheme("default", &token) operation := operation.MustNewOperation(http.MethodGet, "http://localhost:8080/", nil, client) header := getValidHTTPHeaders(operation) header.Del(httpheaders.HSTSHTTPHeader) @@ -168,7 +168,7 @@ func TestHTTPHeadersBestPracticesWithoutXContentTypeOptionsScanHandler(t *testin defer httpmock.DeactivateAndReset() token := "token" - securityScheme := auth.NewAuthorizationBearerSecurityScheme("default", &token) + securityScheme := auth.MustNewAuthorizationBearerSecurityScheme("default", &token) operation := operation.MustNewOperation(http.MethodGet, "http://localhost:8080/", nil, client) header := getValidHTTPHeaders(operation) header.Del(httpheaders.XContentTypeOptionsHTTPHeader) @@ -187,7 +187,7 @@ func TestHTTPHeadersBestPracticesWithoutXFrameOptionsScanHandler(t *testing.T) { defer httpmock.DeactivateAndReset() token := "token" - securityScheme := auth.NewAuthorizationBearerSecurityScheme("default", &token) + securityScheme := auth.MustNewAuthorizationBearerSecurityScheme("default", &token) operation := operation.MustNewOperation(http.MethodGet, "http://localhost:8080/", nil, client) header := getValidHTTPHeaders(operation) header.Del(httpheaders.XFrameOptionsHTTPHeader) diff --git a/scan/misconfiguration/http_method_override/http_method_override.go b/scan/misconfiguration/http_method_override/http_method_override.go index 9f16eb61..750e372f 100644 --- a/scan/misconfiguration/http_method_override/http_method_override.go +++ b/scan/misconfiguration/http_method_override/http_method_override.go @@ -72,7 +72,7 @@ var methodOverrideQueryParams = []string{ "_httpMethod", } -func ScanHandler(op *operation.Operation, securityScheme auth.SecurityScheme) (*report.ScanReport, error) { +func ScanHandler(op *operation.Operation, securityScheme *auth.SecurityScheme) (*report.ScanReport, error) { var err error var newOperation *operation.Operation @@ -85,7 +85,7 @@ func ScanHandler(op *operation.Operation, securityScheme auth.SecurityScheme) (* return r, err } - initialAttempt, err := scan.ScanURL(newOperation, &securityScheme) + initialAttempt, err := scan.ScanURL(newOperation, securityScheme) if err != nil { return r, err } @@ -108,7 +108,7 @@ func ScanHandler(op *operation.Operation, securityScheme auth.SecurityScheme) (* } newOperation.Method = method - methodAttempt, err = scan.ScanURL(newOperation, &securityScheme) + methodAttempt, err = scan.ScanURL(newOperation, securityScheme) if methodAttempt != nil { r.AddScanAttempt(methodAttempt) } @@ -139,7 +139,7 @@ func ScanHandler(op *operation.Operation, securityScheme auth.SecurityScheme) (* newOperation.Header.Set(header, op.Method) newOperation.Method = newOperationMethod - attempt, err = scan.ScanURL(newOperation, &securityScheme) + attempt, err = scan.ScanURL(newOperation, securityScheme) if attempt != nil { r.AddScanAttempt(attempt) } @@ -161,7 +161,7 @@ func ScanHandler(op *operation.Operation, securityScheme auth.SecurityScheme) (* newOperationQueryValues.Set(queryParam, op.Method) newOperation.URL.RawQuery = newOperationQueryValues.Encode() newOperation.Method = newOperationMethod - attempt, err = scan.ScanURL(newOperation, &securityScheme) + attempt, err = scan.ScanURL(newOperation, securityScheme) if attempt != nil { r.AddScanAttempt(attempt) } @@ -179,12 +179,11 @@ func ScanHandler(op *operation.Operation, securityScheme auth.SecurityScheme) (* } r.AddIssueReport(httpMethodOverrideIssueReport.Fail()) - if _, ok := securityScheme.(*auth.NoAuthSecurityScheme); ok { + if securityScheme.GetType() == auth.None { return r.AddIssueReport(httpMethodOverrideAuthenticationByPassIssueReport.Skip()).End(), nil } - noAuthSecurityScheme := auth.SecurityScheme(auth.NewNoAuthSecurityScheme()) - attempt, err = scan.ScanURL(newOperation, &noAuthSecurityScheme) + attempt, err = scan.ScanURL(newOperation, auth.MustNewNoAuthSecurityScheme()) if err != nil { return r, err } diff --git a/scan/misconfiguration/http_method_override/http_method_override_test.go b/scan/misconfiguration/http_method_override/http_method_override_test.go index 5bb4f767..6ac24d0e 100644 --- a/scan/misconfiguration/http_method_override/http_method_override_test.go +++ b/scan/misconfiguration/http_method_override/http_method_override_test.go @@ -21,22 +21,22 @@ func TestHTTPMethodOverrideScanHandler(t *testing.T) { tests := []struct { name string operation *operation.Operation - securityScheme auth.SecurityScheme + securityScheme *auth.SecurityScheme }{ { name: "MethodNotAllowed", operation: operation.MustNewOperation(http.MethodGet, "http://example.com", nil, nil), - securityScheme: auth.NewNoAuthSecurityScheme(), + securityScheme: auth.MustNewNoAuthSecurityScheme(), }, { name: "MethodOverrideDetected", operation: operation.MustNewOperation(http.MethodPost, "http://example.com/test", nil, nil), - securityScheme: auth.NewNoAuthSecurityScheme(), + securityScheme: auth.MustNewNoAuthSecurityScheme(), }, { name: "AuthenticationBypassDetected", operation: operation.MustNewOperation(http.MethodPost, "http://example.com/test", nil, nil), - securityScheme: auth.MustNewAuthorizationJWTBearerSecurityScheme("securityScheme", &value), + securityScheme: auth.MustNewAuthorizationBearerSecurityScheme("securityScheme", &value), }, } @@ -59,7 +59,7 @@ func TestHTTPMethodOverrideScanHandler_When_Error(t *testing.T) { httpmock.ActivateNonDefault(client.Client) defer httpmock.DeactivateAndReset() - securityScheme := auth.NewNoAuthSecurityScheme() + securityScheme := auth.MustNewNoAuthSecurityScheme() operation := operation.MustNewOperation(http.MethodGet, "http://localhost:8080/", nil, client) httpmock.RegisterResponder(operation.Method, operation.URL.String(), httpmock.NewBytesResponder(http.StatusNoContent, nil)) @@ -78,7 +78,7 @@ func TestHTTPMethodOverrideScanHandler_Passed(t *testing.T) { httpmock.ActivateNonDefault(client.Client) defer httpmock.DeactivateAndReset() - securityScheme := auth.NewNoAuthSecurityScheme() + securityScheme := auth.MustNewNoAuthSecurityScheme() operation := operation.MustNewOperation(http.MethodGet, "http://localhost:8080/", nil, client) httpmock.RegisterResponder(operation.Method, operation.URL.String(), httpmock.NewBytesResponder(http.StatusNoContent, nil)) httpmock.RegisterResponder(http.MethodHead, operation.URL.String(), httpmock.NewBytesResponder(http.StatusNoContent, nil)) @@ -99,7 +99,7 @@ func TestHTTPMethodOverrideScanHandler_Failed_With_Header(t *testing.T) { httpmock.ActivateNonDefault(client.Client) defer httpmock.DeactivateAndReset() - securityScheme := auth.NewNoAuthSecurityScheme() + securityScheme := auth.MustNewNoAuthSecurityScheme() operation := operation.MustNewOperation(http.MethodGet, "http://localhost:8080/", nil, client) httpmock.RegisterResponder(operation.Method, operation.URL.String(), httpmock.NewBytesResponder(http.StatusNoContent, nil)) httpmock.RegisterResponder(http.MethodHead, operation.URL.String(), httpmock.NewBytesResponder(http.StatusNoContent, nil)) @@ -125,7 +125,7 @@ func TestHTTPMethodOverrideScanHandler_Failed_With_Query_Parameter(t *testing.T) httpmock.ActivateNonDefault(client.Client) defer httpmock.DeactivateAndReset() - securityScheme := auth.NewNoAuthSecurityScheme() + securityScheme := auth.MustNewNoAuthSecurityScheme() operation := operation.MustNewOperation(http.MethodGet, "http://localhost:8080/", nil, client) httpmock.RegisterResponder(operation.Method, operation.URL.String(), httpmock.NewBytesResponder(http.StatusNoContent, nil)) httpmock.RegisterResponder(http.MethodHead, operation.URL.String(), httpmock.NewBytesResponder(http.StatusNoContent, nil)) @@ -153,7 +153,7 @@ func TestHTTPMethodOverrideScanHandler_Authentication_ByPass_Passed(t *testing.T defer httpmock.DeactivateAndReset() token := jwt.FakeJWT - securityScheme := auth.NewAuthorizationBearerSecurityScheme("securityScheme", &token) + securityScheme := auth.MustNewAuthorizationBearerSecurityScheme("securityScheme", &token) operation := operation.MustNewOperation(http.MethodGet, "http://localhost:8080/", nil, client) httpmock.RegisterResponder(operation.Method, operation.URL.String(), httpmock.NewBytesResponder(http.StatusNoContent, nil)) httpmock.RegisterResponder(http.MethodHead, operation.URL.String(), httpmock.NewBytesResponder(http.StatusNoContent, nil)) @@ -183,7 +183,7 @@ func TestHTTPMethodOverrideScanHandler_Authentication_ByPass_Failed(t *testing.T defer httpmock.DeactivateAndReset() token := jwt.FakeJWT - securityScheme := auth.NewAuthorizationBearerSecurityScheme("securityScheme", &token) + securityScheme := auth.MustNewAuthorizationBearerSecurityScheme("securityScheme", &token) operation := operation.MustNewOperation(http.MethodGet, "http://localhost:8080/", nil, client) httpmock.RegisterResponder(operation.Method, operation.URL.String(), httpmock.NewBytesResponder(http.StatusNoContent, nil)) httpmock.RegisterResponder(http.MethodHead, operation.URL.String(), httpmock.NewBytesResponder(http.StatusNoContent, nil)) diff --git a/scan/misconfiguration/http_trace/http_trace_method.go b/scan/misconfiguration/http_trace/http_trace_method.go index 897db9a5..470789ab 100644 --- a/scan/misconfiguration/http_trace/http_trace_method.go +++ b/scan/misconfiguration/http_trace/http_trace_method.go @@ -31,7 +31,7 @@ var issue = report.Issue{ }, } -func ScanHandler(operation *operation.Operation, securityScheme auth.SecurityScheme) (*report.ScanReport, error) { +func ScanHandler(operation *operation.Operation, securityScheme *auth.SecurityScheme) (*report.ScanReport, error) { vulnReport := report.NewIssueReport(issue).WithOperation(operation).WithSecurityScheme(securityScheme) r := report.NewScanReport(HTTPTraceScanID, HTTPTraceScanName, operation) @@ -41,7 +41,7 @@ func ScanHandler(operation *operation.Operation, securityScheme auth.SecuritySch } newOperation.Method = http.MethodTrace - attempt, err := scan.ScanURL(newOperation, &securityScheme) + attempt, err := scan.ScanURL(newOperation, securityScheme) r.AddScanAttempt(attempt).End().AddIssueReport(vulnReport.WithBooleanStatus(err != nil || attempt.Response.GetStatusCode() != http.StatusOK)) return r, nil diff --git a/scan/misconfiguration/http_trace/http_trace_method_test.go b/scan/misconfiguration/http_trace/http_trace_method_test.go index 3fe1146c..d496bd83 100644 --- a/scan/misconfiguration/http_trace/http_trace_method_test.go +++ b/scan/misconfiguration/http_trace/http_trace_method_test.go @@ -21,7 +21,7 @@ func TestHTTPTraceMethodScanHandler_Passed_WhenNotOKResponse(t *testing.T) { operation := operation.MustNewOperation(http.MethodGet, "http://localhost:8080/", nil, client) httpmock.RegisterResponder(http.MethodTrace, operation.URL.String(), httpmock.NewBytesResponder(http.StatusUnauthorized, nil)) - report, err := httptrace.ScanHandler(operation, auth.NewNoAuthSecurityScheme()) + report, err := httptrace.ScanHandler(operation, auth.MustNewNoAuthSecurityScheme()) require.NoError(t, err) assert.Equal(t, 1, httpmock.GetTotalCallCount()) @@ -36,7 +36,7 @@ func TestHTTPTraceMethodScanHandler_Failed_WhenTraceIsEnabled(t *testing.T) { operation := operation.MustNewOperation(http.MethodGet, "http://localhost:8080/", nil, client) httpmock.RegisterResponder(http.MethodTrace, operation.URL.String(), httpmock.NewBytesResponder(http.StatusOK, nil)) - report, err := httptrace.ScanHandler(operation, auth.NewNoAuthSecurityScheme()) + report, err := httptrace.ScanHandler(operation, auth.MustNewNoAuthSecurityScheme()) require.NoError(t, err) assert.Equal(t, 1, httpmock.GetTotalCallCount()) diff --git a/scan/misconfiguration/http_track/http_track_method.go b/scan/misconfiguration/http_track/http_track_method.go index c1c079ea..d10563bf 100644 --- a/scan/misconfiguration/http_track/http_track_method.go +++ b/scan/misconfiguration/http_track/http_track_method.go @@ -33,7 +33,7 @@ var issue = report.Issue{ const TrackMethod = "TRACK" -func ScanHandler(operation *operation.Operation, securityScheme auth.SecurityScheme) (*report.ScanReport, error) { +func ScanHandler(operation *operation.Operation, securityScheme *auth.SecurityScheme) (*report.ScanReport, error) { vulnReport := report.NewIssueReport(issue).WithOperation(operation).WithSecurityScheme(securityScheme) r := report.NewScanReport(HTTPTrackScanID, HTTPTrackScanName, operation) @@ -43,7 +43,7 @@ func ScanHandler(operation *operation.Operation, securityScheme auth.SecuritySch } newOperation.Method = TrackMethod - attempt, err := scan.ScanURL(newOperation, &securityScheme) + attempt, err := scan.ScanURL(newOperation, securityScheme) r.AddScanAttempt(attempt).End().AddIssueReport(vulnReport.WithBooleanStatus(err != nil || attempt.Response.GetStatusCode() != http.StatusOK)) return r, nil diff --git a/scan/misconfiguration/http_track/http_track_method_test.go b/scan/misconfiguration/http_track/http_track_method_test.go index c2b11f63..dfaa8dce 100644 --- a/scan/misconfiguration/http_track/http_track_method_test.go +++ b/scan/misconfiguration/http_track/http_track_method_test.go @@ -21,7 +21,7 @@ func TestHTTPTrackMethodScanHandler_Passed_WhenNotOKResponse(t *testing.T) { op := operation.MustNewOperation(http.MethodGet, "http://localhost:8080/", nil, client) httpmock.RegisterResponder(httptrack.TrackMethod, op.URL.String(), httpmock.NewBytesResponder(http.StatusUnauthorized, nil)) - report, err := httptrack.ScanHandler(op, auth.NewNoAuthSecurityScheme()) + report, err := httptrack.ScanHandler(op, auth.MustNewNoAuthSecurityScheme()) require.NoError(t, err) assert.Equal(t, 1, httpmock.GetTotalCallCount()) @@ -36,7 +36,7 @@ func TestHTTPTrackMethodScanHandler_Failed_WhenTrackIsEnabled(t *testing.T) { op := operation.MustNewOperation(http.MethodGet, "http://localhost:8080/", nil, client) httpmock.RegisterResponder(httptrack.TrackMethod, op.URL.String(), httpmock.NewBytesResponder(http.StatusOK, nil)) - report, err := httptrack.ScanHandler(op, auth.NewNoAuthSecurityScheme()) + report, err := httptrack.ScanHandler(op, auth.MustNewNoAuthSecurityScheme()) require.NoError(t, err) assert.Equal(t, 1, httpmock.GetTotalCallCount()) diff --git a/scan/operation_scan.go b/scan/operation_scan.go index 4425acb3..656cdfbc 100644 --- a/scan/operation_scan.go +++ b/scan/operation_scan.go @@ -6,7 +6,7 @@ import ( "github.com/cerberauth/vulnapi/report" ) -type OperationScanHandlerFunc func(operation *operation.Operation, ss auth.SecurityScheme) (*report.ScanReport, error) +type OperationScanHandlerFunc func(operation *operation.Operation, securityScheme *auth.SecurityScheme) (*report.ScanReport, error) type OperationScanHandler struct { ID string diff --git a/scan/operation_scan_test.go b/scan/operation_scan_test.go index 55138f72..5cc70875 100644 --- a/scan/operation_scan_test.go +++ b/scan/operation_scan_test.go @@ -11,7 +11,7 @@ import ( ) func TestNewOperationScanHandler(t *testing.T) { - handlerFunc := func(operation *operation.Operation, ss auth.SecurityScheme) (*report.ScanReport, error) { + handlerFunc := func(operation *operation.Operation, securityScheme *auth.SecurityScheme) (*report.ScanReport, error) { return &report.ScanReport{ID: "test-report"}, nil } handlerID := "test-handler" diff --git a/scan/scan_test.go b/scan/scan_test.go index 219915dd..39565cc8 100644 --- a/scan/scan_test.go +++ b/scan/scan_test.go @@ -74,7 +74,7 @@ func TestScanGetOperationsScans(t *testing.T) { op := operation.MustNewOperation(http.MethodGet, "http://localhost:8080/", nil, nil) operations := operation.Operations{op} s, _ := scan.NewScan(operations, nil) - s.AddOperationScanHandler(scan.NewOperationScanHandler("test-handler", func(operation *operation.Operation, ss auth.SecurityScheme) (*report.ScanReport, error) { + s.AddOperationScanHandler(scan.NewOperationScanHandler("test-handler", func(operation *operation.Operation, securityScheme *auth.SecurityScheme) (*report.ScanReport, error) { return nil, nil })) @@ -99,7 +99,7 @@ func TestScanExecuteWithHandler(t *testing.T) { op := operation.MustNewOperation(http.MethodGet, "http://localhost:8080/", nil, nil) operations := operation.Operations{op} s, _ := scan.NewScan(operations, nil) - handler := scan.NewOperationScanHandler("test-handler", func(operation *operation.Operation, ss auth.SecurityScheme) (*report.ScanReport, error) { + handler := scan.NewOperationScanHandler("test-handler", func(operation *operation.Operation, securityScheme *auth.SecurityScheme) (*report.ScanReport, error) { return &report.ScanReport{ID: "test-report"}, nil }) s.AddOperationScanHandler(handler) @@ -118,7 +118,7 @@ func TestScanExecuteWithIncludeScans(t *testing.T) { s, _ := scan.NewScan(operations, &scan.ScanOptions{ IncludeScans: []string{"test-handler"}, }) - handler := scan.NewOperationScanHandler("test-handler", func(operation *operation.Operation, ss auth.SecurityScheme) (*report.ScanReport, error) { + handler := scan.NewOperationScanHandler("test-handler", func(operation *operation.Operation, securityScheme *auth.SecurityScheme) (*report.ScanReport, error) { return &report.ScanReport{ID: "test-report"}, nil }) s.AddOperationScanHandler(handler) @@ -137,7 +137,7 @@ func TestScanExecuteWithEmptyStringIncludeScans(t *testing.T) { s, _ := scan.NewScan(operations, &scan.ScanOptions{ IncludeScans: []string{""}, }) - handler := scan.NewOperationScanHandler("test-handler", func(operation *operation.Operation, ss auth.SecurityScheme) (*report.ScanReport, error) { + handler := scan.NewOperationScanHandler("test-handler", func(operation *operation.Operation, securityScheme *auth.SecurityScheme) (*report.ScanReport, error) { return &report.ScanReport{ID: "test-report"}, nil }) s.AddOperationScanHandler(handler) @@ -156,7 +156,7 @@ func TestScanExecuteWithMatchStringIncludeScans(t *testing.T) { s, _ := scan.NewScan(operations, &scan.ScanOptions{ IncludeScans: []string{"category.*"}, }) - handler := scan.NewOperationScanHandler("category.test-handler", func(operation *operation.Operation, ss auth.SecurityScheme) (*report.ScanReport, error) { + handler := scan.NewOperationScanHandler("category.test-handler", func(operation *operation.Operation, securityScheme *auth.SecurityScheme) (*report.ScanReport, error) { return &report.ScanReport{ID: "test-report"}, nil }) s.AddOperationScanHandler(handler) @@ -175,7 +175,7 @@ func TestScanExecuteWithWrongMatchStringIncludeScans(t *testing.T) { s, _ := scan.NewScan(operations, &scan.ScanOptions{ IncludeScans: []string{"wrong-category.*"}, }) - handler := scan.NewOperationScanHandler("category.test-handler", func(operation *operation.Operation, ss auth.SecurityScheme) (*report.ScanReport, error) { + handler := scan.NewOperationScanHandler("category.test-handler", func(operation *operation.Operation, securityScheme *auth.SecurityScheme) (*report.ScanReport, error) { return &report.ScanReport{ID: "test-report"}, nil }) s.AddOperationScanHandler(handler) @@ -193,7 +193,7 @@ func TestScanExecuteWithExcludeScans(t *testing.T) { s, _ := scan.NewScan(operations, &scan.ScanOptions{ ExcludeScans: []string{"test-handler"}, }) - handler := scan.NewOperationScanHandler("test-handler", func(operation *operation.Operation, ss auth.SecurityScheme) (*report.ScanReport, error) { + handler := scan.NewOperationScanHandler("test-handler", func(operation *operation.Operation, securityScheme *auth.SecurityScheme) (*report.ScanReport, error) { return &report.ScanReport{ID: "test-report"}, nil }) s.AddOperationScanHandler(handler) @@ -211,7 +211,7 @@ func TestScanExecuteWithMatchStringExcludeScans(t *testing.T) { s, _ := scan.NewScan(operations, &scan.ScanOptions{ ExcludeScans: []string{"category.*"}, }) - handler := scan.NewOperationScanHandler("category.test-handler", func(operation *operation.Operation, ss auth.SecurityScheme) (*report.ScanReport, error) { + handler := scan.NewOperationScanHandler("category.test-handler", func(operation *operation.Operation, securityScheme *auth.SecurityScheme) (*report.ScanReport, error) { return &report.ScanReport{ID: "test-report"}, nil }) s.AddOperationScanHandler(handler) @@ -229,7 +229,7 @@ func TestScanExecuteWithWrongMatchStringExcludeScans(t *testing.T) { s, _ := scan.NewScan(operations, &scan.ScanOptions{ ExcludeScans: []string{"wrong-category.*"}, }) - handler := scan.NewOperationScanHandler("category.test-handler", func(operation *operation.Operation, ss auth.SecurityScheme) (*report.ScanReport, error) { + handler := scan.NewOperationScanHandler("category.test-handler", func(operation *operation.Operation, securityScheme *auth.SecurityScheme) (*report.ScanReport, error) { return &report.ScanReport{ID: "test-report"}, nil }) s.AddOperationScanHandler(handler) diff --git a/scenario/graphql.go b/scenario/graphql.go index c895208a..61ff0875 100644 --- a/scenario/graphql.go +++ b/scenario/graphql.go @@ -21,11 +21,11 @@ func NewGraphQLScan(url string, client *request.Client, opts *scan.ScanOptions) return nil, err } - var securitySchemes []auth.SecurityScheme + var securitySchemes []*auth.SecurityScheme if securityScheme != nil { - securitySchemes = []auth.SecurityScheme{securityScheme} + securitySchemes = []*auth.SecurityScheme{securityScheme} } else { - securitySchemes = []auth.SecurityScheme{auth.NewNoAuthSecurityScheme()} + securitySchemes = []*auth.SecurityScheme{auth.MustNewNoAuthSecurityScheme()} } url = addDefaultProtocolWhenMissing(url) diff --git a/scenario/graphql_test.go b/scenario/graphql_test.go index a905a4a4..25e5d1af 100644 --- a/scenario/graphql_test.go +++ b/scenario/graphql_test.go @@ -24,7 +24,7 @@ func TestNewGraphQLScan(t *testing.T) { require.NoError(t, err) assert.Equal(t, server.URL, s.Operations[0].URL.String()) assert.Equal(t, http.MethodPost, s.Operations[0].Method) - assert.Equal(t, []auth.SecurityScheme{auth.NewNoAuthSecurityScheme()}, s.Operations[0].SecuritySchemes) + assert.Equal(t, []*auth.SecurityScheme{auth.MustNewNoAuthSecurityScheme()}, s.Operations[0].SecuritySchemes) } func TestNewGraphQLScanWithoutURLProto(t *testing.T) { @@ -39,7 +39,7 @@ func TestNewGraphQLScanWithoutURLProto(t *testing.T) { require.NoError(t, err) assert.Equal(t, "https://"+url, s.Operations[0].URL.String()) assert.Equal(t, http.MethodPost, s.Operations[0].Method) - assert.Equal(t, []auth.SecurityScheme{auth.NewNoAuthSecurityScheme()}, s.Operations[0].SecuritySchemes) + assert.Equal(t, []*auth.SecurityScheme{auth.MustNewNoAuthSecurityScheme()}, s.Operations[0].SecuritySchemes) } func TestNewGraphQLScanWhenNotReachable(t *testing.T) { @@ -67,7 +67,7 @@ func TestNewGraphQLScanWithUpperCaseAuthorizationHeader(t *testing.T) { require.NoError(t, err) assert.Equal(t, server.URL, s.Operations[0].URL.String()) assert.Equal(t, http.MethodPost, s.Operations[0].Method) - assert.Equal(t, []auth.SecurityScheme{auth.NewAuthorizationBearerSecurityScheme("default", &token)}, s.Operations[0].SecuritySchemes) + assert.Equal(t, []*auth.SecurityScheme{auth.MustNewAuthorizationBearerSecurityScheme("default", &token)}, s.Operations[0].SecuritySchemes) } func TestNewGraphQLScanWithUpperCaseAuthorizationAndLowerCaseBearerHeader(t *testing.T) { @@ -88,7 +88,7 @@ func TestNewGraphQLScanWithUpperCaseAuthorizationAndLowerCaseBearerHeader(t *tes require.NoError(t, err) assert.Equal(t, server.URL, s.Operations[0].URL.String()) assert.Equal(t, http.MethodPost, s.Operations[0].Method) - assert.Equal(t, []auth.SecurityScheme{auth.NewAuthorizationBearerSecurityScheme("default", &token)}, s.Operations[0].SecuritySchemes) + assert.Equal(t, []*auth.SecurityScheme{auth.MustNewAuthorizationBearerSecurityScheme("default", &token)}, s.Operations[0].SecuritySchemes) } func TestNewGraphQLScanWithLowerCaseAuthorizationHeader(t *testing.T) { @@ -109,5 +109,5 @@ func TestNewGraphQLScanWithLowerCaseAuthorizationHeader(t *testing.T) { require.NoError(t, err) assert.Equal(t, server.URL, s.Operations[0].URL.String()) assert.Equal(t, http.MethodPost, s.Operations[0].Method) - assert.Equal(t, []auth.SecurityScheme{auth.NewAuthorizationBearerSecurityScheme("default", &token)}, s.Operations[0].SecuritySchemes) + assert.Equal(t, []*auth.SecurityScheme{auth.MustNewAuthorizationBearerSecurityScheme("default", &token)}, s.Operations[0].SecuritySchemes) } diff --git a/scenario/openapi.go b/scenario/openapi.go index c1160e90..8ed1ae99 100644 --- a/scenario/openapi.go +++ b/scenario/openapi.go @@ -1,14 +1,13 @@ package scenario import ( - "github.com/cerberauth/vulnapi/internal/auth" "github.com/cerberauth/vulnapi/internal/request" "github.com/cerberauth/vulnapi/openapi" "github.com/cerberauth/vulnapi/report" "github.com/cerberauth/vulnapi/scan" ) -func NewOpenAPIScan(openapi *openapi.OpenAPI, securitySchemesValues *auth.SecuritySchemeValues, client *request.Client, opts *scan.ScanOptions) (*scan.Scan, error) { +func NewOpenAPIScan(openapi *openapi.OpenAPI, securitySchemesValues *openapi.SecuritySchemeValues, client *request.Client, opts *scan.ScanOptions) (*scan.Scan, error) { if client == nil { client = request.GetDefaultClient() } diff --git a/scenario/openapi_test.go b/scenario/openapi_test.go index 5a94039f..df5364d8 100644 --- a/scenario/openapi_test.go +++ b/scenario/openapi_test.go @@ -45,7 +45,7 @@ func TestMain(m *testing.M) { func TestNewOpenAPIScanWithHttpBearer(t *testing.T) { token := "token" doc, _ := openapi.LoadOpenAPI(context.Background(), "../test/stub/simple_http_bearer.openapi.json") - securitySchemeValues := auth.NewSecuritySchemeValues(map[string]interface{}{ + securitySchemeValues := openapi.NewSecuritySchemeValues(map[string]interface{}{ "bearer_auth": &token, }) @@ -55,14 +55,14 @@ func TestNewOpenAPIScanWithHttpBearer(t *testing.T) { assert.Equal(t, 1, len(s.Operations)) assert.Equal(t, "http://localhost:8080/", s.Operations[0].URL.String()) assert.Equal(t, http.MethodGet, s.Operations[0].Method) - assert.Equal(t, []auth.SecurityScheme{auth.NewAuthorizationBearerSecurityScheme("bearer_auth", &token)}, s.Operations[0].SecuritySchemes) + assert.Equal(t, []*auth.SecurityScheme{auth.MustNewAuthorizationBearerSecurityScheme("bearer_auth", &token)}, s.Operations[0].SecuritySchemes) } func TestNewOpenAPIScanWithJWTHttpBearer(t *testing.T) { token := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.dozjgNryP4J3jVmNHl0w5N_XgL0n3I9PlFUP0THsR8U" doc, _ := openapi.LoadOpenAPI(context.Background(), "../test/stub/simple_http_bearer_jwt.openapi.json") - expectedSecurityScheme, _ := auth.NewAuthorizationJWTBearerSecurityScheme("bearer_auth", &token) - securitySchemeValues := auth.NewSecuritySchemeValues(map[string]interface{}{ + expectedSecurityScheme := auth.MustNewAuthorizationBearerSecurityScheme("bearer_auth", &token) + securitySchemeValues := openapi.NewSecuritySchemeValues(map[string]interface{}{ "bearer_auth": &token, }) @@ -72,7 +72,7 @@ func TestNewOpenAPIScanWithJWTHttpBearer(t *testing.T) { assert.Equal(t, 1, len(s.Operations)) assert.Equal(t, "http://localhost:8080/", s.Operations[0].URL.String()) assert.Equal(t, http.MethodGet, s.Operations[0].Method) - assert.Equal(t, []auth.SecurityScheme{expectedSecurityScheme}, s.Operations[0].SecuritySchemes) + assert.Equal(t, []*auth.SecurityScheme{expectedSecurityScheme}, s.Operations[0].SecuritySchemes) } func TestNewOpenAPIScanWithMultipleOperations(t *testing.T) { @@ -80,8 +80,8 @@ func TestNewOpenAPIScanWithMultipleOperations(t *testing.T) { token := "token" doc, _ := openapi.LoadOpenAPI(context.Background(), "../test/stub/basic_http_bearer.openapi.json") - securitySchemes := []auth.SecurityScheme{auth.NewAuthorizationBearerSecurityScheme("bearer_auth", &token)} - securitySchemeValues := auth.NewSecuritySchemeValues(map[string]interface{}{ + securitySchemes := []*auth.SecurityScheme{auth.MustNewAuthorizationBearerSecurityScheme("bearer_auth", &token)} + securitySchemeValues := openapi.NewSecuritySchemeValues(map[string]interface{}{ "bearer_auth": &token, }) @@ -99,8 +99,8 @@ func TestNewOpenAPIScanWithoutParamsExample(t *testing.T) { token := "token" doc, _ := openapi.LoadOpenAPI(context.Background(), "../test/stub/basic_http_bearer.openapi.json") - securitySchemes := []auth.SecurityScheme{auth.NewAuthorizationBearerSecurityScheme("bearer_auth", &token)} - securitySchemeValues := auth.NewSecuritySchemeValues(map[string]interface{}{ + securitySchemes := []*auth.SecurityScheme{auth.MustNewAuthorizationBearerSecurityScheme("bearer_auth", &token)} + securitySchemeValues := openapi.NewSecuritySchemeValues(map[string]interface{}{ "bearer_auth": &token, }) diff --git a/scenario/url.go b/scenario/url.go index 16032b5d..79b61cb4 100644 --- a/scenario/url.go +++ b/scenario/url.go @@ -22,11 +22,11 @@ func NewURLScan(method string, url string, data string, client *request.Client, return nil, err } - var securitySchemes []auth.SecurityScheme + var securitySchemes []*auth.SecurityScheme if securityScheme != nil { - securitySchemes = []auth.SecurityScheme{securityScheme} + securitySchemes = []*auth.SecurityScheme{securityScheme} } else { - securitySchemes = []auth.SecurityScheme{auth.NewNoAuthSecurityScheme()} + securitySchemes = []*auth.SecurityScheme{auth.MustNewNoAuthSecurityScheme()} } body := bytes.NewBuffer([]byte(data)) diff --git a/scenario/url_test.go b/scenario/url_test.go index 39799ddd..3ab86c4f 100644 --- a/scenario/url_test.go +++ b/scenario/url_test.go @@ -23,7 +23,7 @@ func TestNewURLScan(t *testing.T) { require.NoError(t, err) assert.Equal(t, server.URL, s.Operations[0].URL.String()) assert.Equal(t, http.MethodGet, s.Operations[0].Method) - assert.Equal(t, []auth.SecurityScheme{auth.NewNoAuthSecurityScheme()}, s.Operations[0].SecuritySchemes) + assert.Equal(t, []*auth.SecurityScheme{auth.MustNewNoAuthSecurityScheme()}, s.Operations[0].SecuritySchemes) } func TestNewURLScanWithUpperCaseAuthorizationHeader(t *testing.T) { @@ -44,7 +44,7 @@ func TestNewURLScanWithUpperCaseAuthorizationHeader(t *testing.T) { require.NoError(t, err) assert.Equal(t, server.URL, s.Operations[0].URL.String()) assert.Equal(t, http.MethodGet, s.Operations[0].Method) - assert.Equal(t, []auth.SecurityScheme{auth.NewAuthorizationBearerSecurityScheme("default", &token)}, s.Operations[0].SecuritySchemes) + assert.Equal(t, []*auth.SecurityScheme{auth.MustNewAuthorizationBearerSecurityScheme("default", &token)}, s.Operations[0].SecuritySchemes) } func TestNewURLScanWithUpperCaseAuthorizationAndLowerCaseBearerHeader(t *testing.T) { @@ -65,7 +65,7 @@ func TestNewURLScanWithUpperCaseAuthorizationAndLowerCaseBearerHeader(t *testing require.NoError(t, err) assert.Equal(t, server.URL, s.Operations[0].URL.String()) assert.Equal(t, http.MethodGet, s.Operations[0].Method) - assert.Equal(t, []auth.SecurityScheme{auth.NewAuthorizationBearerSecurityScheme("default", &token)}, s.Operations[0].SecuritySchemes) + assert.Equal(t, []*auth.SecurityScheme{auth.MustNewAuthorizationBearerSecurityScheme("default", &token)}, s.Operations[0].SecuritySchemes) } func TestNewURLScanWithLowerCaseAuthorizationHeader(t *testing.T) { @@ -86,5 +86,5 @@ func TestNewURLScanWithLowerCaseAuthorizationHeader(t *testing.T) { require.NoError(t, err) assert.Equal(t, server.URL, s.Operations[0].URL.String()) assert.Equal(t, http.MethodGet, s.Operations[0].Method) - assert.Equal(t, []auth.SecurityScheme{auth.NewAuthorizationBearerSecurityScheme("default", &token)}, s.Operations[0].SecuritySchemes) + assert.Equal(t, []*auth.SecurityScheme{auth.MustNewAuthorizationBearerSecurityScheme("default", &token)}, s.Operations[0].SecuritySchemes) } diff --git a/scenario/utils.go b/scenario/utils.go index 074982b2..a251ff50 100644 --- a/scenario/utils.go +++ b/scenario/utils.go @@ -6,7 +6,6 @@ import ( "strings" "github.com/cerberauth/vulnapi/internal/auth" - "github.com/cerberauth/vulnapi/jwt" ) const bearerPrefix = auth.BearerPrefix + " " @@ -36,7 +35,7 @@ func getBearerToken(authHeader string) string { return "" } -func detectSecurityScheme(header http.Header) (auth.SecurityScheme, error) { +func detectSecurityScheme(header http.Header) (*auth.SecurityScheme, error) { authHeader := detectAuthorizationHeader(header) if authHeader == "" { return nil, nil @@ -47,12 +46,7 @@ func detectSecurityScheme(header http.Header) (auth.SecurityScheme, error) { return nil, fmt.Errorf("empty authorization header") } - _, err := jwt.NewJWTWriter(token) - if err != nil { - return auth.NewAuthorizationBearerSecurityScheme("default", &token), nil - } else { - return auth.NewAuthorizationJWTBearerSecurityScheme("default", &token) - } + return auth.NewAuthorizationBearerSecurityScheme("default", &token) } func addDefaultProtocolWhenMissing(url string) string {