Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[client] Add reverse dns zone #3217

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 85 additions & 0 deletions client/internal/dns.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
package internal

import (
"net"
"slices"
"strings"

"github.com/miekg/dns"
log "github.com/sirupsen/logrus"

nbdns "github.com/netbirdio/netbird/dns"
)

func createPTRRecord(aRecord nbdns.SimpleRecord, ipNet *net.IPNet) (nbdns.SimpleRecord, bool) {
ip := net.ParseIP(aRecord.RData)
if ip == nil || ip.To4() == nil {
return nbdns.SimpleRecord{}, false
}

if !ipNet.Contains(ip) {
return nbdns.SimpleRecord{}, false
}

ipOctets := strings.Split(ip.String(), ".")
slices.Reverse(ipOctets)
rdnsName := dns.Fqdn(strings.Join(ipOctets, ".") + ".in-addr.arpa")

return nbdns.SimpleRecord{
Name: rdnsName,
Type: int(dns.TypePTR),
Class: aRecord.Class,
TTL: aRecord.TTL,
RData: dns.Fqdn(aRecord.Name),
}, true
}

func addReverseZone(config *nbdns.Config, ipNet *net.IPNet) {
networkIP := ipNet.IP.Mask(ipNet.Mask)

maskOnes, _ := ipNet.Mask.Size()
// round up to nearest byte
octetsToUse := (maskOnes + 7) / 8

octets := strings.Split(networkIP.String(), ".")
if octetsToUse > len(octets) {
log.Warnf("invalid network mask size for reverse DNS: %d", maskOnes)
return
}

reverseOctets := make([]string, octetsToUse)
for i := 0; i < octetsToUse; i++ {
reverseOctets[octetsToUse-1-i] = octets[i]
}

zoneName := dns.Fqdn(strings.Join(reverseOctets, ".") + ".in-addr.arpa")

for _, zone := range config.CustomZones {
if zone.Domain == zoneName {
log.Debugf("reverse DNS zone %s already exists", zoneName)
return
}
}

var records []nbdns.SimpleRecord

for _, zone := range config.CustomZones {
for _, record := range zone.Records {
if record.Type != int(dns.TypeA) {
continue
}

if ptrRecord, ok := createPTRRecord(record, ipNet); ok {
records = append(records, ptrRecord)
}
}
}

reverseZone := nbdns.CustomZone{
Domain: zoneName,
Records: records,
}

config.CustomZones = append(config.CustomZones, reverseZone)
log.Debugf("added reverse DNS zone: %s with %d records", zoneName, len(records))
}
10 changes: 6 additions & 4 deletions client/internal/dns/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -380,11 +380,11 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {

localMuxUpdates, localRecords, err := s.buildLocalHandlerUpdate(update.CustomZones)
if err != nil {
return fmt.Errorf("not applying dns update, error: %v", err)
return fmt.Errorf("local handler updater: %w", err)
}
upstreamMuxUpdates, err := s.buildUpstreamHandlerUpdate(update.NameServerGroups)
if err != nil {
return fmt.Errorf("not applying dns update, error: %v", err)
return fmt.Errorf("upstream handler updater: %w", err)
}
muxUpdates := append(localMuxUpdates, upstreamMuxUpdates...) //nolint:gocritic

Expand Down Expand Up @@ -425,7 +425,8 @@ func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone)

for _, customZone := range customZones {
if len(customZone.Records) == 0 {
return nil, nil, fmt.Errorf("received an empty list of records")
log.Warnf("received a custom zone with empty records, skipping domain: %s", customZone.Domain)
continue
}

muxUpdates = append(muxUpdates, muxUpdate{
Expand All @@ -437,7 +438,8 @@ func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone)
for _, record := range customZone.Records {
var class uint16 = dns.ClassINET
if record.Class != nbdns.DefaultClass {
return nil, nil, fmt.Errorf("received an invalid class type: %s", record.Class)
log.Warnf("received an invalid class type: %s", record.Class)
continue
}
key := buildRecordKey(record.Name, class, uint16(record.Type))
localRecords[key] = record
Expand Down
4 changes: 2 additions & 2 deletions client/internal/dns/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ func TestUpdateDNSServer(t *testing.T) {
shouldFail: true,
},
{
name: "Invalid Custom Zone Records list Should Fail",
name: "Invalid Custom Zone Records list Should Skip",
initLocalMap: make(registrationMap),
initUpstreamMap: make(registeredHandlerMap),
initSerial: 0,
Expand All @@ -239,7 +239,7 @@ func TestUpdateDNSServer(t *testing.T) {
},
},
},
shouldFail: true,
expectedUpstreamMap: registeredHandlerMap{".": dummyHandler},
},
{
name: "Empty Config Should Succeed and Clean Maps",
Expand Down
11 changes: 8 additions & 3 deletions client/internal/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -972,7 +972,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
protoDNSConfig = &mgmProto.DNSConfig{}
}

if err := e.dnsServer.UpdateDNSServer(serial, toDNSConfig(protoDNSConfig)); err != nil {
if err := e.dnsServer.UpdateDNSServer(serial, toDNSConfig(protoDNSConfig, e.wgInterface.Address().Network)); err != nil {
log.Errorf("failed to update dns server, err: %v", err)
}

Expand Down Expand Up @@ -1041,7 +1041,7 @@ func toRouteDomains(myPubKey string, protoRoutes []*mgmProto.Route) []string {
return dnsRoutes
}

func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig) nbdns.Config {
func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network *net.IPNet) nbdns.Config {
dnsUpdate := nbdns.Config{
ServiceEnable: protoDNSConfig.GetServiceEnable(),
CustomZones: make([]nbdns.CustomZone, 0),
Expand Down Expand Up @@ -1081,6 +1081,11 @@ func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig) nbdns.Config {
}
dnsUpdate.NameServerGroups = append(dnsUpdate.NameServerGroups, dnsNSGroup)
}

if len(dnsUpdate.CustomZones) > 0 {
addReverseZone(&dnsUpdate, network)
}

return dnsUpdate
}

Expand Down Expand Up @@ -1387,7 +1392,7 @@ func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, error) {
return nil, nil, err
}
routes := toRoutes(netMap.GetRoutes())
dnsCfg := toDNSConfig(netMap.GetDNSConfig())
dnsCfg := toDNSConfig(netMap.GetDNSConfig(), e.wgInterface.Address().Network)
return routes, &dnsCfg, nil
}

Expand Down
15 changes: 15 additions & 0 deletions client/internal/engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,15 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
RemovePeerFunc: func(peerKey string) error {
return nil
},
AddressFunc: func() iface.WGAddress {
return iface.WGAddress{
IP: net.ParseIP("10.20.0.1"),
Network: &net.IPNet{
IP: net.ParseIP("10.20.0.0"),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
}
},
}
engine.wgInterface = wgIface
engine.routeManager = routemanager.NewManager(routemanager.ManagerConfig{
Expand Down Expand Up @@ -692,6 +701,9 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
},
},
},
{
Domain: "0.66.100.in-addr.arpa.",
},
},
NameServerGroups: []*mgmtProto.NameServerGroup{
{
Expand Down Expand Up @@ -721,6 +733,9 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
},
},
},
{
Domain: "0.66.100.in-addr.arpa.",
},
},
expectedNSGroupsLen: 1,
expectedNSGroups: []*nbdns.NameServerGroup{
Expand Down
Loading