Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,27 @@ public interface TableSemantics {
*/
int[] partitionByColumns();

/**
* Returns the upsert key of the passed table as derived by the planner from primary key
* constraints and the rewritten relational plan. The upsert key uniquely identifies a row
* within the input changelog and survives planner transformations that preserve key semantics
* (e.g. filters, projections that retain key columns). Applies to both table arguments with row
* and set semantics.
*
* <p>This complements {@link #partitionByColumns()}: a caller is not required to repeat the
* primary key via {@code PARTITION BY} just so a PTF can identify rows - the planner already
* knows the key from the input table's declaration.
*
* @return An array of indexes (0-based) that specify the upsert key columns. Returns an empty
* array if the planner could not derive an upsert key for the input (e.g., append-only
* sources without a declared primary key, or operations that destroyed the key). Returns an
* empty array during the type inference phase as the upsert key is still unknown at that
* point.
*/
default int[] upsertKeyColumns() {
return new int[0];
}

/**
* Returns information about how the passed table is ordered. Applies only to table arguments
* with set semantics.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,10 @@ public class TableSemanticsMock implements TableSemantics {
private final SortDirection[] orderByDirections;
private final int timeColumn;
private final ChangelogMode changelogMode;
private final int[] upsertKeyColumns;

public TableSemanticsMock(DataType dataType) {
this(dataType, new int[0], new int[0], -1, null);
this(dataType, new int[0], new int[0], -1, null, new int[0]);
}

public TableSemanticsMock(
Expand All @@ -46,6 +47,16 @@ public TableSemanticsMock(
int[] orderByColumns,
int timeColumn,
@Nullable ChangelogMode changelogMode) {
this(dataType, partitionByColumns, orderByColumns, timeColumn, changelogMode, new int[0]);
}

public TableSemanticsMock(
DataType dataType,
int[] partitionByColumns,
int[] orderByColumns,
int timeColumn,
@Nullable ChangelogMode changelogMode,
int[] upsertKeyColumns) {
this.dataType = dataType;
this.partitionByColumns = partitionByColumns;
this.orderByColumns = orderByColumns;
Expand All @@ -55,6 +66,7 @@ public TableSemanticsMock(
}
this.timeColumn = timeColumn;
this.changelogMode = changelogMode;
this.upsertKeyColumns = upsertKeyColumns;
}

@Override
Expand Down Expand Up @@ -82,6 +94,11 @@ public int timeColumn() {
return timeColumn;
}

@Override
public int[] upsertKeyColumns() {
return upsertKeyColumns;
}

@Override
public Optional<ChangelogMode> changelogMode() {
return Optional.ofNullable(changelogMode);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -335,27 +335,29 @@ public boolean hasScalarArgument(String name) {
* scalar arguments through the same coercion path as validation.
*/
public CallContext toCallContext(RexCall call) {
return toCallContext(call, null, null, null);
return toCallContext(call, null, null, null, null);
}

/**
* Variant of {@link #toCallContext(RexCall)} that additionally exposes the call's input time
* columns and changelog modes - needed by the streaming codegen path so PTFs can specialize
* themselves to the exact call.
* columns, changelog modes, and per-input upsert keys - needed by the streaming codegen path so
* PTFs can specialize themselves to the exact call.
*/
public CallContext toCallContext(
RexCall call,
@Nullable List<Integer> inputTimeColumns,
@Nullable List<ChangelogMode> inputChangelogModes,
@Nullable ChangelogMode outputChangelogMode) {
@Nullable ChangelogMode outputChangelogMode,
@Nullable List<int[]> inputUpsertKeys) {
return new OperatorBindingCallContext(
dataTypeFactory,
getDefinition(),
RexCallBinding.create(typeFactory, call, Collections.emptyList()),
call.getType(),
inputTimeColumns,
inputChangelogModes,
outputChangelogMode);
outputChangelogMode,
inputUpsertKeys);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,14 @@ public final class OperatorBindingCallContext extends AbstractSqlCallContext {
private final @Nullable List<Integer> inputTimeColumns;
private final @Nullable List<ChangelogMode> inputChangelogModes;
private final @Nullable ChangelogMode outputChangelogMode;
private final @Nullable List<int[]> inputUpsertKeys;

public OperatorBindingCallContext(
DataTypeFactory dataTypeFactory,
FunctionDefinition definition,
SqlOperatorBinding binding,
RelDataType returnRelDataType) {
this(dataTypeFactory, definition, binding, returnRelDataType, null, null, null);
this(dataTypeFactory, definition, binding, returnRelDataType, null, null, null, null);
}

public OperatorBindingCallContext(
Expand All @@ -80,7 +81,8 @@ public OperatorBindingCallContext(
RelDataType returnRelDataType,
@Nullable List<Integer> inputTimeColumns,
@Nullable List<ChangelogMode> inputChangelogModes,
@Nullable ChangelogMode outputChangelogMode) {
@Nullable ChangelogMode outputChangelogMode,
@Nullable List<int[]> inputUpsertKeys) {
super(
dataTypeFactory,
definition,
Expand Down Expand Up @@ -109,6 +111,7 @@ public int size() {
this.inputTimeColumns = inputTimeColumns;
this.inputChangelogModes = inputChangelogModes;
this.outputChangelogMode = outputChangelogMode;
this.inputUpsertKeys = inputUpsertKeys;
}

@Override
Expand Down Expand Up @@ -173,13 +176,18 @@ public Optional<TableSemantics> getTableSemantics(int pos) {
Optional.ofNullable(inputChangelogModes)
.map(m -> m.get(tableArgCall.getInputIndex()))
.orElse(null);
final int[] upsertKeyColumns =
Optional.ofNullable(inputUpsertKeys)
.map(m -> m.get(tableArgCall.getInputIndex()))
.orElse(new int[0]);
return Optional.of(
OperatorBindingTableSemantics.create(
argumentDataTypes.get(pos),
staticArg,
tableArgCall,
timeColumn,
changelogMode));
changelogMode,
upsertKeyColumns));
}

@Override
Expand Down Expand Up @@ -283,20 +291,23 @@ private static class OperatorBindingTableSemantics implements TableSemantics {
private final SortDirection[] orderByDirections;
private final int timeColumn;
private final @Nullable ChangelogMode changelogMode;
private final int[] upsertKeyColumns;

public static OperatorBindingTableSemantics create(
DataType tableDataType,
StaticArgument staticArg,
RexTableArgCall tableArgCall,
int timeColumn,
@Nullable ChangelogMode changelogMode) {
@Nullable ChangelogMode changelogMode,
int[] upsertKeyColumns) {
return new OperatorBindingTableSemantics(
createDataType(tableDataType, staticArg),
tableArgCall.getPartitionKeys(),
tableArgCall.getOrderKeys(),
RexTableArgCall.toSortDirections(tableArgCall.getSortOrder()),
timeColumn,
changelogMode);
changelogMode,
upsertKeyColumns);
}

private OperatorBindingTableSemantics(
Expand All @@ -305,13 +316,15 @@ private OperatorBindingTableSemantics(
int[] orderByColumns,
SortDirection[] orderByDirections,
int timeColumn,
@Nullable ChangelogMode changelogMode) {
@Nullable ChangelogMode changelogMode,
int[] upsertKeyColumns) {
this.dataType = dataType;
this.partitionByColumns = partitionByColumns;
this.orderByColumns = orderByColumns;
this.orderByDirections = orderByDirections;
this.timeColumn = timeColumn;
this.changelogMode = changelogMode;
this.upsertKeyColumns = upsertKeyColumns;
}

private static DataType createDataType(DataType tableDataType, StaticArgument staticArg) {
Expand Down Expand Up @@ -349,6 +362,11 @@ public int timeColumn() {
return timeColumn;
}

@Override
public int[] upsertKeyColumns() {
return upsertKeyColumns;
}

@Override
public Optional<ChangelogMode> changelogMode() {
return Optional.ofNullable(changelogMode);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ public class StreamExecProcessTableFunction extends ExecNodeBase<RowData>
public static final String FIELD_NAME_FUNCTION_CALL = "functionCall";
public static final String FIELD_NAME_INPUT_CHANGELOG_MODES = "inputChangelogModes";
public static final String FIELD_NAME_OUTPUT_CHANGELOG_MODE = "outputChangelogMode";
public static final String FIELD_NAME_INPUT_UPSERT_KEYS = "inputUpsertKeys";

@JsonProperty(FIELD_NAME_UID)
private final @Nullable String uid;
Expand All @@ -121,6 +122,9 @@ public class StreamExecProcessTableFunction extends ExecNodeBase<RowData>
@JsonProperty(FIELD_NAME_OUTPUT_CHANGELOG_MODE)
private final ChangelogMode outputChangelogMode;

@JsonProperty(FIELD_NAME_INPUT_UPSERT_KEYS)
private final List<int[]> inputUpsertKeys;

public StreamExecProcessTableFunction(
ReadableConfig tableConfig,
List<InputProperty> inputProperties,
Expand All @@ -129,7 +133,8 @@ public StreamExecProcessTableFunction(
@Nullable String uid,
RexCall invocation,
List<ChangelogMode> inputChangelogModes,
ChangelogMode outputChangelogMode) {
ChangelogMode outputChangelogMode,
List<int[]> inputUpsertKeys) {
this(
ExecNodeContext.newNodeId(),
ExecNodeContext.newContext(StreamExecProcessTableFunction.class),
Expand All @@ -141,7 +146,8 @@ public StreamExecProcessTableFunction(
uid,
invocation,
inputChangelogModes,
outputChangelogMode);
outputChangelogMode,
inputUpsertKeys);
}

@JsonCreator
Expand All @@ -155,7 +161,8 @@ public StreamExecProcessTableFunction(
@JsonProperty(FIELD_NAME_UID) @Nullable String uid,
@JsonProperty(FIELD_NAME_FUNCTION_CALL) RexNode invocation,
@JsonProperty(FIELD_NAME_INPUT_CHANGELOG_MODES) List<ChangelogMode> inputChangelogModes,
@JsonProperty(FIELD_NAME_OUTPUT_CHANGELOG_MODE) ChangelogMode outputChangelogMode) {
@JsonProperty(FIELD_NAME_OUTPUT_CHANGELOG_MODE) ChangelogMode outputChangelogMode,
@JsonProperty(FIELD_NAME_INPUT_UPSERT_KEYS) @Nullable List<int[]> inputUpsertKeys) {
super(id, context, persistedConfig, inputProperties, outputType, description);
this.uid = uid;
// Mirror the FlinkLogicalTableFunctionScan converter for the compiled-plan restore path:
Expand All @@ -164,6 +171,14 @@ public StreamExecProcessTableFunction(
this.invocation = BridgingSqlFunction.resolveCallTraits((RexCall) invocation);
this.inputChangelogModes = inputChangelogModes;
this.outputChangelogMode = outputChangelogMode;
// Older compiled plans (pre-FLINK-39735) did not persist this field. Default to per-input
// empty arrays so the runtime sees the same behavior as before (no derivable upsert key).
this.inputUpsertKeys =
inputUpsertKeys != null
? inputUpsertKeys
: IntStream.range(0, inputChangelogModes.size())
.mapToObj(i -> new int[0])
.collect(Collectors.toList());
}

public @Nullable String getUid() {
Expand Down Expand Up @@ -202,7 +217,12 @@ protected Transformation<RowData> translateToPlanInternal(
final RexCall udfCall = StreamPhysicalProcessTableFunction.toUdfCall(invocation);
final GeneratedRunnerResult generated =
ProcessTableRunnerGenerator.generate(
ctx, udfCall, inputTimeColumns, inputChangelogModes, outputChangelogMode);
ctx,
udfCall,
inputTimeColumns,
inputChangelogModes,
outputChangelogMode,
inputUpsertKeys);
final GeneratedProcessTableRunner generatedRunner = generated.runner();
final LinkedHashMap<String, StateInfo> stateInfos = generated.stateInfos();

Expand Down Expand Up @@ -309,6 +329,7 @@ private RuntimeTableSemantics createRuntimeTableSemantics(
}

final int timeColumn = inputTimeColumns.get(tableArgCall.getInputIndex());
final int[] upsertKeyColumns = inputUpsertKeys.get(tableArgCall.getInputIndex());

return new RuntimeTableSemantics(
tableArg.getName(),
Expand All @@ -320,7 +341,8 @@ private RuntimeTableSemantics createRuntimeTableSemantics(
consumedChangelogMode,
tableArg.is(StaticArgumentTrait.PASS_COLUMNS_THROUGH),
tableArg.is(StaticArgumentTrait.SET_SEMANTIC_TABLE),
timeColumn);
timeColumn,
upsertKeyColumns);
}

private Transformation<RowData> createKeyedTransformation(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,13 @@
import org.apache.flink.table.planner.calcite.FlinkTypeFactory;
import org.apache.flink.table.planner.calcite.RexTableArgCall;
import org.apache.flink.table.planner.functions.bridging.BridgingSqlFunction;
import org.apache.flink.table.planner.plan.metadata.FlinkRelMetadataQuery;
import org.apache.flink.table.planner.plan.nodes.exec.ExecNode;
import org.apache.flink.table.planner.plan.nodes.exec.InputProperty;
import org.apache.flink.table.planner.plan.nodes.exec.stream.StreamExecProcessTableFunction;
import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalTableFunctionScan;
import org.apache.flink.table.planner.plan.utils.ChangelogPlanUtils;
import org.apache.flink.table.planner.plan.utils.UpsertKeyUtil;
import org.apache.flink.table.planner.utils.JavaScalaConversionUtil;
import org.apache.flink.table.planner.utils.ShortcutUtils;
import org.apache.flink.table.types.inference.StaticArgument;
Expand Down Expand Up @@ -165,6 +167,7 @@ public ExecNode<?> translateToExecNode() {
verifyTimeAttributes(getInputs(), call, inputChangelogModes, outputChangelogMode);
final List<Ord<StaticArgument>> providedInputArgs = getProvidedInputArgs(call);
verifyPassThroughColumnsForUpdates(providedInputArgs, outputChangelogMode);
final List<int[]> inputUpsertKeys = deriveInputUpsertKeys(getInputs());
return new StreamExecProcessTableFunction(
unwrapTableConfig(this),
getInputs().stream().map(i -> InputProperty.DEFAULT).collect(Collectors.toList()),
Expand All @@ -173,7 +176,26 @@ public ExecNode<?> translateToExecNode() {
uid,
call,
inputChangelogModes,
outputChangelogMode);
outputChangelogMode,
inputUpsertKeys);
}

/**
* Derives an upsert key (collapsed to one candidate via {@link UpsertKeyUtil#smallestKey}) for
* each input. Returns an empty array entry for inputs without a derivable upsert key
* (append-only sources without a declared primary key, or operations that destroyed the key).
* Surfaces as {@link org.apache.flink.table.functions.TableSemantics#upsertKeyColumns()} so
* PTFs can identify rows without requiring callers to repeat the key via PARTITION BY.
*/
private static List<int[]> deriveInputUpsertKeys(List<RelNode> inputs) {
final List<int[]> perInput = new ArrayList<>(inputs.size());
for (RelNode input : inputs) {
final FlinkRelMetadataQuery fmq =
FlinkRelMetadataQuery.reuseOrCreate(input.getCluster().getMetadataQuery());
final Set<ImmutableBitSet> upsertKeys = fmq.getUpsertKeys(input);
perInput.add(UpsertKeyUtil.smallestKey(upsertKeys).orElse(new int[0]));
}
return perInput;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ object ProcessTableRunnerGenerator {
udfCall: RexCall,
inputTimeColumns: java.util.List[Integer],
inputChangelogModes: java.util.List[ChangelogMode],
outputChangelogMode: ChangelogMode): GeneratedRunnerResult = {
outputChangelogMode: ChangelogMode,
inputUpsertKeys: java.util.List[Array[Int]]): GeneratedRunnerResult = {
val function: BridgingSqlFunction = udfCall.getOperator.asInstanceOf[BridgingSqlFunction]
val definition: FunctionDefinition = function.getDefinition
val dataTypeFactory = function.getDataTypeFactory
Expand All @@ -77,7 +78,12 @@ object ProcessTableRunnerGenerator {
// Thus, functions can reconfigure themselves for the exact use case.
// Including updating their state layout.
val callContext =
function.toCallContext(udfCall, inputTimeColumns, inputChangelogModes, outputChangelogMode)
function.toCallContext(
udfCall,
inputTimeColumns,
inputChangelogModes,
outputChangelogMode,
inputUpsertKeys)

// Create the final UDF for runtime
val udf = UserDefinedFunctionHelper.createSpecializedFunction(
Expand Down
Loading