Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@

package org.apache.paimon.spark.catalyst.analysis

import org.apache.paimon.spark.SparkTypeUtils.CURRENT_DEFAULT_COLUMN_METADATA_KEY
import org.apache.paimon.spark.catalyst.analysis.expressions.ExpressionHelper

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.SQLConfHelper
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, CreateNamedStruct, Expression, GetStructField, Literal, NamedExpression}
import org.apache.spark.sql.catalyst.plans.logical.Assignment
import org.apache.spark.sql.catalyst.plans.logical.{Assignment, DeleteAction, InsertAction, InsertStarAction, MergeAction, MergeIntoTable, UpdateAction, UpdateStarAction}
import org.apache.spark.sql.types.StructType

trait AssignmentAlignmentHelper extends SQLConfHelper with ExpressionHelper {
Expand Down Expand Up @@ -50,31 +52,81 @@ trait AssignmentAlignmentHelper extends SQLConfHelper with ExpressionHelper {
*/
protected def generateAlignedExpressions(
attrs: Seq[Attribute],
assignments: Seq[Assignment]): Seq[Expression] = {
assignments: Seq[Assignment],
v2Write: Boolean = false): Seq[Expression] = {
val attrUpdates = assignments.map(a => AttrUpdate(toRefSeq(a.key), a.value))
recursiveAlignUpdates(attrs, attrUpdates)
recursiveAlignUpdates(attrs, attrUpdates, Nil, v2Write)
}

protected def alignAssignments(
attrs: Seq[Attribute],
assignments: Seq[Assignment]): Seq[Assignment] = {
generateAlignedExpressions(attrs, assignments).zip(attrs).map {
assignments: Seq[Assignment],
v2Write: Boolean = false): Seq[Assignment] = {
generateAlignedExpressions(attrs, assignments, v2Write).zip(attrs).map {
case (expression, field) => Assignment(field, expression)
}
}

/**
* Align assignments in a MergeAction based on the target table's output attributes.
* - DeleteAction: returns as-is
* - UpdateAction: aligns assignments for update
* - InsertAction: aligns assignments for insert
*/
protected def alignMergeAction(
action: MergeAction,
targetOutput: Seq[Attribute],
v2Write: Boolean): MergeAction = {
action match {
case d @ DeleteAction(_) => d
case u @ UpdateAction(_, assignments) =>
u.copy(assignments = alignAssignments(targetOutput, assignments, v2Write))
case i @ InsertAction(_, assignments) =>
i.copy(assignments = alignAssignments(targetOutput, assignments, v2Write))
case _: UpdateStarAction =>
throw new RuntimeException("UpdateStarAction should not be here.")
case _: InsertStarAction =>
throw new RuntimeException("InsertStarAction should not be here.")
case _ =>
throw new RuntimeException(s"Can't recognize this action: $action")
}
}

/**
* Align all MergeActions in a MergeIntoTable based on the target table's output attributes.
* Returns a new MergeIntoTable with aligned matchedActions, notMatchedActions, and
* notMatchedBySourceActions.
*/
protected def alignMergeIntoTable(
m: MergeIntoTable,
targetOutput: Seq[Attribute],
v2Write: Boolean): MergeIntoTable = {
m.copy(
matchedActions = m.matchedActions.map(alignMergeAction(_, targetOutput, v2Write)),
notMatchedActions = m.notMatchedActions.map(alignMergeAction(_, targetOutput, v2Write)),
notMatchedBySourceActions =
m.notMatchedBySourceActions.map(alignMergeAction(_, targetOutput, v2Write))
)
}

private def recursiveAlignUpdates(
targetAttrs: Seq[NamedExpression],
updates: Seq[AttrUpdate],
namePrefix: Seq[String] = Nil): Seq[Expression] = {
namePrefix: Seq[String] = Nil,
v2Write: Boolean = false): Seq[Expression] = {

// build aligned updated expression for each target attr
targetAttrs.map {
targetAttr =>
val headMatchedUpdates = updates.filter(u => resolver(u.ref.head, targetAttr.name))
if (headMatchedUpdates.isEmpty) {
// when no matched update, return the attr as is
targetAttr
if (v2Write) {
// For V2Write, use default value or NULL for missing columns
getDefaultValueOrNull(targetAttr)
} else {
// For V1Write, return the attr as is
targetAttr
}
} else {
val exactMatchedUpdate = headMatchedUpdates.find(_.ref.size == 1)
if (exactMatchedUpdate.isDefined) {
Expand All @@ -101,7 +153,11 @@ trait AssignmentAlignmentHelper extends SQLConfHelper with ExpressionHelper {
val newUpdates = updates.map(u => u.copy(ref = u.ref.tail))
// process StructType's nested fields recursively
val updatedFieldExprs =
recursiveAlignUpdates(fieldExprs, newUpdates, namePrefix :+ targetAttr.name)
recursiveAlignUpdates(
fieldExprs,
newUpdates,
namePrefix :+ targetAttr.name,
v2Write)

// build updated struct expression
CreateNamedStruct(fields.zip(updatedFieldExprs).flatMap {
Expand All @@ -117,4 +173,28 @@ trait AssignmentAlignmentHelper extends SQLConfHelper with ExpressionHelper {
}
}

/** Get the default value expression for an attribute, or NULL if no default value is defined. */
private def getDefaultValueOrNull(attr: NamedExpression): Expression = {
attr match {
case a: Attribute if a.metadata.contains(CURRENT_DEFAULT_COLUMN_METADATA_KEY) =>
val defaultValueStr = a.metadata.getString(CURRENT_DEFAULT_COLUMN_METADATA_KEY)
parseAndResolveDefaultValue(defaultValueStr, a)
case _ =>
Literal(null, attr.dataType)
}
}

/** Parse the default value string and resolve it to an expression. */
private def parseAndResolveDefaultValue(defaultValueStr: String, attr: Attribute): Expression = {
try {
val spark = SparkSession.active
val parsed = spark.sessionState.sqlParser.parseExpression(defaultValueStr)
castIfNeeded(parsed, attr.dataType)
} catch {
case _: Exception =>
// If parsing fails, fall back to NULL
Literal(null, attr.dataType)
}
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.paimon.spark.sql

import org.apache.paimon.spark.PaimonSparkTestBase
import org.apache.paimon.spark.catalyst.analysis.AssignmentAlignmentHelper

import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Literal}
import org.apache.spark.sql.catalyst.plans.logical.{Assignment, DeleteAction, InsertAction, MergeIntoTable, UpdateAction}
import org.apache.spark.sql.types.{IntegerType, StringType}

/**
* Test suite for [[AssignmentAlignmentHelper]] methods:
* - alignMergeAction (with v2Write parameter)
*/
class AssignmentAlignmentHelperTest extends PaimonSparkTestBase with AssignmentAlignmentHelper {

test("alignMergeAction: DeleteAction with v2Write=false should remain unchanged") {
val condition = Some(Literal(true))
val deleteAction = DeleteAction(condition)

val targetOutput = Seq(
AttributeReference("a", IntegerType)(),
AttributeReference("b", IntegerType)(),
AttributeReference("c", StringType)()
)

val aligned = alignMergeAction(deleteAction, targetOutput, v2Write = false)

assert(aligned.isInstanceOf[DeleteAction])
assert(aligned.asInstanceOf[DeleteAction].condition == condition)
}

test("alignMergeAction: DeleteAction with v2Write=true should remain unchanged") {
val condition = Some(Literal(true))
val deleteAction = DeleteAction(condition)

val targetOutput = Seq(
AttributeReference("a", IntegerType)(),
AttributeReference("b", IntegerType)(),
AttributeReference("c", StringType)()
)

val aligned = alignMergeAction(deleteAction, targetOutput, v2Write = true)

assert(aligned.isInstanceOf[DeleteAction])
assert(aligned.asInstanceOf[DeleteAction].condition == condition)
}

test("alignMergeAction: UpdateAction with v2Write=false should keep missing columns as-is") {
val targetA = AttributeReference("a", IntegerType)()
val targetB = AttributeReference("b", IntegerType)()
val targetC = AttributeReference("c", StringType)()
val targetOutput = Seq(targetA, targetB, targetC)

// Only update column 'a', 'b' and 'c' should be kept as is
val assignments = Seq(Assignment(targetA, Literal(100)))
val updateAction = UpdateAction(None, assignments)

val aligned = alignMergeAction(updateAction, targetOutput, v2Write = false)

assert(aligned.isInstanceOf[UpdateAction])
val alignedAssignments = aligned.asInstanceOf[UpdateAction].assignments
assert(alignedAssignments.size == 3)
// a = 100
assert(alignedAssignments(0).key.asInstanceOf[AttributeReference].name == "a")
assert(alignedAssignments(0).value.isInstanceOf[Literal])
// b = b (unchanged, keeps original attribute)
assert(alignedAssignments(1).key.asInstanceOf[AttributeReference].name == "b")
assert(alignedAssignments(1).value.isInstanceOf[AttributeReference])
// c = c (unchanged, keeps original attribute)
assert(alignedAssignments(2).key.asInstanceOf[AttributeReference].name == "c")
assert(alignedAssignments(2).value.isInstanceOf[AttributeReference])
}

test("alignMergeAction: UpdateAction with v2Write=true should use NULL for missing columns") {
val targetA = AttributeReference("a", IntegerType)()
val targetB = AttributeReference("b", IntegerType)()
val targetC = AttributeReference("c", StringType)()
val targetOutput = Seq(targetA, targetB, targetC)

// Only update column 'a', 'b' and 'c' should be NULL
val assignments = Seq(Assignment(targetA, Literal(100)))
val updateAction = UpdateAction(None, assignments)

val aligned = alignMergeAction(updateAction, targetOutput, v2Write = true)

assert(aligned.isInstanceOf[UpdateAction])
val alignedAssignments = aligned.asInstanceOf[UpdateAction].assignments
assert(alignedAssignments.size == 3)
// a = 100
assert(alignedAssignments(0).key.asInstanceOf[AttributeReference].name == "a")
assert(alignedAssignments(0).value.isInstanceOf[Literal])
assert(alignedAssignments(0).value.asInstanceOf[Literal].value == 100)
// b = NULL (v2Write mode)
assert(alignedAssignments(1).key.asInstanceOf[AttributeReference].name == "b")
assert(alignedAssignments(1).value.isInstanceOf[Literal])
assert(alignedAssignments(1).value.asInstanceOf[Literal].value == null)
// c = NULL (v2Write mode)
assert(alignedAssignments(2).key.asInstanceOf[AttributeReference].name == "c")
assert(alignedAssignments(2).value.isInstanceOf[Literal])
assert(alignedAssignments(2).value.asInstanceOf[Literal].value == null)
}

test("alignMergeAction: InsertAction with v2Write=false should keep missing columns as-is") {
val targetA = AttributeReference("a", IntegerType)()
val targetB = AttributeReference("b", IntegerType)()
val targetC = AttributeReference("c", StringType)()
val targetOutput = Seq(targetA, targetB, targetC)

val sourceA = AttributeReference("a", IntegerType)()
val assignments = Seq(Assignment(targetA, sourceA))
val insertAction = InsertAction(None, assignments)

val aligned = alignMergeAction(insertAction, targetOutput, v2Write = false)

assert(aligned.isInstanceOf[InsertAction])
val alignedAssignments = aligned.asInstanceOf[InsertAction].assignments
assert(alignedAssignments.size == 3)
// a = source.a
assert(alignedAssignments(0).key.asInstanceOf[AttributeReference].name == "a")
// b = b (unchanged, keeps original attribute)
assert(alignedAssignments(1).key.asInstanceOf[AttributeReference].name == "b")
assert(alignedAssignments(1).value.isInstanceOf[AttributeReference])
// c = c (unchanged, keeps original attribute)
assert(alignedAssignments(2).key.asInstanceOf[AttributeReference].name == "c")
assert(alignedAssignments(2).value.isInstanceOf[AttributeReference])
}

test("alignMergeAction: InsertAction with v2Write=true should use NULL for missing columns") {
val targetA = AttributeReference("a", IntegerType)()
val targetB = AttributeReference("b", IntegerType)()
val targetC = AttributeReference("c", StringType)()
val targetOutput = Seq(targetA, targetB, targetC)

// Only insert column 'a', 'b' and 'c' should be NULL
val sourceA = AttributeReference("a", IntegerType)()
val assignments = Seq(Assignment(targetA, sourceA))
val insertAction = InsertAction(None, assignments)

val aligned = alignMergeAction(insertAction, targetOutput, v2Write = true)

assert(aligned.isInstanceOf[InsertAction])
val alignedAssignments = aligned.asInstanceOf[InsertAction].assignments
assert(alignedAssignments.size == 3)
// a = source.a
assert(alignedAssignments(0).key.asInstanceOf[AttributeReference].name == "a")
// b = NULL (v2Write mode)
assert(alignedAssignments(1).key.asInstanceOf[AttributeReference].name == "b")
assert(alignedAssignments(1).value.isInstanceOf[Literal])
assert(alignedAssignments(1).value.asInstanceOf[Literal].value == null)
// c = NULL (v2Write mode)
assert(alignedAssignments(2).key.asInstanceOf[AttributeReference].name == "c")
assert(alignedAssignments(2).value.isInstanceOf[Literal])
assert(alignedAssignments(2).value.asInstanceOf[Literal].value == null)
}
}