Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions pkg/gateway/gateway.go
Original file line number Diff line number Diff line change
Expand Up @@ -588,7 +588,7 @@ func (g *gatewayModule) SetNamedProxy(wlID string, config zos.GatewayNameProxy)
},
}

if err := g.setupRouting(ctx, wlID, fqdn, gatewayTLSConfig, config.GatewayBase); err != nil {
if err := g.setupRouting(ctx, wlID, fqdn, gatewayTLSConfig, config.GatewayBase, cfg.IPv4.IP, cfg.IPv6.IP); err != nil {
return "", err
}

Expand Down Expand Up @@ -622,14 +622,14 @@ func (g *gatewayModule) SetFQDNProxy(wlID string, config zos.GatewayFQDNProxy) e
},
}

return g.setupRouting(ctx, wlID, config.FQDN, gatewayTLSConfig, config.GatewayBase)
return g.setupRouting(ctx, wlID, config.FQDN, gatewayTLSConfig, config.GatewayBase, cfg.IPv4.IP, cfg.IPv6.IP)
}

func (g *gatewayModule) setupRouting(ctx context.Context, wlID string, fqdn string, tlsConfig TlsConfig, config zos.GatewayBase) error {
func (g *gatewayModule) setupRouting(ctx context.Context, wlID string, fqdn string, tlsConfig TlsConfig, config zos.GatewayBase, nodeIPs ...net.IP) error {
g.domainLock.Lock()
defer g.domainLock.Unlock()

if err := zos.ValidateBackends(config.Backends, config.TLSPassthrough); err != nil {
if err := zos.ValidateBackends(config.Backends, config.TLSPassthrough, nodeIPs...); err != nil {
return err
}

Expand Down
22 changes: 18 additions & 4 deletions pkg/gateway_light/gateway.go
Original file line number Diff line number Diff line change
Expand Up @@ -584,6 +584,13 @@ func (g *gatewayModule) SetNamedProxy(wlID string, config zos.GatewayNameProxy)
return "", errors.New("node doesn't support name proxy (doesn't have a domain)")
}

// Get public config for node IP validation
netStub := stubs.NewNetworkerLightStub(g.cl)
pubConfig, err := netStub.LoadPublicConfig(ctx)
if err != nil {
return "", errors.Wrap(err, "failed to load public config")
}

if err := g.validateNameContract(config.Name, twinID); err != nil {
return "", errors.Wrap(err, "failed to verify name contract")
}
Expand All @@ -599,7 +606,7 @@ func (g *gatewayModule) SetNamedProxy(wlID string, config zos.GatewayNameProxy)
},
}

if err := g.setupRouting(ctx, wlID, fqdn, gatewayTLSConfig, config.GatewayBase); err != nil {
if err := g.setupRouting(ctx, wlID, fqdn, gatewayTLSConfig, config.GatewayBase, pubConfig.IPv4.IP, pubConfig.IPv6.IP); err != nil {
return "", err
}

Expand All @@ -618,6 +625,13 @@ func (g *gatewayModule) SetFQDNProxy(wlID string, config zos.GatewayFQDNProxy) e
return err
}

// Get public config for node IP validation
netStub := stubs.NewNetworkerLightStub(g.cl)
pubConfig, err := netStub.LoadPublicConfig(ctx)
if err != nil {
return errors.Wrap(err, "failed to load public config")
}

if domain != "" && strings.HasSuffix(config.FQDN, domain) {
return errors.New("can't create a fqdn workload with a subdomain of the gateway's managed domain")
}
Expand All @@ -633,14 +647,14 @@ func (g *gatewayModule) SetFQDNProxy(wlID string, config zos.GatewayFQDNProxy) e
},
}

return g.setupRouting(ctx, wlID, config.FQDN, gatewayTLSConfig, config.GatewayBase)
return g.setupRouting(ctx, wlID, config.FQDN, gatewayTLSConfig, config.GatewayBase, pubConfig.IPv4.IP, pubConfig.IPv6.IP)
}

func (g *gatewayModule) setupRouting(ctx context.Context, wlID string, fqdn string, tlsConfig TlsConfig, config zos.GatewayBase) error {
func (g *gatewayModule) setupRouting(ctx context.Context, wlID string, fqdn string, tlsConfig TlsConfig, config zos.GatewayBase, nodeIPs ...net.IP) error {
g.domainLock.Lock()
defer g.domainLock.Unlock()

if err := zos.ValidateBackends(config.Backends, config.TLSPassthrough); err != nil {
if err := zos.ValidateBackends(config.Backends, config.TLSPassthrough, nodeIPs...); err != nil {
return err
}

Expand Down
40 changes: 38 additions & 2 deletions pkg/gridtypes/zos/gw.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"math"
"net"
"net/url"
"slices"
"strconv"

"github.com/hashicorp/go-multierror"
Expand Down Expand Up @@ -45,14 +46,49 @@ func (b Backend) Valid(tlsPassthrough bool) error {
return nil
}

func ValidateBackends(backends []Backend, tlsPassthrough bool) error {
func ValidateBackends(backends []Backend, tlsPassthrough bool, nodeIPs ...net.IP) error {
var errs error
for _, backend := range backends {
if err := backend.Valid(tlsPassthrough); err != nil {
errs = multierror.Append(errs, errors.Wrapf(err, "failed to validate backend '%s'", backend))
}
}
return errs
if errs != nil {
return errs
}

// Check that backends don't point to the node's own public IPs (prevents infinite loops)
for _, backend := range backends {
backendIP, err := backend.ExtractIP()
if err != nil {
return errors.Wrapf(err, "failed to extract IP from backend '%s'", backend)
}
if slices.ContainsFunc(nodeIPs, backendIP.Equal) {
return fmt.Errorf("backend %s points to the node's own public IP address", backend)
}
}
return nil
}

// ExtractIP extracts the IP address from a backend string.
func (b Backend) ExtractIP() (net.IP, error) {
// Try ip:port format first
if ip, _, err := asIpPort(string(b)); err == nil {
return ip, nil
}

// Try URL format
u, err := url.Parse(string(b))
if err != nil {
return nil, fmt.Errorf("failed to parse backend: %w", err)
}

ip := net.ParseIP(u.Hostname())
if ip == nil {
return nil, fmt.Errorf("invalid ip address in backend: %s", u.Hostname())
}

return ip, nil
}

func asIpPort(a string) (ip net.IP, port uint16, err error) {
Expand Down
Loading