From aed26f47b88f243d70b69eb32e710b22834c484b Mon Sep 17 00:00:00 2001 From: eric Date: Mon, 2 Feb 2026 18:45:46 -0800 Subject: [PATCH] Fix additional transpiler issues found during audit Comprehensive audit revealed additional issues similar to #126: 1. InformationSchemaTransform missing ColumnRef handling - Column refs like columns.column_name not rewritten when table renamed - Added Node_ColumnRef case to rewrite qualified column references 2. PgCatalogTransform missing ColumnRef handling (same as #126) - Included here for completeness as this branch is based on main 3. TypeCastTransform missing type mappings for common PostgreSQL types - decimal, boolean, date, uuid, bit -> keep same, strip pg_catalog prefix - xml, varbit, name -> convert to varchar - oid -> convert to integer Co-Authored-By: Claude Opus 4.5 --- transpiler/transform/information_schema.go | 14 ++ transpiler/transform/pgcatalog.go | 14 ++ transpiler/transform/typecast.go | 22 ++- transpiler/transpiler_test.go | 211 +++++++++++++++++++++ 4 files changed, 260 insertions(+), 1 deletion(-) diff --git a/transpiler/transform/information_schema.go b/transpiler/transform/information_schema.go index 22f4d19..68d8348 100644 --- a/transpiler/transform/information_schema.go +++ b/transpiler/transform/information_schema.go @@ -208,6 +208,20 @@ func (t *InformationSchemaTransform) walkAndTransform(node *pg_query.Node, chang t.walkAndTransform(item, changed) } } + + case *pg_query.Node_ColumnRef: + // Rewrite qualified column references to match renamed tables + if n.ColumnRef != nil && len(n.ColumnRef.Fields) >= 2 { + if first := n.ColumnRef.Fields[0]; first != nil { + if str := first.GetString_(); str != nil { + tableName := strings.ToLower(str.Sval) + if newName, ok := t.ViewMappings[tableName]; ok { + str.Sval = newName + *changed = true + } + } + } + } } } diff --git a/transpiler/transform/pgcatalog.go b/transpiler/transform/pgcatalog.go index 00aff75..4fa2569 100644 --- a/transpiler/transform/pgcatalog.go +++ b/transpiler/transform/pgcatalog.go @@ -218,6 +218,20 @@ func (t *PgCatalogTransform) walkAndTransform(node *pg_query.Node, changed *bool t.walkAndTransform(n.TypeCast.Arg, changed) } + case *pg_query.Node_ColumnRef: + // Rewrite qualified column references to match renamed tables + if n.ColumnRef != nil && len(n.ColumnRef.Fields) >= 2 { + if first := n.ColumnRef.Fields[0]; first != nil { + if str := first.GetString_(); str != nil { + tableName := strings.ToLower(str.Sval) + if newName, ok := t.ViewMappings[tableName]; ok { + str.Sval = newName + *changed = true + } + } + } + } + case *pg_query.Node_AExpr: // Expression: check for OPERATOR(pg_catalog.~) pattern if n.AExpr != nil { diff --git a/transpiler/transform/typecast.go b/transpiler/transform/typecast.go index abfd284..9c75f6c 100644 --- a/transpiler/transform/typecast.go +++ b/transpiler/transform/typecast.go @@ -18,6 +18,7 @@ type TypeCastTransform struct { func NewTypeCastTransform() *TypeCastTransform { return &TypeCastTransform{ TypeMappings: map[string]string{ + // PostgreSQL reg* types -> varchar "regtype": "varchar", "regnamespace": "varchar", "regproc": "varchar", @@ -26,8 +27,27 @@ func NewTypeCastTransform() *TypeCastTransform { "regprocedure": "varchar", "regconfig": "varchar", "regdictionary": "varchar", - "text": "varchar", // Note: regclass is handled specially - converted to subquery lookup + + // String types + "text": "varchar", + "name": "varchar", // PostgreSQL internal name type + + // JSON types + "json": "json", + "jsonb": "json", + + // Types that just need pg_catalog prefix stripped + "decimal": "decimal", + "boolean": "boolean", + "date": "date", + "uuid": "uuid", + "bit": "bit", + + // Types that need conversion + "xml": "varchar", + "varbit": "varchar", + "oid": "integer", }, } } diff --git a/transpiler/transpiler_test.go b/transpiler/transpiler_test.go index e0762a8..a089e25 100644 --- a/transpiler/transpiler_test.go +++ b/transpiler/transpiler_test.go @@ -1809,6 +1809,217 @@ func TestConvertAlterTableToAlterView(t *testing.T) { } } +func TestTranspile_PgCatalog_ColumnRefRewrite(t *testing.T) { + // Bug: When pg_class is rewritten to pg_class_full, column references like + // pg_class.oid should also be rewritten to pg_class_full.oid + // This is needed for queries generated by psql and other PostgreSQL clients + // that describe tables using pg_attribute JOIN pg_class ON pg_class.oid = attrelid + tests := []struct { + name string + input string + contains string + excludes string + }{ + { + name: "column ref pg_class.oid should be rewritten", + input: "SELECT a.attname FROM pg_catalog.pg_attribute a JOIN pg_catalog.pg_class c ON pg_class.oid = a.attrelid", + contains: "pg_class_full.oid", + excludes: "pg_class.oid", + }, + { + name: "column ref with alias pg_class.relname should be rewritten", + input: "SELECT pg_class.relname FROM pg_catalog.pg_class", + contains: "pg_class_full.relname", + excludes: "pg_class.relname", + }, + { + name: "multiple column refs should all be rewritten", + input: "SELECT pg_class.oid, pg_class.relname, pg_class.relkind FROM pg_catalog.pg_class WHERE pg_class.relkind = 'r'", + contains: "pg_class_full.oid", + excludes: "pg_class.oid", + }, + { + name: "column ref in WHERE clause", + input: "SELECT * FROM pg_catalog.pg_class WHERE pg_class.relnamespace = 2200", + contains: "pg_class_full.relnamespace", + excludes: "pg_class.relnamespace", + }, + { + name: "column ref in JOIN condition", + input: "SELECT * FROM pg_catalog.pg_attribute JOIN pg_catalog.pg_class ON pg_class.oid = pg_attribute.attrelid", + contains: "pg_class_full.oid", + excludes: " pg_class.oid", // space prefix to avoid matching pg_class_full + }, + { + name: "unqualified table with qualified column ref", + input: "SELECT pg_class.relname FROM pg_class", + contains: "pg_class_full.relname", + excludes: " pg_class.relname", // space prefix to avoid matching pg_class_full + }, + } + + tr := New(DefaultConfig()) + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := tr.Transpile(tt.input) + if err != nil { + t.Fatalf("Transpile(%q) error: %v", tt.input, err) + } + if tt.contains != "" && !strings.Contains(result.SQL, tt.contains) { + t.Errorf("Transpile(%q) = %q, should contain %q", tt.input, result.SQL, tt.contains) + } + if tt.excludes != "" && strings.Contains(result.SQL, tt.excludes) { + t.Errorf("Transpile(%q) = %q, should NOT contain %q", tt.input, result.SQL, tt.excludes) + } + }) + } +} + +func TestTranspile_InformationSchema_ColumnRefRewrite(t *testing.T) { + // Bug: When information_schema.columns is rewritten to information_schema_columns_compat, + // column references like columns.column_name should also be rewritten + tests := []struct { + name string + input string + contains string + excludes string + }{ + { + name: "column ref columns.column_name should be rewritten", + input: "SELECT columns.column_name FROM information_schema.columns", + contains: "information_schema_columns_compat.column_name", + excludes: " columns.column_name", + }, + { + name: "column ref in WHERE clause", + input: "SELECT * FROM information_schema.columns WHERE columns.table_name = 'users'", + contains: "information_schema_columns_compat.table_name", + excludes: " columns.table_name", + }, + { + name: "column ref tables.table_name should be rewritten", + input: "SELECT tables.table_name FROM information_schema.tables", + contains: "information_schema_tables_compat.table_name", + excludes: " tables.table_name", + }, + { + name: "column ref in JOIN condition", + input: "SELECT * FROM information_schema.columns c JOIN information_schema.tables t ON columns.table_name = tables.table_name", + contains: "information_schema_columns_compat.table_name", + excludes: " columns.table_name", + }, + } + + tr := New(DefaultConfig()) + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := tr.Transpile(tt.input) + if err != nil { + t.Fatalf("Transpile(%q) error: %v", tt.input, err) + } + if tt.contains != "" && !strings.Contains(result.SQL, tt.contains) { + t.Errorf("Transpile(%q) = %q, should contain %q", tt.input, result.SQL, tt.contains) + } + if tt.excludes != "" && strings.Contains(result.SQL, tt.excludes) { + t.Errorf("Transpile(%q) = %q, should NOT contain %q", tt.input, result.SQL, tt.excludes) + } + }) + } +} + +func TestTranspile_TypeCast_JsonType(t *testing.T) { + // Bug: ::pg_catalog.json should have pg_catalog. stripped, becoming just ::json + // DuckDB doesn't understand pg_catalog.json type qualifier + // Note: pg_query's deparser adds pg_catalog. prefix to certain types during deparsing, + // so we need to strip it in all cases to produce DuckDB-compatible output. + tests := []struct { + name string + input string + contains string + excludes string + }{ + { + name: "pg_catalog.json cast should strip prefix", + input: "SELECT data::pg_catalog.json FROM t", + excludes: "pg_catalog.json", + }, + { + name: "pg_catalog.jsonb cast should strip prefix and convert to json", + input: "SELECT data::pg_catalog.jsonb FROM t", + excludes: "pg_catalog.jsonb", + }, + { + name: "unqualified json cast should not get pg_catalog prefix in output", + input: "SELECT data::json FROM t", + excludes: "pg_catalog.json", // pg_query adds pg_catalog prefix, but we should strip it + }, + { + name: "json cast in complex expression", + input: "SELECT custom_subscriber_attributes::pg_catalog.json AS attrs FROM subscribers", + excludes: "pg_catalog.json", + }, + { + name: "multiple json casts", + input: "SELECT a::pg_catalog.json, b::pg_catalog.jsonb FROM t", + excludes: "pg_catalog.json", + }, + } + + tr := New(DefaultConfig()) + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := tr.Transpile(tt.input) + if err != nil { + t.Fatalf("Transpile(%q) error: %v", tt.input, err) + } + lowerSQL := strings.ToLower(result.SQL) + if tt.contains != "" && !strings.Contains(lowerSQL, strings.ToLower(tt.contains)) { + t.Errorf("Transpile(%q) = %q, should contain %q", tt.input, result.SQL, tt.contains) + } + if tt.excludes != "" && strings.Contains(lowerSQL, strings.ToLower(tt.excludes)) { + t.Errorf("Transpile(%q) = %q, should NOT contain %q", tt.input, result.SQL, tt.excludes) + } + }) + } +} + +func TestTranspile_TypeCast_PgCatalogPrefix(t *testing.T) { + // Test that various types with pg_catalog. prefix are handled correctly + tests := []struct { + name string + input string + excludes string // should NOT contain pg_catalog.typename + }{ + {"decimal", "SELECT x::pg_catalog.decimal FROM t", "pg_catalog.decimal"}, + {"boolean", "SELECT x::pg_catalog.boolean FROM t", "pg_catalog.boolean"}, + {"date", "SELECT x::pg_catalog.date FROM t", "pg_catalog.date"}, + {"uuid", "SELECT x::pg_catalog.uuid FROM t", "pg_catalog.uuid"}, + {"xml", "SELECT x::pg_catalog.xml FROM t", "pg_catalog.xml"}, + {"bit", "SELECT x::pg_catalog.bit FROM t", "pg_catalog.bit"}, + {"varbit", "SELECT x::pg_catalog.varbit FROM t", "pg_catalog.varbit"}, + {"oid", "SELECT x::pg_catalog.oid FROM t", "pg_catalog.oid"}, + {"name", "SELECT x::pg_catalog.name FROM t", "pg_catalog.name"}, + } + + tr := New(DefaultConfig()) + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := tr.Transpile(tt.input) + if err != nil { + t.Fatalf("Transpile(%q) error: %v", tt.input, err) + } + lowerSQL := strings.ToLower(result.SQL) + if strings.Contains(lowerSQL, strings.ToLower(tt.excludes)) { + t.Errorf("Transpile(%q) = %q, should NOT contain %q", tt.input, result.SQL, tt.excludes) + } + }) + } +} + func TestTranspile_FallbackParamCount(t *testing.T) { // Test that when pg_query fails to parse DuckDB-specific syntax, // the transpiler still correctly counts $N parameter placeholders