diff --git a/go.mod b/go.mod index 7d853cf..a9a84e6 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.24.5 require ( github.com/antlr4-go/antlr/v4 v4.13.1 - github.com/bytebase/parser v0.0.0-20260121030202-698704919f24 + github.com/bytebase/parser v0.0.0-20260130090605-effef73942d9 github.com/google/uuid v1.6.0 github.com/stretchr/testify v1.11.1 github.com/testcontainers/testcontainers-go v0.40.0 diff --git a/go.sum b/go.sum index 72fbf63..75686f2 100644 --- a/go.sum +++ b/go.sum @@ -8,8 +8,8 @@ github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERo github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= github.com/bytebase/antlr/v4 v4.0.0-20240827034948-8c385f108920 h1:IfmPt5o5R70NKtOrs+QHOoCgViYZelZysGxVBvV4ybA= github.com/bytebase/antlr/v4 v4.0.0-20240827034948-8c385f108920/go.mod h1:ykhjIPiv0IWpu3OGXCHdz2eUSe8UNGGD6baqjs8jSuU= -github.com/bytebase/parser v0.0.0-20260121030202-698704919f24 h1:oonTO26orUa4bYk/hQjALiYO1zII+Kzpjg75OnC3VtU= -github.com/bytebase/parser v0.0.0-20260121030202-698704919f24/go.mod h1:jeak/EfutSOAuWKvrFIT2IZunhWprM7oTFBRgZ9RCxo= +github.com/bytebase/parser v0.0.0-20260130090605-effef73942d9 h1:q5MnVPWlV/p3MPe60SysVoUaEnyeS6+OOOk0F3DLLK8= +github.com/bytebase/parser v0.0.0-20260130090605-effef73942d9/go.mod h1:jeak/EfutSOAuWKvrFIT2IZunhWprM7oTFBRgZ9RCxo= github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8= github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= github.com/containerd/errdefs v1.0.0 h1:tg5yIfIlQIrxYtu9ajqY42W3lpS19XqdxRQeEwYG8PI= diff --git a/internal/translator/visitor.go b/internal/translator/visitor.go index ee9ea9d..4795e28 100644 --- a/internal/translator/visitor.go +++ b/internal/translator/visitor.go @@ -123,28 +123,28 @@ func (v *visitor) visitMethodChain(ctx mongodb.IMethodChainContext) { if !ok { return } - for _, methodCall := range mc.AllMethodCall() { - v.visitMethodCall(methodCall) + + if mc.CollectionMethodCall() != nil { + v.visitCollectionMethodCall(mc.CollectionMethodCall()) + if v.err != nil { + return + } + } + + for _, cursorCall := range mc.AllCursorMethodCall() { + v.visitCursorMethodCall(cursorCall) if v.err != nil { return } } } -func (v *visitor) visitMethodCall(ctx mongodb.IMethodCallContext) { - mc, ok := ctx.(*mongodb.MethodCallContext) +func (v *visitor) visitCollectionMethodCall(ctx mongodb.ICollectionMethodCallContext) { + mc, ok := ctx.(*mongodb.CollectionMethodCallContext) if !ok { return } - // Determine method context for registry lookup - getMethodContext := func() string { - if v.operation.OpType == types.OpFind || v.operation.OpType == types.OpFindOne { - return "cursor" - } - return "collection" - } - switch { // Supported read operations case mc.FindMethod() != nil: @@ -168,32 +168,13 @@ func (v *visitor) visitMethodCall(ctx mongodb.IMethodCallContext) { case mc.GetIndexesMethod() != nil: v.operation.OpType = types.OpGetIndexes - // Supported cursor modifiers - case mc.SortMethod() != nil: - v.extractSort(mc.SortMethod()) - case mc.LimitMethod() != nil: - v.extractLimit(mc.LimitMethod()) - case mc.SkipMethod() != nil: - v.extractSkip(mc.SkipMethod()) - case mc.ProjectionMethod() != nil: - v.extractProjection(mc.ProjectionMethod()) - case mc.HintMethod() != nil: - v.extractHint(mc.HintMethod()) - case mc.MaxMethod() != nil: - v.extractMax(mc.MaxMethod()) - case mc.MinMethod() != nil: - v.extractMin(mc.MinMethod()) - - // Supported M2 write operations + // Supported write operations case mc.InsertOneMethod() != nil: v.operation.OpType = types.OpInsertOne v.extractInsertOneArgs(mc.InsertOneMethod()) - case mc.InsertManyMethod() != nil: v.operation.OpType = types.OpInsertMany v.extractInsertManyArgs(mc.InsertManyMethod()) - - // Supported M2 write operations - updateOne case mc.UpdateOneMethod() != nil: v.operation.OpType = types.OpUpdateOne v.extractUpdateOneArgs(mc.UpdateOneMethod()) @@ -219,12 +200,12 @@ func (v *visitor) visitMethodCall(ctx mongodb.IMethodCallContext) { v.operation.OpType = types.OpFindOneAndDelete v.extractFindOneAndDeleteArgs(mc.FindOneAndDeleteMethod()) - // Supported M3 index operations + // Supported index operations case mc.CreateIndexMethod() != nil: v.operation.OpType = types.OpCreateIndex v.extractCreateIndexArgs(mc.CreateIndexMethod()) case mc.CreateIndexesMethod() != nil: - v.handleUnsupportedMethod("collection", "createIndexes") // Lower ROI, keep as planned + v.handleUnsupportedMethod("collection", "createIndexes") case mc.DropIndexMethod() != nil: v.operation.OpType = types.OpDropIndex v.extractDropIndexArgs(mc.DropIndexMethod()) @@ -232,14 +213,14 @@ func (v *visitor) visitMethodCall(ctx mongodb.IMethodCallContext) { v.operation.OpType = types.OpDropIndexes v.extractDropIndexesArgs(mc.DropIndexesMethod()) - // Supported M3 collection management + // Supported collection management case mc.DropMethod() != nil: v.operation.OpType = types.OpDrop case mc.RenameCollectionMethod() != nil: v.operation.OpType = types.OpRenameCollection v.extractRenameCollectionArgs(mc.RenameCollectionMethod()) - // Planned M3 stats operations - return PlannedOperationError for fallback + // Planned stats operations case mc.StatsMethod() != nil: v.handleUnsupportedMethod("collection", "stats") case mc.StorageSizeMethod() != nil: @@ -257,34 +238,45 @@ func (v *visitor) visitMethodCall(ctx mongodb.IMethodCallContext) { case mc.LatencyStatsMethod() != nil: v.handleUnsupportedMethod("collection", "latencyStats") - // Generic method fallback - all methods going through genericMethod are unsupported - case mc.GenericMethod() != nil: - gmCtx, ok := mc.GenericMethod().(*mongodb.GenericMethodContext) - if !ok { - return + default: + methodName := extractMethodNameFromText(mc.GetText()) + if methodName != "" { + v.handleUnsupportedMethod("collection", methodName) } - methodName := gmCtx.Identifier().GetText() - v.handleUnsupportedMethod(getMethodContext(), methodName) + } +} - // Default: all other methods not explicitly handled - // These go to handleUnsupportedMethod which returns UnsupportedOperationError - // since they're not in the planned registry +func (v *visitor) visitCursorMethodCall(ctx mongodb.ICursorMethodCallContext) { + mc, ok := ctx.(*mongodb.CursorMethodCallContext) + if !ok { + return + } + + switch { + case mc.SortMethod() != nil: + v.extractSort(mc.SortMethod()) + case mc.LimitMethod() != nil: + v.extractLimit(mc.LimitMethod()) + case mc.SkipMethod() != nil: + v.extractSkip(mc.SkipMethod()) + case mc.ProjectionMethod() != nil: + v.extractProjection(mc.ProjectionMethod()) + case mc.HintMethod() != nil: + v.extractHint(mc.HintMethod()) + case mc.MaxMethod() != nil: + v.extractMax(mc.MaxMethod()) + case mc.MinMethod() != nil: + v.extractMin(mc.MinMethod()) default: - // Extract method name from the parse tree for error message - methodName := v.extractMethodName(mc) + methodName := extractMethodNameFromText(mc.GetText()) if methodName != "" { - v.handleUnsupportedMethod(getMethodContext(), methodName) + v.handleUnsupportedMethod("cursor", methodName) } } } -// extractMethodName extracts the method name from a MethodCallContext for error messages. -func (v *visitor) extractMethodName(mc *mongodb.MethodCallContext) string { - // Try to get method name from various method contexts - // The parser creates specific method contexts for known methods - // For unknown methods, they go through GenericMethod which is handled separately - text := mc.GetText() - // Extract method name before the opening parenthesis +// extractMethodNameFromText extracts the method name from a parse tree text before the opening parenthesis. +func extractMethodNameFromText(text string) string { if idx := strings.Index(text, "("); idx > 0 { return text[:idx] }