Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 54 additions & 1 deletion backend/query/insert_on_conflict_dml_execution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ using googlesql::ResolvedColumnRef;
using googlesql::ResolvedDMLValue;
using googlesql::ResolvedInsertStmt;
using googlesql::ResolvedOnConflictClauseEnums;
using googlesql::ResolvedSubqueryExpr;
} // namespace

absl::Status InsertOnConflictDoUpdateRewriter::VisitResolvedColumnRef(
Expand All @@ -88,7 +89,8 @@ absl::Status InsertOnConflictDoUpdateRewriter::VisitResolvedColumnRef(
std::pair<const googlesql::Type*, int> column_info =
column_ids_referenced_from_insert_row_.at(node->column().column_id());
auto array_element_column_ref = googlesql::MakeResolvedColumnRef(
struct_column_holder_->type(), *struct_column_holder_, false);
struct_column_holder_->type(), *struct_column_holder_,
node->is_correlated());
// Build RESOLVED_GET_STRUCT_FIELD to extract the column value from the
// source STRUCT column.
auto get_struct_field = googlesql::MakeResolvedGetStructField(
Expand All @@ -98,6 +100,57 @@ absl::Status InsertOnConflictDoUpdateRewriter::VisitResolvedColumnRef(
return absl::OkStatus();
}

absl::Status InsertOnConflictDoUpdateRewriter::VisitResolvedSubqueryExpr(
const ResolvedSubqueryExpr* node) {
std::vector<std::unique_ptr<ResolvedColumnRef>> parameter_list;
parameter_list.reserve(node->parameter_list_size());
bool added_struct_column_holder = false;
for (const auto& parameter : node->parameter_list()) {
if (column_ids_referenced_from_insert_row_.contains(
parameter->column().column_id())) {
if (!added_struct_column_holder) {
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for addressing this issue and adding a fix.
Can we pls add a comment here - "Replace the column references from insert row with the insert row value struct"

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

parameter_list.push_back(googlesql::MakeResolvedColumnRef(
struct_column_holder_->type(), *struct_column_holder_, false));
added_struct_column_holder = true;
}
continue;
}
GOOGLESQL_ASSIGN_OR_RETURN(std::unique_ptr<ResolvedColumnRef> parameter_copy,
ProcessNode(parameter.get()));
parameter_list.push_back(std::move(parameter_copy));
}

GOOGLESQL_ASSIGN_OR_RETURN(std::unique_ptr<googlesql::ResolvedExpr> in_expr,
ProcessNode(node->in_expr()));
GOOGLESQL_ASSIGN_OR_RETURN(std::unique_ptr<googlesql::ResolvedScan> subquery,
ProcessNode(node->subquery()));

auto copy = googlesql::MakeResolvedSubqueryExpr(
node->type(), node->subquery_type(), std::move(parameter_list),
std::move(in_expr), std::move(subquery));
copy->set_type_annotation_map(node->type_annotation_map());
copy->set_in_collation(node->in_collation());

GOOGLESQL_ASSIGN_OR_RETURN(
std::vector<std::unique_ptr<googlesql::ResolvedOption>> hint_list,
ProcessNodeList(node->hint_list()));
copy->set_hint_list({std::make_move_iterator(hint_list.begin()),
std::make_move_iterator(hint_list.end())});

const auto parse_location = node->GetParseLocationRangeOrNULL();
if (parse_location != nullptr) {
copy->SetParseLocationRange(*parse_location);
}
const auto operator_keyword_parse_location =
node->GetOperatorKeywordLocationRangeOrNULL();
if (operator_keyword_parse_location != nullptr) {
copy->SetOperatorKeywordLocationRange(*operator_keyword_parse_location);
}

PushNodeToStack(std::move(copy));
return absl::OkStatus();
}

absl::Status InsertOnConflictToInsertOrIgnoreRewriter::VisitResolvedInsertStmt(
const ResolvedInsertStmt* node) {
GOOGLESQL_RETURN_IF_ERROR(CopyVisitResolvedInsertStmt(node));
Expand Down
2 changes: 2 additions & 0 deletions backend/query/insert_on_conflict_dml_execution.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ class InsertOnConflictDoUpdateRewriter

absl::Status VisitResolvedColumnRef(
const googlesql::ResolvedColumnRef* node) override;
absl::Status VisitResolvedSubqueryExpr(
const googlesql::ResolvedSubqueryExpr* node) override;

private:
// Map of column id(s) of columns referenced from the insert row (i.e. of the
Expand Down
33 changes: 33 additions & 0 deletions backend/query/query_engine_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1899,6 +1899,39 @@ TEST_P(QueryEngineTest, InsertOnConflictDoUpdateDml) {
EXPECT_EQ(result.modified_row_count, 2);
}

TEST_P(QueryEngineTest, InsertOnConflictDoUpdateSubqueryCanReferenceExcluded) {
if (GetParam() == POSTGRESQL) {
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should work for PG as well. Do we need this GTEST_SKIP?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

GTEST_SKIP() << "GoogleSQL-specific regression test";
}

MockRowWriter writer;
EXPECT_CALL(
writer,
Write(Property(
&Mutation::ops,
UnorderedElementsAre(AllOf(
Field(&MutationOp::type, MutationOpType::kUpdate),
Field(&MutationOp::table, "test_table"),
Field(&MutationOp::columns,
std::vector<std::string>{"int64_col", "string_col"}),
Field(&MutationOp::rows,
UnorderedElementsAre(ValueList{Int64(1), String("ten")})))))))
.Times(1)
.WillOnce(Return(absl::OkStatus()));

GOOGLESQL_ASSERT_OK_AND_ASSIGN(
QueryResult result,
query_engine().ExecuteSql(
Query{"INSERT INTO test_table (int64_col, string_col) "
"VALUES(1, 'ten') "
"ON CONFLICT(int64_col) DO UPDATE SET string_col = "
"(SELECT excluded.string_col)"},
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe good to replace this with a subquery that does a table scan as well:
(SELECT CONCAT(t.string_col, '-', excluded.string_col) from test_table t WHERE t.int64_col = 1)

This will change the output string_col to one-ten since the row with int64_col:1 already exists in this test.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

QueryContext{schema(), reader(), &writer}));

ASSERT_EQ(result.rows, nullptr);
EXPECT_EQ(result.modified_row_count, 1);
}

TEST_P(QueryEngineTest, InsertOnConflictDoUpdateDmlWithReturning) {
std::string returning =
(GetParam() == POSTGRESQL) ? "RETURNING" : "THEN RETURN WITH ACTION";
Expand Down