diff --git a/prober/dns.go b/prober/dns.go index 184fa949f..281401529 100644 --- a/prober/dns.go +++ b/prober/dns.go @@ -259,8 +259,9 @@ func ProbeDNS(ctx context.Context, target string, module config.Module, registry msg.Question[0] = dns.Question{dns.Fqdn(module.DNS.QueryName), qt, qc} logger.Info("Making DNS query", "target", targetIP, "dial_protocol", dialProtocol, "query", module.DNS.QueryName, "type", qt, "class", qc) - timeoutDeadline, _ := ctx.Deadline() - client.Timeout = time.Until(timeoutDeadline) + if timeoutDeadline, ok := ctx.Deadline(); ok { + client.Timeout = time.Until(timeoutDeadline) + } requestStart := time.Now() response, rtt, err := client.Exchange(msg, targetIP) // The rtt value returned from client.Exchange includes only the time to diff --git a/prober/http.go b/prober/http.go index 5ffc029f7..83f82e85a 100644 --- a/prober/http.go +++ b/prober/http.go @@ -503,29 +503,54 @@ func ProbeHTTP(ctx context.Context, target string, module config.Module, registr } } - // Since the configuration specifies a compression algorithm, blindly treat the response body as a - // compressed payload; if we cannot decompress it it's a failure because the configuration says we - // should expect the response to be compressed in that way. + originalBody := resp.Body + if !requestErrored { + defer func(c io.Closer) { + err := c.Close() + if err != nil { + logger.Error("Error while closing response from server", "err", err) + } + }(originalBody) + } + + var dec io.ReadCloser if httpConfig.Compression != "" { - dec, err := getDecompressionReader(httpConfig.Compression, resp.Body) - if err != nil { - logger.Info("Failed to get decompressor for HTTP response body", "err", err) + dec, err = getDecompressionReader(httpConfig.Compression, resp.Body) + } else { + var acceptEncoding string + if request.Header != nil { + acceptEncoding = request.Header.Get("Accept-Encoding") + } + respEncoding := resp.Header.Get("Content-Encoding") + logger.Debug("Response encoding", "encoding", respEncoding, "Accept-Encoding", acceptEncoding) + if responseEncodingIfInAcceptHeader(respEncoding, acceptEncoding) { + dec, err = getDecompressionReader(respEncoding, resp.Body) + } else { + logger.Warn( + "Response encoding not in Accept-Encoding header", + "encoding", respEncoding, "Accept-Encoding", acceptEncoding) + err = fmt.Errorf("response encoding not in Accept-Encoding header: %s not in %s", respEncoding, acceptEncoding) success = false - } else if dec != nil { - // Since we are replacing the original resp.Body with the decoder, we need to make sure - // we close the original body. We cannot close it right away because the decompressor - // might not have read it yet. + } + } + if err != nil { + logger.Error("Failed to get decompressor for HTTP response body", "err", err) + success = false + } + + if dec != nil { + if originalBody != dec { + // If the decompressor is different from the original body, + // we need to close the decompressor. defer func(c io.Closer) { err := c.Close() if err != nil { - // At this point we cannot really do anything with this error, but log - // it in case it contains useful information as to what's the problem. - logger.Info("Error while closing response from server", "err", err) + logger.Error("Error while closing decompressor", "err", err) } - }(resp.Body) - - resp.Body = dec + }(dec) } + + resp.Body = dec } // If there's a configured body_size_limit, wrap the body in the response in a http.MaxBytesReader. @@ -682,3 +707,18 @@ func getDecompressionReader(algorithm string, origBody io.ReadCloser) (io.ReadCl return nil, errors.New("unsupported compression algorithm") } } + +func responseEncodingIfInAcceptHeader(encoding string, acceptHeader string) bool { + if encoding == "" || encoding == "identity" { + return true + } + + acceptHeaderParts := strings.Split(acceptHeader, ",") + for _, part := range acceptHeaderParts { + if strings.TrimSpace(part) == encoding { + return true + } + } + + return false +} diff --git a/prober/icmp.go b/prober/icmp.go index 883fbcf73..9f45f7249 100644 --- a/prober/icmp.go +++ b/prober/icmp.go @@ -293,11 +293,13 @@ func ProbeICMP(ctx context.Context, target string, module config.Module, registr } rb := make([]byte, 65536) - deadline, _ := ctx.Deadline() - if icmpConn != nil { - err = icmpConn.SetReadDeadline(deadline) - } else { - err = v4RawConn.SetReadDeadline(deadline) + + if deadline, ok := ctx.Deadline(); ok { + if icmpConn != nil { + err = icmpConn.SetReadDeadline(deadline) + } else { + err = v4RawConn.SetReadDeadline(deadline) + } } if err != nil { logger.Error("Error setting socket deadline", "err", err) diff --git a/prober/tcp.go b/prober/tcp.go index 98262813a..99eaf577f 100644 --- a/prober/tcp.go +++ b/prober/tcp.go @@ -80,8 +80,10 @@ func dialTCP(ctx context.Context, target string, module config.Module, registry // via tlsConfig to enable hostname verification. tlsConfig.ServerName = targetAddress } - timeoutDeadline, _ := ctx.Deadline() - dialer.Deadline = timeoutDeadline + + if timeoutDeadline, ok := ctx.Deadline(); ok { + dialer.Deadline = timeoutDeadline + } logger.Info("Dialing TCP with TLS") return tls.DialWithDialer(dialer, dialProtocol, dialTarget, tlsConfig) @@ -124,7 +126,6 @@ func ProbeTCP(ctx context.Context, target string, module config.Module, registry Help: "Indicates if probe failed due to regex", }) registry.MustRegister(probeFailedDueToRegex) - deadline, _ := ctx.Deadline() conn, err := dialTCP(ctx, target, module, registry, logger) if err != nil { @@ -134,13 +135,13 @@ func ProbeTCP(ctx context.Context, target string, module config.Module, registry defer conn.Close() logger.Info("Successfully dialed") - // Set a deadline to prevent the following code from blocking forever. - // If a deadline cannot be set, better fail the probe by returning an error - // now rather than blocking forever. - if err := conn.SetDeadline(deadline); err != nil { - logger.Error("Error setting deadline", "err", err) - return false + if deadline, ok := ctx.Deadline(); ok { + if err := conn.SetDeadline(deadline); err != nil { + logger.Error("Error setting deadline", "err", err) + return false + } } + if module.TCP.TLS { state := conn.(*tls.Conn).ConnectionState() registry.MustRegister(probeSSLEarliestCertExpiry, probeTLSVersion, probeSSLLastChainExpiryTimestampSeconds, probeSSLLastInformation)