Skip to content
Draft
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,6 @@
import org.apache.doris.nereids.rules.rewrite.PullUpProjectUnderLimit;
import org.apache.doris.nereids.rules.rewrite.PullUpProjectUnderTopN;
import org.apache.doris.nereids.rules.rewrite.PushCountIntoUnionAll;
import org.apache.doris.nereids.rules.rewrite.PushDownAggThroughJoin;
import org.apache.doris.nereids.rules.rewrite.PushDownAggThroughJoinOnPkFk;
import org.apache.doris.nereids.rules.rewrite.PushDownAggThroughJoinOneSide;
import org.apache.doris.nereids.rules.rewrite.PushDownAggWithDistinctThroughJoinOneSide;
Expand Down Expand Up @@ -173,6 +172,7 @@
import org.apache.doris.nereids.rules.rewrite.batch.ApplyToJoin;
import org.apache.doris.nereids.rules.rewrite.batch.CorrelateApplyToUnCorrelateApply;
import org.apache.doris.nereids.rules.rewrite.batch.EliminateUselessPlanUnderApply;
import org.apache.doris.nereids.rules.rewrite.eageraggregation.PushDownAggregation;
import org.apache.doris.nereids.trees.plans.algebra.SetOperation;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalApply;
Expand Down Expand Up @@ -657,19 +657,6 @@ public class Rewriter extends AbstractBatchJobExecutor {
new MergeAggregate()
)
),
topic("Eager aggregation",
cascadesContext -> cascadesContext.rewritePlanContainsTypes(
LogicalAggregate.class, LogicalJoin.class
),
costBased(topDown(
new PushDownAggWithDistinctThroughJoinOneSide(),
new PushDownAggThroughJoinOneSide(),
new PushDownAggThroughJoin()
)),
costBased(custom(RuleType.PUSH_DOWN_DISTINCT_THROUGH_JOIN, PushDownDistinctThroughJoin::new)),
topDown(new PushCountIntoUnionAll())
),

// this rule should invoke after infer predicate and push down distinct, and before push down limit
topic("eliminate join according unique or foreign key",
cascadesContext -> cascadesContext.rewritePlanContainsTypes(LogicalJoin.class),
Expand All @@ -686,7 +673,19 @@ public class Rewriter extends AbstractBatchJobExecutor {
topDown(new PushDownAggThroughJoinOnPkFk()),
topDown(new PullUpJoinFromUnionAll())
),
topic("Eager aggregation",
cascadesContext -> cascadesContext.rewritePlanContainsTypes(
LogicalAggregate.class, LogicalJoin.class
),
costBased(topDown(
new PushDownAggWithDistinctThroughJoinOneSide(),
new PushDownAggThroughJoinOneSide()
)),

costBased(custom(RuleType.PUSH_DOWN_DISTINCT_THROUGH_JOIN, PushDownDistinctThroughJoin::new)),
custom(RuleType.PUSH_DOWN_AGG_THROUGH_JOIN, PushDownAggregation::new),
topDown(new PushCountIntoUnionAll())
),
topic("Limit optimization",
cascadesContext -> cascadesContext.rewritePlanContainsTypes(LogicalLimit.class)
|| cascadesContext.rewritePlanContainsTypes(LogicalTopN.class)
Expand Down Expand Up @@ -936,23 +935,24 @@ private static List<RewriteJob> getWholeTreeRewriteJobs(
}
rewriteJobs.add(
topic("nested column prune",
custom(RuleType.NESTED_COLUMN_PRUNING, NestedColumnPruning::new)
custom(RuleType.NESTED_COLUMN_PRUNING, NestedColumnPruning::new)
)
);
rewriteJobs.addAll(jobs(
topic("rewrite cte sub-tree after sub path push down",
custom(RuleType.CLEAR_CONTEXT_STATUS, ClearContextStatus::new),
custom(RuleType.REWRITE_CTE_CHILDREN,
() -> new RewriteCteChildren(afterPushDownJobs, runCboRules)
)
),
topic("whole plan check",
custom(RuleType.ADJUST_NULLABLE, () -> new AdjustNullable(false))
),
// NullableDependentExpressionRewrite need to be done after nullable fixed
topic("condition function", bottomUp(ImmutableList.of(
new NullableDependentExpressionRewrite())))
));
topic("rewrite cte sub-tree after sub path push down",
custom(RuleType.CLEAR_CONTEXT_STATUS, ClearContextStatus::new),
custom(RuleType.REWRITE_CTE_CHILDREN,
() -> new RewriteCteChildren(afterPushDownJobs, runCboRules)
)
),
topic("whole plan check",
custom(RuleType.ADJUST_NULLABLE, () -> new AdjustNullable(false))
),
// NullableDependentExpressionRewrite need to be done after nullable fixed
topic("condition function", bottomUp(ImmutableList.of(
new NullableDependentExpressionRewrite())))
)
);
return rewriteJobs;
}
));
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.pattern;

