diff --git a/massrdns.go b/massrdns.go
@@ -15,11 +15,14 @@ import (
)
var dnsServers []string
+var failureCounts = make(map[string]int)
+
+func loadDNSServersFromFile(filePath string) ([]string, error) {
+ var servers []string
-func loadDNSServersFromFile(filePath string) error {
file, err := os.Open(filePath)
if err != nil {
- return err
+ return nil, err
}
defer file.Close()
@@ -27,32 +30,32 @@ func loadDNSServersFromFile(filePath string) error {
for scanner.Scan() {
server := scanner.Text()
- // Check if the server contains a port
if strings.Contains(server, ":") {
host, port, err := net.SplitHostPort(server)
if err != nil {
- return fmt.Errorf("invalid IP:port format for %s", server)
+ return nil, fmt.Errorf("invalid IP:port format for %s", server)
}
if net.ParseIP(host) == nil {
- return fmt.Errorf("invalid IP address in %s", server)
+ return nil, fmt.Errorf("invalid IP address in %s", server)
}
if _, err := strconv.Atoi(port); err != nil {
- return fmt.Errorf("invalid port in %s", server)
+ return nil, fmt.Errorf("invalid port in %s", server)
}
} else {
if net.ParseIP(server) == nil {
- return fmt.Errorf("invalid IP address %s", server)
+ return nil, fmt.Errorf("invalid IP address %s", server)
}
- server += ":53" // Default to port 53 if not specified
+ server += ":53"
}
- dnsServers = append(dnsServers, server)
+ servers = append(servers, server)
}
- return scanner.Err()
+ return servers, scanner.Err()
}
-func reverseDNSLookup(ip string, server string) string {
- ctx := context.Background()
+func reverseDNSLookup(ip string, server string) (string, error) {
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel()
resolver := &net.Resolver{
PreferGo: true,
@@ -64,24 +67,40 @@ func reverseDNSLookup(ip string, server string) string {
names, err := resolver.LookupAddr(ctx, ip)
if err != nil {
- return fmt.Sprintf("%s | %s | Error: %s", time.Now().Format("03:04:05 PM"), server, err)
+ if isNetworkError(err) {
+ return "", err
+ }
+ return fmt.Sprintf("%s | %s | Error: %s", time.Now().Format("03:04:05 PM"), server, err), nil
}
if len(names) == 0 {
- return fmt.Sprintf("%s | %s | No PTR records", time.Now().Format("03:04:05 PM"), server)
+ return fmt.Sprintf("%s | %s | No PTR records", time.Now().Format("03:04:05 PM"), server), nil
}
- return fmt.Sprintf("%s | %s | %s", time.Now().Format("03:04:05 PM"), server, names[0])
+ return fmt.Sprintf("%s | %s | %s", time.Now().Format("03:04:05 PM"), server, names[0]), nil
}
-func worker(cidr *net.IPNet, resultsChan chan string) {
- for ip := make(net.IP, len(cidr.IP)); copy(ip, cidr.IP) != 0; incrementIP(ip) {
- if !cidr.Contains(ip) {
- break
+func isNetworkError(err error) bool {
+ errorString := err.Error()
+ return strings.Contains(errorString, "timeout") || strings.Contains(errorString, "connection refused")
+}
+
+func pickRandomServer(servers []string, triedServers map[string]bool) string {
+ for _, i := range rand.Perm(len(servers)) {
+ if !triedServers[servers[i]] {
+ return servers[i]
}
- randomServer := dnsServers[rand.Intn(len(dnsServers))]
- result := reverseDNSLookup(ip.String(), randomServer)
- resultsChan <- result
}
+ return ""
+}
+
+func removeFromList(servers []string, server string) []string {
+ var newList []string
+ for _, s := range servers {
+ if s != server {
+ newList = append(newList, s)
+ }
+ }
+ return newList
}
func splitCIDR(cidr string, parts int) ([]*net.IPNet, error) {
@@ -91,6 +110,12 @@ func splitCIDR(cidr string, parts int) ([]*net.IPNet, error) {
}
maskSize, _ := ipNet.Mask.Size()
+
+ maxParts := 1 << uint(32-maskSize)
+ if parts > maxParts {
+ parts = maxParts
+ }
+
newMaskSize := maskSize
for ; (1 << uint(newMaskSize-maskSize)) < parts; newMaskSize++ {
if newMaskSize > 32 {
@@ -110,6 +135,48 @@ func splitCIDR(cidr string, parts int) ([]*net.IPNet, error) {
return subnets, nil
}
+func worker(cidr *net.IPNet, resultsChan chan string) {
+ for ip := make(net.IP, len(cidr.IP)); copy(ip, cidr.IP) != 0; incrementIP(ip) {
+ if !cidr.Contains(ip) {
+ break
+ }
+
+ triedServers := make(map[string]bool)
+ retries := 10
+ success := false
+
+ for retries > 0 {
+ randomServer := pickRandomServer(dnsServers, triedServers)
+ if randomServer == "" {
+ break
+ }
+
+ result, err := reverseDNSLookup(ip.String(), randomServer)
+
+ // Check for network errors
+ if err != nil && isNetworkError(err) {
+ failureCounts[randomServer]++
+ if failureCounts[randomServer] > 10 {
+ dnsServers = removeFromList(dnsServers, randomServer)
+ delete(failureCounts, randomServer)
+ }
+
+ triedServers[randomServer] = true
+ retries--
+ continue
+ } else if err == nil {
+ resultsChan <- result
+ success = true
+ break
+ }
+ }
+
+ if !success {
+ resultsChan <- fmt.Sprintf("%s | %s | Max retries reached", time.Now().Format("03:04:05 PM"), ip)
+ }
+ }
+}
+
func main() {
var cidr string
var concurrency int
@@ -125,7 +192,9 @@ func main() {
os.Exit(1)
}
- if err := loadDNSServersFromFile(dnsFile); err != nil {
+ var err error
+ dnsServers, err = loadDNSServersFromFile(dnsFile)
+ if err != nil {
fmt.Printf("Error reading DNS servers from file %s: %s\n", dnsFile, err)
os.Exit(1)
}
@@ -143,21 +212,24 @@ func main() {
os.Exit(1)
}
- // Create a channel to feed CIDR blocks to workers
+ if len(subnets) < concurrency {
+ concurrency = len(subnets) // Limit concurrency to number of subnets
+ }
+
cidrChan := make(chan *net.IPNet, len(subnets))
for _, subnet := range subnets {
cidrChan <- subnet
}
- close(cidrChan) // Close it, so workers can detect when there's no more work
+ close(cidrChan)
- resultsChan := make(chan string, concurrency*2) // Increased buffer size for results
+ resultsChan := make(chan string, concurrency*2)
var wg sync.WaitGroup
for i := 0; i < concurrency; i++ {
wg.Add(1)
go func() {
defer wg.Done()
- for subnet := range cidrChan { // Keep working until there's no more work
+ for subnet := range cidrChan {
worker(subnet, resultsChan)
}
}()
|