Skip to content

Commit

Permalink
Merge pull request #242 from cerberauth/fix-nil-pointer-issues
Browse files Browse the repository at this point in the history
fix: nil pointer exceptions
  • Loading branch information
emmanuelgautier authored Jan 17, 2025
2 parents 256db46 + ca1a921 commit 334861d
Show file tree
Hide file tree
Showing 7 changed files with 60 additions and 24 deletions.
7 changes: 7 additions & 0 deletions internal/request/error.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package request

import "errors"

func NilResponseError() error {
return errors.New("response is nil")
}
18 changes: 14 additions & 4 deletions internal/request/response.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,30 @@ import (
)

type Response struct {
Body bytes.Buffer
Body *bytes.Buffer
HttpResponse *http.Response
}

func NewResponse(response *http.Response) (*Response, error) {
if response == nil {
return nil, NilResponseError()
}

if response.Body == nil {
return &Response{
Body: nil,
HttpResponse: response,
}, nil
}

defer response.Body.Close()
body, err := io.ReadAll(response.Body)
if err != nil {
return nil, err
}

return &Response{
Body: *bytes.NewBuffer(body),

Body: bytes.NewBuffer(body),
HttpResponse: response,
}, nil
}
Expand All @@ -30,7 +40,7 @@ func (response *Response) GetStatusCode() int {
}

func (response *Response) GetBody() *bytes.Buffer {
return &response.Body
return response.Body
}

func (response *Response) GetHeader() http.Header {
Expand Down
21 changes: 21 additions & 0 deletions internal/request/response_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,24 @@ func TestNewResponse(t *testing.T) {
assert.Equal(t, httpResponse.Header, res.GetHeader())
assert.Equal(t, httpResponse.Cookies(), res.GetCookies())
}

func TestNewResponseNil(t *testing.T) {
_, err := request.NewResponse(nil)

assert.Error(t, err)
}

func TestNewResponseNilBody(t *testing.T) {
httpResponse := &http.Response{
StatusCode: http.StatusOK,
Header: make(http.Header),
}

res, err := request.NewResponse(httpResponse)

assert.NoError(t, err)
assert.Equal(t, http.StatusOK, res.GetStatusCode())
assert.Equal(t, httpResponse.Header, res.GetHeader())
assert.Equal(t, httpResponse.Cookies(), res.GetCookies())
assert.Nil(t, res.Body)
}
8 changes: 1 addition & 7 deletions internal/scan/scan_url.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
package scan

import (
"errors"

"github.com/cerberauth/vulnapi/internal/auth"
"github.com/cerberauth/vulnapi/internal/operation"
"github.com/cerberauth/vulnapi/internal/request"
Expand All @@ -27,13 +25,9 @@ func ScanURL(operation *operation.Operation, securityScheme *auth.SecurityScheme
}

res, err := req.Do()
if err != nil {
return nil, errors.New("request has an unexpected error")
}

return &IssueScanAttempt{
Request: req,
Response: res,
Err: err,
}, nil
}, err
}
12 changes: 6 additions & 6 deletions internal/scan/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ import (
"github.com/cerberauth/vulnapi/internal/request"
)

func IsUnauthorizedStatusCodeOrSimilar(resp *request.Response) bool {
return resp.GetStatusCode() == http.StatusUnauthorized ||
resp.GetStatusCode() == http.StatusForbidden ||
resp.GetStatusCode() == http.StatusBadRequest ||
resp.GetStatusCode() == http.StatusNotFound ||
resp.GetStatusCode() == http.StatusInternalServerError
func IsUnauthorizedStatusCodeOrSimilar(res *request.Response) bool {
return res.GetStatusCode() == http.StatusUnauthorized ||
res.GetStatusCode() == http.StatusForbidden ||
res.GetStatusCode() == http.StatusBadRequest ||
res.GetStatusCode() == http.StatusNotFound ||
res.GetStatusCode() == http.StatusInternalServerError
}
12 changes: 8 additions & 4 deletions scan/discover/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ func ScanURLs(scanUrls []string, op *operation.Operation, securityScheme *auth.S
base := ExtractBaseURL(&op.URL)
chunkSize := 20
results := make(chan *scan.IssueScanAttempt, len(scanUrls))
errors := make(chan error, len(scanUrls))
errs := make(chan error, len(scanUrls))

for i := 0; i < len(scanUrls); i += chunkSize {
end := i + chunkSize
Expand All @@ -42,13 +42,13 @@ func ScanURLs(scanUrls []string, op *operation.Operation, securityScheme *auth.S
newOperation, err := operation.NewOperation(http.MethodGet, base.ResolveReference(&url.URL{Path: path}).String(), nil, op.Client)
newOperation.SetSecuritySchemes(securitySchemes)
if err != nil {
errors <- err
errs <- err
return
}

attempt, err := scan.ScanURL(newOperation, securityScheme)
if err != nil {
errors <- err
errs <- err
return
}

Expand All @@ -62,10 +62,14 @@ func ScanURLs(scanUrls []string, op *operation.Operation, securityScheme *auth.S
select {
case attempt := <-results:
r.AddScanAttempt(attempt)
if attempt.Err != nil {
errs <- attempt.Err
continue
}
if attempt.Response.GetStatusCode() == http.StatusOK { // TODO: check if the response contains the expected content
data = append(data, struct{ URL string }{URL: attempt.Request.GetURL()})
}
case err := <-errors:
case err := <-errs:
log.Printf("Error scanning URL: %v", err)
continue
}
Expand Down
6 changes: 3 additions & 3 deletions seclist/seclist.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,12 +100,12 @@ func (s *SecList) DownloadFromURL(url string) error {
return err
}

resp, err := req.Do()
res, err := req.Do()
if err != nil {
return err
}

if resp.GetStatusCode() != http.StatusOK {
if res.GetStatusCode() != http.StatusOK {
return errors.New("sec list download failed")
}

Expand All @@ -115,7 +115,7 @@ func (s *SecList) DownloadFromURL(url string) error {
}
defer tempFile.Close()

_, err = io.Copy(tempFile, resp.GetBody())
_, err = io.Copy(tempFile, res.GetBody())
if err != nil {
return err
}
Expand Down

0 comments on commit 334861d

Please sign in to comment.