import java.util.function.Predicate;

/**
* A predicate wrapper with a human-readable description.
* Used in pattern matching to provide better diagnostic messages when a predicate fails.
*/
public class DescribedPredicate<T> implements Predicate<T> {
private final String description;
private final Predicate<T> delegate;

public DescribedPredicate(String description, Predicate<T> delegate) {
this.description = description;
this.delegate = delegate;
}

public static <T> DescribedPredicate<T> of(String description, Predicate<T> delegate) {
return new DescribedPredicate<>(description, delegate);
}

@Override
public boolean test(T t) {
return delegate.test(t);
}

public String getDescription() {
return description;
}

@Override
public String toString() {
return description;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,76 @@ public boolean matchPredicates(TYPE root) {
return true;
}

/**
* Diagnostic version of matchPlanTree. Returns null if match succeeds,
* or a human-readable failure reason string if match fails.
*
* @param plan the plan to match against
* @param path the current path in the tree (e.g. "root/child[0]/child[1]")
* @return null if matched, or a diagnostic message describing where and why the match failed
*/
public String matchPlanTreeDiagnostic(Plan plan, String path) {
if (!matchRoot(plan)) {
return path + ": expected node type " + planType + " but got " + plan.getType();
}
int childPatternNum = arity();
if (childPatternNum != plan.arity() && childPatternNum > 0 && child(childPatternNum - 1) != MULTI) {
return path + ": expected " + childPatternNum + " children but got " + plan.arity()
+ " on node " + plan.getType();
}
switch (patternType) {
case ANY:
case MULTI:
return matchPredicatesDiagnostic((TYPE) plan, path);
default:
}
if (this instanceof SubTreePattern) {
return matchPredicatesDiagnostic((TYPE) plan, path);
}
return matchChildrenAndSelfPredicatesDiagnostic(plan, childPatternNum, path);
}

private String matchChildrenAndSelfPredicatesDiagnostic(Plan plan, int childPatternNum, String path) {
List<Plan> childrenPlan = plan.children();
for (int i = 0; i < childrenPlan.size(); i++) {
Plan child = childrenPlan.get(i);
Pattern childPattern = child(Math.min(i, childPatternNum - 1));
String childPath = path + "/" + plan.getType() + ".child[" + i + "]";
String childResult = childPattern.matchPlanTreeDiagnostic(child, childPath);
if (childResult != null) {
return childResult;
}
}
return matchPredicatesDiagnostic((TYPE) plan, path);
}

/**
* Diagnostic version of matchPredicates. Returns null if all predicates pass,
* or a message describing which predicate failed.
*/
public String matchPredicatesDiagnostic(TYPE root, String path) {
for (int i = 0; i < predicates.size(); i++) {
Predicate<TYPE> predicate = predicates.get(i);
try {
if (!predicate.test(root)) {
String predicateDesc = (predicate instanceof DescribedPredicate)
? ((DescribedPredicate<TYPE>) predicate).getDescription()
: "predicate #" + (i + 1);
return path + " (" + root.getType() + "): " + predicateDesc
+ " failed. [predicate " + (i + 1) + "/" + predicates.size() + "]";
}
} catch (Throwable t) {
String predicateDesc = (predicate instanceof DescribedPredicate)
? ((DescribedPredicate<TYPE>) predicate).getDescription()
: "predicate #" + (i + 1);
return path + " (" + root.getType() + "): " + predicateDesc
+ " threw " + t.getClass().getSimpleName() + ": " + t.getMessage()
+ " [predicate " + (i + 1) + "/" + predicates.size() + "]";
}
}
return null;
}

@Override
public Pattern<? extends Plan> withChildren(
List<Pattern<? extends Plan>> children) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,11 @@ public PatternDescriptor<INPUT_TYPE> when(Predicate<INPUT_TYPE> predicate) {
return new PatternDescriptor<>(pattern.withPredicates(predicates), defaultPromise);
}

/** when with description, the description will be shown in diagnostic message when predicate fails */
public PatternDescriptor<INPUT_TYPE> when(String description, Predicate<INPUT_TYPE> predicate) {
return when(DescribedPredicate.of(description, predicate));
}

public PatternDescriptor<INPUT_TYPE> whenNot(Predicate<INPUT_TYPE> predicate) {
return when(predicate.negate());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,33 +130,34 @@ public List<Rule> buildRules() {
.toRule(RuleType.NORMALIZE_AGGREGATE));
}

/**
* The LogicalAggregate node may contain window agg functions and usual agg functions
* we call window agg functions as window-agg and usual agg functions as trivial-agg for short
* This rule simplify LogicalAggregate node by:
* 1. Push down some exprs from old LogicalAggregate node to a new child LogicalProject Node,
* 2. create a new LogicalAggregate with normalized group by exprs and trivial-aggs
* 3. Pull up normalized old LogicalAggregate's output exprs to a new parent LogicalProject Node
* Push down exprs:
* 1. all group by exprs
* 2. child contains subquery expr in trivial-agg
* 3. child contains window expr in trivial-agg
* 4. all input slots of trivial-agg
* 5. expr(including subquery) in distinct trivial-agg
* Normalize LogicalAggregate's output.
* 1. normalize group by exprs by outputs of bottom LogicalProject
* 2. normalize trivial-aggs by outputs of bottom LogicalProject
* 3. build normalized agg outputs
* Pull up exprs:
* normalize all output exprs in old LogicalAggregate to build a parent project node, typically includes:
* 1. simple slots
* 2. aliases
* a. alias with no aggs child
* b. alias with trivial-agg child
* c. alias with window-agg
*/
@SuppressWarnings("checkstyle:UnusedLocalVariable")
private LogicalPlan normalizeAgg(LogicalAggregate<Plan> aggregate, Optional<LogicalHaving<?>> having,
public LogicalPlan normalizeAgg(LogicalAggregate<Plan> aggregate, Optional<LogicalHaving<?>> having,
CascadesContext ctx) {
// The LogicalAggregate node may contain window agg functions and usual agg functions
// we call window agg functions as window-agg and usual agg functions as trivial-agg for short
// This rule simplify LogicalAggregate node by:
// 1. Push down some exprs from old LogicalAggregate node to a new child LogicalProject Node,
// 2. create a new LogicalAggregate with normalized group by exprs and trivial-aggs
// 3. Pull up normalized old LogicalAggregate's output exprs to a new parent LogicalProject Node
// Push down exprs:
// 1. all group by exprs
// 2. child contains subquery expr in trivial-agg
// 3. child contains window expr in trivial-agg
// 4. all input slots of trivial-agg
// 5. expr(including subquery) in distinct trivial-agg
// Normalize LogicalAggregate's output.
// 1. normalize group by exprs by outputs of bottom LogicalProject
// 2. normalize trivial-aggs by outputs of bottom LogicalProject
// 3. build normalized agg outputs
// Pull up exprs:
// normalize all output exprs in old LogicalAggregate to build a parent project node, typically includes:
// 1. simple slots
// 2. aliases
// a. alias with no aggs child
// b. alias with trivial-agg child
// c. alias with window-agg

// Push down exprs:
// collect group by exprs
Set<Expression> groupingByExprs = Utils.fastToImmutableSet(aggregate.getGroupByExpressions());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@

package org.apache.doris.nereids.rules.rewrite;

import org.apache.doris.common.util.DebugUtil;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.jobs.JobContext;
import org.apache.doris.nereids.properties.OrderKey;
import org.apache.doris.nereids.trees.expressions.Alias;
Expand Down Expand Up @@ -52,15 +50,14 @@
import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.qe.SessionVariable;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.LinkedHashMultimap;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Multimap;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

import java.util.LinkedHashMap;
import java.util.List;
Expand All @@ -74,9 +71,6 @@
* So, we need add a rule to adjust all expression's nullable attribute after rewrite.
*/
public class AdjustNullable extends DefaultPlanRewriter<Map<ExprId, Slot>> implements CustomRewriter {

private static final Logger LOG = LogManager.getLogger(AdjustNullable.class);

private final boolean isAnalyzedPhase;

public AdjustNullable(boolean isAnalyzedPhase) {
Expand Down Expand Up @@ -485,14 +479,9 @@ private static Expression doUpdateExpression(AtomicBoolean changed, Expression i
// repeat may check fail.
if (!slotReference.nullable() && newSlotReference.nullable()
&& check && ConnectContext.get() != null) {
if (ConnectContext.get().getSessionVariable().feDebug) {
throw new AnalysisException("AdjustNullable convert slot " + slotReference
+ " from not-nullable to nullable. You can disable check by set fe_debug = false.");
} else {
LOG.warn("adjust nullable convert slot '" + slotReference
+ "' from not-nullable to nullable for query "
+ DebugUtil.printId(ConnectContext.get().queryId()));
}
SessionVariable.throwAnalysisExceptionWhenFeDebug("AdjustNullable convert slot "
+ slotReference
+ " from not-nullable to nullable. You can disable check by set fe_debug = false.");
}
return newSlotReference;
} else {
Expand Down
Loading