diff --git a/main.go b/main.go index 5696262..9038670 100644 --- a/main.go +++ b/main.go @@ -23,6 +23,7 @@ type FileConfig struct { DataDir string `yaml:"data_dir"` TLS TLSConfig `yaml:"tls"` Users map[string]string `yaml:"users"` + AuthMethod string `yaml:"auth_method"` // "cleartext" (default) or "md5" RateLimit RateLimitFileConfig `yaml:"rate_limit"` Extensions []string `yaml:"extensions"` DuckLake DuckLakeFileConfig `yaml:"ducklake"` @@ -109,12 +110,13 @@ func main() { fmt.Fprintf(os.Stderr, "Options:\n") flag.PrintDefaults() fmt.Fprintf(os.Stderr, "\nEnvironment variables:\n") - fmt.Fprintf(os.Stderr, " DUCKGRES_CONFIG Path to YAML config file\n") - fmt.Fprintf(os.Stderr, " DUCKGRES_HOST Host to bind to (default: 0.0.0.0)\n") - fmt.Fprintf(os.Stderr, " DUCKGRES_PORT Port to listen on (default: 5432)\n") - fmt.Fprintf(os.Stderr, " DUCKGRES_DATA_DIR Directory for DuckDB files (default: ./data)\n") - fmt.Fprintf(os.Stderr, " DUCKGRES_CERT TLS certificate file (default: ./certs/server.crt)\n") - fmt.Fprintf(os.Stderr, " DUCKGRES_KEY TLS private key file (default: ./certs/server.key)\n") + fmt.Fprintf(os.Stderr, " DUCKGRES_CONFIG Path to YAML config file\n") + fmt.Fprintf(os.Stderr, " DUCKGRES_HOST Host to bind to (default: 0.0.0.0)\n") + fmt.Fprintf(os.Stderr, " DUCKGRES_PORT Port to listen on (default: 5432)\n") + fmt.Fprintf(os.Stderr, " DUCKGRES_DATA_DIR Directory for DuckDB files (default: ./data)\n") + fmt.Fprintf(os.Stderr, " DUCKGRES_CERT TLS certificate file (default: ./certs/server.crt)\n") + fmt.Fprintf(os.Stderr, " DUCKGRES_KEY TLS private key file (default: ./certs/server.key)\n") + fmt.Fprintf(os.Stderr, " DUCKGRES_AUTH_METHOD Auth method: cleartext (default) or md5\n") fmt.Fprintf(os.Stderr, "\nPrecedence: CLI flags > environment variables > config file > defaults\n") } @@ -166,6 +168,9 @@ func main() { if len(fileCfg.Users) > 0 { cfg.Users = fileCfg.Users } + if fileCfg.AuthMethod != "" { + cfg.AuthMethod = server.AuthMethod(fileCfg.AuthMethod) + } // Apply rate limit config if fileCfg.RateLimit.MaxFailedAttempts > 0 { @@ -249,6 +254,9 @@ func main() { if v := os.Getenv("DUCKGRES_KEY"); v != "" { cfg.TLSKeyFile = v } + if v := os.Getenv("DUCKGRES_AUTH_METHOD"); v != "" { + cfg.AuthMethod = server.AuthMethod(v) + } if v := os.Getenv("DUCKGRES_DUCKLAKE_METADATA_STORE"); v != "" { cfg.DuckLake.MetadataStore = v } diff --git a/server/conn.go b/server/conn.go index 32f3b2d..47bc9d4 100644 --- a/server/conn.go +++ b/server/conn.go @@ -3,10 +3,13 @@ package server import ( "bufio" "bytes" + "crypto/md5" + "crypto/rand" "crypto/tls" "database/sql" "encoding/binary" "encoding/csv" + "encoding/hex" "fmt" "io" "log/slog" @@ -482,9 +485,29 @@ func (c *clientConn) handleStartup() error { break } - // Request password - if err := writeAuthCleartextPassword(c.writer); err != nil { - return err + // Get the expected password for this user + expectedPassword, userExists := c.server.cfg.Users[c.username] + + // Determine auth method (default to cleartext for backwards compatibility) + authMethod := c.server.cfg.AuthMethod + if authMethod == "" { + authMethod = AuthCleartext + } + + var salt [4]byte + if authMethod == AuthMD5 { + // Generate random salt for MD5 auth + if _, err := rand.Read(salt[:]); err != nil { + return fmt.Errorf("failed to generate salt: %w", err) + } + if err := writeAuthMD5Password(c.writer, salt); err != nil { + return err + } + } else { + // Request cleartext password + if err := writeAuthCleartextPassword(c.writer); err != nil { + return err + } } if err := c.writer.Flush(); err != nil { return fmt.Errorf("failed to flush writer: %w", err) @@ -504,9 +527,17 @@ func (c *clientConn) handleStartup() error { // Password is null-terminated password := string(bytes.TrimRight(body, "\x00")) - // Validate password - expectedPassword, ok := c.server.cfg.Users[c.username] - if !ok || expectedPassword != password { + // Validate password based on auth method + var authValid bool + if !userExists { + authValid = false + } else if authMethod == AuthMD5 { + authValid = verifyMD5Password(password, expectedPassword, c.username, salt) + } else { + authValid = password == expectedPassword + } + + if !authValid { // Record failed authentication attempt banned := c.server.rateLimiter.RecordFailedAuth(c.conn.RemoteAddr()) if banned { @@ -528,6 +559,26 @@ func (c *clientConn) handleStartup() error { return nil } +// verifyMD5Password verifies an MD5-hashed password response. +// The client computes: "md5" + md5(md5(password + username) + salt) +// where salt is the 4-byte random salt sent by the server. +func verifyMD5Password(clientResponse, password, username string, salt [4]byte) bool { + // Client response should start with "md5" followed by 32 hex chars + if len(clientResponse) != 35 || clientResponse[:3] != "md5" { + return false + } + + // Compute expected hash: md5(md5(password + username) + salt) + inner := md5.Sum([]byte(password + username)) + innerHex := hex.EncodeToString(inner[:]) + + outer := md5.Sum(append([]byte(innerHex), salt[:]...)) + outerHex := hex.EncodeToString(outer[:]) + + expected := "md5" + outerHex + return clientResponse == expected +} + func (c *clientConn) sendInitialParams() { params := map[string]string{ "server_version": "15.0 (Duckgres)", diff --git a/server/protocol.go b/server/protocol.go index 618b522..dc7cc2e 100644 --- a/server/protocol.go +++ b/server/protocol.go @@ -176,6 +176,14 @@ func writeAuthCleartextPassword(w io.Writer) error { return writeMessage(w, msgAuth, data) } +// writeAuthMD5Password requests MD5-hashed password with a 4-byte salt +func writeAuthMD5Password(w io.Writer, salt [4]byte) error { + data := make([]byte, 8) + binary.BigEndian.PutUint32(data, authMD5Pwd) + copy(data[4:], salt[:]) + return writeMessage(w, msgAuth, data) +} + // writeParameterStatus sends a parameter status message func writeParameterStatus(w io.Writer, name, value string) error { data := []byte(name) diff --git a/server/server.go b/server/server.go index acda3f8..0bd01c7 100644 --- a/server/server.go +++ b/server/server.go @@ -55,12 +55,26 @@ func redactConnectionString(connStr string) string { return passwordPattern.ReplaceAllString(connStr, "${1}[REDACTED]") } +// AuthMethod represents the authentication method to use +type AuthMethod string + +const ( + // AuthCleartext uses cleartext password (default, protected by TLS) + AuthCleartext AuthMethod = "cleartext" + // AuthMD5 uses MD5 hashed password (PostgreSQL standard) + AuthMD5 AuthMethod = "md5" +) + type Config struct { Host string Port int DataDir string Users map[string]string // username -> password + // AuthMethod specifies the authentication method. + // Supported values: "cleartext" (default), "md5". + AuthMethod AuthMethod + // TLS configuration (required) TLSCertFile string // Path to TLS certificate file TLSKeyFile string // Path to TLS private key file