Skip to content

Commit

Permalink
Merge pull request #161 from cerberauth/add-include-exclude-scans-flags
Browse files Browse the repository at this point in the history
feat: add include and exclude scans flags
  • Loading branch information
emmanuelgautier authored Sep 16, 2024
2 parents 59f6c53 + d565a0f commit 9e2cea0
Show file tree
Hide file tree
Showing 19 changed files with 271 additions and 44 deletions.
2 changes: 1 addition & 1 deletion api/curl.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ func (h *Handler) ScanURL(ctx *gin.Context) {
return
}

reporter, _, err := s.Execute(func(operationScan *scan.OperationScan) {})
reporter, _, err := s.Execute(func(operationScan *scan.OperationScan) {}, form.Opts.Scans, form.Opts.ExcludeScans)
if err != nil {
analyticsx.TrackError(ctx, serverApiUrlTracer, err)
ctx.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
Expand Down
2 changes: 1 addition & 1 deletion api/graphql.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ func (h *Handler) ScanGraphQL(ctx *gin.Context) {
return
}

reporter, _, err := s.Execute(func(operationScan *scan.OperationScan) {})
reporter, _, err := s.Execute(func(operationScan *scan.OperationScan) {}, form.Opts.Scans, form.Opts.ExcludeScans)
if err != nil {
analyticsx.TrackError(ctx, serverApiGraphQLTracer, err)
ctx.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
Expand Down
8 changes: 7 additions & 1 deletion api/openapi.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ func (h *Handler) ScanOpenAPI(ctx *gin.Context) {
return
}

reporter, _, err := s.Execute(func(operationScan *scan.OperationScan) {})
reporter, _, err := s.Execute(func(operationScan *scan.OperationScan) {}, form.Opts.Scans, form.Opts.ExcludeScans)
if err != nil {
analyticsx.TrackError(ctx, serverApiOpenAPITracer, err)
ctx.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
Expand All @@ -81,5 +81,11 @@ func (h *Handler) ScanOpenAPI(ctx *gin.Context) {
Reports: reporter.GetReports(),
}
_, err = json.Marshal(response)
if err != nil {
analyticsx.TrackError(ctx, serverApiOpenAPITracer, err)
ctx.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}

ctx.JSON(http.StatusOK, response)
}
3 changes: 3 additions & 0 deletions api/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ import (
type ScanOptions struct {
RateLimit int `json:"rateLimit"`
ProxyURL string `json:"proxy"`

Scans []string `json:"scans"`
ExcludeScans []string `json:"excludeScans"`
}

func parseScanOptions(opts *ScanOptions) request.NewClientOptions {
Expand Down
2 changes: 1 addition & 1 deletion cmd/discover/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ func NewAPICmd() (apiCmd *cobra.Command) {
bar := internalCmd.NewProgressBar(len(s.GetOperationsScans()))
reporter, _, err := s.Execute(func(operationScan *scan.OperationScan) {
bar.Add(1)
})
}, internalCmd.GetIncludeScans(), internalCmd.GetExcludeScans())
if err != nil {
analyticsx.TrackError(ctx, tracer, err)
log.Fatal(err)
Expand Down
2 changes: 1 addition & 1 deletion cmd/discover/domain.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ func NewDomainCmd() (domainCmd *cobra.Command) {
bar := internalCmd.NewProgressBar(len(s.GetOperationsScans()))
reporter, _, err := s.Execute(func(operationScan *scan.OperationScan) {
bar.Add(1)
})
}, internalCmd.GetIncludeScans(), internalCmd.GetExcludeScans())
if err != nil {
analyticsx.TrackError(ctx, tracer, err)
log.Fatal(err)
Expand Down
2 changes: 1 addition & 1 deletion cmd/scan/curl.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func NewCURLScanCmd() (scanCmd *cobra.Command) {
bar := internalCmd.NewProgressBar(len(s.GetOperationsScans()))
if reporter, _, err = s.Execute(func(operationScan *scan.OperationScan) {
bar.Add(1)
}); err != nil {
}, internalCmd.GetIncludeScans(), internalCmd.GetExcludeScans()); err != nil {
analyticsx.TrackError(ctx, tracer, err)
log.Fatal(err)
}
Expand Down
2 changes: 1 addition & 1 deletion cmd/scan/graphql.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func NewGraphQLScanCmd() (scanCmd *cobra.Command) {
bar := internalCmd.NewProgressBar(len(s.GetOperationsScans()))
if reporter, _, err = s.Execute(func(operationScan *scan.OperationScan) {
bar.Add(1)
}); err != nil {
}, internalCmd.GetIncludeScans(), internalCmd.GetExcludeScans()); err != nil {
analyticsx.TrackError(ctx, tracer, err)
log.Fatal(err)
}
Expand Down
2 changes: 1 addition & 1 deletion cmd/scan/openapi.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ func NewOpenAPIScanCmd() (scanCmd *cobra.Command) {
bar := internalCmd.NewProgressBar(len(s.GetOperationsScans()))
if reporter, _, err = s.Execute(func(operationScan *scan.OperationScan) {
bar.Add(1)
}); err != nil {
}, internalCmd.GetIncludeScans(), internalCmd.GetExcludeScans()); err != nil {
analyticsx.TrackError(ctx, tracer, err)
log.Fatal(err)
}
Expand Down
18 changes: 16 additions & 2 deletions internal/cmd/args.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ var (
rateLimit string
proxy string

includeScans = []string{"*"}
excludeScans = []string{}

placeholderString string
placeholderBool bool
)
Expand All @@ -17,8 +20,11 @@ var defaultRateLimit = "10/s"
func AddCommonArgs(cmd *cobra.Command) {
cmd.Flags().StringVarP(&rateLimit, "rate-limit", "r", defaultRateLimit, "Rate limit for requests (e.g. 10/s, 1/m)")
cmd.Flags().StringVarP(&proxy, "proxy", "p", "", "Proxy URL for requests")
cmd.Flags().StringArrayVarP(&headers, "header", "H", nil, "Headers to include in requests")
cmd.Flags().StringArrayVarP(&cookies, "cookie", "c", nil, "Cookies to include in requests")
cmd.Flags().StringArrayVarP(&headers, "header", "H", headers, "Headers to include in requests")
cmd.Flags().StringArrayVarP(&cookies, "cookie", "c", cookies, "Cookies to include in requests")

cmd.Flags().StringArrayVarP(&includeScans, "scans", "", includeScans, "Include specific scans")
cmd.Flags().StringArrayVarP(&excludeScans, "exclude-scans", "e", excludeScans, "Exclude specific scans")
}

func AddPlaceholderArgs(cmd *cobra.Command) {
Expand Down Expand Up @@ -46,3 +52,11 @@ func GetRateLimit() string {
func GetProxy() string {
return proxy
}

func GetIncludeScans() []string {
return includeScans
}

func GetExcludeScans() []string {
return excludeScans
}
26 changes: 26 additions & 0 deletions scan/operation_scan.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package scan

import (
"github.com/cerberauth/vulnapi/internal/auth"
"github.com/cerberauth/vulnapi/internal/request"
"github.com/cerberauth/vulnapi/report"
)

type OperationScanHandlerFunc func(operation *request.Operation, ss auth.SecurityScheme) (*report.ScanReport, error)

type OperationScanHandler struct {
ID string
Handler OperationScanHandlerFunc
}

type OperationScan struct {
Operation *request.Operation
ScanHandler *OperationScanHandler
}

func NewOperationScanHandler(id string, handler OperationScanHandlerFunc) *OperationScanHandler {
return &OperationScanHandler{
ID: id,
Handler: handler,
}
}
24 changes: 24 additions & 0 deletions scan/operation_scan_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package scan_test

import (
"testing"

"github.com/cerberauth/vulnapi/internal/auth"
"github.com/cerberauth/vulnapi/internal/request"
"github.com/cerberauth/vulnapi/report"
"github.com/cerberauth/vulnapi/scan"
"github.com/stretchr/testify/assert"
)

func TestNewOperationScanHandler(t *testing.T) {
handlerFunc := func(operation *request.Operation, ss auth.SecurityScheme) (*report.ScanReport, error) {
return &report.ScanReport{ID: "test-report"}, nil
}
handlerID := "test-handler"

handler := scan.NewOperationScanHandler(handlerID, handlerFunc)

assert.NotNil(t, handler)
assert.Equal(t, handlerID, handler.ID)
assert.NotNil(t, handler.Handler)
}
53 changes: 37 additions & 16 deletions scan/scan.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,12 @@ package scan

import (
"fmt"
"regexp"

"github.com/cerberauth/vulnapi/internal/auth"
"github.com/cerberauth/vulnapi/internal/request"
"github.com/cerberauth/vulnapi/report"
)

type OperationScan struct {
Operation *request.Operation
Handler ScanHandler
}

type ScanHandler func(operation *request.Operation, ss auth.SecurityScheme) (*report.ScanReport, error)

type Scan struct {
Operations request.Operations
Reporter *report.Reporter
Expand All @@ -41,34 +34,48 @@ func (s *Scan) GetOperationsScans() []OperationScan {
return s.OperationsScans
}

func (s *Scan) AddOperationScanHandler(handler ScanHandler) *Scan {
func (s *Scan) AddOperationScanHandler(handler *OperationScanHandler) *Scan {
for _, operation := range s.Operations {
s.OperationsScans = append(s.OperationsScans, OperationScan{
Operation: operation,
Handler: handler,
Operation: operation,
ScanHandler: handler,
})
}

return s
}

func (s *Scan) AddScanHandler(handler ScanHandler) *Scan {
func (s *Scan) AddScanHandler(handler *OperationScanHandler) *Scan {
s.OperationsScans = append(s.OperationsScans, OperationScan{
Operation: s.Operations[0],
Handler: handler,
Operation: s.Operations[0],
ScanHandler: handler,
})

return s
}

func (s *Scan) Execute(scanCallback func(operationScan *OperationScan)) (*report.Reporter, []error, error) {
func (s *Scan) Execute(scanCallback func(operationScan *OperationScan), includeScans []string, excludeScans []string) (*report.Reporter, []error, error) {
if scanCallback == nil {
scanCallback = func(operationScan *OperationScan) {}
}

var errors []error
for _, scan := range s.OperationsScans {
report, err := scan.Handler(scan.Operation, scan.Operation.SecuritySchemes[0]) // TODO: handle multiple security schemes
if scan.ScanHandler == nil {
continue
}

// Check if the scan should be excluded
if len(excludeScans) > 0 && contains(excludeScans, scan.ScanHandler.ID) {
continue
}

// Check if the scan should be included
if len(includeScans) > 0 && !contains(includeScans, scan.ScanHandler.ID) {
continue
}

report, err := scan.ScanHandler.Handler(scan.Operation, scan.Operation.SecuritySchemes[0]) // TODO: handle multiple security schemes
if err != nil {
errors = append(errors, err)
}
Expand All @@ -82,3 +89,17 @@ func (s *Scan) Execute(scanCallback func(operationScan *OperationScan)) (*report

return s.Reporter, errors, nil
}

func contains(slice []string, item string) bool {
for _, s := range slice {
if s == item {
return true
}

match, _ := regexp.MatchString(s, item)
if match {
return true
}
}
return false
}
Loading

0 comments on commit 9e2cea0

Please sign in to comment.