Skip to content

Commit

Permalink
add DisableDNSResolution for TCPDialer. Sometimes, users do not need …
Browse files Browse the repository at this point in the history
…to use DNS resolution because they have already determined that the requested address is a list of IP addresses. (#1702)

Co-authored-by: wangzhengkai.wzk <[email protected]>
  • Loading branch information
xuxiao415 and wangzhengkai.wzk authored Feb 10, 2024
1 parent 48dd2d0 commit dfb7e62
Showing 1 changed file with 15 additions and 10 deletions.
25 changes: 15 additions & 10 deletions tcpdialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,8 @@ type TCPDialer struct {
// }
Resolver Resolver

// DisableDNSResolution may be used to disable DNS resolution
DisableDNSResolution bool
// DNSCacheDuration may be used to override the default DNS cache duration (DefaultDNSCacheDuration)
DNSCacheDuration time.Duration

Expand Down Expand Up @@ -277,23 +279,26 @@ func (d *TCPDialer) dial(addr string, dualStack bool, timeout time.Duration) (ne
d.DNSCacheDuration = DefaultDNSCacheDuration
}

go d.tcpAddrsClean()
if !d.DisableDNSResolution {
go d.tcpAddrsClean()
}
})

deadline := time.Now().Add(timeout)
addrs, idx, err := d.getTCPAddrs(addr, dualStack, deadline)
if err != nil {
return nil, err
}
network := "tcp4"
if dualStack {
network = "tcp"
}

if d.DisableDNSResolution {
return d.tryDial(network, addr, deadline, d.concurrencyCh)
}
addrs, idx, err := d.getTCPAddrs(addr, dualStack, deadline)
if err != nil {
return nil, err
}
var conn net.Conn
n := uint32(len(addrs))
for n > 0 {
conn, err = d.tryDial(network, &addrs[idx%n], deadline, d.concurrencyCh)
conn, err = d.tryDial(network, addrs[idx%n].String(), deadline, d.concurrencyCh)
if err == nil {
return conn, nil
}
Expand All @@ -307,7 +312,7 @@ func (d *TCPDialer) dial(addr string, dualStack bool, timeout time.Duration) (ne
}

func (d *TCPDialer) tryDial(
network string, addr *net.TCPAddr, deadline time.Time, concurrencyCh chan struct{},
network string, addr string, deadline time.Time, concurrencyCh chan struct{},
) (net.Conn, error) {
timeout := time.Until(deadline)
if timeout <= 0 {
Expand Down Expand Up @@ -340,7 +345,7 @@ func (d *TCPDialer) tryDial(

ctx, cancelCtx := context.WithDeadline(context.Background(), deadline)
defer cancelCtx()
conn, err := dialer.DialContext(ctx, network, addr.String())
conn, err := dialer.DialContext(ctx, network, addr)
if err != nil && ctx.Err() == context.DeadlineExceeded {
return nil, ErrDialTimeout
}
Expand Down

0 comments on commit dfb7e62

Please sign in to comment.