-
Notifications
You must be signed in to change notification settings - Fork 341
feat: add automatic DNS rebinding protection for localhost servers #760
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -2280,3 +2280,138 @@ data: {"jsonrpc":"2.0","method":"test2","params":{}} | |||||
| } | ||||||
| } | ||||||
| } | ||||||
|
|
||||||
| // TestStreamableLocalhostProtection verifies that DNS rebinding protection | ||||||
| // is automatically enabled for localhost servers. | ||||||
| func TestStreamableLocalhostProtection(t *testing.T) { | ||||||
| server := NewServer(testImpl, nil) | ||||||
|
|
||||||
| tests := []struct { | ||||||
| name string | ||||||
| listenAddr string // Address to listen on | ||||||
| hostHeader string // Host header in request | ||||||
| disableProt bool // DisableLocalhostProtection setting | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||||||
| wantStatus int | ||||||
| }{ | ||||||
| // Auto-enabled for localhost listeners (127.0.0.1) | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Extending the test cases to multiple lines and including field names is more readable. I would suggest to do it. |
||||||
| {"127.0.0.1 accepts 127.0.0.1", "127.0.0.1:0", "127.0.0.1:1234", false, http.StatusOK}, | ||||||
| {"127.0.0.1 accepts localhost", "127.0.0.1:0", "localhost:1234", false, http.StatusOK}, | ||||||
| {"127.0.0.1 rejects evil.com", "127.0.0.1:0", "evil.com", false, http.StatusForbidden}, | ||||||
| {"127.0.0.1 rejects evil.com:80", "127.0.0.1:0", "evil.com:80", false, http.StatusForbidden}, | ||||||
| {"127.0.0.1 rejects localhost.evil.com", "127.0.0.1:0", "localhost.evil.com", false, http.StatusForbidden}, | ||||||
|
|
||||||
| // When listening on 0.0.0.0, requests arriving via localhost are still protected | ||||||
| // because LocalAddrContextKey returns the actual connection's local address. | ||||||
| // This is actually more secure - DNS rebinding attacks target localhost regardless | ||||||
| // of the listener configuration. | ||||||
| {"0.0.0.0 via localhost rejects evil.com", "0.0.0.0:0", "evil.com", false, http.StatusForbidden}, | ||||||
|
|
||||||
| // Explicit disable | ||||||
| {"disabled accepts evil.com", "127.0.0.1:0", "evil.com", true, http.StatusOK}, | ||||||
| } | ||||||
|
|
||||||
| for _, tt := range tests { | ||||||
| t.Run(tt.name, func(t *testing.T) { | ||||||
| opts := &StreamableHTTPOptions{ | ||||||
| Stateless: true, // Simpler for testing | ||||||
| DisableLocalhostProtection: tt.disableProt, | ||||||
| } | ||||||
| handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, opts) | ||||||
|
|
||||||
| // Create a listener on the specified address to control LocalAddrContextKey | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| listener, err := net.Listen("tcp", tt.listenAddr) | ||||||
| if err != nil { | ||||||
| t.Fatalf("failed to listen on %s: %v", tt.listenAddr, err) | ||||||
| } | ||||||
| defer listener.Close() | ||||||
|
|
||||||
| // Start server in background | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| srv := &http.Server{Handler: handler} | ||||||
| go srv.Serve(listener) | ||||||
| defer srv.Close() | ||||||
|
|
||||||
| // Make request with custom Host header | ||||||
| req, err := http.NewRequest("POST", fmt.Sprintf("http://%s", listener.Addr().String()), strings.NewReader(`{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"test","version":"1.0"}}}`)) | ||||||
| if err != nil { | ||||||
| t.Fatal(err) | ||||||
| } | ||||||
| req.Host = tt.hostHeader | ||||||
| req.Header.Set("Content-Type", "application/json") | ||||||
| req.Header.Set("Accept", "application/json, text/event-stream") | ||||||
|
|
||||||
| resp, err := http.DefaultClient.Do(req) | ||||||
| if err != nil { | ||||||
| t.Fatal(err) | ||||||
| } | ||||||
| defer resp.Body.Close() | ||||||
|
|
||||||
| if got := resp.StatusCode; got != tt.wantStatus { | ||||||
| t.Errorf("status code: got %d, want %d", got, tt.wantStatus) | ||||||
| } | ||||||
| }) | ||||||
| } | ||||||
| } | ||||||
|
|
||||||
| // TestIsLocalhostAddr tests the isLocalhostAddr helper function. | ||||||
| func TestIsLocalhostAddr(t *testing.T) { | ||||||
| tests := []struct { | ||||||
| addr string | ||||||
| want bool | ||||||
| }{ | ||||||
| {"127.0.0.1:3000", true}, | ||||||
| {"127.0.0.1:0", true}, | ||||||
| {"localhost:3000", true}, | ||||||
| {"[::1]:3000", true}, | ||||||
| {"0.0.0.0:3000", false}, | ||||||
| {"192.168.1.1:3000", false}, | ||||||
| {"example.com:3000", false}, | ||||||
| } | ||||||
|
|
||||||
| for _, tt := range tests { | ||||||
| t.Run(tt.addr, func(t *testing.T) { | ||||||
| addr, err := net.ResolveTCPAddr("tcp", tt.addr) | ||||||
| if err != nil { | ||||||
| // For hostname-based addresses, use a mock | ||||||
| if strings.HasPrefix(tt.addr, "localhost") { | ||||||
| addr = &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 3000} | ||||||
| } else if strings.HasPrefix(tt.addr, "example.com") { | ||||||
| addr = &net.TCPAddr{IP: net.ParseIP("93.184.216.34"), Port: 3000} | ||||||
| } else { | ||||||
| t.Fatalf("failed to resolve %s: %v", tt.addr, err) | ||||||
| } | ||||||
| } | ||||||
| if got := isLocalhostAddr(addr); got != tt.want { | ||||||
| t.Errorf("isLocalhostAddr(%q) = %v, want %v", tt.addr, got, tt.want) | ||||||
| } | ||||||
| }) | ||||||
| } | ||||||
| } | ||||||
|
|
||||||
| // TestIsLocalhostHost tests the isLocalhostHost helper function. | ||||||
| func TestIsLocalhostHost(t *testing.T) { | ||||||
| tests := []struct { | ||||||
| host string | ||||||
| want bool | ||||||
| }{ | ||||||
| {"localhost", true}, | ||||||
| {"localhost:3000", true}, | ||||||
| {"127.0.0.1", true}, | ||||||
| {"127.0.0.1:3000", true}, | ||||||
| {"[::1]", true}, | ||||||
| {"[::1]:3000", true}, | ||||||
| {"::1", true}, | ||||||
| {"", false}, | ||||||
| {"evil.com", false}, | ||||||
| {"evil.com:80", false}, | ||||||
| {"localhost.evil.com", false}, | ||||||
| {"127.0.0.1.evil.com", false}, | ||||||
| } | ||||||
|
|
||||||
| for _, tt := range tests { | ||||||
| t.Run(tt.host, func(t *testing.T) { | ||||||
| if got := isLocalhostHost(tt.host); got != tt.want { | ||||||
| t.Errorf("isLocalhostHost(%q) = %v, want %v", tt.host, got, tt.want) | ||||||
| } | ||||||
| }) | ||||||
| } | ||||||
| } | ||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With this change below I think we could only have a single implementation of the
isLocalhostfunction, since they are largely the same. It would be called withlocalAddr.String()andreq.Host.