From a200f6e8c9aa148c561455cfd44a15eed737e64d Mon Sep 17 00:00:00 2001 From: Emmanuel Gautier Date: Wed, 11 Oct 2023 23:18:32 +0200 Subject: [PATCH] feat: add scan reporter --- cmd/scan/root.go | 8 ++--- internal/request/request.go | 29 ++++++++++-------- report/report.go | 59 +++++++++++++++++++++++++++++++++++++ report/reporter.go | 48 ++++++++++++++++++++++++++++++ report/vuln.go | 40 +++++++++++++++++++++++++ scan/jwt/alg_none.go | 34 +++++++++++++-------- scan/jwt/not_verified.go | 51 ++++++++++++++++---------------- scan/jwt/null_signature.go | 35 ++++++++++++++-------- scan/jwt/weak_secret.go | 42 ++++++++++++++++---------- scan/rest_api/request.go | 25 ++++++++++++++++ scan/scan.go | 48 ++++++++++++++++-------------- 11 files changed, 313 insertions(+), 106 deletions(-) create mode 100644 report/report.go create mode 100644 report/reporter.go create mode 100644 report/vuln.go create mode 100644 scan/rest_api/request.go diff --git a/cmd/scan/root.go b/cmd/scan/root.go index 1804d97..68107e3 100644 --- a/cmd/scan/root.go +++ b/cmd/scan/root.go @@ -32,17 +32,17 @@ func NewScanCmd() (scanCmd *cobra.Command) { jwt = stdin } - reports, err := scan.NewScanner(url, &jwt).WithAllScans().Execute() + rpr, _, err := scan.NewScanner(url, &jwt).WithAllScans().Execute() if err != nil { log.Fatal(err) } - if len(reports) == 0 { + if !rpr.HasVulnerability() { println("Congratulations! No vulnerability has been discovered!") } - for _, report := range reports { - log.Println(report) + for _, r := range rpr.GetVulnerabilityReports() { + log.Println(r) } }, } diff --git a/internal/request/request.go b/internal/request/request.go index eb5ab86..db40481 100644 --- a/internal/request/request.go +++ b/internal/request/request.go @@ -2,31 +2,34 @@ package request import ( "fmt" - "io" "net/http" ) -func SendRequestWithBearerAuth(url string, token string) (int, []byte, error) { - client := &http.Client{} - - req, err := http.NewRequest("GET", url, nil) +func prepareVulnAPIRequest(method string, url string) (*http.Request, error) { + req, err := http.NewRequest(method, url, nil) if err != nil { - return 0, nil, err + return nil, err } - req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", token)) req.Header.Set("User-Agent", "vulnapi/0.1") - resp, err := client.Do(req) + return req, nil +} + +func SendRequestWithBearerAuth(url string, token string) (*http.Request, *http.Response, error) { + req, err := prepareVulnAPIRequest("GET", url) if err != nil { - return 0, nil, err + return req, nil, err } - defer resp.Body.Close() - body, err := io.ReadAll(resp.Body) + req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", token)) + + client := &http.Client{} + resp, err := client.Do(req) if err != nil { - return 0, nil, err + return req, resp, err } + defer resp.Body.Close() - return resp.StatusCode, body, nil + return req, resp, nil } diff --git a/report/report.go b/report/report.go new file mode 100644 index 0000000..bce3186 --- /dev/null +++ b/report/report.go @@ -0,0 +1,59 @@ +package report + +import ( + "net/http" + "time" +) + +type VulnerabilityScanAttempt struct { + Request *http.Request + Response *http.Response + + Err error +} + +type ScanReport struct { + scans []*VulnerabilityScanAttempt + vulns []*VulnerabilityReport + + startTime time.Time + endTime time.Time +} + +func NewScanReport() *ScanReport { + return &ScanReport{ + startTime: time.Now(), + } +} + +func (sc *ScanReport) Start() *ScanReport { + sc.startTime = time.Now() + return sc +} + +func (sc *ScanReport) End() *ScanReport { + sc.endTime = time.Now() + return sc +} + +func (sc *ScanReport) AddScanAttempt(a *VulnerabilityScanAttempt) *ScanReport { + sc.scans = append(sc.scans, a) + return sc +} + +func (sc *ScanReport) GetScanAttempts() []*VulnerabilityScanAttempt { + return sc.scans +} + +func (sc *ScanReport) AddVulnerabilityReport(vr *VulnerabilityReport) *ScanReport { + sc.vulns = append(sc.vulns, vr) + return sc +} + +func (sc *ScanReport) GetVulnerabilityReports() []*VulnerabilityReport { + return sc.vulns +} + +func (sc *ScanReport) HasVulnerabilityReport() bool { + return len(sc.GetVulnerabilityReports()) > 0 +} diff --git a/report/reporter.go b/report/reporter.go new file mode 100644 index 0000000..c92efed --- /dev/null +++ b/report/reporter.go @@ -0,0 +1,48 @@ +package report + +type Reporter struct { + reports []*ScanReport +} + +func NewReporter() *Reporter { + return &Reporter{ + reports: []*ScanReport{}, + } +} + +func (rr *Reporter) AddReport(r *ScanReport) { + rr.reports = append(rr.reports, r) +} + +func (rr *Reporter) GetReports() []*ScanReport { + return rr.reports +} + +func (rr *Reporter) HasVulnerability() bool { + for _, r := range rr.GetReports() { + if r.HasVulnerabilityReport() { + return true + } + } + + return false +} + +func (rr *Reporter) GetVulnerabilityReports() []*VulnerabilityReport { + var vrs []*VulnerabilityReport + for _, r := range rr.GetReports() { + vrs = append(vrs, r.GetVulnerabilityReports()...) + } + + return vrs +} + +func (rr *Reporter) HasHighRiskSeverityVulnerability() bool { + for _, r := range rr.GetVulnerabilityReports() { + if r.IsHighRiskSeverity() { + return true + } + } + + return false +} diff --git a/report/vuln.go b/report/vuln.go new file mode 100644 index 0000000..3ae97eb --- /dev/null +++ b/report/vuln.go @@ -0,0 +1,40 @@ +package report + +import "fmt" + +type VulnerabilityReport struct { + SeverityLevel float64 // https://nvd.nist.gov/vuln-metrics/cvss + Name string + Description string + Url *string +} + +func (vr *VulnerabilityReport) IsLowRiskSeverity() bool { + return vr.SeverityLevel < 4 +} + +func (vr *VulnerabilityReport) IsMediumRiskSeverity() bool { + return vr.SeverityLevel > 4 && vr.SeverityLevel < 7 +} + +func (vr *VulnerabilityReport) IsHighRiskSeverity() bool { + return vr.SeverityLevel > 7 +} + +func (vr *VulnerabilityReport) String() string { + return fmt.Sprintf("[%s] %s: %s", severyLevelString(vr.SeverityLevel), vr.Name, vr.Description) +} + +func severyLevelString(severityLevel float64) string { + if severityLevel >= 9 { + return "critical" + } else if severityLevel < 9 && severityLevel >= 7 { + return "hight" + } else if severityLevel < 7 && severityLevel >= 4 { + return "medium" + } else if severityLevel < 4 && severityLevel >= 0.1 { + return "low" + } else { + return "none" + } +} diff --git a/scan/jwt/alg_none.go b/scan/jwt/alg_none.go index 3c77800..4e5fba7 100644 --- a/scan/jwt/alg_none.go +++ b/scan/jwt/alg_none.go @@ -1,26 +1,34 @@ package jwt import ( - "fmt" - - "github.com/cerberauth/vulnapi/internal/request" + "github.com/cerberauth/vulnapi/report" + restapi "github.com/cerberauth/vulnapi/scan/rest_api" "github.com/golang-jwt/jwt/v5" ) -func AlgNoneJwtScanHandler(url string, token string) []error { - newToken, err := createNewJWTWithClaimsAndMethod(token, jwt.SigningMethodNone, jwt.UnsafeAllowNoneSignatureType) - if err != nil { - return []error{err} - } +const ( + AlgNoneVulnerabilitySeverityLevel = 9 + AlgNoneVulnerabilityName = "JWT Alg None" + AlgNoneVulnerabilityDescription = "JWT accepts none algorithm and does verify jwt." +) + +func AlgNoneJwtScanHandler(url string, token string) (*report.ScanReport, error) { + r := report.NewScanReport() - statusCode, _, err := request.SendRequestWithBearerAuth(url, newToken) + newToken, err := createNewJWTWithClaimsAndMethod(token, jwt.SigningMethodNone, jwt.UnsafeAllowNoneSignatureType) if err != nil { - return []error{err} + return r, err } + vsa := restapi.ScanRestAPI(url, newToken) + r.AddScanAttempt(vsa).End() - if statusCode > 200 && statusCode <= 300 { - return []error{fmt.Errorf("unexpected status code %d with an alg none forged token", statusCode)} + if vsa.Response.StatusCode < 300 { + r.AddVulnerabilityReport(&report.VulnerabilityReport{ + SeverityLevel: AlgNoneVulnerabilitySeverityLevel, + Name: AlgNoneVulnerabilityName, + Description: AlgNoneVulnerabilityDescription, + }) } - return nil + return r, nil } diff --git a/scan/jwt/not_verified.go b/scan/jwt/not_verified.go index 31e768c..aa05f34 100644 --- a/scan/jwt/not_verified.go +++ b/scan/jwt/not_verified.go @@ -1,46 +1,45 @@ package jwt import ( - "fmt" - - "github.com/cerberauth/vulnapi/internal/request" + "github.com/cerberauth/vulnapi/report" + restapi "github.com/cerberauth/vulnapi/scan/rest_api" "github.com/golang-jwt/jwt/v5" ) -func NotVerifiedScanHandler(url string, token string) []error { +const ( + NotVerifiedVulnerabilitySeverityLevel = 9 + NotVerifiedVulnerabilityName = "JWT Not Verified" + NotVerifiedVulnerabilityDescription = "JWT is not verified." +) + +func NotVerifiedScanHandler(url string, token string) (*report.ScanReport, error) { + r := report.NewScanReport() + newTokenA, err := createNewJWTWithClaimsAndMethod(token, jwt.SigningMethodHS256, []byte("a")) if err != nil { - return []error{err} + return r, err } newTokenB, err := createNewJWTWithClaimsAndMethod(token, jwt.SigningMethodHS256, []byte("b")) if err != nil { - return []error{err} - } - - statusCodeA, _, errRequestA := request.SendRequestWithBearerAuth(url, newTokenA) - statusCodeB, _, errRequestB := request.SendRequestWithBearerAuth(url, newTokenB) - - var errors []error - if errRequestA != nil { - errors = append(errors, errRequestA) + return r, err } - if errRequestB != nil { - errors = append(errors, errRequestB) - } + vsa1 := restapi.ScanRestAPI(url, newTokenA) + r.AddScanAttempt(vsa1) - if statusCodeA > 200 && statusCodeA <= 300 { - errors = append(errors, fmt.Errorf("unexpected status code %d with an invalid forged token", statusCodeA)) - } + vsa2 := restapi.ScanRestAPI(url, newTokenB) + r.AddScanAttempt(vsa2) - if statusCodeA != statusCodeB { - errors = append(errors, fmt.Errorf("status code are not the same between the two attempts")) - } + r.End() - if len(errors) > 0 { - return errors + if vsa1.Response.StatusCode != vsa2.Response.StatusCode { + r.AddVulnerabilityReport(&report.VulnerabilityReport{ + SeverityLevel: NotVerifiedVulnerabilitySeverityLevel, + Name: NotVerifiedVulnerabilityName, + Description: NotVerifiedVulnerabilityDescription, + }) } - return nil + return r, nil } diff --git a/scan/jwt/null_signature.go b/scan/jwt/null_signature.go index 4321dd6..5da2dd8 100644 --- a/scan/jwt/null_signature.go +++ b/scan/jwt/null_signature.go @@ -1,10 +1,16 @@ package jwt import ( - "fmt" "strings" - "github.com/cerberauth/vulnapi/internal/request" + "github.com/cerberauth/vulnapi/report" + restapi "github.com/cerberauth/vulnapi/scan/rest_api" +) + +const ( + NullSigVulnerabilitySeverityLevel = 9 + NullSigVulnerabilityName = "JWT Null Signature" + NullSigVulnerabilityDescription = "JWT with null signature is accepted allowing to bypass authentication." ) func createNewJWTWithoutSignature(originalTokenString string) (string, error) { @@ -17,20 +23,23 @@ func createNewJWTWithoutSignature(originalTokenString string) (string, error) { return strings.Join([]string{parts[0], parts[1], ""}, "."), nil } -func NullSignatureScanHandler(url string, token string) []error { - newToken, err := createNewJWTWithoutSignature(token) - if err != nil { - return []error{err} - } +func NullSignatureScanHandler(url string, token string) (*report.ScanReport, error) { + r := report.NewScanReport() - statusCode, _, err := request.SendRequestWithBearerAuth(url, newToken) + newToken, err := createNewJWTWithoutSignature(token) if err != nil { - return []error{err} + return r, err } - - if statusCode > 200 && statusCode <= 300 { - return []error{fmt.Errorf("unexpected status code %d with a null signature", statusCode)} + vsa := restapi.ScanRestAPI(url, newToken) + r.AddScanAttempt(vsa).End() + + if vsa.Response.StatusCode < 300 { + r.AddVulnerabilityReport(&report.VulnerabilityReport{ + SeverityLevel: NullSigVulnerabilitySeverityLevel, + Name: NullSigVulnerabilityName, + Description: NullSigVulnerabilityDescription, + }) } - return nil + return r, nil } diff --git a/scan/jwt/weak_secret.go b/scan/jwt/weak_secret.go index caa3a1b..2094f90 100644 --- a/scan/jwt/weak_secret.go +++ b/scan/jwt/weak_secret.go @@ -1,31 +1,43 @@ package jwt import ( - "fmt" + "github.com/cerberauth/vulnapi/report" + restapi "github.com/cerberauth/vulnapi/scan/rest_api" +) - "github.com/cerberauth/vulnapi/internal/request" +const ( + WeakSecretVulnerabilitySeverityLevel = 9 + WeakSecretVulnerabilityName = "Weak Secret Vulnerability" + WeakSecretVulnerabilityDescription = "JWT is signed with a weak secret allowing attackers to issue valid JWT." ) -func BlankSecretScanHandler(url string, token string) []error { - newToken, err := createNewJWTWithClaims(token, []byte("")) - if err != nil { - return []error{err} - } +func BlankSecretScanHandler(url string, token string) (*report.ScanReport, error) { + r := report.NewScanReport() - statusCode, _, err := request.SendRequestWithBearerAuth(url, newToken) + newToken, err := createNewJWTWithClaims(token, []byte("")) if err != nil { - return []error{err} + return r, err } - - if statusCode > 200 && statusCode <= 300 { - return []error{fmt.Errorf("unexpected status code %d with a blank secret", statusCode)} + vsa := restapi.ScanRestAPI(url, newToken) + r.AddScanAttempt(vsa).End() + + if vsa.Response.StatusCode < 300 { + r.AddVulnerabilityReport(&report.VulnerabilityReport{ + SeverityLevel: WeakSecretVulnerabilitySeverityLevel, + Name: WeakSecretVulnerabilityName, + Description: WeakSecretVulnerabilityDescription, + }) } - return nil + return r, nil } -func DictSecretScanHandler(url string, token string) []error { +func DictSecretScanHandler(url string, token string) (*report.ScanReport, error) { + r := report.NewScanReport() + // Use a dictionary attack to try finding the secret - return nil + r.End() + + return r, nil } diff --git a/scan/rest_api/request.go b/scan/rest_api/request.go new file mode 100644 index 0000000..cd5b1d7 --- /dev/null +++ b/scan/rest_api/request.go @@ -0,0 +1,25 @@ +package restapi + +import ( + "fmt" + + "github.com/cerberauth/vulnapi/internal/request" + "github.com/cerberauth/vulnapi/report" +) + +func ScanRestAPI(url string, token string) *report.VulnerabilityScanAttempt { + req, resp, err := request.SendRequestWithBearerAuth(url, token) + if err != nil { + err = fmt.Errorf("request with url %s has an unexpected error", err) + } + + if resp.StatusCode < 200 && resp.StatusCode >= 300 { + err = fmt.Errorf("unexpected status code %d during test request", resp.StatusCode) + } + + return &report.VulnerabilityScanAttempt{ + Request: req, + Response: resp, + Err: err, + } +} diff --git a/scan/scan.go b/scan/scan.go index 0cd20ba..b47efe1 100644 --- a/scan/scan.go +++ b/scan/scan.go @@ -2,57 +2,61 @@ package scan import ( "errors" - "fmt" - "github.com/cerberauth/vulnapi/internal/request" + "github.com/cerberauth/vulnapi/report" + restapi "github.com/cerberauth/vulnapi/scan/rest_api" ) -type ScanHandler func(url string, jwt string) []error +type ScanHandler func(url string, jwt string) (*report.ScanReport, error) type Scan struct { - Url string - ValidJwt *string - PendingScans []ScanHandler + url string + validJwt *string + pendingScans []ScanHandler + reporter *report.Reporter } func NewScanner(url string, valid_jwt *string) *Scan { return &Scan{ - Url: url, - ValidJwt: valid_jwt, + reporter: report.NewReporter(), + url: url, + validJwt: valid_jwt, } } func (s *Scan) AddPendingScanHandler(sh ScanHandler) *Scan { - s.PendingScans = append(s.PendingScans, sh) + s.pendingScans = append(s.pendingScans, sh) return s } -func (s *Scan) Execute() ([]error, error) { +func (s *Scan) Execute() (*report.Reporter, []error, error) { if err := s.ValidateRequest(); err != nil { - return nil, err + return nil, nil, err } var errors []error - for i := 0; i < len(s.PendingScans); i++ { - errors = append(errors, s.PendingScans[i](s.Url, *s.ValidJwt)...) + for i := 0; i < len(s.pendingScans); i++ { + rep, err := s.pendingScans[i](s.url, *s.validJwt) + + if err != nil { + errors = append(errors, err) + } else if rep != nil { + s.reporter.AddReport(rep) + } } - return errors, nil + return s.reporter, errors, nil } func (s *Scan) ValidateRequest() error { - if s.ValidJwt == nil { + if s.validJwt == nil { return errors.New("no valid JWT provided") } - statusCode, _, err := request.SendRequestWithBearerAuth(s.Url, *s.ValidJwt) - if err != nil { - return fmt.Errorf("request with url %s has an unexpected error", err) - } - - if statusCode < 200 && statusCode >= 300 { - return fmt.Errorf("unexpected status code %d during test request", statusCode) + r := restapi.ScanRestAPI(s.url, *s.validJwt) + if r.Err != nil { + return r.Err } return nil