Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

package org.apache.flink.table.planner.plan.rules.logical;

import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalCalc;
import org.apache.flink.table.planner.plan.nodes.physical.stream.StreamPhysicalCalc;
import org.apache.flink.table.planner.plan.utils.FlinkRelUtil;

Expand Down Expand Up @@ -52,6 +53,8 @@ public class FlinkCalcMergeRule extends RelRule<FlinkCalcMergeRule.FlinkCalcMerg
public static final FlinkCalcMergeRule INSTANCE = FlinkCalcMergeRuleConfig.DEFAULT.toRule();
public static final FlinkCalcMergeRule STREAM_PHYSICAL_INSTANCE =
FlinkCalcMergeRuleConfig.STREAM_PHYSICAL.toRule();
public static final FlinkCalcMergeRule BATCH_PHYSICAL_INSTANCE =
FlinkCalcMergeRuleConfig.BATCH_PHYSICAL.toRule();

protected FlinkCalcMergeRule(FlinkCalcMergeRuleConfig config) {
super(config);
Expand Down Expand Up @@ -112,6 +115,19 @@ public interface FlinkCalcMergeRuleConfig extends RelRule.Config {
.withRelBuilderFactory(RelFactories.LOGICAL_BUILDER)
.withDescription("FlinkCalcMergeRule");

FlinkCalcMergeRule.FlinkCalcMergeRuleConfig BATCH_PHYSICAL =
ImmutableFlinkCalcMergeRule.FlinkCalcMergeRuleConfig.builder()
.build()
.withOperandSupplier(
b0 ->
b0.operand(BatchPhysicalCalc.class)
.inputs(
b1 ->
b1.operand(BatchPhysicalCalc.class)
.anyInputs()))
.withRelBuilderFactory(RelFactories.LOGICAL_BUILDER)
.withDescription("FlinkCalcMergeRule");

@Override
default FlinkCalcMergeRule toRule() {
return new FlinkCalcMergeRule(this);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,7 @@ object FlinkBatchRuleSets {

/** RuleSet to do physical optimize for batch */
val PHYSICAL_OPT_RULES: RuleSet = RuleSets.ofList(
FlinkCalcMergeRule.BATCH_PHYSICAL_INSTANCE,
FlinkExpandConversionRule.BATCH_INSTANCE,
// source
BatchPhysicalBoundedStreamScanRule.INSTANCE,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,24 @@
*/
package org.apache.flink.table.planner.plan.rules.physical.batch

import org.apache.flink.table.planner.calcite.FlinkTypeFactory
import org.apache.flink.table.planner.plan.nodes.FlinkConventions
import org.apache.flink.table.planner.plan.nodes.logical.{FlinkLogicalCalc, FlinkLogicalCorrelate, FlinkLogicalTableFunctionScan}
import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalCorrelate
import org.apache.flink.table.planner.plan.nodes.physical.batch.{BatchPhysicalCalc, BatchPhysicalCorrelate}
import org.apache.flink.table.planner.plan.utils.PythonUtil

import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall, RelTraitSet}
import org.apache.calcite.plan.{RelOptCluster, RelOptRule, RelOptRuleCall, RelTraitSet}
import org.apache.calcite.plan.volcano.RelSubset
import org.apache.calcite.rel.`type`.{RelDataType, RelDataTypeField}
import org.apache.calcite.rel.RelNode
import org.apache.calcite.rel.convert.ConverterRule
import org.apache.calcite.rel.convert.ConverterRule.Config
import org.apache.calcite.rex.RexNode
import org.apache.calcite.rex.{RexNode, RexProgram, RexUtil}
import org.apache.calcite.sql.validate.SqlValidatorUtil

import java.util.Collections

import scala.collection.JavaConverters._

class BatchPhysicalCorrelateRule(config: Config) extends ConverterRule(config) {

Expand All @@ -51,35 +58,72 @@ class BatchPhysicalCorrelateRule(config: Config) extends ConverterRule(config) {
}

override def convert(rel: RelNode): RelNode = {
val join = rel.asInstanceOf[FlinkLogicalCorrelate]
val correlate = rel.asInstanceOf[FlinkLogicalCorrelate]
val cluster = correlate.getCluster
val traitSet: RelTraitSet = rel.getTraitSet.replace(FlinkConventions.BATCH_PHYSICAL)
val convInput: RelNode = RelOptRule.convert(join.getInput(0), FlinkConventions.BATCH_PHYSICAL)
val right: RelNode = join.getInput(1)
val convInput: RelNode =
RelOptRule.convert(correlate.getInput(0), FlinkConventions.BATCH_PHYSICAL)

def convertToCorrelate(relNode: RelNode, condition: Option[RexNode]): BatchPhysicalCorrelate = {
// matches() guarantees the right side is either a TableFunctionScan, or a single Calc
// whose immediate input is a TableFunctionScan.
@scala.annotation.tailrec
def unwrap(
relNode: RelNode): (FlinkLogicalTableFunctionScan, Option[Seq[RexNode]], Option[RexNode]) =
relNode match {
case rel: RelSubset =>
convertToCorrelate(rel.getRelList.get(0), condition)

case rel: RelSubset => unwrap(rel.getRelList.get(0))
case scan: FlinkLogicalTableFunctionScan => (scan, None, None)
case calc: FlinkLogicalCalc =>
convertToCorrelate(
calc.getInput.asInstanceOf[RelSubset].getOriginal,
if (calc.getProgram.getCondition == null) None
else Some(calc.getProgram.expandLocalRef(calc.getProgram.getCondition))
)

case scan: FlinkLogicalTableFunctionScan =>
new BatchPhysicalCorrelate(
rel.getCluster,
traitSet,
convInput,
scan,
condition,
rel.getRowType,
join.getJoinType)
val scan = calc.getInput
.asInstanceOf[RelSubset]
.getOriginal
.asInstanceOf[FlinkLogicalTableFunctionScan]
val program = calc.getProgram
val condition =
if (program.getCondition == null) None
else Some(program.expandLocalRef(program.getCondition))
val projects =
if (program.projectsOnlyIdentity()) None
else Some(program.getProjectList.asScala.map(program.expandLocalRef).toSeq)
(scan, projects, condition)
}

val (scan, projectsOpt, condition) = unwrap(correlate.getInput(1))

projectsOpt match {
case None =>
new BatchPhysicalCorrelate(
cluster,
traitSet,
convInput,
scan,
condition,
correlate.getRowType,
correlate.getJoinType)
case Some(projects) =>
val innerRowType = SqlValidatorUtil.deriveJoinRowType(
correlate.getLeft.getRowType,
scan.getRowType,
correlate.getJoinType,
cluster.getTypeFactory.asInstanceOf[FlinkTypeFactory],
null,
Collections.emptyList[RelDataTypeField]()
)
val innerCorrelate = new BatchPhysicalCorrelate(
cluster,
traitSet,
convInput,
scan,
condition,
innerRowType,
correlate.getJoinType)
val outerProgram = BatchPhysicalCorrelateRule.buildOuterProgram(
cluster,
correlate.getLeft.getRowType.getFieldCount,
innerRowType,
correlate.getRowType,
projects)
new BatchPhysicalCalc(cluster, traitSet, innerCorrelate, outerProgram, correlate.getRowType)
}
convertToCorrelate(right, None)
}
}

Expand All @@ -90,4 +134,26 @@ object BatchPhysicalCorrelateRule {
FlinkConventions.LOGICAL,
FlinkConventions.BATCH_PHYSICAL,
"BatchPhysicalCorrelateRule"))

/**
* Builds the outer Calc program that sits on top of the inner correlate: passes the left input
* through unchanged, then appends the right-side projections shifted by the left field count.
*/
def buildOuterProgram(
cluster: RelOptCluster,
leftFieldCount: Int,
innerRowType: RelDataType,
outputRowType: RelDataType,
rightProjects: Seq[RexNode]): RexProgram = {
val rexBuilder = cluster.getRexBuilder
val outerProjects = new java.util.ArrayList[RexNode]()
val innerFields = innerRowType.getFieldList
var i = 0
while (i < leftFieldCount) {
outerProjects.add(rexBuilder.makeInputRef(innerFields.get(i).getType, i))
i += 1
}
rightProjects.foreach(p => outerProjects.add(RexUtil.shift(p, leftFieldCount)))
RexProgram.create(innerRowType, outerProjects, null, outputRowType, rexBuilder)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,26 @@
package org.apache.flink.table.planner.plan.rules.physical.stream

import org.apache.flink.table.api.TableException
import org.apache.flink.table.planner.calcite.FlinkTypeFactory
import org.apache.flink.table.planner.plan.nodes.FlinkConventions
import org.apache.flink.table.planner.plan.nodes.logical.{FlinkLogicalCalc, FlinkLogicalCorrelate, FlinkLogicalTableFunctionScan}
import org.apache.flink.table.planner.plan.nodes.physical.stream.StreamPhysicalCorrelate
import org.apache.flink.table.planner.plan.nodes.physical.stream.{StreamPhysicalCalc, StreamPhysicalCorrelate}
import org.apache.flink.table.planner.plan.rules.physical.stream.StreamPhysicalCorrelateRule.{getMergedCalc, getTableScan}
import org.apache.flink.table.planner.plan.utils.{AsyncUtil, PythonUtil}
import org.apache.flink.table.planner.plan.utils.{AsyncUtil, FlinkRelUtil, PythonUtil}

import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall, RelTraitSet}
import org.apache.calcite.plan.{RelOptCluster, RelOptRule, RelOptRuleCall, RelTraitSet}
import org.apache.calcite.plan.hep.HepRelVertex
import org.apache.calcite.plan.volcano.RelSubset
import org.apache.calcite.rel.`type`.{RelDataType, RelDataTypeField}
import org.apache.calcite.rel.RelNode
import org.apache.calcite.rel.convert.ConverterRule
import org.apache.calcite.rel.convert.ConverterRule.Config
import org.apache.calcite.rex.{RexNode, RexProgram, RexProgramBuilder}
import org.apache.calcite.rex.{RexNode, RexProgram, RexUtil}
import org.apache.calcite.sql.validate.SqlValidatorUtil

import java.util.Collections

import scala.collection.JavaConverters._

/** Rule that converts [[FlinkLogicalCorrelate]] to [[StreamPhysicalCorrelate]]. */
class StreamPhysicalCorrelateRule(config: Config) extends ConverterRule(config) {
Expand Down Expand Up @@ -63,40 +70,75 @@ class StreamPhysicalCorrelateRule(config: Config) extends ConverterRule(config)

override def convert(rel: RelNode): RelNode = {
val correlate = rel.asInstanceOf[FlinkLogicalCorrelate]
val cluster = correlate.getCluster
val traitSet: RelTraitSet = rel.getTraitSet.replace(FlinkConventions.STREAM_PHYSICAL)
val convInput: RelNode =
RelOptRule.convert(correlate.getInput(0), FlinkConventions.STREAM_PHYSICAL)
val right: RelNode = correlate.getInput(1)

@scala.annotation.tailrec
def convertToCorrelate(
relNode: RelNode,
condition: Option[RexNode]): StreamPhysicalCorrelate = {
def unwrap(relNode: RelNode)
: (FlinkLogicalTableFunctionScan, Option[Seq[RexNode]], Option[RexNode]) = {
relNode match {
case rel: RelSubset =>
convertToCorrelate(rel.getRelList.get(0), condition)

case rel: RelSubset => unwrap(rel.getRelList.get(0))
case calc: FlinkLogicalCalc =>
val tableScan = getTableScan(calc)
val newCalc = getMergedCalc(calc)
convertToCorrelate(
tableScan,
if (newCalc.getProgram.getCondition == null) None
else Some(newCalc.getProgram.expandLocalRef(newCalc.getProgram.getCondition))
)

val program = newCalc.getProgram
val condition =
if (program.getCondition == null) None
else Some(program.expandLocalRef(program.getCondition))
val projects =
if (program.projectsOnlyIdentity()) None
else Some(program.getProjectList.asScala.map(program.expandLocalRef).toSeq)
(tableScan, projects, condition)
case scan: FlinkLogicalTableFunctionScan =>
new StreamPhysicalCorrelate(
rel.getCluster,
traitSet,
convInput,
scan,
condition,
rel.getRowType,
correlate.getJoinType)
(scan, None, None)
}
}
convertToCorrelate(right, None)

val (scan, projectsOpt, condition) = unwrap(right)

projectsOpt match {
case None =>
new StreamPhysicalCorrelate(
cluster,
traitSet,
convInput,
scan,
condition,
correlate.getRowType,
correlate.getJoinType)
case Some(projects) =>
val innerRowType = SqlValidatorUtil.deriveJoinRowType(
correlate.getLeft.getRowType,
scan.getRowType,
correlate.getJoinType,
cluster.getTypeFactory.asInstanceOf[FlinkTypeFactory],
null,
Collections.emptyList[RelDataTypeField]()
)
val innerCorrelate = new StreamPhysicalCorrelate(
cluster,
traitSet,
convInput,
scan,
condition,
innerRowType,
correlate.getJoinType)
val outerProgram = StreamPhysicalCorrelateRule.buildOuterProgram(
cluster,
correlate.getLeft.getRowType.getFieldCount,
innerRowType,
correlate.getRowType,
projects)
new StreamPhysicalCalc(
cluster,
traitSet,
innerCorrelate,
outerProgram,
correlate.getRowType)
}
}

}
Expand All @@ -117,17 +159,7 @@ object StreamPhysicalCorrelateRule {
child match {
case calc1: FlinkLogicalCalc =>
val bottomCalc = getMergedCalc(calc1)
val topCalc = calc
val topProgram: RexProgram = topCalc.getProgram
val mergedProgram: RexProgram = RexProgramBuilder
.mergePrograms(
topCalc.getProgram,
bottomCalc.getProgram,
topCalc.getCluster.getRexBuilder)
assert(mergedProgram.getOutputRowType eq topProgram.getOutputRowType)
topCalc
.copy(topCalc.getTraitSet, bottomCalc.getInput, mergedProgram)
.asInstanceOf[FlinkLogicalCalc]
FlinkRelUtil.merge(calc, bottomCalc).asInstanceOf[FlinkLogicalCalc]
case _ =>
calc
}
Expand All @@ -145,4 +177,22 @@ object StreamPhysicalCorrelateRule {
case _ => throw new TableException("This must be a bug, could not find table scan")
}
}

def buildOuterProgram(
cluster: RelOptCluster,
leftFieldCount: Int,
innerRowType: RelDataType,
outputRowType: RelDataType,
rightProjects: Seq[RexNode]): RexProgram = {
val rexBuilder = cluster.getRexBuilder
val builder = new java.util.ArrayList[RexNode]()
val leftFields = innerRowType.getFieldList
var i = 0
while (i < leftFieldCount) {
builder.add(rexBuilder.makeInputRef(leftFields.get(i).getType, i))
i += 1
}
rightProjects.foreach(p => builder.add(RexUtil.shift(p, leftFieldCount)))
RexProgram.create(innerRowType, builder, null, outputRowType, rexBuilder)
}
}
Loading