diff --git a/.github/workflows/go-test.yml b/.github/workflows/go-test.yml index d73979c..a9dad23 100644 --- a/.github/workflows/go-test.yml +++ b/.github/workflows/go-test.yml @@ -24,5 +24,5 @@ jobs: run: go build - name: Run all tests - run: go test -v -tags=router_test,integration,e2e ./... + run: go run gotest.tools/gotestsum@latest -- -tags=router_test,integration,e2e ./... diff --git a/cmd/inspect.go b/cmd/inspect.go index ac54e5c..35fe6ba 100644 --- a/cmd/inspect.go +++ b/cmd/inspect.go @@ -1,6 +1,8 @@ package cmd import ( + "fmt" + "github.com/encodeous/nylon/core" "github.com/spf13/cobra" ) @@ -11,16 +13,16 @@ var inspectCmd = &cobra.Command{ Short: "Inspects the current state of nylon", Run: func(cmd *cobra.Command, args []string) { if len(args) != 1 { - println("Usage: nylon inspect ") + fmt.Println("Usage: nylon inspect ") return } itf := args[0] result, err := core.IPCGet(itf) if err != nil { - println("Error:", err.Error()) + fmt.Println("Error:", err.Error()) return } - println(result) + fmt.Print(result) }, GroupID: "ny", } diff --git a/core/ipc.go b/core/ipc.go index 092bda1..cec2dac 100644 --- a/core/ipc.go +++ b/core/ipc.go @@ -17,6 +17,7 @@ func IPCGet(itf string) (string, error) { if err != nil { return "", err } + defer conn.Close() rw := bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)) _, err = rw.WriteString("get=nylon\n") @@ -37,7 +38,7 @@ func IPCGet(itf string) (string, error) { if err != nil && err != io.EOF { return "", err } - return res, nil + return strings.TrimSuffix(res, "\x00"), nil } func HandleNylonIPCGet(s *state.State, rw *bufio.ReadWriter) error { @@ -57,6 +58,16 @@ func HandleNylonIPCGet(s *state.State, rw *bufio.ReadWriter) error { met = n.BestEndpoint().Metric() } sb.WriteString(fmt.Sprintf(" Metric: %d\n", met)) + sb.WriteString(fmt.Sprintf(" Endpoints:\n")) + for _, ep := range n.Eps { + nep := ep.AsNylonEndpoint() + ap, err := nep.DynEP.Get() + if err != nil { + sb.WriteString(fmt.Sprintf(" - %s (unresolved)\n", nep.DynEP.Value)) + } else { + sb.WriteString(fmt.Sprintf(" - %s (resolved: %s) active=%v metric=%d\n", nep.DynEP.Value, ap.String(), nep.IsActive(), nep.Metric())) + } + } sb.WriteString(fmt.Sprintf(" Published Routes:\n")) rt := make([]string, 0) if len(n.Routes) == 0 { @@ -119,11 +130,15 @@ func HandleNylonIPCGet(s *state.State, rw *bufio.ReadWriter) error { slices.Sort(rt) sb.WriteString(strings.Join(rt, "\n") + "\n") - sb.WriteRune(0) _, err = rw.WriteString(sb.String()) if err != nil { return err } + err = rw.WriteByte(0) + if err != nil { + return err + } + return rw.Flush() default: return fmt.Errorf("unknown command %s", cmd) } diff --git a/core/nylon.go b/core/nylon.go index be4866f..25b4dfc 100644 --- a/core/nylon.go +++ b/core/nylon.go @@ -1,7 +1,6 @@ package core import ( - "context" "net" "net/netip" "time" @@ -28,26 +27,7 @@ func (n *Nylon) Init(s *state.State) error { s.Log.Debug("init nylon") - if len(s.DnsResolvers) != 0 { - s.Log.Debug("setting custom DNS resolvers", "resolvers", s.DnsResolvers) - net.DefaultResolver = &net.Resolver{ - PreferGo: true, - Dial: func(ctx context.Context, network, address string) (net.Conn, error) { - d := net.Dialer{ - Timeout: time.Second * 10, - } - var err error - var conn net.Conn - for _, resolver := range s.DnsResolvers { - conn, err = d.DialContext(ctx, network, resolver) - if err == nil { - return conn, nil - } - } - return conn, err - }, - } - } + state.SetResolvers(s.DnsResolvers) // add neighbours for _, peer := range s.GetPeers(s.Id) { @@ -61,7 +41,7 @@ func (n *Nylon) Init(s *state.State) error { } cfg := s.GetRouter(peer) for _, ep := range cfg.Endpoints { - stNeigh.Eps = append(stNeigh.Eps, state.NewEndpoint(ep, peer, false, nil)) + stNeigh.Eps = append(stNeigh.Eps, state.NewEndpoint(ep, false, nil)) } s.Neighbours = append(s.Neighbours, stNeigh) @@ -85,6 +65,22 @@ func (n *Nylon) Init(s *state.State) error { s.Env.RepeatTask(func(s *state.State) error { return n.probeLinks(s, true) }, state.ProbeDelay) + s.Env.RepeatTask(func(s *state.State) error { + // refresh dynamic endpoints + for _, neigh := range s.Neighbours { + for _, ep := range neigh.Eps { + if nep, ok := ep.(*state.NylonEndpoint); ok { + go func() { + _, err := nep.DynEP.Refresh() + if err != nil { + s.Log.Debug("failed to resolve endpoint", "ep", nep.DynEP.Value, "err", err.Error()) + } + }() + } + } + } + return nil + }, state.EndpointResolveDelay) s.Env.RepeatTask(func(s *state.State) error { return n.probeLinks(s, false) }, state.ProbeRecoveryDelay) diff --git a/core/nylon_distribution.go b/core/nylon_distribution.go index e657cb5..9da7a13 100644 --- a/core/nylon_distribution.go +++ b/core/nylon_distribution.go @@ -1,9 +1,11 @@ package core import ( + "context" "errors" "fmt" "io" + "net" "net/http" "net/url" "os" @@ -27,7 +29,29 @@ func FetchConfig(repoStr string, key state.NyPublicKey) (*state.CentralCfg, erro } cfgBody = file } else if repo.Scheme == "http" || repo.Scheme == "https" { - res, err := http.Get(repo.String()) + client := &http.Client{ + Transport: &http.Transport{ + DialContext: func(ctx context.Context, network string, addr string) (conn net.Conn, err error) { + host, port, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + addrs, err := state.ResolveName(ctx, host) + if err != nil { + return nil, err + } + for _, ip := range addrs { + var dialer net.Dialer + conn, err = dialer.DialContext(ctx, network, net.JoinHostPort(ip.String(), port)) + if err == nil { + break + } + } + return + }, + }, + } + res, err := client.Get(repo.String()) if err != nil { return nil, fmt.Errorf("failed to fetch %s: %w", repo.String(), err) } diff --git a/core/nylon_endpoints.go b/core/nylon_endpoints.go index d845be9..1c26e2a 100644 --- a/core/nylon_endpoints.go +++ b/core/nylon_endpoints.go @@ -16,7 +16,7 @@ type EpPing struct { TimeSent time.Time } -func (n *Nylon) Probe(ep *state.NylonEndpoint) error { +func (n *Nylon) Probe(node state.NodeId, ep *state.NylonEndpoint) error { token := rand.Uint64() ping := &protocol.Ny{ Type: &protocol.Ny_ProbeOp{ @@ -26,8 +26,12 @@ func (n *Nylon) Probe(ep *state.NylonEndpoint) error { }, }, } - peer := n.Device.LookupPeer(device.NoisePublicKey(n.env.GetNode(ep.Node()).PubKey)) - err := n.SendNylon(ping, ep.GetWgEndpoint(n.Device), peer) + peer := n.Device.LookupPeer(device.NoisePublicKey(n.env.GetNode(node).PubKey)) + nep, err := ep.GetWgEndpoint(n.Device) + if err != nil { + return err + } + err = n.SendNylon(ping, nep, peer) if err != nil { return err } @@ -76,7 +80,8 @@ func handleProbePing(s *state.State, node state.NodeId, ep conn.Endpoint) { for _, neigh := range s.Neighbours { for _, dep := range neigh.Eps { dep := dep.AsNylonEndpoint() - if dep.Ep == ep.DstIPPort() && neigh.Id == node { + ap, err := dep.DynEP.Get() + if err == nil && ap == ep.DstIPPort() && neigh.Id == node { // we have a link // refresh wireguard ep @@ -88,7 +93,7 @@ func handleProbePing(s *state.State, node state.NodeId, ep conn.Endpoint) { dep.Renew() if state.DBG_log_probe { - s.Log.Debug("probe from", "addr", dep.Ep) + s.Log.Debug("probe from", "addr", ap.String()) } return } @@ -97,7 +102,7 @@ func handleProbePing(s *state.State, node state.NodeId, ep conn.Endpoint) { // create a new link if we dont have a link for _, neigh := range s.Neighbours { if neigh.Id == node { - newEp := state.NewEndpoint(ep.DstIPPort(), neigh.Id, true, ep) + newEp := state.NewEndpoint(state.NewDynamicEndpoint(ep.DstIPPort().String()), true, ep) newEp.Renew() neigh.Eps = append(neigh.Eps, newEp) // push route update to improve convergence time @@ -114,7 +119,8 @@ func handleProbePong(s *state.State, node state.NodeId, token uint64, ep conn.En for _, neigh := range s.Neighbours { for _, dpLink := range neigh.Eps { dpLink := dpLink.AsNylonEndpoint() - if dpLink.Ep == ep.DstIPPort() && neigh.Id == node { + ap, err := dpLink.DynEP.Get() + if err == nil && ap == ep.DstIPPort() && neigh.Id == node { linkHealth, ok := n.PingBuf.GetAndDelete(token) if ok { health := linkHealth.Value() @@ -145,7 +151,7 @@ func (n *Nylon) probeLinks(s *state.State, active bool) error { for _, ep := range neigh.Eps { if ep.IsActive() == active { go func() { - err := n.Probe(ep.AsNylonEndpoint()) + err := n.Probe(neigh.Id, ep.AsNylonEndpoint()) if err != nil { s.Log.Debug("probe failed", "err", err.Error()) } @@ -169,18 +175,23 @@ func (n *Nylon) probeNew(s *state.State) error { cfg := s.GetRouter(peer) // assumption: we don't need to connect to the same endpoint again within the scope of the same node for _, ep := range cfg.Endpoints { - if !ep.IsValid() { + ap, err := ep.Get() + if err != nil { continue } idx := slices.IndexFunc(neigh.Eps, func(link state.Endpoint) bool { - return !link.IsRemote() && link.AsNylonEndpoint().Ep == ep + lap, err := link.AsNylonEndpoint().DynEP.Get() + if err != nil { + return false + } + return !link.IsRemote() && lap == ap }) if idx == -1 { // add the link to the neighbour - dpl := state.NewEndpoint(ep, peer, false, nil) + dpl := state.NewEndpoint(ep, false, nil) neigh.Eps = append(neigh.Eps, dpl) go func() { - err := n.Probe(dpl) + err := n.Probe(peer, dpl) if err != nil { //s.Log.Debug("discovery probe failed", "err", err.Error()) } diff --git a/core/nylon_gc.go b/core/nylon_gc.go index 0691bf4..0662330 100644 --- a/core/nylon_gc.go +++ b/core/nylon_gc.go @@ -11,11 +11,14 @@ func nylonGc(s *state.State) error { n := 0 for _, x := range neigh.Eps { x := x.AsNylonEndpoint() + if !x.IsActive() { + x.DynEP.Clear() + } if x.IsAlive() { neigh.Eps[n] = x n++ } else { - s.Log.Debug("removed dead endpoint", "ep", x.Ep, "to", x.Node()) + s.Log.Debug("removed dead endpoint", "ep", x.DynEP.String(), "to", neigh.Id) } } neigh.Eps = neigh.Eps[:n] diff --git a/core/nylon_wireguard.go b/core/nylon_wireguard.go index 676c372..de95983 100644 --- a/core/nylon_wireguard.go +++ b/core/nylon_wireguard.go @@ -69,7 +69,11 @@ listen_port=%d rcfg := s.GetRouter(peer) endpoints := make([]conn.Endpoint, 0) for _, nep := range rcfg.Endpoints { - endpoint, err := n.Device.Bind().ParseEndpoint(nep.String()) + ap, err := nep.Get() + if err != nil { + continue + } + endpoint, err := n.Device.Bind().ParseEndpoint(ap.String()) if err != nil { return err } @@ -166,16 +170,24 @@ func UpdateWireGuard(s *state.State) error { return cmp.Compare(a.Metric(), b.Metric()) }) for _, ep := range links { - eps = append(eps, ep.AsNylonEndpoint().GetWgEndpoint(n.Device)) + nep, err := ep.AsNylonEndpoint().GetWgEndpoint(n.Device) + if err != nil { + continue + } + eps = append(eps, nep) } } // add endpoint if it is not in the list for _, ep := range pcfg.Endpoints { + ap, err := ep.Get() + if err != nil { + continue + } if !slices.ContainsFunc(eps, func(endpoint conn.Endpoint) bool { - return endpoint.DstIPPort() == ep + return endpoint.DstIPPort() == ap }) { - endpoint, err := n.Device.Bind().ParseEndpoint(ep.String()) + endpoint, err := n.Device.Bind().ParseEndpoint(ap.String()) if err != nil { return err } diff --git a/core/router.go b/core/router.go index 26dd1f0..5422236 100644 --- a/core/router.go +++ b/core/router.go @@ -349,7 +349,7 @@ func flushIO(s *state.State) error { continue } if best != nil && best.IsActive() { - peer := n.Device.LookupPeer(device.NoisePublicKey(n.env.GetNode(best.Node()).PubKey)) + peer := n.Device.LookupPeer(device.NoisePublicKey(n.env.GetNode(neigh.Id).PubKey)) for { bundle := &protocol.TransportBundle{} tLength := 0 diff --git a/core/router_harness.go b/core/router_harness.go index 95a2a9c..726fbd8 100644 --- a/core/router_harness.go +++ b/core/router_harness.go @@ -48,7 +48,7 @@ func (m MockEndpoint) IsActive() bool { } func (m MockEndpoint) AsNylonEndpoint() *state.NylonEndpoint { - panic("MockEndpoint is not a NylonEndpoint") + return nil } func NewMockEndpoint(node state.NodeId, metric uint32) *MockEndpoint { @@ -207,7 +207,7 @@ func AddLink(r *state.RouterState, ep *MockEndpoint) *MockEndpoint { return nil } -func RemoveLink(r *state.RouterState, ep state.Endpoint) { +func RemoveLink(r *state.RouterState, ep *MockEndpoint) { for _, n := range r.Neighbours { if n.Id == ep.Node() { for i, e := range n.Eps { diff --git a/e2e/distribution_test.go b/e2e/distribution_test.go index 306606b..93bf9cc 100644 --- a/e2e/distribution_test.go +++ b/e2e/distribution_test.go @@ -4,7 +4,6 @@ package e2e import ( "context" - "net/netip" "os" "path/filepath" "testing" @@ -18,8 +17,6 @@ import ( func TestDistribution(t *testing.T) { h := NewHarness(t) - // Cleanup is handled by Harness via t.Cleanup - ctx := context.Background() // 1. Setup Keys @@ -27,10 +24,7 @@ func TestDistribution(t *testing.T) { pubKey := privKey.Pubkey() // 2. Prepare Directories - runDir := filepath.Join(h.RootDir, "e2e", "runs", t.Name()) - if err := os.MkdirAll(runDir, 0755); err != nil { - t.Fatal(err) - } + runDir := h.SetupTestDir() // 3. Prepare Initial Bundle (v1) distCfg := &state.DistributionCfg{ @@ -40,21 +34,14 @@ func TestDistribution(t *testing.T) { nodeKey := state.GenerateKey() nodeId := "node-1" + nodeIP := GetIP(h.Subnet, 10) cfg1 := state.CentralCfg{ Timestamp: 1, Dist: distCfg, Routers: []state.RouterCfg{ - { - NodeCfg: state.NodeCfg{ - Id: state.NodeId(nodeId), - PubKey: nodeKey.Pubkey(), - }, - Endpoints: []netip.AddrPort{}, - }, + SimpleRouter(nodeId, nodeKey.Pubkey(), "10.0.0.1", ""), }, - Clients: []state.ClientCfg{}, - Graph: []string{}, } cfg1Bytes, err := yaml.Marshal(cfg1) @@ -65,6 +52,14 @@ func TestDistribution(t *testing.T) { if err != nil { t.Fatal(err) } + + // Ensure cfg1 has the same timestamp as bundle1 to prevent immediate update + unbundled1, err := state.UnbundleConfig(bundle1Str, pubKey) + if err != nil { + t.Fatal(err) + } + cfg1.Timestamp = unbundled1.Timestamp + bundle1Path := filepath.Join(runDir, "bundle1") if err := os.WriteFile(bundle1Path, []byte(bundle1Str), 0644); err != nil { t.Fatal(err) @@ -99,35 +94,23 @@ func TestDistribution(t *testing.T) { } // 5. Start Nylon Node - // Write central.yaml (v1) to disk for initial startup - centralConfigPath := filepath.Join(runDir, "central.yaml") - if err := os.WriteFile(centralConfigPath, cfg1Bytes, 0644); err != nil { - t.Fatal(err) - } + centralConfigPath := h.WriteConfig(runDir, "central.yaml", cfg1) // Write node.yaml - nodeCfg := state.LocalCfg{ - Id: state.NodeId(nodeId), - Key: nodeKey, - Port: 51820, - // We can omit Dist here since we are providing central.yaml, - // but providing it doesn't hurt. - Dist: &state.LocalDistributionCfg{ - Key: pubKey, - Url: "http://repo:80/bundle", - }, - } - nodeCfgBytes, err := yaml.Marshal(nodeCfg) - if err != nil { - t.Fatal(err) - } - nodeConfigPath := filepath.Join(runDir, "node.yaml") - if err := os.WriteFile(nodeConfigPath, nodeCfgBytes, 0644); err != nil { - t.Fatal(err) + nodeCfg := SimpleLocal(nodeId, nodeKey) + nodeCfg.Dist = &state.LocalDistributionCfg{ + Key: pubKey, + Url: "http://repo:80/bundle", } + nodeConfigPath := h.WriteConfig(runDir, "node.yaml", nodeCfg) t.Log("Starting Nylon Node...") - h.StartNode(nodeId, "", centralConfigPath, nodeConfigPath) + h.StartNodes(NodeSpec{ + Name: nodeId, + IP: nodeIP, + CentralConfigPath: centralConfigPath, + NodeConfigPath: nodeConfigPath, + }) // Wait for start h.WaitForLog(nodeId, "Nylon has been initialized") @@ -137,7 +120,7 @@ func TestDistribution(t *testing.T) { t.Log("Preparing Bundle 2...") // Wait a bit to ensure timestamp is different if using UnixNano time.Sleep(1 * time.Second) - + cfg2 := cfg1 // BundleConfig will overwrite this timestamp anyway cfg2Bytes, err := yaml.Marshal(cfg2) @@ -169,12 +152,12 @@ func TestDistribution(t *testing.T) { // 7. Verify Update t.Log("Waiting for update detection...") h.WaitForLog(nodeId, "Found a new config update in repo") - + t.Log("Waiting for restart...") h.WaitForLog(nodeId, "Restarting Nylon...") // Allow some time for the restart to complete and write the file - time.Sleep(5 * time.Second) + h.WaitForLog(nodeId, "Nylon has been initialized.") t.Log("Verifying config version on node...") stdout, _, err := h.Exec(nodeId, []string{"cat", "/app/config/central.yaml"}) diff --git a/e2e/harness.go b/e2e/harness.go index 26f18e4..47292cd 100644 --- a/e2e/harness.go +++ b/e2e/harness.go @@ -180,7 +180,7 @@ func (h *Harness) StartNode(name string, ip string, centralConfigPath, nodeConfi Env: map[string]string{ "NYLON_LOG_LEVEL": "debug", }, - WaitingFor: wait.ForLog("Nylon has been initialized").WithStartupTimeout(15 * time.Second), + WaitingFor: wait.ForLog("Nylon has been initialized").WithStartupTimeout(30 * time.Second), HostConfigModifier: func(hostConfig *container.HostConfig) { hostConfig.Privileged = true hostConfig.CapAdd = []string{"NET_ADMIN"} @@ -370,6 +370,83 @@ func (h *Harness) PrintLogs(nodeName string) { h.t.Logf("Logs for %s:\n%s", nodeName, buf.String()) } +func (h *Harness) CopyFile(nodeName string, hostPath string, containerPath string) { + h.mu.Lock() + c, ok := h.Nodes[nodeName] + h.mu.Unlock() + if !ok { + h.t.Fatalf("node %s not found", nodeName) + } + err := c.CopyFileToContainer(h.ctx, hostPath, containerPath, 0644) + if err != nil { + h.t.Fatalf("failed to copy file to container %s: %v", nodeName, err) + } +} + +func (h *Harness) StartDNS(name string, ip string, corefile string, zones map[string]string) testcontainers.Container { + h.t.Logf("Starting DNS server %s at %s", name, ip) + + tempDir := h.SetupTestDir() + dnsDir := filepath.Join(tempDir, "dns") + os.MkdirAll(dnsDir, 0755) + + corefilePath := filepath.Join(dnsDir, "Corefile") + os.WriteFile(corefilePath, []byte(corefile), 0644) + + files := []testcontainers.ContainerFile{ + { + HostFilePath: corefilePath, + ContainerFilePath: "/etc/coredns/Corefile", + FileMode: 0644, + }, + } + + for zoneName, zoneContent := range zones { + zonePath := filepath.Join(dnsDir, zoneName) + os.WriteFile(zonePath, []byte(zoneContent), 0644) + files = append(files, testcontainers.ContainerFile{ + HostFilePath: zonePath, + ContainerFilePath: "/etc/coredns/" + zoneName, + FileMode: 0644, + }) + } + + req := testcontainers.ContainerRequest{ + Image: "coredns/coredns:latest", + Networks: []string{h.Network.Name}, + NetworkAliases: map[string][]string{ + h.Network.Name: {name}, + }, + Cmd: []string{"-conf", "/etc/coredns/Corefile"}, + Files: files, + WaitingFor: wait.ForListeningPort("53/udp"), + EndpointSettingsModifier: func(m map[string]*network.EndpointSettings) { + if ip != "" { + if s, ok := m[h.Network.Name]; ok { + s.IPAMConfig = &network.EndpointIPAMConfig{ + IPv4Address: ip, + } + } + } + }, + Name: h.t.Name() + "-" + name, + } + + container, err := testcontainers.GenericContainer(h.ctx, testcontainers.GenericContainerRequest{ + ContainerRequest: req, + Started: true, + }) + if err != nil { + h.t.Fatalf("failed to start coredns container %s: %v", name, err) + } + + h.mu.Lock() + h.Nodes[name] = container + h.mu.Unlock() + + return container +} + // SetupTestDir creates a directory for the current test run func (h *Harness) SetupTestDir() string { dir := filepath.Join(h.RootDir, "e2e", "runs", h.t.Name()) @@ -406,8 +483,8 @@ func SimpleRouter(id string, pubKey state.NyPublicKey, nylonIP string, endpointI }, } if endpointIP != "" { - cfg.Endpoints = []netip.AddrPort{ - netip.MustParseAddrPort(fmt.Sprintf("%s:57175", endpointIP)), + cfg.Endpoints = []*state.DynamicEndpoint{ + state.NewDynamicEndpoint(fmt.Sprintf("%s:57175", endpointIP)), } } return cfg diff --git a/e2e/passive_roaming_test.go b/e2e/passive_roaming_test.go index 812677b..43c73f2 100644 --- a/e2e/passive_roaming_test.go +++ b/e2e/passive_roaming_test.go @@ -46,7 +46,7 @@ func TestPassiveRoaming(t *testing.T) { PubKey: pubKeys["node-1"], Addresses: []netip.Addr{netip.MustParseAddr("10.0.0.1")}, }, - Endpoints: []netip.AddrPort{netip.AddrPortFrom(netip.MustParseAddr(ip1), 51820)}, + Endpoints: []*state.DynamicEndpoint{state.NewDynamicEndpoint(fmt.Sprintf("%s:51820", ip1))}, }, { NodeCfg: state.NodeCfg{ @@ -54,7 +54,7 @@ func TestPassiveRoaming(t *testing.T) { PubKey: pubKeys["node-2"], Addresses: []netip.Addr{netip.MustParseAddr("10.0.0.2")}, }, - Endpoints: []netip.AddrPort{netip.AddrPortFrom(netip.MustParseAddr(ip2), 51820)}, + Endpoints: []*state.DynamicEndpoint{state.NewDynamicEndpoint(fmt.Sprintf("%s:51820", ip2))}, }, { NodeCfg: state.NodeCfg{ @@ -62,7 +62,7 @@ func TestPassiveRoaming(t *testing.T) { PubKey: pubKeys["node-3"], Addresses: []netip.Addr{netip.MustParseAddr("10.0.0.3")}, }, - Endpoints: []netip.AddrPort{netip.AddrPortFrom(netip.MustParseAddr(ip3), 51820)}, + Endpoints: []*state.DynamicEndpoint{state.NewDynamicEndpoint(fmt.Sprintf("%s:51820", ip3))}, }, }, Clients: []state.ClientCfg{ diff --git a/e2e/resolution_test.go b/e2e/resolution_test.go new file mode 100644 index 0000000..decf349 --- /dev/null +++ b/e2e/resolution_test.go @@ -0,0 +1,267 @@ +//go:build e2e + +package e2e + +import ( + "fmt" + "net/netip" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/encodeous/nylon/state" +) + +func TestEndpointResolution(t *testing.T) { + h := NewHarness(t) + + dnsIP := GetIP(h.Subnet, 100) + node1IP := GetIP(h.Subnet, 2) + node2IP := GetIP(h.Subnet, 3) + + // example.com -> node1IP + // srv.example.com -> SRV _nylon._udp.srv.example.com -> 57175 node2.example.com + // node2.example.com -> node2IP + corefile := ` +. { + file /etc/coredns/example.com.db example.com + log + errors +} +` + zoneFile := fmt.Sprintf(` +example.com. 0 IN SOA sns.dns.icann.org. noc.dns.icann.org. 2017042745 7200 3600 1209600 0 +example.com. 0 IN A %s +node2.example.com. 0 IN A %s +_nylon._udp.srv.example.com. 0 IN SRV 10 10 57175 node2.example.com. +`, node1IP, node2IP) + h.StartDNS("dns", dnsIP, corefile, map[string]string{"example.com.db": zoneFile}) + + key1 := state.GenerateKey() + key2 := state.GenerateKey() + + centralCfg := state.CentralCfg{ + Routers: []state.RouterCfg{ + { + NodeCfg: state.NodeCfg{ + Id: "node-1", + PubKey: key1.Pubkey(), + Addresses: []netip.Addr{netip.MustParseAddr("10.0.0.1")}, + }, + // Node 1's endpoint is a hostname + Endpoints: []*state.DynamicEndpoint{ + state.NewDynamicEndpoint("example.com"), + }, + }, + { + NodeCfg: state.NodeCfg{ + Id: "node-2", + PubKey: key2.Pubkey(), + Addresses: []netip.Addr{netip.MustParseAddr("10.0.0.2")}, + }, + // Node 2's endpoint is an SRV record + Endpoints: []*state.DynamicEndpoint{ + state.NewDynamicEndpoint("srv.example.com"), + }, + }, + }, + Graph: []string{"node-1, node-2"}, + } + + testDir := h.SetupTestDir() + centralPath := h.WriteConfig(testDir, "central.yaml", centralCfg) + + node1Cfg := SimpleLocal("node-1", key1) + node1Cfg.DnsResolvers = []string{dnsIP + ":53"} + node1Path := h.WriteConfig(testDir, "node1.yaml", node1Cfg) + + node2Cfg := SimpleLocal("node-2", key2) + node2Cfg.DnsResolvers = []string{dnsIP + ":53"} + node2Path := h.WriteConfig(testDir, "node2.yaml", node2Cfg) + + h.StartNode("node-1", node1IP, centralPath, node1Path) + h.StartNode("node-2", node2IP, centralPath, node2Path) + + h.WaitForLog("node-1", "Nylon has been initialized") + h.WaitForLog("node-2", "Nylon has been initialized") + + verify := func(node string, expectedPattern string) { + timeout := time.After(30 * time.Second) + ticker := time.NewTicker(2 * time.Second) + defer ticker.Stop() + for { + select { + case <-timeout: + t.Fatalf("timed out waiting for resolution pattern %q on node %s", expectedPattern, node) + case <-ticker.C: + stdout, _, err := h.Exec(node, []string{"nylon", "inspect", "nylon0"}) + if err != nil { + continue + } + if strings.Contains(stdout, expectedPattern) { + return + } + } + } + } + + // node-1 should resolve node-2 (srv.example.com) to node2IP:57175 + verify("node-1", fmt.Sprintf("srv.example.com (resolved: %s:57175)", node2IP)) + + // node-2 should resolve node-1 (example.com) to node1IP:57175 + verify("node-2", fmt.Sprintf("example.com (resolved: %s:%d)", node1IP, state.DefaultPort)) +} + +func TestDynamicResolution(t *testing.T) { + h := NewHarness(t) + + dnsIP := GetIP(h.Subnet, 100) + node1IP := GetIP(h.Subnet, 2) + node2IP_A := GetIP(h.Subnet, 3) + node2IP_B := GetIP(h.Subnet, 4) + + // Initial DNS setup + corefile := ` +. { + file /etc/coredns/example.com.db example.com { + reload 2s + } + log + errors +} +` + zoneFileA := fmt.Sprintf(` +example.com. 0 IN SOA sns.dns.icann.org. noc.dns.icann.org. 2017042745 7200 3600 1209600 0 +node2.example.com. 0 IN A %s +`, node2IP_A) + h.StartDNS("dns", dnsIP, corefile, map[string]string{"example.com.db": zoneFileA}) + + key1 := state.GenerateKey() + key2 := state.GenerateKey() + + centralCfg := state.CentralCfg{ + Routers: []state.RouterCfg{ + { + NodeCfg: state.NodeCfg{ + Id: "node-1", + PubKey: key1.Pubkey(), + Addresses: []netip.Addr{netip.MustParseAddr("10.0.0.1")}, + }, + }, + { + NodeCfg: state.NodeCfg{ + Id: "node-2", + PubKey: key2.Pubkey(), + Addresses: []netip.Addr{netip.MustParseAddr("10.0.0.2")}, + }, + // Node 2's endpoint is a hostname + Endpoints: []*state.DynamicEndpoint{ + state.NewDynamicEndpoint("node2.example.com"), + }, + }, + }, + Graph: []string{"node-1, node-2"}, + } + + testDir := h.SetupTestDir() + centralPath := h.WriteConfig(testDir, "central.yaml", centralCfg) + + node1Cfg := SimpleLocal("node-1", key1) + node1Cfg.DnsResolvers = []string{dnsIP + ":53"} + node1Path := h.WriteConfig(testDir, "node1.yaml", node1Cfg) + + node2Cfg := SimpleLocal("node-2", key2) + node2Cfg.DnsResolvers = []string{dnsIP + ":53"} + node2Path := h.WriteConfig(testDir, "node2.yaml", node2Cfg) + + h.StartNode("node-1", node1IP, centralPath, node1Path) + h.StartNode("node-2", node2IP_A, centralPath, node2Path) + + h.WaitForLog("node-1", "Nylon has been initialized") + h.WaitForLog("node-2", "Nylon has been initialized") + + verify := func(node string, expectedPattern string) { + timeout := time.After(60 * time.Second) + ticker := time.NewTicker(2 * time.Second) + defer ticker.Stop() + for { + select { + case <-timeout: + // Print logs of node-1 to help debugging + h.PrintLogs("node-1") + stdout, _, _ := h.Exec(node, []string{"nylon", "inspect", "nylon0"}) + t.Fatalf("timed out waiting for resolution pattern %q on node %s. Current inspect:\n%s", expectedPattern, node, stdout) + case <-ticker.C: + stdout, _, err := h.Exec(node, []string{"nylon", "inspect", "nylon0"}) + if err != nil { + continue + } + if strings.Contains(stdout, expectedPattern) { + return + } + } + } + } + + // Verify initial connection + verify("node-1", fmt.Sprintf("node2.example.com (resolved: %s:57175)", node2IP_A)) + + h.WaitForLog("node-1", "installing new route prefix=10.0.0.2") + h.WaitForLog("node-2", "installing new route prefix=10.0.0.1") + + // Ping from node-1 to node-2 + _, _, err := h.Exec("node-1", []string{"ping", "-c", "3", "10.0.0.2"}) + if err != nil { + t.Fatalf("initial ping failed: %v", err) + } + + // Update DNS record to point to node2IP_B + zoneFileB := fmt.Sprintf(` +example.com. 0 IN SOA sns.dns.icann.org. noc.dns.icann.org. 2017042746 7200 3600 1209600 0 +node2.example.com. 0 IN A %s +`, node2IP_B) + + zonePath := filepath.Join(testDir, "example.com.db.new") + if err := os.WriteFile(zonePath, []byte(zoneFileB), 0644); err != nil { + t.Fatalf("failed to write new zone file: %v", err) + } + h.CopyFile("dns", zonePath, "/etc/coredns/example.com.db") + + // Stop old node-2 + h.mu.Lock() + node2Container := h.Nodes["node-2"] + h.mu.Unlock() + err = node2Container.Terminate(h.ctx) + if err != nil { + t.Logf("failed to terminate node-2: %v", err) + } + + // Start new node-2 at node2IP_B + h.StartNode("node-2-new", node2IP_B, centralPath, node2Path) + h.WaitForLog("node-2-new", "Nylon has been initialized") + + // Wait for Nylon on node-1 to re-resolve. + verify("node-1", fmt.Sprintf("node2.example.com (resolved: %s:57175)", node2IP_B)) + + h.WaitForLog("node-2-new", "installing new route prefix=10.0.0.1/32") + + // Ping from node-1 to node-2 (at new IP) + var lastErr error + for i := 0; i < 15; i++ { + _, _, lastErr = h.Exec("node-1", []string{"ping", "-c", "1", "-W", "1", "10.0.0.2"}) + if lastErr == nil { + break + } + t.Logf("Ping attempt %d failed: %v", i+1, lastErr) + time.Sleep(2 * time.Second) + } + if lastErr != nil { + stdout, _, _ := h.Exec("node-1", []string{"nylon", "inspect", "nylon0"}) + t.Logf("Node 1 inspect:\n%s", stdout) + stdout2, _, _ := h.Exec("node-2-new", []string{"nylon", "inspect", "nylon0"}) + t.Logf("Node 2-new inspect:\n%s", stdout2) + t.Fatalf("ping after DNS change failed after retries: %v", lastErr) + } +} diff --git a/example/sample-central.yaml b/example/sample-central.yaml index b473b72..09d2cfe 100644 --- a/example/sample-central.yaml +++ b/example/sample-central.yaml @@ -30,6 +30,7 @@ routers: addresses: [10.0.0.2, 10.1.0.2] # you can advertise multiple addresses endpoints: - '192.168.1.1:57175' + - 'nylon.example.org' # nylon will re-resolve this domain frequently to check for updates. - id: eve pubkey: 2mXTTD+FYdtJm/v1vSHz8qimvCucjW9vY+nLYacXJFE= addresses: [10.0.0.3] diff --git a/integration/harness.go b/integration/harness.go index 394a65f..55901bb 100644 --- a/integration/harness.go +++ b/integration/harness.go @@ -167,6 +167,7 @@ func (v *VirtualHarness) Start() chan error { nodes := len(v.Central.Routers) vn.virtTun = make([]*tuntest.ChannelTUN, nodes) vn.binds = make([]conn.Bind, nodes) + vn.readyCond = sync.NewCond(&sync.Mutex{}) // pick the first endpoint specified for each node vn.EpOutMapping = func(curNode state.NodeId, to bindtest.ChannelEndpoint2) bindtest.ChannelEndpoint2 { for k, x := range v.Endpoints { @@ -178,7 +179,7 @@ func (v *VirtualHarness) Start() chan error { } for e, n := range v.Endpoints { idx := v.IndexOf(n) - v.Central.Routers[idx].Endpoints = append(v.Central.Routers[idx].Endpoints, netip.MustParseAddrPort(e)) + v.Central.Routers[idx].Endpoints = append(v.Central.Routers[idx].Endpoints, state.NewDynamicEndpoint(e)) } startDelay := 0 * time.Millisecond for idx, rt := range v.Central.Routers { @@ -222,6 +223,7 @@ func (v *VirtualHarness) Start() chan error { return errChan } } + v.Net.Ready() return errChan } @@ -253,6 +255,23 @@ type InMemoryNetwork struct { SelfHandler PacketFilter // packet filter for handling packets destined for the current node TransitHandler PacketFilter // packet filter for handling packets passing through the current node EpOutMapping OutMapping + ready bool + readyCond *sync.Cond +} + +func (i *InMemoryNetwork) WaitForReady() { + i.readyCond.L.Lock() + defer i.readyCond.L.Unlock() + for !i.ready { + i.readyCond.Wait() + } +} + +func (i *InMemoryNetwork) Ready() { + i.Lock() + defer i.Unlock() + i.ready = true + i.readyCond.Broadcast() } func (i *InMemoryNetwork) virtualRouteTable(node state.NodeId, src, dst netip.Addr, data []byte, pkt []byte) bool { @@ -328,6 +347,7 @@ func (i *InMemoryNetwork) Bind(node state.NodeId) conn.Bind { pktBuf[0] = make([]byte, device.MaxMessageSize) lenBuf := make([]int, bufSize) epBuf := make([]conn.Endpoint, bufSize) + i.WaitForReady() for { for _, recv := range open { n, err := recv(pktBuf, lenBuf, epBuf) @@ -365,6 +385,7 @@ func (i *InMemoryNetwork) Tun(node state.NodeId) tun.Device { i.virtTun[numId] = bt go func() { + i.WaitForReady() for { select { case <-i.cfg.Context.Done(): diff --git a/state/config.go b/state/config.go index a17fc3a..fe81d9a 100644 --- a/state/config.go +++ b/state/config.go @@ -21,7 +21,7 @@ type NodeCfg struct { // RouterCfg represents a central representation of a node that can route type RouterCfg struct { NodeCfg `yaml:",inline"` - Endpoints []netip.AddrPort + Endpoints []*DynamicEndpoint } type ClientCfg struct { NodeCfg `yaml:",inline"` diff --git a/state/constants.go b/state/constants.go index 42fce72..0c33db0 100644 --- a/state/constants.go +++ b/state/constants.go @@ -41,4 +41,10 @@ var ( // healthcheck defaults HealthCheckDelay = time.Second * 15 HealthCheckMaxFailures = 3 + + // default port + DefaultPort = 57175 + + EndpointResolveExpiry = time.Minute * 1 + EndpointResolveDelay = time.Second * 15 ) diff --git a/state/dns.go b/state/dns.go new file mode 100644 index 0000000..d785c0d --- /dev/null +++ b/state/dns.go @@ -0,0 +1,58 @@ +package state + +import ( + "context" + "fmt" + "net" + "net/netip" + "strings" + "time" +) + +// SetResolvers configures the global default resolver +func SetResolvers(resolvers []string) { + if len(resolvers) != 0 { + net.DefaultResolver = &net.Resolver{ + Dial: func(ctx context.Context, network, address string) (net.Conn, error) { + d := net.Dialer{Timeout: time.Second * 10} + var lastErr error + for _, r := range resolvers { + conn, err := d.DialContext(ctx, network, r) + if err == nil { + return conn, nil + } + lastErr = err + } + return nil, lastErr + }, + } + } +} + +// ResolveName resolves a hostname to a list of IP addresses +func ResolveName(ctx context.Context, host string) ([]netip.Addr, error) { + ips, err := net.DefaultResolver.LookupHost(ctx, host) + if err != nil { + return nil, err + } + var addrs []netip.Addr + for _, ipStr := range ips { + if addr, err := netip.ParseAddr(ipStr); err == nil { + addrs = append(addrs, addr) + } + } + return addrs, nil +} + +// ResolveSRV resolves an SRV record using the default resolver +func ResolveSRV(ctx context.Context, service, proto, name string) (string, uint16, error) { + _, addrs, err := net.DefaultResolver.LookupSRV(ctx, service, proto, name) + if err != nil { + return "", 0, err + } + if len(addrs) == 0 { + return "", 0, fmt.Errorf("no SRV records found") + } + // Return the first SRV target and port + return strings.TrimSuffix(addrs[0].Target, "."), addrs[0].Port, nil +} diff --git a/state/endpoint.go b/state/endpoint.go index 73cea12..167805a 100644 --- a/state/endpoint.go +++ b/state/endpoint.go @@ -1,10 +1,14 @@ package state import ( + "context" "fmt" "math" + "net" "net/netip" "slices" + "strconv" + "sync" "time" "github.com/encodeous/nylon/polyamide/conn" @@ -12,7 +16,6 @@ import ( ) type Endpoint interface { - Node() NodeId UpdatePing(ping time.Duration) Metric() uint32 IsRemote() bool @@ -20,8 +23,136 @@ type Endpoint interface { AsNylonEndpoint() *NylonEndpoint } +/* + DynamicEndpoint represents either an ip:port or a dns name. This may be resolved to a different address at any time + + Examples: + - nylon.example.com -> resolves to :57175 (DefaultPort) + - nylon2.example.com:12345 -> resolves to :12345 + - SRV record: _nylon._udp.example.com. port: 8000 target: nylon3.example.com -> resolves to :8000 +*/ +type DynamicEndpoint struct { + Value string + lastValue netip.AddrPort + lastUpdate time.Time + rw *sync.RWMutex +} + +func NewDynamicEndpoint(value string) *DynamicEndpoint { + return &DynamicEndpoint{ + Value: value, + rw: &sync.RWMutex{}, + } +} + +func (ep *DynamicEndpoint) Parse() (host string, port uint16, err error) { + // Try to parse as AddrPort directly first to handle cases like [::1]:port correctly + if ap, err := netip.ParseAddrPort(ep.Value); err == nil { + return ap.Addr().String(), ap.Port(), nil + } + + h, portStr, err := net.SplitHostPort(ep.Value) + if err != nil { + // No port specified? + // TODO: more robust validation + return ep.Value, uint16(DefaultPort), nil + } else { + p, err := strconv.ParseUint(portStr, 10, 16) + if err != nil { + return "", 0, fmt.Errorf("invalid port: %w", err) + } + return h, uint16(p), nil + } +} + +func (ep *DynamicEndpoint) Refresh() (netip.AddrPort, error) { + // 1. Try to parse as AddrPort directly + if ap, err := netip.ParseAddrPort(ep.Value); err == nil { + return ap, nil + } + + ep.rw.RLock() + // if this endpoint is down, we will refresh every EndpointResolveDelay + if time.Now().Sub(ep.lastUpdate) < EndpointResolveExpiry && ep.lastValue != (netip.AddrPort{}) { + ep.rw.RUnlock() + return ep.lastValue, nil + } + ep.rw.RUnlock() + + host, port, err := ep.Parse() + if err != nil { + return netip.AddrPort{}, err + } + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + // 2. Try SRV lookup + target, srvPort, err := ResolveSRV(ctx, "nylon", "udp", host) + if err == nil { + addrs, err := ResolveName(ctx, target) + if err == nil && len(addrs) > 0 { + ep.rw.Lock() + defer ep.rw.Unlock() + ep.lastUpdate = time.Now() + ep.lastValue = netip.AddrPortFrom(addrs[0], srvPort) + return ep.lastValue, nil + } + } + + // 3. Normal A/AAAA lookup + addrs, err := ResolveName(ctx, host) + if err != nil { + return netip.AddrPort{}, err + } + if len(addrs) == 0 { + return netip.AddrPort{}, fmt.Errorf("no addresses found for %s", host) + } + + ep.rw.Lock() + defer ep.rw.Unlock() + ep.lastUpdate = time.Now() + ep.lastValue = netip.AddrPortFrom(addrs[0], port) + return ep.lastValue, nil +} + +func (ep *DynamicEndpoint) Get() (netip.AddrPort, error) { + if ap, err := netip.ParseAddrPort(ep.Value); err == nil { + return ap, nil + } + ep.rw.RLock() + defer ep.rw.RUnlock() + if ep.lastValue != (netip.AddrPort{}) { + return ep.lastValue, nil + } + return netip.AddrPort{}, fmt.Errorf("endpoint not resolved") +} + +func (ep *DynamicEndpoint) Clear() { + ep.rw.Lock() + defer ep.rw.Unlock() + ep.lastUpdate = time.Time{} +} + +func (ep *DynamicEndpoint) String() string { + return ep.Value +} + +func (ep *DynamicEndpoint) UnmarshalYAML(unmarshal func(interface{}) error) error { + var s string + if err := unmarshal(&s); err != nil { + return err + } + ep.Value = s + ep.rw = &sync.RWMutex{} + return nil +} + +func (ep *DynamicEndpoint) MarshalYAML() (interface{}, error) { + return ep.Value, nil +} + type NylonEndpoint struct { - node NodeId history []time.Duration histSort []time.Duration dirty bool @@ -30,22 +161,27 @@ type NylonEndpoint struct { expRTT float64 remoteInit bool WgEndpoint conn.Endpoint - Ep netip.AddrPort + DynEP *DynamicEndpoint } func (ep *NylonEndpoint) AsNylonEndpoint() *NylonEndpoint { return ep } -func (ep *NylonEndpoint) GetWgEndpoint(device *device.Device) conn.Endpoint { - if ep.WgEndpoint == nil || ep.WgEndpoint.DstToString() != ep.Ep.String() { - wgEp, err := device.Bind().ParseEndpoint(ep.Ep.String()) +func (ep *NylonEndpoint) GetWgEndpoint(device *device.Device) (conn.Endpoint, error) { + ap, err := ep.DynEP.Get() + if err != nil { + return nil, err + } + + if ep.WgEndpoint == nil || ep.WgEndpoint.DstIPPort() != ap { + wgEp, err := device.Bind().ParseEndpoint(ap.String()) if err != nil { - panic(fmt.Sprintf("Failed to parse endpoint: %s, %v", ep.Ep.String(), err)) + return nil, fmt.Errorf("failed to parse endpoint: %s, %v", ap.String(), err) } ep.WgEndpoint = wgEp } - return ep.WgEndpoint + return ep.WgEndpoint, nil } func (n *Neighbour) BestEndpoint() Endpoint { @@ -59,10 +195,6 @@ func (n *Neighbour) BestEndpoint() Endpoint { return best } -func (u *NylonEndpoint) Node() NodeId { - return u.node -} - func (u *NylonEndpoint) IsActive() bool { return time.Now().Sub(u.lastHeardBack) <= LinkDeadThreshold } @@ -80,13 +212,12 @@ func (u *NylonEndpoint) IsAlive() bool { return u.IsActive() || !u.remoteInit // we never gc endpoints that we have in our config } -func NewEndpoint(endpoint netip.AddrPort, node NodeId, remoteInit bool, wgEndpoint conn.Endpoint) *NylonEndpoint { +func NewEndpoint(endpoint *DynamicEndpoint, remoteInit bool, wgEndpoint conn.Endpoint) *NylonEndpoint { return &NylonEndpoint{ remoteInit: remoteInit, WgEndpoint: wgEndpoint, - Ep: endpoint, + DynEP: endpoint, history: make([]time.Duration, 0), - node: node, expRTT: math.Inf(1), } } diff --git a/state/endpoint_test.go b/state/endpoint_test.go index a45b452..b8b2a61 100644 --- a/state/endpoint_test.go +++ b/state/endpoint_test.go @@ -3,7 +3,6 @@ package state import ( "math" "math/rand/v2" - "net/netip" "testing" "time" @@ -53,7 +52,7 @@ type DataSource struct { func runTests(t *testing.T, ping func(i int) float64, dura time.Duration, fn string) (DataSource, DataSource) { t.Helper() - dep := NewEndpoint(netip.AddrPort{}, "dummy", false, nil) + dep := NewEndpoint(NewDynamicEndpoint("127.0.0.1:0"), false, nil) truth := DataSource{ Name: "Truth", @@ -206,3 +205,68 @@ func TestEndpointNormal(t *testing.T) { // once per minute is acceptable assert.Less(t, len(distinctValues), int(time.Hour*2/time.Minute)) } + +func TestDynamicEndpoint_Parse(t *testing.T) { + tests := []struct { + name string + input string + expectedHost string + expectedPort uint16 + wantErr bool + }{ + { + name: "IPv4 with port", + input: "127.0.0.1:12345", + expectedHost: "127.0.0.1", + expectedPort: 12345, + }, + { + name: "IPv6 with port", + input: "[::1]:12345", + expectedHost: "::1", + expectedPort: 12345, + }, + { + name: "Hostname with port", + input: "example.com:54321", + expectedHost: "example.com", + expectedPort: 54321, + }, + { + name: "Hostname default port", + input: "nylon.example.com", + expectedHost: "nylon.example.com", + expectedPort: uint16(DefaultPort), + }, + { + name: "IPv4 default port", + input: "192.168.1.1", + expectedHost: "192.168.1.1", + expectedPort: uint16(DefaultPort), + }, + { + name: "Invalid port", + input: "example.com:abc", + wantErr: true, + }, + { + name: "Not a URL", + input: "http://example.com/nylon", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ep := NewDynamicEndpoint(tt.input) + host, port, err := ep.Parse() + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.expectedHost, host) + assert.Equal(t, tt.expectedPort, port) + } + }) + } +} diff --git a/state/scheduler.go b/state/scheduler.go index 4d114e7..0f0fae7 100644 --- a/state/scheduler.go +++ b/state/scheduler.go @@ -32,11 +32,14 @@ func (e *Env) ScheduleTask(fun func(*State) error, delay time.Duration) { } func (e *Env) repeatedTask(fun func(*State) error, delay time.Duration) { + // run immediately + e.Dispatch(fun) + ticker := time.NewTicker(delay) for e.Context.Err() == nil { select { case <-e.Context.Done(): return - case <-time.After(delay): + case <-ticker.C: e.Dispatch(fun) } } diff --git a/state/utils_test.go b/state/utils_test.go index bf4869b..d8f68a6 100644 --- a/state/utils_test.go +++ b/state/utils_test.go @@ -62,8 +62,8 @@ func SampleNetwork(t *testing.T, numClients, numRouters int, fullyConnected bool }, }, }, - Endpoints: []netip.AddrPort{ - netip.MustParseAddrPort(fmt.Sprintf("192.168.0.%d:25565", idx)), + Endpoints: []*DynamicEndpoint{ + NewDynamicEndpoint(fmt.Sprintf("192.168.0.%d:25565", idx)), }, }) }