From 3a1b6a2c81ff214a23b35a4447b5c88ed2655848 Mon Sep 17 00:00:00 2001 From: "xiyu.zk" Date: Wed, 28 Jan 2026 22:46:36 +0800 Subject: [PATCH] [spark] Add MergeIntoTable alignment methods for v2 write --- .../analysis/AssignmentAlignmentHelper.scala | 98 +++++++++- .../sql/AssignmentAlignmentHelperTest.scala | 174 ++++++++++++++++++ 2 files changed, 263 insertions(+), 9 deletions(-) create mode 100644 paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/AssignmentAlignmentHelperTest.scala diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/AssignmentAlignmentHelper.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/AssignmentAlignmentHelper.scala index 86c6881aa4a1..ed03d96bfe56 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/AssignmentAlignmentHelper.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/catalyst/analysis/AssignmentAlignmentHelper.scala @@ -18,11 +18,13 @@ package org.apache.paimon.spark.catalyst.analysis +import org.apache.paimon.spark.SparkTypeUtils.CURRENT_DEFAULT_COLUMN_METADATA_KEY import org.apache.paimon.spark.catalyst.analysis.expressions.ExpressionHelper +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.SQLConfHelper import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, CreateNamedStruct, Expression, GetStructField, Literal, NamedExpression} -import org.apache.spark.sql.catalyst.plans.logical.Assignment +import org.apache.spark.sql.catalyst.plans.logical.{Assignment, DeleteAction, InsertAction, InsertStarAction, MergeAction, MergeIntoTable, UpdateAction, UpdateStarAction} import org.apache.spark.sql.types.StructType trait AssignmentAlignmentHelper extends SQLConfHelper with ExpressionHelper { @@ -50,31 +52,81 @@ trait AssignmentAlignmentHelper extends SQLConfHelper with ExpressionHelper { */ protected def generateAlignedExpressions( attrs: Seq[Attribute], - assignments: Seq[Assignment]): Seq[Expression] = { + assignments: Seq[Assignment], + v2Write: Boolean = false): Seq[Expression] = { val attrUpdates = assignments.map(a => AttrUpdate(toRefSeq(a.key), a.value)) - recursiveAlignUpdates(attrs, attrUpdates) + recursiveAlignUpdates(attrs, attrUpdates, Nil, v2Write) } protected def alignAssignments( attrs: Seq[Attribute], - assignments: Seq[Assignment]): Seq[Assignment] = { - generateAlignedExpressions(attrs, assignments).zip(attrs).map { + assignments: Seq[Assignment], + v2Write: Boolean = false): Seq[Assignment] = { + generateAlignedExpressions(attrs, assignments, v2Write).zip(attrs).map { case (expression, field) => Assignment(field, expression) } } + /** + * Align assignments in a MergeAction based on the target table's output attributes. + * - DeleteAction: returns as-is + * - UpdateAction: aligns assignments for update + * - InsertAction: aligns assignments for insert + */ + protected def alignMergeAction( + action: MergeAction, + targetOutput: Seq[Attribute], + v2Write: Boolean): MergeAction = { + action match { + case d @ DeleteAction(_) => d + case u @ UpdateAction(_, assignments) => + u.copy(assignments = alignAssignments(targetOutput, assignments, v2Write)) + case i @ InsertAction(_, assignments) => + i.copy(assignments = alignAssignments(targetOutput, assignments, v2Write)) + case _: UpdateStarAction => + throw new RuntimeException("UpdateStarAction should not be here.") + case _: InsertStarAction => + throw new RuntimeException("InsertStarAction should not be here.") + case _ => + throw new RuntimeException(s"Can't recognize this action: $action") + } + } + + /** + * Align all MergeActions in a MergeIntoTable based on the target table's output attributes. + * Returns a new MergeIntoTable with aligned matchedActions, notMatchedActions, and + * notMatchedBySourceActions. + */ + protected def alignMergeIntoTable( + m: MergeIntoTable, + targetOutput: Seq[Attribute], + v2Write: Boolean): MergeIntoTable = { + m.copy( + matchedActions = m.matchedActions.map(alignMergeAction(_, targetOutput, v2Write)), + notMatchedActions = m.notMatchedActions.map(alignMergeAction(_, targetOutput, v2Write)), + notMatchedBySourceActions = + m.notMatchedBySourceActions.map(alignMergeAction(_, targetOutput, v2Write)) + ) + } + private def recursiveAlignUpdates( targetAttrs: Seq[NamedExpression], updates: Seq[AttrUpdate], - namePrefix: Seq[String] = Nil): Seq[Expression] = { + namePrefix: Seq[String] = Nil, + v2Write: Boolean = false): Seq[Expression] = { // build aligned updated expression for each target attr targetAttrs.map { targetAttr => val headMatchedUpdates = updates.filter(u => resolver(u.ref.head, targetAttr.name)) if (headMatchedUpdates.isEmpty) { - // when no matched update, return the attr as is - targetAttr + if (v2Write) { + // For V2Write, use default value or NULL for missing columns + getDefaultValueOrNull(targetAttr) + } else { + // For V1Write, return the attr as is + targetAttr + } } else { val exactMatchedUpdate = headMatchedUpdates.find(_.ref.size == 1) if (exactMatchedUpdate.isDefined) { @@ -101,7 +153,11 @@ trait AssignmentAlignmentHelper extends SQLConfHelper with ExpressionHelper { val newUpdates = updates.map(u => u.copy(ref = u.ref.tail)) // process StructType's nested fields recursively val updatedFieldExprs = - recursiveAlignUpdates(fieldExprs, newUpdates, namePrefix :+ targetAttr.name) + recursiveAlignUpdates( + fieldExprs, + newUpdates, + namePrefix :+ targetAttr.name, + v2Write) // build updated struct expression CreateNamedStruct(fields.zip(updatedFieldExprs).flatMap { @@ -117,4 +173,28 @@ trait AssignmentAlignmentHelper extends SQLConfHelper with ExpressionHelper { } } + /** Get the default value expression for an attribute, or NULL if no default value is defined. */ + private def getDefaultValueOrNull(attr: NamedExpression): Expression = { + attr match { + case a: Attribute if a.metadata.contains(CURRENT_DEFAULT_COLUMN_METADATA_KEY) => + val defaultValueStr = a.metadata.getString(CURRENT_DEFAULT_COLUMN_METADATA_KEY) + parseAndResolveDefaultValue(defaultValueStr, a) + case _ => + Literal(null, attr.dataType) + } + } + + /** Parse the default value string and resolve it to an expression. */ + private def parseAndResolveDefaultValue(defaultValueStr: String, attr: Attribute): Expression = { + try { + val spark = SparkSession.active + val parsed = spark.sessionState.sqlParser.parseExpression(defaultValueStr) + castIfNeeded(parsed, attr.dataType) + } catch { + case _: Exception => + // If parsing fails, fall back to NULL + Literal(null, attr.dataType) + } + } + } diff --git a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/AssignmentAlignmentHelperTest.scala b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/AssignmentAlignmentHelperTest.scala new file mode 100644 index 000000000000..422c1a581897 --- /dev/null +++ b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/AssignmentAlignmentHelperTest.scala @@ -0,0 +1,174 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.spark.sql + +import org.apache.paimon.spark.PaimonSparkTestBase +import org.apache.paimon.spark.catalyst.analysis.AssignmentAlignmentHelper + +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Literal} +import org.apache.spark.sql.catalyst.plans.logical.{Assignment, DeleteAction, InsertAction, MergeIntoTable, UpdateAction} +import org.apache.spark.sql.types.{IntegerType, StringType} + +/** + * Test suite for [[AssignmentAlignmentHelper]] methods: + * - alignMergeAction (with v2Write parameter) + */ +class AssignmentAlignmentHelperTest extends PaimonSparkTestBase with AssignmentAlignmentHelper { + + test("alignMergeAction: DeleteAction with v2Write=false should remain unchanged") { + val condition = Some(Literal(true)) + val deleteAction = DeleteAction(condition) + + val targetOutput = Seq( + AttributeReference("a", IntegerType)(), + AttributeReference("b", IntegerType)(), + AttributeReference("c", StringType)() + ) + + val aligned = alignMergeAction(deleteAction, targetOutput, v2Write = false) + + assert(aligned.isInstanceOf[DeleteAction]) + assert(aligned.asInstanceOf[DeleteAction].condition == condition) + } + + test("alignMergeAction: DeleteAction with v2Write=true should remain unchanged") { + val condition = Some(Literal(true)) + val deleteAction = DeleteAction(condition) + + val targetOutput = Seq( + AttributeReference("a", IntegerType)(), + AttributeReference("b", IntegerType)(), + AttributeReference("c", StringType)() + ) + + val aligned = alignMergeAction(deleteAction, targetOutput, v2Write = true) + + assert(aligned.isInstanceOf[DeleteAction]) + assert(aligned.asInstanceOf[DeleteAction].condition == condition) + } + + test("alignMergeAction: UpdateAction with v2Write=false should keep missing columns as-is") { + val targetA = AttributeReference("a", IntegerType)() + val targetB = AttributeReference("b", IntegerType)() + val targetC = AttributeReference("c", StringType)() + val targetOutput = Seq(targetA, targetB, targetC) + + // Only update column 'a', 'b' and 'c' should be kept as is + val assignments = Seq(Assignment(targetA, Literal(100))) + val updateAction = UpdateAction(None, assignments) + + val aligned = alignMergeAction(updateAction, targetOutput, v2Write = false) + + assert(aligned.isInstanceOf[UpdateAction]) + val alignedAssignments = aligned.asInstanceOf[UpdateAction].assignments + assert(alignedAssignments.size == 3) + // a = 100 + assert(alignedAssignments(0).key.asInstanceOf[AttributeReference].name == "a") + assert(alignedAssignments(0).value.isInstanceOf[Literal]) + // b = b (unchanged, keeps original attribute) + assert(alignedAssignments(1).key.asInstanceOf[AttributeReference].name == "b") + assert(alignedAssignments(1).value.isInstanceOf[AttributeReference]) + // c = c (unchanged, keeps original attribute) + assert(alignedAssignments(2).key.asInstanceOf[AttributeReference].name == "c") + assert(alignedAssignments(2).value.isInstanceOf[AttributeReference]) + } + + test("alignMergeAction: UpdateAction with v2Write=true should use NULL for missing columns") { + val targetA = AttributeReference("a", IntegerType)() + val targetB = AttributeReference("b", IntegerType)() + val targetC = AttributeReference("c", StringType)() + val targetOutput = Seq(targetA, targetB, targetC) + + // Only update column 'a', 'b' and 'c' should be NULL + val assignments = Seq(Assignment(targetA, Literal(100))) + val updateAction = UpdateAction(None, assignments) + + val aligned = alignMergeAction(updateAction, targetOutput, v2Write = true) + + assert(aligned.isInstanceOf[UpdateAction]) + val alignedAssignments = aligned.asInstanceOf[UpdateAction].assignments + assert(alignedAssignments.size == 3) + // a = 100 + assert(alignedAssignments(0).key.asInstanceOf[AttributeReference].name == "a") + assert(alignedAssignments(0).value.isInstanceOf[Literal]) + assert(alignedAssignments(0).value.asInstanceOf[Literal].value == 100) + // b = NULL (v2Write mode) + assert(alignedAssignments(1).key.asInstanceOf[AttributeReference].name == "b") + assert(alignedAssignments(1).value.isInstanceOf[Literal]) + assert(alignedAssignments(1).value.asInstanceOf[Literal].value == null) + // c = NULL (v2Write mode) + assert(alignedAssignments(2).key.asInstanceOf[AttributeReference].name == "c") + assert(alignedAssignments(2).value.isInstanceOf[Literal]) + assert(alignedAssignments(2).value.asInstanceOf[Literal].value == null) + } + + test("alignMergeAction: InsertAction with v2Write=false should keep missing columns as-is") { + val targetA = AttributeReference("a", IntegerType)() + val targetB = AttributeReference("b", IntegerType)() + val targetC = AttributeReference("c", StringType)() + val targetOutput = Seq(targetA, targetB, targetC) + + val sourceA = AttributeReference("a", IntegerType)() + val assignments = Seq(Assignment(targetA, sourceA)) + val insertAction = InsertAction(None, assignments) + + val aligned = alignMergeAction(insertAction, targetOutput, v2Write = false) + + assert(aligned.isInstanceOf[InsertAction]) + val alignedAssignments = aligned.asInstanceOf[InsertAction].assignments + assert(alignedAssignments.size == 3) + // a = source.a + assert(alignedAssignments(0).key.asInstanceOf[AttributeReference].name == "a") + // b = b (unchanged, keeps original attribute) + assert(alignedAssignments(1).key.asInstanceOf[AttributeReference].name == "b") + assert(alignedAssignments(1).value.isInstanceOf[AttributeReference]) + // c = c (unchanged, keeps original attribute) + assert(alignedAssignments(2).key.asInstanceOf[AttributeReference].name == "c") + assert(alignedAssignments(2).value.isInstanceOf[AttributeReference]) + } + + test("alignMergeAction: InsertAction with v2Write=true should use NULL for missing columns") { + val targetA = AttributeReference("a", IntegerType)() + val targetB = AttributeReference("b", IntegerType)() + val targetC = AttributeReference("c", StringType)() + val targetOutput = Seq(targetA, targetB, targetC) + + // Only insert column 'a', 'b' and 'c' should be NULL + val sourceA = AttributeReference("a", IntegerType)() + val assignments = Seq(Assignment(targetA, sourceA)) + val insertAction = InsertAction(None, assignments) + + val aligned = alignMergeAction(insertAction, targetOutput, v2Write = true) + + assert(aligned.isInstanceOf[InsertAction]) + val alignedAssignments = aligned.asInstanceOf[InsertAction].assignments + assert(alignedAssignments.size == 3) + // a = source.a + assert(alignedAssignments(0).key.asInstanceOf[AttributeReference].name == "a") + // b = NULL (v2Write mode) + assert(alignedAssignments(1).key.asInstanceOf[AttributeReference].name == "b") + assert(alignedAssignments(1).value.isInstanceOf[Literal]) + assert(alignedAssignments(1).value.asInstanceOf[Literal].value == null) + // c = NULL (v2Write mode) + assert(alignedAssignments(2).key.asInstanceOf[AttributeReference].name == "c") + assert(alignedAssignments(2).value.isInstanceOf[Literal]) + assert(alignedAssignments(2).value.asInstanceOf[Literal].value == null) + } +}