From ca1a9214a66264787ce5a67e87e8749e158bb1f1 Mon Sep 17 00:00:00 2001 From: Emmanuel Gautier Date: Sat, 18 Jan 2025 00:30:18 +0100 Subject: [PATCH] fix: nil pointer exceptions --- internal/request/error.go | 7 +++++++ internal/request/response.go | 18 ++++++++++++++---- internal/request/response_test.go | 21 +++++++++++++++++++++ internal/scan/scan_url.go | 8 +------- internal/scan/utils.go | 12 ++++++------ scan/discover/utils.go | 12 ++++++++---- seclist/seclist.go | 6 +++--- 7 files changed, 60 insertions(+), 24 deletions(-) create mode 100644 internal/request/error.go diff --git a/internal/request/error.go b/internal/request/error.go new file mode 100644 index 0000000..1522f45 --- /dev/null +++ b/internal/request/error.go @@ -0,0 +1,7 @@ +package request + +import "errors" + +func NilResponseError() error { + return errors.New("response is nil") +} diff --git a/internal/request/response.go b/internal/request/response.go index a9a5291..3a0f3f7 100644 --- a/internal/request/response.go +++ b/internal/request/response.go @@ -7,11 +7,22 @@ import ( ) type Response struct { - Body bytes.Buffer + Body *bytes.Buffer HttpResponse *http.Response } func NewResponse(response *http.Response) (*Response, error) { + if response == nil { + return nil, NilResponseError() + } + + if response.Body == nil { + return &Response{ + Body: nil, + HttpResponse: response, + }, nil + } + defer response.Body.Close() body, err := io.ReadAll(response.Body) if err != nil { @@ -19,8 +30,7 @@ func NewResponse(response *http.Response) (*Response, error) { } return &Response{ - Body: *bytes.NewBuffer(body), - + Body: bytes.NewBuffer(body), HttpResponse: response, }, nil } @@ -30,7 +40,7 @@ func (response *Response) GetStatusCode() int { } func (response *Response) GetBody() *bytes.Buffer { - return &response.Body + return response.Body } func (response *Response) GetHeader() http.Header { diff --git a/internal/request/response_test.go b/internal/request/response_test.go index 369b607..9c0b547 100644 --- a/internal/request/response_test.go +++ b/internal/request/response_test.go @@ -27,3 +27,24 @@ func TestNewResponse(t *testing.T) { assert.Equal(t, httpResponse.Header, res.GetHeader()) assert.Equal(t, httpResponse.Cookies(), res.GetCookies()) } + +func TestNewResponseNil(t *testing.T) { + _, err := request.NewResponse(nil) + + assert.Error(t, err) +} + +func TestNewResponseNilBody(t *testing.T) { + httpResponse := &http.Response{ + StatusCode: http.StatusOK, + Header: make(http.Header), + } + + res, err := request.NewResponse(httpResponse) + + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, res.GetStatusCode()) + assert.Equal(t, httpResponse.Header, res.GetHeader()) + assert.Equal(t, httpResponse.Cookies(), res.GetCookies()) + assert.Nil(t, res.Body) +} diff --git a/internal/scan/scan_url.go b/internal/scan/scan_url.go index fa764c7..bf07bc8 100644 --- a/internal/scan/scan_url.go +++ b/internal/scan/scan_url.go @@ -1,8 +1,6 @@ package scan import ( - "errors" - "github.com/cerberauth/vulnapi/internal/auth" "github.com/cerberauth/vulnapi/internal/operation" "github.com/cerberauth/vulnapi/internal/request" @@ -27,13 +25,9 @@ func ScanURL(operation *operation.Operation, securityScheme *auth.SecurityScheme } res, err := req.Do() - if err != nil { - return nil, errors.New("request has an unexpected error") - } - return &IssueScanAttempt{ Request: req, Response: res, Err: err, - }, nil + }, err } diff --git a/internal/scan/utils.go b/internal/scan/utils.go index ce429d0..ed8c1d8 100644 --- a/internal/scan/utils.go +++ b/internal/scan/utils.go @@ -6,10 +6,10 @@ import ( "github.com/cerberauth/vulnapi/internal/request" ) -func IsUnauthorizedStatusCodeOrSimilar(resp *request.Response) bool { - return resp.GetStatusCode() == http.StatusUnauthorized || - resp.GetStatusCode() == http.StatusForbidden || - resp.GetStatusCode() == http.StatusBadRequest || - resp.GetStatusCode() == http.StatusNotFound || - resp.GetStatusCode() == http.StatusInternalServerError +func IsUnauthorizedStatusCodeOrSimilar(res *request.Response) bool { + return res.GetStatusCode() == http.StatusUnauthorized || + res.GetStatusCode() == http.StatusForbidden || + res.GetStatusCode() == http.StatusBadRequest || + res.GetStatusCode() == http.StatusNotFound || + res.GetStatusCode() == http.StatusInternalServerError } diff --git a/scan/discover/utils.go b/scan/discover/utils.go index d10f6d5..10f19a0 100644 --- a/scan/discover/utils.go +++ b/scan/discover/utils.go @@ -28,7 +28,7 @@ func ScanURLs(scanUrls []string, op *operation.Operation, securityScheme *auth.S base := ExtractBaseURL(&op.URL) chunkSize := 20 results := make(chan *scan.IssueScanAttempt, len(scanUrls)) - errors := make(chan error, len(scanUrls)) + errs := make(chan error, len(scanUrls)) for i := 0; i < len(scanUrls); i += chunkSize { end := i + chunkSize @@ -42,13 +42,13 @@ func ScanURLs(scanUrls []string, op *operation.Operation, securityScheme *auth.S newOperation, err := operation.NewOperation(http.MethodGet, base.ResolveReference(&url.URL{Path: path}).String(), nil, op.Client) newOperation.SetSecuritySchemes(securitySchemes) if err != nil { - errors <- err + errs <- err return } attempt, err := scan.ScanURL(newOperation, securityScheme) if err != nil { - errors <- err + errs <- err return } @@ -62,10 +62,14 @@ func ScanURLs(scanUrls []string, op *operation.Operation, securityScheme *auth.S select { case attempt := <-results: r.AddScanAttempt(attempt) + if attempt.Err != nil { + errs <- attempt.Err + continue + } if attempt.Response.GetStatusCode() == http.StatusOK { // TODO: check if the response contains the expected content data = append(data, struct{ URL string }{URL: attempt.Request.GetURL()}) } - case err := <-errors: + case err := <-errs: log.Printf("Error scanning URL: %v", err) continue } diff --git a/seclist/seclist.go b/seclist/seclist.go index 78d4407..ee8ae96 100644 --- a/seclist/seclist.go +++ b/seclist/seclist.go @@ -100,12 +100,12 @@ func (s *SecList) DownloadFromURL(url string) error { return err } - resp, err := req.Do() + res, err := req.Do() if err != nil { return err } - if resp.GetStatusCode() != http.StatusOK { + if res.GetStatusCode() != http.StatusOK { return errors.New("sec list download failed") } @@ -115,7 +115,7 @@ func (s *SecList) DownloadFromURL(url string) error { } defer tempFile.Close() - _, err = io.Copy(tempFile, resp.GetBody()) + _, err = io.Copy(tempFile, res.GetBody()) if err != nil { return err }