diff --git a/cmd/mal/mal.go b/cmd/mal/mal.go index 2df445cc..87336b12 100644 --- a/cmd/mal/mal.go +++ b/cmd/mal/mal.go @@ -63,6 +63,7 @@ var ( outputFlag string profileFlag bool quantityIncreasesRiskFlag bool + scannersFlag int statsFlag bool thirdPartyFlag bool verboseFlag bool @@ -92,7 +93,7 @@ func showError(err error) { fmt.Fprintf(os.Stderr, "%s %s\n", emoji, err.Error()) } -//nolint:cyclop // ignore complexity of 40 +//nolint:cyclop,gocognit // ignore complexity of 40,100 func main() { returnCode := ExitOK defer func() { os.Exit(returnCode) }() @@ -249,9 +250,14 @@ func main() { concurrency = 1 } + maxScanners := scannersFlag + if maxScanners > concurrency { + maxScanners = concurrency + } + var pool *malcontent.ScannerPool if mc.ScannerPool == nil { - pool, err = malcontent.NewScannerPool(yrs, concurrency) + pool, err = malcontent.NewScannerPool(yrs, maxScanners) if err != nil { returnCode = ExitInvalidRules } @@ -264,6 +270,7 @@ func main() { IgnoreSelf: ignoreSelfFlag, IgnoreTags: ignoreTags, IncludeDataFiles: includeDataFiles, + MaxScanners: maxScanners, MinFileRisk: minFileRisk, MinRisk: minRisk, OCI: ociFlag, @@ -372,6 +379,12 @@ func main() { Usage: "Increase file risk score based on behavior quantity", Destination: &quantityIncreasesRiskFlag, }, + &cli.IntFlag{ + Name: "scanners", + Value: runtime.NumCPU(), + Usage: "Number of scanners to create", + Destination: &scannersFlag, + }, &cli.BoolFlag{ Name: "stats", Aliases: []string{"s"}, diff --git a/pkg/action/scan.go b/pkg/action/scan.go index a9bbda6b..b6affa74 100644 --- a/pkg/action/scan.go +++ b/pkg/action/scan.go @@ -57,7 +57,7 @@ func scanSinglePath(ctx context.Context, c malcontent.Config, path string, ruleF var pool *malcontent.ScannerPool if c.ScannerPool == nil { - pool, err = malcontent.NewScannerPool(yrs, c.Concurrency) + pool, err = malcontent.NewScannerPool(yrs, c.MaxScanners) if err != nil { return nil, fmt.Errorf("failed to create scanner pool: %w", err) } diff --git a/pkg/malcontent/malcontent.go b/pkg/malcontent/malcontent.go index e0798ee3..ff0633f1 100644 --- a/pkg/malcontent/malcontent.go +++ b/pkg/malcontent/malcontent.go @@ -11,7 +11,6 @@ import ( "runtime" "sync" "sync/atomic" - "time" yarax "github.com/VirusTotal/yara-x/go" orderedmap "github.com/wk8/go-ordered-map/v2" @@ -34,6 +33,7 @@ type Config struct { IgnoreSelf bool IgnoreTags []string IncludeDataFiles bool + MaxScanners int MinFileRisk int MinRisk int OCI bool @@ -175,23 +175,17 @@ func NewScannerPool(rules *yarax.Rules, maxScanners int) (*ScannerPool, error) { available: make(chan *yarax.Scanner, maxScanners), maxScanners: int32(maxScanners), scanners: make([]*yarax.Scanner, 0, maxScanners), + closed: atomic.Bool{}, } - pool.closed.Store(false) - - // Create a subset of the maximum number of scanners to avoid contention - initialScanners := maxScanners/2 + 1 - for i := 0; i < initialScanners; i++ { - scanner, err := pool.createScanner() - if err != nil { - pool.Cleanup() - return nil, fmt.Errorf("failed to create initial scanner: %w", err) - } - pool.scanners = append(pool.scanners, scanner) - pool.available <- scanner - atomic.AddInt32(&pool.currentCount, 1) + scanner := yarax.NewScanner(rules) + if scanner == nil { + return nil, fmt.Errorf("failed to create scanner") } + pool.available <- scanner + atomic.AddInt32(&pool.currentCount, 1) + return pool, nil } @@ -236,39 +230,27 @@ func (p *ScannerPool) Get() (*yarax.Scanner, error) { return nil, fmt.Errorf("scanner pool is closed") } + // Retrieve an existing scanner + // If none are available, create up to the maximum number of scanners select { case scanner := <-p.available: - if scanner == nil { - return nil, fmt.Errorf("received nil scanner from pool") - } return scanner, nil - case <-time.After(100 * time.Millisecond): - } - - // Create a new scanner if we aren't already running the maximum number - p.mu.Lock() - current := atomic.LoadInt32(&p.currentCount) - if current < p.maxScanners { - scanner, err := p.createScanner() - if err != nil { + default: + p.mu.Lock() + if atomic.LoadInt32(&p.currentCount) < p.maxScanners { + scanner, err := p.createScanner() + if err != nil { + p.mu.Unlock() + return nil, fmt.Errorf("create scanner: %w", err) + } + p.scanners = append(p.scanners, scanner) + atomic.AddInt32(&p.currentCount, 1) p.mu.Unlock() - return nil, fmt.Errorf("create scanner: %w", err) + return scanner, nil } - p.scanners = append(p.scanners, scanner) - atomic.AddInt32(&p.currentCount, 1) p.mu.Unlock() - return scanner, nil - } - p.mu.Unlock() - select { - case scanner := <-p.available: - if scanner == nil { - return nil, fmt.Errorf("received nil scanner from pool") - } - return scanner, nil - case <-time.After(10 * time.Second): - return nil, fmt.Errorf("timeout waiting for available scanner") + return <-p.available, nil } } @@ -277,20 +259,7 @@ func (p *ScannerPool) Put(scanner *yarax.Scanner) { if scanner == nil || p.closed.Load() { return } - - select { - case p.available <- scanner: - default: - p.mu.Lock() - defer func() { - p.mu.Unlock() - if atomic.LoadInt32(&p.currentCount) > p.maxScanners/2 { - runtime.GC() - } - }() - scanner.Destroy() - atomic.AddInt32(&p.currentCount, -1) - } + p.available <- scanner } // Cleanup destroys all scanners in the pool. diff --git a/pkg/refresh/action.go b/pkg/refresh/action.go index 6fcca755..6859bc16 100644 --- a/pkg/refresh/action.go +++ b/pkg/refresh/action.go @@ -69,6 +69,7 @@ func actionRefresh(ctx context.Context) ([]TestData, error) { c := &malcontent.Config{ Concurrency: runtime.NumCPU(), IgnoreSelf: false, + MaxScanners: runtime.NumCPU(), MinFileRisk: 0, MinRisk: 0, OCI: false, @@ -81,7 +82,7 @@ func actionRefresh(ctx context.Context) ([]TestData, error) { var pool *malcontent.ScannerPool if c.ScannerPool == nil { - pool, err = malcontent.NewScannerPool(yrs, c.Concurrency) + pool, err = malcontent.NewScannerPool(yrs, c.MaxScanners) if err != nil { return nil, err } diff --git a/pkg/refresh/diff.go b/pkg/refresh/diff.go index a2e91d59..13a4ef52 100644 --- a/pkg/refresh/diff.go +++ b/pkg/refresh/diff.go @@ -196,6 +196,7 @@ func diffRefresh(ctx context.Context, rc Config) ([]TestData, error) { Concurrency: runtime.NumCPU(), FileRiskChange: td.riskChange, FileRiskIncrease: td.riskIncrease, + MaxScanners: runtime.NumCPU(), MinFileRisk: minFileRisk, MinRisk: minRisk, QuantityIncreasesRisk: true, @@ -207,7 +208,7 @@ func diffRefresh(ctx context.Context, rc Config) ([]TestData, error) { var pool *malcontent.ScannerPool if c.ScannerPool == nil { - pool, err = malcontent.NewScannerPool(yrs, c.Concurrency) + pool, err = malcontent.NewScannerPool(yrs, c.MaxScanners) if err != nil { return nil, err } diff --git a/pkg/refresh/refresh.go b/pkg/refresh/refresh.go index 2aa13bcf..ce4c8716 100644 --- a/pkg/refresh/refresh.go +++ b/pkg/refresh/refresh.go @@ -74,6 +74,7 @@ func newConfig(rc Config) *malcontent.Config { return &malcontent.Config{ Concurrency: runtime.NumCPU(), IgnoreTags: []string{"harmless"}, + MaxScanners: runtime.NumCPU(), MinFileRisk: 1, MinRisk: 1, QuantityIncreasesRisk: true,