Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 56 additions & 1 deletion backend/query/insert_on_conflict_dml_execution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ using googlesql::ResolvedColumnRef;
using googlesql::ResolvedDMLValue;
using googlesql::ResolvedInsertStmt;
using googlesql::ResolvedOnConflictClauseEnums;
using googlesql::ResolvedSubqueryExpr;
} // namespace

absl::Status InsertOnConflictDoUpdateRewriter::VisitResolvedColumnRef(
Expand All @@ -88,7 +89,8 @@ absl::Status InsertOnConflictDoUpdateRewriter::VisitResolvedColumnRef(
std::pair<const googlesql::Type*, int> column_info =
column_ids_referenced_from_insert_row_.at(node->column().column_id());
auto array_element_column_ref = googlesql::MakeResolvedColumnRef(
struct_column_holder_->type(), *struct_column_holder_, false);
struct_column_holder_->type(), *struct_column_holder_,
node->is_correlated());
// Build RESOLVED_GET_STRUCT_FIELD to extract the column value from the
// source STRUCT column.
auto get_struct_field = googlesql::MakeResolvedGetStructField(
Expand All @@ -98,6 +100,59 @@ absl::Status InsertOnConflictDoUpdateRewriter::VisitResolvedColumnRef(
return absl::OkStatus();
}

absl::Status InsertOnConflictDoUpdateRewriter::VisitResolvedSubqueryExpr(
const ResolvedSubqueryExpr* node) {
std::vector<std::unique_ptr<ResolvedColumnRef>> parameter_list;
parameter_list.reserve(node->parameter_list_size());
bool added_struct_column_holder = false;
for (const auto& parameter : node->parameter_list()) {
if (column_ids_referenced_from_insert_row_.contains(
parameter->column().column_id())) {
if (!added_struct_column_holder) {
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for addressing this issue and adding a fix.
Can we pls add a comment here - "Replace the column references from insert row with the insert row value struct"

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

// Replace the column references from insert row with the insert row
// value struct.
parameter_list.push_back(googlesql::MakeResolvedColumnRef(
struct_column_holder_->type(), *struct_column_holder_, false));
added_struct_column_holder = true;
}
continue;
}
GOOGLESQL_ASSIGN_OR_RETURN(std::unique_ptr<ResolvedColumnRef> parameter_copy,
ProcessNode(parameter.get()));
parameter_list.push_back(std::move(parameter_copy));
}

GOOGLESQL_ASSIGN_OR_RETURN(std::unique_ptr<googlesql::ResolvedExpr> in_expr,
ProcessNode(node->in_expr()));
GOOGLESQL_ASSIGN_OR_RETURN(std::unique_ptr<googlesql::ResolvedScan> subquery,
ProcessNode(node->subquery()));

auto copy = googlesql::MakeResolvedSubqueryExpr(
node->type(), node->subquery_type(), std::move(parameter_list),
std::move(in_expr), std::move(subquery));
copy->set_type_annotation_map(node->type_annotation_map());
copy->set_in_collation(node->in_collation());

GOOGLESQL_ASSIGN_OR_RETURN(
std::vector<std::unique_ptr<googlesql::ResolvedOption>> hint_list,
ProcessNodeList(node->hint_list()));
copy->set_hint_list({std::make_move_iterator(hint_list.begin()),
std::make_move_iterator(hint_list.end())});

const auto parse_location = node->GetParseLocationRangeOrNULL();
if (parse_location != nullptr) {
copy->SetParseLocationRange(*parse_location);
}
const auto operator_keyword_parse_location =
node->GetOperatorKeywordLocationRangeOrNULL();
if (operator_keyword_parse_location != nullptr) {
copy->SetOperatorKeywordLocationRange(*operator_keyword_parse_location);
}

PushNodeToStack(std::move(copy));
return absl::OkStatus();
}

absl::Status InsertOnConflictToInsertOrIgnoreRewriter::VisitResolvedInsertStmt(
const ResolvedInsertStmt* node) {
GOOGLESQL_RETURN_IF_ERROR(CopyVisitResolvedInsertStmt(node));
Expand Down
2 changes: 2 additions & 0 deletions backend/query/insert_on_conflict_dml_execution.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ class InsertOnConflictDoUpdateRewriter

absl::Status VisitResolvedColumnRef(
const googlesql::ResolvedColumnRef* node) override;
absl::Status VisitResolvedSubqueryExpr(
const googlesql::ResolvedSubqueryExpr* node) override;

private:
// Map of column id(s) of columns referenced from the insert row (i.e. of the
Expand Down
31 changes: 31 additions & 0 deletions backend/query/query_engine_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1899,6 +1899,37 @@ TEST_P(QueryEngineTest, InsertOnConflictDoUpdateDml) {
EXPECT_EQ(result.modified_row_count, 2);
}

TEST_P(QueryEngineTest, InsertOnConflictDoUpdateSubqueryCanReferenceExcluded) {
MockRowWriter writer;
EXPECT_CALL(
writer,
Write(Property(
&Mutation::ops,
UnorderedElementsAre(AllOf(
Field(&MutationOp::type, MutationOpType::kUpdate),
Field(&MutationOp::table, "test_table"),
Field(&MutationOp::columns,
std::vector<std::string>{"int64_col", "string_col"}),
Field(&MutationOp::rows,
UnorderedElementsAre(
ValueList{Int64(1), String("one-ten")})))))))
.Times(1)
.WillOnce(Return(absl::OkStatus()));

GOOGLESQL_ASSERT_OK_AND_ASSIGN(
QueryResult result,
query_engine().ExecuteSql(
Query{"INSERT INTO test_table (int64_col, string_col) "
"VALUES(1, 'ten') "
"ON CONFLICT(int64_col) DO UPDATE SET string_col = "
"(SELECT CONCAT(t.string_col, '-', excluded.string_col) "
"FROM test_table t WHERE t.int64_col = 1)"},
QueryContext{schema(), reader(), &writer}));

ASSERT_EQ(result.rows, nullptr);
EXPECT_EQ(result.modified_row_count, 1);
}

TEST_P(QueryEngineTest, InsertOnConflictDoUpdateDmlWithReturning) {
std::string returning =
(GetParam() == POSTGRESQL) ? "RETURNING" : "THEN RETURN WITH ACTION";
Expand Down