Skip to content

Commit

Permalink
Improve --err-first-hit handling (#596)
Browse files Browse the repository at this point in the history
* Improve --err-first-hit handling

* Make error handling uniform across commands
  • Loading branch information
tstromberg authored Nov 6, 2024
1 parent 7847a94 commit b6125e7
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 24 deletions.
25 changes: 21 additions & 4 deletions cmd/mal/mal.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package main

import (
"context"
"errors"
"fmt"
"io/fs"
"log/slog"
Expand Down Expand Up @@ -80,6 +81,16 @@ var riskMap = map[string]int{
"critical": 4,
}

func showError(err error) {
emoji := "💣"
if errors.Is(err, action.ErrMatchedCondition) {
emoji = "👋"
err = errors.Unwrap(err)
}

fmt.Fprintf(os.Stderr, "%s %s\n", emoji, err.Error())
}

//nolint:cyclop // ignore complexity of 40
func main() {
returnCode := ExitOK
Expand Down Expand Up @@ -398,7 +409,7 @@ func main() {
ps, err := action.ActiveProcesses(ctx)
if err != nil {
returnCode = ExitActionFailed
return fmt.Errorf("process paths: %w", err)
return err
}
for _, p := range ps {
// in the future, we'll also want to attach process info directly
Expand All @@ -409,7 +420,7 @@ func main() {
res, err = action.Scan(ctx, mc)
if err != nil {
returnCode = ExitActionFailed
return fmt.Errorf("scan: %w", err)
return err
}

err = renderer.Full(ctx, res)
Expand Down Expand Up @@ -530,7 +541,13 @@ func main() {
}

if err := app.Run(os.Args); err != nil {
fmt.Fprintf(os.Stderr, "error: %v\n", err)
returnCode = ExitActionFailed
if returnCode != 0 {
returnCode = ExitActionFailed
}
if errors.Is(err, action.ErrMatchedCondition) {
returnCode = ExitOK
}

showError(err)
}
}
44 changes: 24 additions & 20 deletions pkg/action/scan.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package action

import (
"context"
"errors"
"fmt"
"io/fs"
"log/slog"
Expand All @@ -31,7 +32,8 @@ var (
// compiledRuleCache are a cache of previously compiled rules.
compiledRuleCache *yara.Rules
// compileOnce ensures that we compile rules only once even across threads.
compileOnce sync.Once
compileOnce sync.Once
ErrMatchedCondition = errors.New("matched requested condition")
)

// findFilesRecursively returns a list of files found recursively within a path.
Expand Down Expand Up @@ -233,7 +235,6 @@ func cachedRules(ctx context.Context, fss []fs.FS) (*yara.Rules, error) {
//nolint:gocognit,cyclop // ignoring complexity of 101,38
func recursiveScan(ctx context.Context, c malcontent.Config) (*malcontent.Report, error) {
logger := clog.FromContext(ctx)
logger.Debug("recursive scan", slog.Any("config", c))
r := &malcontent.Report{
Files: orderedmap.New[string, *malcontent.FileReport](),
}
Expand All @@ -243,11 +244,12 @@ func recursiveScan(ctx context.Context, c malcontent.Config) (*malcontent.Report

var scanPathFindings sync.Map

var waitErr error

for _, scanPath := range c.ScanPaths {
if c.Renderer != nil {
c.Renderer.Scanning(ctx, scanPath)
}
logger.Debug("recursive scan", slog.Any("scanPath", scanPath))
imageURI := ""
ociExtractPath := ""
var err error
Expand Down Expand Up @@ -323,18 +325,19 @@ func recursiveScan(ctx context.Context, c malcontent.Config) (*malcontent.Report
fr, err := processFile(ctx, c, c.RuleFS, path, scanPath, trimPath, logger)
if err != nil {
scanPathFindings.Store(path, &malcontent.FileReport{})
return err
return fmt.Errorf("process: %w", err)
}
if fr != nil {
scanPathFindings.Store(path, fr)
if !c.OCI {
var frMap sync.Map
frMap.Store(path, fr)
if err := errIfHitOrMiss(&frMap, "file", path, c.ErrFirstHit, c.ErrFirstMiss); err != nil {
logger.Debugf("match short circuit: %s", err)
scanPathFindings.Store(path, fr)
return err
}
if fr == nil {
return nil
}

scanPathFindings.Store(path, fr)
if !c.OCI {
var frMap sync.Map
frMap.Store(path, fr)
if err := errIfHitOrMiss(&frMap, "file", path, c.ErrFirstHit, c.ErrFirstMiss); err != nil {
scanPathFindings.Store(path, fr)
return fmt.Errorf("%q: %w", path, ErrMatchedCondition)
}
}
return nil
Expand All @@ -351,8 +354,7 @@ func recursiveScan(ctx context.Context, c malcontent.Config) (*malcontent.Report
}

if err := g.Wait(); err != nil {
logger.Errorf("error with processing %v\n", err)
return nil, err
waitErr = err
}

var pathKeys []string
Expand Down Expand Up @@ -396,6 +398,11 @@ func recursiveScan(ctx context.Context, c malcontent.Config) (*malcontent.Report
}
}
}

// short-circuit out
if waitErr != nil {
return r, waitErr
}
} // loop: next scan path
return r, nil
}
Expand Down Expand Up @@ -460,9 +467,6 @@ func processFile(ctx context.Context, c malcontent.Config, ruleFS []fs.FS, path
func Scan(ctx context.Context, c malcontent.Config) (*malcontent.Report, error) {
r, err := recursiveScan(ctx, c)
if err != nil {
if strings.Contains(err.Error(), "no matching capabilities") {
return r, nil
}
return r, err
}
for files := r.Files.Oldest(); files != nil; files = files.Next() {
Expand All @@ -473,7 +477,7 @@ func Scan(ctx context.Context, c malcontent.Config) (*malcontent.Report, error)
if c.Stats {
err = render.Statistics(r)
if err != nil {
return r, err
return r, fmt.Errorf("stats: %w", err)
}
}
return r, nil
Expand Down

0 comments on commit b6125e7

Please sign in to comment.