-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain.go
170 lines (150 loc) · 4.14 KB
/
main.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
package main
import (
"flag"
"fmt"
"log"
"net"
"os"
"strings"
"time"
"github.com/lanrat/allxfr/zone"
"github.com/lanrat/allxfr/psl"
"github.com/miekg/dns"
"golang.org/x/sync/errgroup"
)
var (
parallel = flag.Uint("parallel", 10, "number of parallel zone transfers to perform")
saveDir = flag.String("out", "zones", "directory to save found zones in")
verbose = flag.Bool("verbose", false, "enable verbose output")
zonefile = flag.String("zonefile", "", "use the provided zonefile instead of getting the root zonefile")
ns = flag.String("ns", "", "nameserver to use for manually querying of records not in zone file")
saveAll = flag.Bool("save-all", false, "attempt AXFR from every nameserver for a given zone and save all answers")
usePSL = flag.Bool("psl", false, "attempt AXFR from zones listed in the public suffix list, requires -ns flag")
ixfr = flag.Bool("ixfr", false, "attempt an IXFR instead of AXFR")
dryRun = flag.Bool("dry-run", false, "only test if xfr is allowed by retrieving one envelope")
retry = flag.Int("retry", 3, "number of times to retry failed operations")
overwrite = flag.Bool("overwrite", false, "if zone already exists on disk, overwrite it with newer data")
)
var (
localNameserver string
totalXFR uint32
)
const (
globalTimeout = 15 * time.Second
)
func main() {
//log.SetFlags(0)
flag.Parse()
if *usePSL && len(*ns) == 0 {
log.Fatal("must pass nameserver with -ns when using -psl")
}
if *retry < 1 {
log.Fatal("retry must be positive")
}
if flag.NArg() > 0 {
log.Fatalf("unexpected arguments %v", flag.Args())
}
var err error
localNameserver, err = getNameserver()
check(err)
v("using initial nameserver %s", localNameserver)
start := time.Now()
var z zone.Zone
if len(*zonefile) == 0 {
rootNameservers, err := zone.GetRootServers(localNameserver)
check(err)
// get zone file from root AXFR
// not all the root nameservers allow AXFR, try them until we find one that does
for _, ns := range rootNameservers {
v("trying root nameserver %s", ns)
startTime := time.Now()
z, err = zone.RootAXFR(ns)
if err == nil {
took := time.Since(startTime).Round(time.Millisecond)
log.Printf("ROOT %s xfr size: %d records in %s \n", ns, z.Records, took.String())
break
}
}
} else {
// zone file is provided
v("parsing zonefile: %q\n", *zonefile)
z, err = zone.ParseZoneFile(*zonefile)
check(err)
}
if z.CountNS() == 0 {
log.Fatal("Got empty zone")
}
if *usePSL {
pslDomains, err := psl.GetDomains()
check(err)
for _, domain := range pslDomains {
z.AddNS(domain, "")
}
v("added %d domains from PSL\n", len(pslDomains))
}
// create outpout dir if does not exist
if !*dryRun {
if _, err := os.Stat(*saveDir); os.IsNotExist(err) {
err = os.MkdirAll(*saveDir, os.ModePerm)
check(err)
}
}
if *verbose {
z.PrintTree()
}
zoneChan := z.GetNameChan()
var g errgroup.Group
// start workers
for i := uint(0); i < *parallel; i++ {
g.Go(func() error { return worker(z, zoneChan) })
}
err = g.Wait()
check(err)
took := time.Since(start).Round(time.Millisecond)
log.Printf("%d / %d transferred in %s\n", totalXFR, len(z.NS), took.String())
v("exiting normally\n")
}
func worker(z zone.Zone, c chan string) error {
for {
domain, more := <-c
if !more {
return nil
}
err := axfrWorker(z, domain)
if err != nil {
return err
}
}
}
func check(err error) {
if err != nil {
log.Fatal(err)
}
}
func v(format string, v ...interface{}) {
if *verbose {
line := fmt.Sprintf(format, v...)
lines := strings.ReplaceAll(line, "\n", "\n\t")
log.Print(lines)
}
}
// getNameserver returns the nameserver passed via flag if provided, if not returns the system's NS
func getNameserver() (string, error) {
var server string
if len(*ns) == 0 {
// get root server from local DNS
conf, err := dns.ClientConfigFromFile("/etc/resolv.conf")
if err != nil {
return "", err
}
server = net.JoinHostPort(conf.Servers[0], conf.Port)
} else {
host, port, err := net.SplitHostPort(*ns)
if err != nil {
server = net.JoinHostPort(*ns, "53")
} else {
server = net.JoinHostPort(host, port)
}
}
return server, nil
}