diff --git a/cmd/golcg/main.go b/cmd/golcg/main.go
@@ -5,16 +5,47 @@ import (
"fmt"
"os"
"strconv"
+ "strings"
"github.com/acidvegas/golcg"
)
const Version = "1.0.0"
+func parseShardArg(shard string) (int, int, error) {
+ if shard == "" {
+ return 1, 1, nil
+ }
+
+ parts := strings.Split(shard, "/")
+ if len(parts) != 2 {
+ return 0, 0, fmt.Errorf("invalid shard format. Expected INDEX/TOTAL, got %s", shard)
+ }
+
+ index, err := strconv.Atoi(parts[0])
+ if err != nil {
+ return 0, 0, fmt.Errorf("invalid shard index: %v", err)
+ }
+
+ total, err := strconv.Atoi(parts[1])
+ if err != nil {
+ return 0, 0, fmt.Errorf("invalid shard total: %v", err)
+ }
+
+ if index < 1 || index > total {
+ return 0, 0, fmt.Errorf("shard index must be between 1 and total")
+ }
+
+ if total < 1 {
+ return 0, 0, fmt.Errorf("total shards must be greater than 0")
+ }
+
+ return index, total, nil
+}
+
func main() {
cidr := flag.String("cidr", "", "Target IP range in CIDR format")
- shardNum := flag.Int("shard-num", 1, "Shard number (1-based)")
- totalShards := flag.Int("total-shards", 1, "Total number of shards")
+ shard := flag.String("shard", "", "Shard specification in INDEX/TOTAL format (e.g., 1/4)")
seed := flag.Int("seed", 0, "Random seed for LCG")
stateStr := flag.String("state", "", "Resume from specific LCG state")
version := flag.Bool("version", false, "Show version information")
@@ -31,6 +62,12 @@ func main() {
os.Exit(1)
}
+ shardNum, totalShards, err := parseShardArg(*shard)
+ if err != nil {
+ fmt.Printf("Error: %v\n", err)
+ os.Exit(1)
+ }
+
var state *uint32
if *stateStr != "" {
stateVal, err := strconv.ParseUint(*stateStr, 10, 32)
@@ -42,7 +79,7 @@ func main() {
state = &stateUint32
}
- stream, err := golcg.IPStream(*cidr, *shardNum, *totalShards, *seed, state)
+ stream, err := golcg.IPStream(*cidr, shardNum, totalShards, *seed, state)
if err != nil {
fmt.Printf("Error: %v\n", err)
os.Exit(1)
diff --git a/golcg.go b/golcg.go
@@ -20,7 +20,7 @@ type LCG struct {
func NewLCG(seed int, m uint32) *LCG {
return &LCG{
- M: m,
+ M: 1<<32 - 1,
A: 1664525,
C: 1013904223,
Current: uint32(seed),
@@ -46,17 +46,22 @@ func NewIPRange(cidr string) (*IPRange, error) {
start := ipToUint32(network.IP)
ones, bits := network.Mask.Size()
hostBits := uint(bits - ones)
- broadcast := start | (1<<hostBits - 1)
- total := broadcast - start + 1
+
+ var total uint32
+ if hostBits == 32 {
+ total = 0
+ } else {
+ total = 1 << hostBits
+ }
return &IPRange{
Start: start,
- Total: uint32(total),
+ Total: total,
}, nil
}
func (r *IPRange) GetIPAtIndex(index uint32) (string, error) {
- if index >= r.Total {
+ if r.Total > 0 && index >= r.Total {
return "", errors.New("IP index out of range")
}
@@ -79,7 +84,7 @@ func uint32ToIP(n uint32) net.IP {
}
func SaveState(seed int, cidr string, shard int, total int, lcgCurrent uint32) error {
- fileName := fmt.Sprintf("pylcg_%d_%s_%d_%d.state", seed, strings.Replace(cidr, "/", "_", -1), shard, total)
+ fileName := fmt.Sprintf("golcg_%d_%s_%d_%d.state", seed, strings.Replace(cidr, "/", "_", -1), shard, total)
stateFile := filepath.Join(os.TempDir(), fileName)
return os.WriteFile(stateFile, []byte(fmt.Sprintf("%d", lcgCurrent)), 0644)
@@ -103,19 +108,33 @@ func IPStream(cidr string, shardNum, totalShards, seed int, state *uint32) (<-ch
lcg.Current = *state
}
- shardSize := ipRange.Total / uint32(totalShards)
-
- if uint32(shardIndex) < (ipRange.Total % uint32(totalShards)) {
- shardSize++
+ var shardSize uint32
+ if ipRange.Total == 0 {
+ shardSize = (1<<32 - 1) / uint32(totalShards)
+ if uint32(shardIndex) < uint32(totalShards-1) {
+ shardSize++
+ }
+ } else {
+ shardSize = ipRange.Total / uint32(totalShards)
+ if uint32(shardIndex) < ipRange.Total%uint32(totalShards) {
+ shardSize++
+ }
}
- out := make(chan string)
+ out := make(chan string, 1000)
go func() {
defer close(out)
remaining := shardSize
for remaining > 0 {
- index := lcg.Next() % ipRange.Total
+ next := lcg.Next()
+ var index uint32
+ if ipRange.Total > 0 {
+ index = next % ipRange.Total
+ } else {
+ index = next
+ }
+
if totalShards == 1 || index%uint32(totalShards) == uint32(shardIndex) {
ip, err := ipRange.GetIPAtIndex(index)
if err != nil {
| |