diff --git a/README.md b/README.md index 5b6aed6..f45200c 100644 --- a/README.md +++ b/README.md @@ -22,6 +22,10 @@ You can test the scanner against example [vulnerability challenges](https://gith Run `vulnapi -h` or `vulnapi help`. +This projects support two way to scan APIs: +- CURL Like scan +- OpenAPI based Scan (experimental) + ## Disclaimer This scanner is provided for educational and informational purposes only. It should not be used for malicious purposes or to attack any system without proper authorization. Always respect the security and privacy of others. diff --git a/cmd/scan/curl.go b/cmd/scan/curl.go new file mode 100644 index 0000000..3d9159e --- /dev/null +++ b/cmd/scan/curl.go @@ -0,0 +1,70 @@ +package scan + +import ( + "log" + "net/http" + "strings" + + "github.com/cerberauth/vulnapi/scan" + "github.com/spf13/cobra" +) + +var ( + url string + method string + headers []string + cookies []string +) + +func NewCURLScanCmd() (scanCmd *cobra.Command) { + scanCmd = &cobra.Command{ + Use: "curl [URL]", + Short: "URL Scan in CURL style", + Args: cobra.ExactArgs(1), + FParseErrWhitelist: cobra.FParseErrWhitelist{ + UnknownFlags: true, + }, + Run: func(cmd *cobra.Command, args []string) { + url = args[0] + + httpHeaders := http.Header{} + for _, h := range headers { + parts := strings.SplitN(h, ":", 2) + httpHeaders.Add(parts[0], strings.TrimLeft(parts[1], " ")) + } + + var httpCookies []http.Cookie + for _, c := range cookies { + parts := strings.SplitN(c, ":", 2) + httpCookies = append(httpCookies, http.Cookie{ + Name: parts[0], + Value: parts[1], + }) + } + + scan, err := scan.NewURLScan(method, url, &httpHeaders, httpCookies, nil) + if err != nil { + log.Fatal(err) + } + + rpr, _, err := scan.WithAllVulnsScans().Execute() + if err != nil { + log.Fatal(err) + } + + if !rpr.HasVulnerability() { + log.Println("Congratulations! No vulnerability has been discovered!") + } + + for _, r := range rpr.GetVulnerabilityReports() { + log.Fatalln(r) + } + }, + } + + scanCmd.PersistentFlags().StringVarP(&method, "request", "X", "GET", "Specify request method to use") + scanCmd.PersistentFlags().StringArrayVarP(&headers, "header", "H", nil, "Pass custom header(s) to target API") + scanCmd.PersistentFlags().StringArrayVarP(&cookies, "cookie", "b", nil, "Send cookies from string") + + return scanCmd +} diff --git a/cmd/scan/openapi.go b/cmd/scan/openapi.go new file mode 100644 index 0000000..8451fca --- /dev/null +++ b/cmd/scan/openapi.go @@ -0,0 +1,62 @@ +package scan + +import ( + "bufio" + "log" + "os" + + "github.com/cerberauth/vulnapi/scan" + "github.com/spf13/cobra" +) + +func isStdinOpen() bool { + stat, _ := os.Stdin.Stat() + return (stat.Mode() & os.ModeCharDevice) == 0 +} + +func readStdin() *string { + scanner := bufio.NewScanner(os.Stdin) + if scanner.Scan() { + t := scanner.Text() + return &t + } + + return nil +} + +func NewOpenAPIScanCmd() (scanCmd *cobra.Command) { + scanCmd = &cobra.Command{ + Use: "openapi [OpenAPIPAth]", + Short: "Full OpenAPI operations scan", + Args: cobra.ExactArgs(1), + Run: func(cmd *cobra.Command, args []string) { + openapiUrlOrPath := args[0] + + var validToken *string + if isStdinOpen() { + stdin := readStdin() + validToken = stdin + } + + scan, err := scan.NewOpenAPIScan(openapiUrlOrPath, validToken, nil) + if err != nil { + log.Fatal(err) + } + + rpr, _, err := scan.WithAllVulnsScans().Execute() + if err != nil { + log.Fatal(err) + } + + if !rpr.HasVulnerability() { + log.Println("Congratulations! No vulnerability has been discovered!") + } + + for _, r := range rpr.GetVulnerabilityReports() { + log.Fatalln(r) + } + }, + } + + return scanCmd +} diff --git a/cmd/scan/root.go b/cmd/scan/root.go index 68107e3..71a72e7 100644 --- a/cmd/scan/root.go +++ b/cmd/scan/root.go @@ -1,54 +1,16 @@ package scan import ( - "bufio" - "fmt" - "log" - - "github.com/cerberauth/vulnapi/scan" "github.com/spf13/cobra" ) -var ( - url string - jwt string -) - func NewScanCmd() (scanCmd *cobra.Command) { scanCmd = &cobra.Command{ - Use: "scan [URL]", + Use: "scan [type]", Short: "API Scan", - // Full API scan coming (not only one URL) - Run: func(cmd *cobra.Command, args []string) { - if len(args) > 0 { - url = args[0] - } - - if jwt == "" { - stdin, err := bufio.NewReader(cmd.InOrStdin()).ReadString('\n') - if err != nil { - log.Fatal(fmt.Errorf("failed process input: %v", err)) - } - jwt = stdin - } - - rpr, _, err := scan.NewScanner(url, &jwt).WithAllScans().Execute() - if err != nil { - log.Fatal(err) - } - - if !rpr.HasVulnerability() { - println("Congratulations! No vulnerability has been discovered!") - } - - for _, r := range rpr.GetVulnerabilityReports() { - log.Println(r) - } - }, } - - scanCmd.PersistentFlags().StringVarP(&url, "url", "u", "", "URL") - scanCmd.PersistentFlags().StringVarP(&jwt, "jwt", "j", "", "Valid JWT") + scanCmd.AddCommand(NewCURLScanCmd()) + scanCmd.AddCommand(NewOpenAPIScanCmd()) return scanCmd } diff --git a/go.mod b/go.mod index c48ea4f..7650d7b 100644 --- a/go.mod +++ b/go.mod @@ -3,11 +3,26 @@ module github.com/cerberauth/vulnapi go 1.22 require ( + github.com/brianvoe/gofakeit/v6 v6.28.0 + github.com/getkin/kin-openapi v0.120.0 github.com/golang-jwt/jwt/v5 v5.2.0 + github.com/jarcoal/httpmock v1.3.1 github.com/spf13/cobra v1.8.0 + github.com/std-uritemplate/std-uritemplate/go v0.0.52 + github.com/stretchr/testify v1.8.4 ) require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/go-openapi/jsonpointer v0.19.6 // indirect + github.com/go-openapi/swag v0.22.4 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/invopop/yaml v0.2.0 // indirect + github.com/josharian/intern v1.0.0 // indirect + github.com/mailru/easyjson v0.7.7 // indirect + github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 // indirect + github.com/perimeterx/marshmallow v1.1.5 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect github.com/spf13/pflag v1.0.5 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index f1968ca..d9b81e8 100644 --- a/go.sum +++ b/go.sum @@ -1,12 +1,66 @@ +github.com/brianvoe/gofakeit/v6 v6.28.0 h1:Xib46XXuQfmlLS2EXRuJpqcw8St6qSZz75OUo0tgAW4= +github.com/brianvoe/gofakeit/v6 v6.28.0/go.mod h1:Xj58BMSnFqcn/fAQeSK+/PLtC5kSb7FJIq4JyGa8vEs= github.com/cpuguy83/go-md2man/v2 v2.0.3/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/getkin/kin-openapi v0.120.0 h1:MqJcNJFrMDFNc07iwE8iFC5eT2k/NPUFDIpNeiZv8Jg= +github.com/getkin/kin-openapi v0.120.0/go.mod h1:PCWw/lfBrJY4HcdqE3jj+QFkaFK8ABoqo7PvqVhXXqw= +github.com/go-openapi/jsonpointer v0.19.6 h1:eCs3fxoIi3Wh6vtgmLTOjdhSpiqphQ+DaPn38N2ZdrE= +github.com/go-openapi/jsonpointer v0.19.6/go.mod h1:osyAmYz/mB/C3I+WsTTSgw1ONzaLJoLCyoi6/zppojs= +github.com/go-openapi/swag v0.22.3/go.mod h1:UzaqsxGiab7freDnrUUra0MwWfN/q7tE4j+VcZ0yl14= +github.com/go-openapi/swag v0.22.4 h1:QLMzNJnMGPRNDCbySlcj1x01tzU8/9LTTL9hZZZogBU= +github.com/go-openapi/swag v0.22.4/go.mod h1:UzaqsxGiab7freDnrUUra0MwWfN/q7tE4j+VcZ0yl14= +github.com/go-test/deep v1.0.8 h1:TDsG77qcSprGbC6vTN8OuXp5g+J+b5Pcguhf7Zt61VM= +github.com/go-test/deep v1.0.8/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncVwcsE= github.com/golang-jwt/jwt/v5 v5.2.0 h1:d/ix8ftRUorsN+5eMIlF4T6J8CAt9rch3My2winC1Jw= github.com/golang-jwt/jwt/v5 v5.2.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/invopop/yaml v0.2.0 h1:7zky/qH+O0DwAyoobXUqvVBwgBFRxKoQ/3FjcVpjTMY= +github.com/invopop/yaml v0.2.0/go.mod h1:2XuRLgs/ouIrW3XNzuNj7J3Nvu/Dig5MXvbCEdiBN3Q= +github.com/jarcoal/httpmock v1.3.1 h1:iUx3whfZWVf3jT01hQTO/Eo5sAYtB2/rqaUuOtpInww= +github.com/jarcoal/httpmock v1.3.1/go.mod h1:3yb8rc4BI7TCBhFY8ng0gjuLKJNquuDNiPaZjnENuYg= +github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= +github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= +github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI= +github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= +github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= +github.com/maxatome/go-testdeep v1.12.0 h1:Ql7Go8Tg0C1D/uMMX59LAoYK7LffeJQ6X2T04nTH68g= +github.com/maxatome/go-testdeep v1.12.0/go.mod h1:lPZc/HAcJMP92l7yI6TRz1aZN5URwUBUAfUNvrclaNM= +github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 h1:RWengNIwukTxcDr9M+97sNutRR1RKhG96O6jWumTTnw= +github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826/go.mod h1:TaXosZuwdSHYgviHp1DAtfrULt5eUgsSMsZf+YrPgl8= +github.com/perimeterx/marshmallow v1.1.5 h1:a2LALqQ1BlHM8PZblsDdidgv1mWi1DgC2UmX50IvK2s= +github.com/perimeterx/marshmallow v1.1.5/go.mod h1:dsXbUu8CRzfYP5a87xpp0xq9S3u0Vchtcl8we9tYaXw= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/spf13/cobra v1.8.0 h1:7aJaZx1B85qltLMc546zn58BxxfZdR/W22ej9CFoEf0= github.com/spf13/cobra v1.8.0/go.mod h1:WXLWApfZ71AjXPya3WOlMsY9yMs7YeiHhFVlvLyhcho= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/std-uritemplate/std-uritemplate/go v0.0.52 h1:2r8rdugq0WZlRDkLlwH/9sKZG2iYXvFCEcKFIKmfSQQ= +github.com/std-uritemplate/std-uritemplate/go v0.0.52/go.mod h1:CLZ1543WRCuUQQjK0BvPM4QrG2toY8xNZUm8Vbt7vTc= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/ugorji/go/codec v1.2.7 h1:YPXUKf7fYbp/y8xloBqZOw2qaVggbfwMlI8WM3wZUJ0= +github.com/ugorji/go/codec v1.2.7/go.mod h1:WGN1fab3R1fzQlVQTkfxVtIBhWDRqOviHU95kRgeqEY= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.0/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/auth/auth.go b/internal/auth/auth.go new file mode 100644 index 0000000..0a1f8d3 --- /dev/null +++ b/internal/auth/auth.go @@ -0,0 +1,29 @@ +package auth + +import "net/http" + +type Type string + +const ( + HttpType Type = "http" + OAuth2 Type = "oauth2" + OpenIdConnect Type = "openIdConnect" + ApiKey Type = "apiKey" +) + +type SecurityScheme interface { + GetHeaders() http.Header + GetCookies() []*http.Cookie + GetValidValue() interface{} + SetAttackValue(v interface{}) + GetAttackValue() interface{} +} + +type Operation struct { + Url string + Method string + Headers *http.Header + Cookies []http.Cookie + + SecuritySchemes []SecurityScheme +} diff --git a/internal/auth/bearer.go b/internal/auth/bearer.go new file mode 100644 index 0000000..4548a2b --- /dev/null +++ b/internal/auth/bearer.go @@ -0,0 +1,53 @@ +package auth + +import ( + "fmt" + "net/http" +) + +type BearerSecurityScheme struct { + Type Type + Scheme SchemeName + In SchemeIn + Name string + ValidValue *string + AttackValue string +} + +var _ SecurityScheme = (*BearerSecurityScheme)(nil) + +func NewAuthorizationBearerSecurityScheme(name string, value *string) *BearerSecurityScheme { + return &BearerSecurityScheme{ + Type: HttpType, + Scheme: BearerScheme, + In: InHeader, + Name: name, + ValidValue: value, + AttackValue: "", + } +} + +func (ss *BearerSecurityScheme) GetHeaders() http.Header { + header := http.Header{} + if ss.ValidValue != nil { + header.Set(AuthorizationHeader, fmt.Sprintf("%s %s", BearerPrefix, *ss.ValidValue)) + } + + return header +} + +func (ss *BearerSecurityScheme) GetCookies() []*http.Cookie { + return []*http.Cookie{} +} + +func (ss *BearerSecurityScheme) GetValidValue() interface{} { + return *ss.ValidValue +} + +func (ss *BearerSecurityScheme) SetAttackValue(v interface{}) { + ss.AttackValue = v.(string) +} + +func (ss *BearerSecurityScheme) GetAttackValue() interface{} { + return ss.AttackValue +} diff --git a/internal/auth/headers.go b/internal/auth/headers.go new file mode 100644 index 0000000..993733f --- /dev/null +++ b/internal/auth/headers.go @@ -0,0 +1,4 @@ +package auth + +const AuthorizationHeader = "Authorization" +const BearerPrefix = "Bearer" diff --git a/internal/auth/scheme.go b/internal/auth/scheme.go new file mode 100644 index 0000000..1ab0f14 --- /dev/null +++ b/internal/auth/scheme.go @@ -0,0 +1,43 @@ +package auth + +import "errors" + +type SchemeName string + +// Values are registred in the IANA Authentication Scheme registry +// https://www.iana.org/assignments/http-authschemes/http-authschemes.xhtml +const ( + BasicScheme SchemeName = "basic" + BearerScheme SchemeName = "bearer" + DigestScheme SchemeName = "digest" + OAuthScheme SchemeName = "oauth" + PrivateToken SchemeName = "privateToken" + + NoneScheme SchemeName = "none" +) + +func (s *SchemeName) String() string { + return string(*s) +} + +func (s *SchemeName) Set(v string) error { + switch v { + case "basic", "bearer", "digest", "oauth", "privateToken": + *s = SchemeName(v) + return nil + default: + return errors.New(`must be one of "basic", "bearer", "digest", "oauth", "privateToken"`) + } +} + +func (e *SchemeName) Type() string { + return "scheme" +} + +type SchemeIn string + +const ( + InHeader SchemeIn = "header" + InCookie SchemeIn = "cookie" + InUnknown SchemeIn = "unknown" +) diff --git a/internal/request/request.go b/internal/request/request.go index db40481..b3cd7de 100644 --- a/internal/request/request.go +++ b/internal/request/request.go @@ -1,11 +1,12 @@ package request import ( - "fmt" "net/http" + + "github.com/cerberauth/vulnapi/internal/auth" ) -func prepareVulnAPIRequest(method string, url string) (*http.Request, error) { +func NewRequest(method string, url string) (*http.Request, error) { req, err := http.NewRequest(method, url, nil) if err != nil { return nil, err @@ -16,20 +17,22 @@ func prepareVulnAPIRequest(method string, url string) (*http.Request, error) { return req, nil } -func SendRequestWithBearerAuth(url string, token string) (*http.Request, *http.Response, error) { - req, err := prepareVulnAPIRequest("GET", url) - if err != nil { - return req, nil, err - } +func DoRequest(client *http.Client, req *http.Request, ss auth.SecurityScheme) (*http.Request, *http.Response, error) { + if ss != nil { + for _, c := range ss.GetCookies() { + req.AddCookie(c) + } - req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", token)) + for n, h := range ss.GetHeaders() { + req.Header.Add(n, h[0]) + } + } - client := &http.Client{} - resp, err := client.Do(req) + res, err := client.Do(req) if err != nil { - return req, resp, err + return req, res, err } - defer resp.Body.Close() + defer res.Body.Close() - return req, resp, nil + return req, res, nil } diff --git a/internal/request/request_test.go b/internal/request/request_test.go new file mode 100644 index 0000000..32ee094 --- /dev/null +++ b/internal/request/request_test.go @@ -0,0 +1,150 @@ +package request + +import ( + "net/http" + "net/url" + "testing" + + "github.com/jarcoal/httpmock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type SecuritySchemeMock struct { + Cookies []*http.Cookie + Headers http.Header + ValidValue interface{} + AttackValue interface{} +} + +func NewSecuritySchemeMock() *SecuritySchemeMock { + return &SecuritySchemeMock{ + Cookies: []*http.Cookie{}, + Headers: http.Header{}, + ValidValue: nil, + AttackValue: nil, + } +} + +func (ss *SecuritySchemeMock) GetCookies() []*http.Cookie { + return ss.Cookies +} + +func (ss *SecuritySchemeMock) GetHeaders() http.Header { + return ss.Headers +} + +func (ss *SecuritySchemeMock) GetValidValue() interface{} { + return ss.ValidValue +} + +func (ss *SecuritySchemeMock) SetAttackValue(v interface{}) { + ss.AttackValue = v +} + +func (ss *SecuritySchemeMock) GetAttackValue() interface{} { + return ss.AttackValue +} + +var reqMethod = "GET" +var reqUrl = "http://localhost:8080" + +func setupSuite(tb testing.TB) func(tb testing.TB) { + httpmock.Activate() + httpmock.RegisterResponder(reqMethod, reqUrl, httpmock.NewBytesResponder(204, nil)) + + return func(tb testing.TB) { + defer httpmock.DeactivateAndReset() + } +} + +func TestNewRequestUserMethodAndUrl(t *testing.T) { + teardownSuite := setupSuite(t) + defer teardownSuite(t) + + req, err := NewRequest(reqMethod, reqUrl) + require.NoError(t, err) + assert.Equal(t, reqMethod, req.Method) + assert.Equal(t, &url.URL{Scheme: "http", Host: "localhost:8080"}, req.URL) + + reqMethod2 := "PUT" + + req2, err2 := NewRequest(reqMethod2, reqUrl) + require.NoError(t, err2) + assert.Equal(t, reqMethod2, req2.Method) + assert.Equal(t, &url.URL{Scheme: "http", Host: "localhost:8080"}, req2.URL) +} + +func TestNewRequestAddUserAgent(t *testing.T) { + teardownSuite := setupSuite(t) + defer teardownSuite(t) + + req, err := NewRequest(reqMethod, reqUrl) + require.NoError(t, err) + assert.Equal(t, "vulnapi/0.1", req.UserAgent()) +} + +func TestNewRequestWithWrongUrl(t *testing.T) { + teardownSuite := setupSuite(t) + defer teardownSuite(t) + + _, err := NewRequest(reqMethod, "://localhost:8080") + require.Error(t, err) +} + +func TestDoRequestWithoutSecurityScheme(t *testing.T) { + teardownSuite := setupSuite(t) + defer teardownSuite(t) + + client := &http.Client{} + req, _ := NewRequest(reqMethod, reqUrl) + + req, res, err := DoRequest(client, req, nil) + require.NoError(t, err) + assert.Equal(t, 0, len(req.Cookies())) + assert.Equal(t, req, req) + assert.NotNil(t, res) + + assert.Equal(t, 1, httpmock.GetTotalCallCount()) +} + +func TestDoRequestWithSecuritySchemeAndCookies(t *testing.T) { + teardownSuite := setupSuite(t) + defer teardownSuite(t) + + ss := NewSecuritySchemeMock() + ss.Cookies = append(ss.Cookies, &http.Cookie{ + Name: "cookie", + Value: "cookie value", + }) + client := &http.Client{} + req, _ := NewRequest(reqMethod, reqUrl) + + req, res, err := DoRequest(client, req, ss) + require.NoError(t, err) + assert.Equal(t, 1, len(req.Cookies())) + assert.Equal(t, ss.Cookies[0].Name, req.Cookies()[0].Name) + assert.Equal(t, ss.Cookies[0].Value, req.Cookies()[0].Value) + assert.NotNil(t, res) + + assert.Equal(t, 1, httpmock.GetTotalCallCount()) +} + +func TestDoRequestWithSecuritySchemeAndHeaders(t *testing.T) { + teardownSuite := setupSuite(t) + defer teardownSuite(t) + + ss := NewSecuritySchemeMock() + ss.Headers = http.Header{} + ss.Headers.Add("header1", "value1") + client := &http.Client{} + req, _ := NewRequest(reqMethod, reqUrl) + + req, res, err := DoRequest(client, req, ss) + require.NoError(t, err) + assert.Equal(t, 0, len(req.Cookies())) + assert.Equal(t, "value1", req.Header.Get("header1")) + assert.NotNil(t, res) + + assert.Equal(t, 1, httpmock.GetTotalCallCount()) +} diff --git a/internal/rest_api/loader.go b/internal/rest_api/loader.go new file mode 100644 index 0000000..19ee2c7 --- /dev/null +++ b/internal/rest_api/loader.go @@ -0,0 +1,34 @@ +package restapi + +import ( + "errors" + "fmt" + "net/url" + "os" + "regexp" + + "github.com/getkin/kin-openapi/openapi3" +) + +var urlPatternRe = regexp.MustCompile(`^(http:\/\/www\.|https:\/\/www\.|http:\/\/|https:\/\/|\/|\/\/)?[A-z0-9_-]*?[:]?[A-z0-9_-]*?[@]?[A-z0-9]+([\-\.]{1}[a-z0-9]+)*\.[a-z]{2,5}(:[0-9]{1,5})?(\/.*)?$`) + +func LoadOpenAPI(urlOrPath string) (*openapi3.T, error) { + if urlOrPath == "" { + return nil, errors.New("url or path must not be empty") + } + + if urlPatternRe.MatchString(urlOrPath) { + uri, urlerr := url.Parse(urlOrPath) + if urlerr != nil { + return nil, urlerr + } + + return openapi3.NewLoader().LoadFromURI(uri) + } + + if _, err := os.Stat(urlOrPath); err != nil { + return nil, fmt.Errorf("the openapi file has not been found on %s", urlOrPath) + } + + return openapi3.NewLoader().LoadFromFile(urlOrPath) +} diff --git a/internal/rest_api/request.go b/internal/rest_api/request.go new file mode 100644 index 0000000..402b09e --- /dev/null +++ b/internal/rest_api/request.go @@ -0,0 +1,36 @@ +package restapi + +import ( + "fmt" + "net/http" + + "github.com/cerberauth/vulnapi/internal/auth" + "github.com/cerberauth/vulnapi/internal/request" + "github.com/cerberauth/vulnapi/report" +) + +func ScanRestAPI(url string, ss auth.SecurityScheme) *report.VulnerabilityScanAttempt { + var req *http.Request + var res *http.Response + var err error = nil + + client := &http.Client{} + req, err = request.NewRequest("GET", url) + if err != nil { + err = fmt.Errorf("request with url %s has an unexpected error", err) + } else { + req, res, err = request.DoRequest(client, req, ss) + } + + if err != nil { + err = fmt.Errorf("request with url %s has an unexpected error", err) + } else if res.StatusCode < 200 && res.StatusCode >= 300 { + err = fmt.Errorf("unexpected status code %d during test request", res.StatusCode) + } + + return &report.VulnerabilityScanAttempt{ + Request: req, + Response: res, + Err: err, + } +} diff --git a/report/vuln.go b/report/vuln.go index 3ae97eb..369f7d7 100644 --- a/report/vuln.go +++ b/report/vuln.go @@ -6,7 +6,7 @@ type VulnerabilityReport struct { SeverityLevel float64 // https://nvd.nist.gov/vuln-metrics/cvss Name string Description string - Url *string + Url string } func (vr *VulnerabilityReport) IsLowRiskSeverity() bool { @@ -22,7 +22,7 @@ func (vr *VulnerabilityReport) IsHighRiskSeverity() bool { } func (vr *VulnerabilityReport) String() string { - return fmt.Sprintf("[%s] %s: %s", severyLevelString(vr.SeverityLevel), vr.Name, vr.Description) + return fmt.Sprintf("[%s][%s] %s: %s", severyLevelString(vr.SeverityLevel), vr.Name, vr.Url, vr.Description) } func severyLevelString(severityLevel float64) string { diff --git a/scan/jwt/alg_none.go b/scan/jwt/alg_none.go index 4e5fba7..68603bb 100644 --- a/scan/jwt/alg_none.go +++ b/scan/jwt/alg_none.go @@ -1,8 +1,9 @@ package jwt import ( + "github.com/cerberauth/vulnapi/internal/auth" + restapi "github.com/cerberauth/vulnapi/internal/rest_api" "github.com/cerberauth/vulnapi/report" - restapi "github.com/cerberauth/vulnapi/scan/rest_api" "github.com/golang-jwt/jwt/v5" ) @@ -12,14 +13,16 @@ const ( AlgNoneVulnerabilityDescription = "JWT accepts none algorithm and does verify jwt." ) -func AlgNoneJwtScanHandler(url string, token string) (*report.ScanReport, error) { +func AlgNoneJwtScanHandler(url string, ss auth.SecurityScheme) (*report.ScanReport, error) { r := report.NewScanReport() + token := ss.GetValidValue().(string) newToken, err := createNewJWTWithClaimsAndMethod(token, jwt.SigningMethodNone, jwt.UnsafeAllowNoneSignatureType) if err != nil { return r, err } - vsa := restapi.ScanRestAPI(url, newToken) + ss.SetAttackValue(newToken) + vsa := restapi.ScanRestAPI(url, ss) r.AddScanAttempt(vsa).End() if vsa.Response.StatusCode < 300 { @@ -27,6 +30,7 @@ func AlgNoneJwtScanHandler(url string, token string) (*report.ScanReport, error) SeverityLevel: AlgNoneVulnerabilitySeverityLevel, Name: AlgNoneVulnerabilityName, Description: AlgNoneVulnerabilityDescription, + Url: url, }) } diff --git a/scan/jwt/not_verified.go b/scan/jwt/not_verified.go index aa05f34..1b168b5 100644 --- a/scan/jwt/not_verified.go +++ b/scan/jwt/not_verified.go @@ -1,8 +1,9 @@ package jwt import ( + "github.com/cerberauth/vulnapi/internal/auth" + restapi "github.com/cerberauth/vulnapi/internal/rest_api" "github.com/cerberauth/vulnapi/report" - restapi "github.com/cerberauth/vulnapi/scan/rest_api" "github.com/golang-jwt/jwt/v5" ) @@ -12,8 +13,9 @@ const ( NotVerifiedVulnerabilityDescription = "JWT is not verified." ) -func NotVerifiedScanHandler(url string, token string) (*report.ScanReport, error) { +func NotVerifiedScanHandler(url string, ss auth.SecurityScheme) (*report.ScanReport, error) { r := report.NewScanReport() + token := ss.GetValidValue().(string) newTokenA, err := createNewJWTWithClaimsAndMethod(token, jwt.SigningMethodHS256, []byte("a")) if err != nil { @@ -25,10 +27,12 @@ func NotVerifiedScanHandler(url string, token string) (*report.ScanReport, error return r, err } - vsa1 := restapi.ScanRestAPI(url, newTokenA) + ss.SetAttackValue(newTokenA) + vsa1 := restapi.ScanRestAPI(url, ss) r.AddScanAttempt(vsa1) - vsa2 := restapi.ScanRestAPI(url, newTokenB) + ss.SetAttackValue(newTokenB) + vsa2 := restapi.ScanRestAPI(url, ss) r.AddScanAttempt(vsa2) r.End() @@ -38,6 +42,7 @@ func NotVerifiedScanHandler(url string, token string) (*report.ScanReport, error SeverityLevel: NotVerifiedVulnerabilitySeverityLevel, Name: NotVerifiedVulnerabilityName, Description: NotVerifiedVulnerabilityDescription, + Url: url, }) } diff --git a/scan/jwt/null_signature.go b/scan/jwt/null_signature.go index 5da2dd8..8ad5eed 100644 --- a/scan/jwt/null_signature.go +++ b/scan/jwt/null_signature.go @@ -3,8 +3,9 @@ package jwt import ( "strings" + "github.com/cerberauth/vulnapi/internal/auth" + restapi "github.com/cerberauth/vulnapi/internal/rest_api" "github.com/cerberauth/vulnapi/report" - restapi "github.com/cerberauth/vulnapi/scan/rest_api" ) const ( @@ -23,14 +24,16 @@ func createNewJWTWithoutSignature(originalTokenString string) (string, error) { return strings.Join([]string{parts[0], parts[1], ""}, "."), nil } -func NullSignatureScanHandler(url string, token string) (*report.ScanReport, error) { +func NullSignatureScanHandler(url string, ss auth.SecurityScheme) (*report.ScanReport, error) { r := report.NewScanReport() + token := ss.GetValidValue().(string) newToken, err := createNewJWTWithoutSignature(token) if err != nil { return r, err } - vsa := restapi.ScanRestAPI(url, newToken) + ss.SetAttackValue(newToken) + vsa := restapi.ScanRestAPI(url, ss) r.AddScanAttempt(vsa).End() if vsa.Response.StatusCode < 300 { @@ -38,6 +41,7 @@ func NullSignatureScanHandler(url string, token string) (*report.ScanReport, err SeverityLevel: NullSigVulnerabilitySeverityLevel, Name: NullSigVulnerabilityName, Description: NullSigVulnerabilityDescription, + Url: url, }) } diff --git a/scan/jwt/weak_secret.go b/scan/jwt/weak_secret.go index 2094f90..c40477c 100644 --- a/scan/jwt/weak_secret.go +++ b/scan/jwt/weak_secret.go @@ -1,8 +1,9 @@ package jwt import ( + "github.com/cerberauth/vulnapi/internal/auth" + restapi "github.com/cerberauth/vulnapi/internal/rest_api" "github.com/cerberauth/vulnapi/report" - restapi "github.com/cerberauth/vulnapi/scan/rest_api" ) const ( @@ -11,14 +12,16 @@ const ( WeakSecretVulnerabilityDescription = "JWT is signed with a weak secret allowing attackers to issue valid JWT." ) -func BlankSecretScanHandler(url string, token string) (*report.ScanReport, error) { +func BlankSecretScanHandler(url string, ss auth.SecurityScheme) (*report.ScanReport, error) { r := report.NewScanReport() + token := ss.GetValidValue().(string) newToken, err := createNewJWTWithClaims(token, []byte("")) if err != nil { return r, err } - vsa := restapi.ScanRestAPI(url, newToken) + ss.SetAttackValue(newToken) + vsa := restapi.ScanRestAPI(url, ss) r.AddScanAttempt(vsa).End() if vsa.Response.StatusCode < 300 { @@ -26,13 +29,14 @@ func BlankSecretScanHandler(url string, token string) (*report.ScanReport, error SeverityLevel: WeakSecretVulnerabilitySeverityLevel, Name: WeakSecretVulnerabilityName, Description: WeakSecretVulnerabilityDescription, + Url: url, }) } return r, nil } -func DictSecretScanHandler(url string, token string) (*report.ScanReport, error) { +func DictSecretScanHandler(url string, ss auth.SecurityScheme) (*report.ScanReport, error) { r := report.NewScanReport() // Use a dictionary attack to try finding the secret diff --git a/scan/openapi.go b/scan/openapi.go new file mode 100644 index 0000000..6a3ee4c --- /dev/null +++ b/scan/openapi.go @@ -0,0 +1,164 @@ +package scan + +import ( + "fmt" + "net/http" + "net/url" + "path" + + "github.com/brianvoe/gofakeit/v6" + "github.com/cerberauth/vulnapi/internal/auth" + restapi "github.com/cerberauth/vulnapi/internal/rest_api" + "github.com/cerberauth/vulnapi/report" + "github.com/getkin/kin-openapi/openapi3" + stduritemplate "github.com/std-uritemplate/std-uritemplate/go" +) + +func getBaseUrl(doc *openapi3.T) (*url.URL, error) { + var baseUrl *url.URL + var err error + for _, server := range doc.Servers { + baseUrl, err = url.Parse(server.URL) + if err != nil { + continue + } + + basePath, err := server.BasePath() + if err != nil { + continue + } + + baseUrl.Path = path.Join(baseUrl.Path, basePath) + break + } + + if baseUrl == nil { + return nil, fmt.Errorf("no valid base url has been found in OpenAPI file") + } + + return baseUrl, nil +} + +func getOperationSecuritySchemes(securityRequirements *openapi3.SecurityRequirements, securitySchemes map[string]auth.SecurityScheme) []auth.SecurityScheme { + operationsSecuritySchemes := []auth.SecurityScheme{} + for _, security := range *securityRequirements { + if len(security) == 0 { + continue + } + + keys := make([]string, 0, len(security)) + for k := range security { + keys = append(keys, k) + } + + operationSecurityScheme := securitySchemes[keys[0]] + if operationSecurityScheme == nil { + continue + } + + operationsSecuritySchemes = append(operationsSecuritySchemes, operationSecurityScheme) + } + + return operationsSecuritySchemes +} + +func getOperationPath(p string, params openapi3.Parameters) (string, error) { + subs := map[string]interface{}{} + for _, v := range params { + if v.Value.In != "path" { + continue + } + + var value interface{} + if v.Value.Example != nil { + value = v.Value.Example + } else if len(v.Value.Schema.Value.Enum) > 0 { + value = v.Value.Schema.Value.Enum[0] + } + + // if there is no example generate random param + if value == nil { + switch v.Value.Schema.Value.Type { + case "string": + value = gofakeit.Word() + case "number", "integer": + value = gofakeit.Number(0, 5) + } + } + + subs[v.Value.Name] = value + } + + return stduritemplate.Expand(p, subs) +} + +func NewOpenAPIScan(openAPIUrlOrPath string, validToken *string, reporter *report.Reporter) (*Scan, error) { + doc, err := restapi.LoadOpenAPI(openAPIUrlOrPath) + if err != nil { + return nil, err + } + + baseUrl, err := getBaseUrl(doc) + if err != nil { + return nil, err + } + + securitySchemes := map[string]auth.SecurityScheme{} + for name, scheme := range doc.Components.SecuritySchemes { + switch scheme.Value.Type { + case "http": + if scheme.Value.Scheme == string(auth.BearerScheme) { + securitySchemes[name] = auth.NewAuthorizationBearerSecurityScheme(name, validToken) + } + } + } + + operations := []auth.Operation{} + for docPath, p := range doc.Paths { + for method, o := range p.Operations() { + headers := http.Header{} + cookies := []http.Cookie{} + for _, h := range o.Parameters { + if !h.Value.Required { + continue + } + + name := h.Value.Name + value := "" + if h.Value.Example != nil { + // value = h.Value.Examples + value = "" + } + + if h.Value.In == "header" { + headers.Add(name, value) + } else if h.Value.In == "cookie" { + cookies = append(cookies, http.Cookie{ + Name: name, + Value: value, + }) + } + } + + operationsSecuritySchemes := getOperationSecuritySchemes(o.Security, securitySchemes) + operationPath, err := getOperationPath(docPath, o.Parameters) + if err != nil { + return nil, err + } + + operationUrl := *baseUrl + operationUrl.Path = path.Join(operationUrl.Path, operationPath) + + operations = append(operations, auth.Operation{ + Url: operationUrl.String(), + Method: method, + Headers: &headers, + Cookies: cookies, + + SecuritySchemes: operationsSecuritySchemes, + }) + } + } + + return NewScan(operations, reporter) +} diff --git a/scan/openapi_test.go b/scan/openapi_test.go new file mode 100644 index 0000000..affbe82 --- /dev/null +++ b/scan/openapi_test.go @@ -0,0 +1,145 @@ +package scan_test + +import ( + "net/http" + "testing" + + "github.com/brianvoe/gofakeit/v6" + "github.com/cerberauth/vulnapi/internal/auth" + "github.com/cerberauth/vulnapi/report" + "github.com/cerberauth/vulnapi/scan" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewOpenAPIScan(t *testing.T) { + token := "token" + s, err := scan.NewOpenAPIScan("../test/stub/simple_http_bearer_jwt.openapi.json", &token, nil) + + require.NoError(t, err) + assert.Equal(t, &scan.Scan{ + Operations: []auth.Operation{{ + Method: "GET", + Url: "http://localhost:8080/", + Headers: &http.Header{}, + Cookies: []http.Cookie{}, + + SecuritySchemes: []auth.SecurityScheme{auth.NewAuthorizationBearerSecurityScheme("bearer_auth", &token)}, + }}, + Handlers: []scan.ScanHandler{}, + Reporter: report.NewReporter(), + }, s) +} + +func TestNewOpenAPIScanWithPathError(t *testing.T) { + token := "" + _, err := scan.NewOpenAPIScan("../test/stub/non_existing_file.openapi.json", &token, nil) + + require.Error(t, err) +} + +func TestNewOpenAPIScanWithMultipleOperations(t *testing.T) { + gofakeit.Seed(1) + + token := "token" + securitySchemes := []auth.SecurityScheme{auth.NewAuthorizationBearerSecurityScheme("bearer_auth", &token)} + operations := []auth.Operation{ + { + Method: "GET", + Url: "http://localhost:8080/", + Headers: &http.Header{}, + Cookies: []http.Cookie{}, + + SecuritySchemes: securitySchemes, + }, + + { + Method: "POST", + Url: "http://localhost:8080/", + Headers: &http.Header{}, + Cookies: []http.Cookie{}, + + SecuritySchemes: securitySchemes, + }, + + { + Method: "GET", + Url: "http://localhost:8080/resources/perfectly", + Headers: &http.Header{}, + Cookies: []http.Cookie{}, + + SecuritySchemes: securitySchemes, + }, + + { + Method: "POST", + Url: "http://localhost:8080/resources/as", + Headers: &http.Header{}, + Cookies: []http.Cookie{}, + + SecuritySchemes: securitySchemes, + }, + } + + s, err := scan.NewOpenAPIScan("../test/stub/basic_http_bearer_jwt.openapi.json", &token, nil) + + require.NoError(t, err) + assert.Equal(t, &scan.Scan{ + Operations: operations, + Handlers: []scan.ScanHandler{}, + Reporter: report.NewReporter(), + }, s) +} + +func TestNewOpenAPIScanWithoutParamsExample(t *testing.T) { + gofakeit.Seed(1) + + token := "token" + securitySchemes := []auth.SecurityScheme{auth.NewAuthorizationBearerSecurityScheme("bearer_auth", &token)} + operations := []auth.Operation{ + { + Method: "GET", + Url: "http://localhost:8080/", + Headers: &http.Header{}, + Cookies: []http.Cookie{}, + + SecuritySchemes: securitySchemes, + }, + + { + Method: "POST", + Url: "http://localhost:8080/", + Headers: &http.Header{}, + Cookies: []http.Cookie{}, + + SecuritySchemes: securitySchemes, + }, + + { + Method: "GET", + Url: "http://localhost:8080/resources/perfectly", + Headers: &http.Header{}, + Cookies: []http.Cookie{}, + + SecuritySchemes: securitySchemes, + }, + + { + Method: "POST", + Url: "http://localhost:8080/resources/as", + Headers: &http.Header{}, + Cookies: []http.Cookie{}, + + SecuritySchemes: securitySchemes, + }, + } + + s, err := scan.NewOpenAPIScan("../test/stub/basic_http_bearer_jwt.openapi.json", &token, nil) + + require.NoError(t, err) + assert.Equal(t, &scan.Scan{ + Operations: operations, + Handlers: []scan.ScanHandler{}, + Reporter: report.NewReporter(), + }, s) +} diff --git a/scan/rest_api/request.go b/scan/rest_api/request.go deleted file mode 100644 index cd5b1d7..0000000 --- a/scan/rest_api/request.go +++ /dev/null @@ -1,25 +0,0 @@ -package restapi - -import ( - "fmt" - - "github.com/cerberauth/vulnapi/internal/request" - "github.com/cerberauth/vulnapi/report" -) - -func ScanRestAPI(url string, token string) *report.VulnerabilityScanAttempt { - req, resp, err := request.SendRequestWithBearerAuth(url, token) - if err != nil { - err = fmt.Errorf("request with url %s has an unexpected error", err) - } - - if resp.StatusCode < 200 && resp.StatusCode >= 300 { - err = fmt.Errorf("unexpected status code %d during test request", resp.StatusCode) - } - - return &report.VulnerabilityScanAttempt{ - Request: req, - Response: resp, - Err: err, - } -} diff --git a/scan/scan.go b/scan/scan.go index b47efe1..01886c0 100644 --- a/scan/scan.go +++ b/scan/scan.go @@ -1,63 +1,97 @@ package scan import ( - "errors" + "fmt" + "github.com/cerberauth/vulnapi/internal/auth" + restapi "github.com/cerberauth/vulnapi/internal/rest_api" "github.com/cerberauth/vulnapi/report" - restapi "github.com/cerberauth/vulnapi/scan/rest_api" ) -type ScanHandler func(url string, jwt string) (*report.ScanReport, error) +type ScanHandler func(url string, ss auth.SecurityScheme) (*report.ScanReport, error) type Scan struct { - url string - validJwt *string - pendingScans []ScanHandler - reporter *report.Reporter + Operations []auth.Operation + Handlers []ScanHandler + Reporter *report.Reporter } -func NewScanner(url string, valid_jwt *string) *Scan { - return &Scan{ - reporter: report.NewReporter(), - url: url, - validJwt: valid_jwt, +func NewScan(operations []auth.Operation, reporter *report.Reporter) (*Scan, error) { + if len(operations) == 0 { + return nil, fmt.Errorf("a scan must have at least one operation") + } + + if reporter == nil { + reporter = report.NewReporter() } + + return &Scan{ + Operations: operations, + Handlers: []ScanHandler{}, + Reporter: reporter, + }, nil } -func (s *Scan) AddPendingScanHandler(sh ScanHandler) *Scan { - s.pendingScans = append(s.pendingScans, sh) +func (s *Scan) AddScanHandler(sh ScanHandler) *Scan { + s.Handlers = append(s.Handlers, sh) return s } func (s *Scan) Execute() (*report.Reporter, []error, error) { - if err := s.ValidateRequest(); err != nil { + if len(s.Operations) == 0 { + return nil, nil, fmt.Errorf("no operations has been configured before executing scan") + } + + if err := s.ValidateOperation(&s.Operations[0]); err != nil { return nil, nil, err } var errors []error - for i := 0; i < len(s.pendingScans); i++ { - rep, err := s.pendingScans[i](s.url, *s.validJwt) + for _, o := range s.Operations { + opErrors, opError := s.ExecuteOperation(&o) + if opError != nil { + return nil, nil, opError + } + + errors = append(errors, opErrors...) + } + + return s.Reporter, errors, nil +} + +func (s *Scan) ExecuteOperation(o *auth.Operation) ([]error, error) { + if len(o.SecuritySchemes) == 0 { + return nil, fmt.Errorf("no security schemes has been configured") + } + + var errors []error + for i := 0; i < len(s.Handlers); i++ { + rep, err := s.Handlers[i](o.Url, o.SecuritySchemes[0]) if err != nil { errors = append(errors, err) } else if rep != nil { - s.reporter.AddReport(rep) + s.Reporter.AddReport(rep) } } - return s.reporter, errors, nil + return errors, nil } -func (s *Scan) ValidateRequest() error { - if s.validJwt == nil { - return errors.New("no valid JWT provided") +func (s *Scan) ValidateOperation(o *auth.Operation) error { + if len(o.SecuritySchemes) == 0 { + return fmt.Errorf("no security schemes has been configured") } - r := restapi.ScanRestAPI(s.url, *s.validJwt) + r := restapi.ScanRestAPI(o.Url, o.SecuritySchemes[0]) if r.Err != nil { return r.Err } + if r.Response.StatusCode >= 300 { + return fmt.Errorf("the request with the passed JWT should return 2xx status code") + } + return nil } diff --git a/scan/scan_test.go b/scan/scan_test.go new file mode 100644 index 0000000..95c550b --- /dev/null +++ b/scan/scan_test.go @@ -0,0 +1,59 @@ +package scan_test + +import ( + "net/http" + "testing" + + "github.com/cerberauth/vulnapi/internal/auth" + "github.com/cerberauth/vulnapi/report" + "github.com/cerberauth/vulnapi/scan" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewScanWithNoOperations(t *testing.T) { + _, err := scan.NewScan([]auth.Operation{}, nil) + + require.Error(t, err) +} + +func TestNewScan(t *testing.T) { + operations := []auth.Operation{{ + Method: "GET", + Url: "http://localhost:8080", + Headers: &http.Header{}, + Cookies: []http.Cookie{}, + + SecuritySchemes: []auth.SecurityScheme{}, + }} + + s, err := scan.NewScan(operations, nil) + + require.NoError(t, err) + assert.Equal(t, &scan.Scan{ + Operations: operations, + Handlers: []scan.ScanHandler{}, + Reporter: report.NewReporter(), + }, s) +} + +func TestNewScanWithReporter(t *testing.T) { + operations := []auth.Operation{{ + Method: "GET", + Url: "http://localhost:8080", + Headers: &http.Header{}, + Cookies: []http.Cookie{}, + + SecuritySchemes: []auth.SecurityScheme{}, + }} + reporter := report.NewReporter() + + s, err := scan.NewScan(operations, reporter) + + require.NoError(t, err) + assert.Equal(t, &scan.Scan{ + Operations: operations, + Handlers: []scan.ScanHandler{}, + Reporter: reporter, + }, s) +} diff --git a/scan/url.go b/scan/url.go new file mode 100644 index 0000000..332a3f3 --- /dev/null +++ b/scan/url.go @@ -0,0 +1,64 @@ +package scan + +import ( + "net/http" + "strings" + + "github.com/cerberauth/vulnapi/internal/auth" + "github.com/cerberauth/vulnapi/report" +) + +const bearerPrefix = auth.BearerPrefix + " " + +func detectAuthorizationHeader(headers *http.Header) string { + if h := headers.Get(auth.AuthorizationHeader); h != "" { + return h + } + + if h := headers.Get(strings.ToLower(auth.AuthorizationHeader)); h != "" { + return h + } + + return "" +} + +func getBearerToken(authHeader string) string { + if strings.HasPrefix(authHeader, bearerPrefix) { + return strings.TrimPrefix(authHeader, bearerPrefix) + } + + lowerCasePrefix := strings.ToLower(bearerPrefix) + if strings.HasPrefix(authHeader, lowerCasePrefix) { + return strings.TrimPrefix(authHeader, lowerCasePrefix) + } + + return "" +} + +func detectSecurityScheme(headers *http.Header, cookies []http.Cookie) auth.SecurityScheme { + if authHeader := detectAuthorizationHeader(headers); authHeader != "" { + if token := getBearerToken(authHeader); token != "" { + return auth.NewAuthorizationBearerSecurityScheme("default", &token) + } + } + + return nil +} + +func NewURLScan(method string, url string, headers *http.Header, cookies []http.Cookie, reporter *report.Reporter) (*Scan, error) { + var securitySchemes []auth.SecurityScheme + if securityScheme := detectSecurityScheme(headers, cookies); securityScheme != nil { + securitySchemes = append(securitySchemes, securityScheme) + } + + operations := []auth.Operation{{ + Url: url, + Method: method, + Headers: headers, + Cookies: cookies, + + SecuritySchemes: securitySchemes, + }} + + return NewScan(operations, reporter) +} diff --git a/scan/url_test.go b/scan/url_test.go new file mode 100644 index 0000000..e2c02a6 --- /dev/null +++ b/scan/url_test.go @@ -0,0 +1,134 @@ +package scan_test + +import ( + "net/http" + "testing" + + "github.com/cerberauth/vulnapi/internal/auth" + "github.com/cerberauth/vulnapi/report" + "github.com/cerberauth/vulnapi/scan" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewURLScan(t *testing.T) { + s, err := scan.NewURLScan("GET", "http://localhost:8080", &http.Header{}, []http.Cookie{}, nil) + + require.NoError(t, err) + assert.Equal(t, &scan.Scan{ + Operations: []auth.Operation{{ + Method: "GET", + Url: "http://localhost:8080", + Headers: &http.Header{}, + Cookies: []http.Cookie{}, + }}, + Handlers: []scan.ScanHandler{}, + Reporter: report.NewReporter(), + }, s) +} + +func TestNewURLScanWithHeaders(t *testing.T) { + headers := http.Header{} + headers.Add("Cache-Control", "no-cache") + + s, err := scan.NewURLScan("GET", "http://localhost:8080", &headers, []http.Cookie{}, nil) + + require.NoError(t, err) + assert.Equal(t, &scan.Scan{ + Operations: []auth.Operation{{ + Method: "GET", + Url: "http://localhost:8080", + Headers: &headers, + Cookies: []http.Cookie{}, + }}, + Handlers: []scan.ScanHandler{}, + Reporter: report.NewReporter(), + }, s) +} + +func TestNewURLScanWithCookies(t *testing.T) { + cookies := []http.Cookie{{ + Name: "name", + Value: "value", + }} + + s, err := scan.NewURLScan("GET", "http://localhost:8080", &http.Header{}, cookies, nil) + + require.NoError(t, err) + assert.Equal(t, &scan.Scan{ + Operations: []auth.Operation{{ + Method: "GET", + Url: "http://localhost:8080", + Headers: &http.Header{}, + Cookies: cookies, + }}, + Handlers: []scan.ScanHandler{}, + Reporter: report.NewReporter(), + }, s) +} + +func TestNewURLScanWithUpperCaseAuthorizationHeader(t *testing.T) { + headers := http.Header{} + headers.Add("Authorization", "Bearer token") + token := "token" + + s, err := scan.NewURLScan("GET", "http://localhost:8080", &headers, []http.Cookie{}, nil) + + require.NoError(t, err) + assert.Equal(t, &scan.Scan{ + Operations: []auth.Operation{{ + Method: "GET", + Url: "http://localhost:8080", + Headers: &headers, + Cookies: []http.Cookie{}, + + SecuritySchemes: []auth.SecurityScheme{auth.NewAuthorizationBearerSecurityScheme("default", &token)}, + }}, + Handlers: []scan.ScanHandler{}, + Reporter: report.NewReporter(), + }, s) +} + +func TestNewURLScanWithUpperCaseAuthorizationAndLowerCaseBearerHeader(t *testing.T) { + headers := http.Header{} + headers.Add("Authorization", "bearer token") + token := "token" + + s, err := scan.NewURLScan("GET", "http://localhost:8080", &headers, []http.Cookie{}, nil) + + require.NoError(t, err) + assert.Equal(t, &scan.Scan{ + Operations: []auth.Operation{{ + Method: "GET", + Url: "http://localhost:8080", + Headers: &headers, + Cookies: []http.Cookie{}, + + SecuritySchemes: []auth.SecurityScheme{auth.NewAuthorizationBearerSecurityScheme("default", &token)}, + }}, + Handlers: []scan.ScanHandler{}, + Reporter: report.NewReporter(), + }, s) +} + +func TestNewURLScanWithLowerCaseAuthorizationHeader(t *testing.T) { + headers := http.Header{} + headers.Add("authorization", "Bearer token") + token := "token" + + s, err := scan.NewURLScan("GET", "http://localhost:8080", &headers, []http.Cookie{}, nil) + + require.NoError(t, err) + assert.Equal(t, &scan.Scan{ + Operations: []auth.Operation{{ + Method: "GET", + Url: "http://localhost:8080", + Headers: &headers, + Cookies: []http.Cookie{}, + + SecuritySchemes: []auth.SecurityScheme{auth.NewAuthorizationBearerSecurityScheme("default", &token)}, + }}, + Handlers: []scan.ScanHandler{}, + Reporter: report.NewReporter(), + }, s) +} diff --git a/scan/scans.go b/scan/vulns.go similarity index 53% rename from scan/scans.go rename to scan/vulns.go index 7572d78..8b0458f 100644 --- a/scan/scans.go +++ b/scan/vulns.go @@ -3,21 +3,21 @@ package scan import "github.com/cerberauth/vulnapi/scan/jwt" func (s *Scan) WithAlgNoneJwtScan() *Scan { - return s.AddPendingScanHandler(jwt.AlgNoneJwtScanHandler) + return s.AddScanHandler(jwt.AlgNoneJwtScanHandler) } func (s *Scan) WithNotVerifiedJwtScan() *Scan { - return s.AddPendingScanHandler(jwt.NotVerifiedScanHandler) + return s.AddScanHandler(jwt.NotVerifiedScanHandler) } func (s *Scan) WithJWTNullSignatureScan() *Scan { - return s.AddPendingScanHandler(jwt.NullSignatureScanHandler) + return s.AddScanHandler(jwt.NullSignatureScanHandler) } func (s *Scan) WithWeakJwtSecretScan() *Scan { - return s.AddPendingScanHandler(jwt.BlankSecretScanHandler).AddPendingScanHandler(jwt.DictSecretScanHandler) + return s.AddScanHandler(jwt.BlankSecretScanHandler).AddScanHandler(jwt.DictSecretScanHandler) } -func (s *Scan) WithAllScans() *Scan { +func (s *Scan) WithAllVulnsScans() *Scan { return s.WithAlgNoneJwtScan().WithNotVerifiedJwtScan().WithJWTNullSignatureScan().WithWeakJwtSecretScan() } diff --git a/test/stub/basic_http_bearer_jwt.openapi.json b/test/stub/basic_http_bearer_jwt.openapi.json new file mode 100644 index 0000000..86a0434 --- /dev/null +++ b/test/stub/basic_http_bearer_jwt.openapi.json @@ -0,0 +1,93 @@ +{ + "openapi": "3.0.2", + "servers": [ + { + "url": "http://localhost:8080" + } + ], + "paths": { + "/": { + "get": { + "parameters": [], + "responses": { + "204": { + "description": "successful operation" + } + }, + "security": [ + { + "bearer_auth": [] + } + ] + }, + "post": { + "parameters": [], + "responses": { + "204": { + "description": "successful operation" + } + }, + "security": [ + { + "bearer_auth": [] + } + ] + } + }, + "/resources/{id}": { + "get": { + "parameters": [ + { + "name": "id", + "in": "path", + "required": true, + "schema": { + "type": "string" + } + } + ], + "responses": { + "204": { + "description": "successful operation" + } + }, + "security": [ + { + "bearer_auth": [] + } + ] + }, + "post": { + "parameters": [ + { + "name": "id", + "in": "path", + "required": true, + "schema": { + "type": "string" + } + } + ], + "responses": { + "204": { + "description": "successful operation" + } + }, + "security": [ + { + "bearer_auth": [] + } + ] + } + } + }, + "components": { + "securitySchemes": { + "bearer_auth": { + "type": "http", + "scheme": "bearer", + "bearerFormat": "JWT" + } + } + } +} diff --git a/test/stub/simple_http_bearer_jwt.openapi.json b/test/stub/simple_http_bearer_jwt.openapi.json new file mode 100644 index 0000000..07b2af2 --- /dev/null +++ b/test/stub/simple_http_bearer_jwt.openapi.json @@ -0,0 +1,34 @@ +{ + "openapi": "3.0.2", + "servers": [ + { + "url": "http://localhost:8080" + } + ], + "paths": { + "/": { + "get": { + "parameters": [], + "responses": { + "204": { + "description": "successful operation" + } + }, + "security": [ + { + "bearer_auth": [] + } + ] + } + } + }, + "components": { + "securitySchemes": { + "bearer_auth": { + "type": "http", + "scheme": "bearer", + "bearerFormat": "JWT" + } + } + } +}