diff --git a/backend/query/insert_on_conflict_dml_execution.cc b/backend/query/insert_on_conflict_dml_execution.cc index da250a2c..9fccd6ed 100644 --- a/backend/query/insert_on_conflict_dml_execution.cc +++ b/backend/query/insert_on_conflict_dml_execution.cc @@ -65,6 +65,7 @@ using googlesql::ResolvedColumnRef; using googlesql::ResolvedDMLValue; using googlesql::ResolvedInsertStmt; using googlesql::ResolvedOnConflictClauseEnums; +using googlesql::ResolvedSubqueryExpr; } // namespace absl::Status InsertOnConflictDoUpdateRewriter::VisitResolvedColumnRef( @@ -88,7 +89,8 @@ absl::Status InsertOnConflictDoUpdateRewriter::VisitResolvedColumnRef( std::pair column_info = column_ids_referenced_from_insert_row_.at(node->column().column_id()); auto array_element_column_ref = googlesql::MakeResolvedColumnRef( - struct_column_holder_->type(), *struct_column_holder_, false); + struct_column_holder_->type(), *struct_column_holder_, + node->is_correlated()); // Build RESOLVED_GET_STRUCT_FIELD to extract the column value from the // source STRUCT column. auto get_struct_field = googlesql::MakeResolvedGetStructField( @@ -98,6 +100,59 @@ absl::Status InsertOnConflictDoUpdateRewriter::VisitResolvedColumnRef( return absl::OkStatus(); } +absl::Status InsertOnConflictDoUpdateRewriter::VisitResolvedSubqueryExpr( + const ResolvedSubqueryExpr* node) { + std::vector> parameter_list; + parameter_list.reserve(node->parameter_list_size()); + bool added_struct_column_holder = false; + for (const auto& parameter : node->parameter_list()) { + if (column_ids_referenced_from_insert_row_.contains( + parameter->column().column_id())) { + if (!added_struct_column_holder) { + // Replace the column references from insert row with the insert row + // value struct. + parameter_list.push_back(googlesql::MakeResolvedColumnRef( + struct_column_holder_->type(), *struct_column_holder_, false)); + added_struct_column_holder = true; + } + continue; + } + GOOGLESQL_ASSIGN_OR_RETURN(std::unique_ptr parameter_copy, + ProcessNode(parameter.get())); + parameter_list.push_back(std::move(parameter_copy)); + } + + GOOGLESQL_ASSIGN_OR_RETURN(std::unique_ptr in_expr, + ProcessNode(node->in_expr())); + GOOGLESQL_ASSIGN_OR_RETURN(std::unique_ptr subquery, + ProcessNode(node->subquery())); + + auto copy = googlesql::MakeResolvedSubqueryExpr( + node->type(), node->subquery_type(), std::move(parameter_list), + std::move(in_expr), std::move(subquery)); + copy->set_type_annotation_map(node->type_annotation_map()); + copy->set_in_collation(node->in_collation()); + + GOOGLESQL_ASSIGN_OR_RETURN( + std::vector> hint_list, + ProcessNodeList(node->hint_list())); + copy->set_hint_list({std::make_move_iterator(hint_list.begin()), + std::make_move_iterator(hint_list.end())}); + + const auto parse_location = node->GetParseLocationRangeOrNULL(); + if (parse_location != nullptr) { + copy->SetParseLocationRange(*parse_location); + } + const auto operator_keyword_parse_location = + node->GetOperatorKeywordLocationRangeOrNULL(); + if (operator_keyword_parse_location != nullptr) { + copy->SetOperatorKeywordLocationRange(*operator_keyword_parse_location); + } + + PushNodeToStack(std::move(copy)); + return absl::OkStatus(); +} + absl::Status InsertOnConflictToInsertOrIgnoreRewriter::VisitResolvedInsertStmt( const ResolvedInsertStmt* node) { GOOGLESQL_RETURN_IF_ERROR(CopyVisitResolvedInsertStmt(node)); diff --git a/backend/query/insert_on_conflict_dml_execution.h b/backend/query/insert_on_conflict_dml_execution.h index 47cd1ac1..51e75383 100644 --- a/backend/query/insert_on_conflict_dml_execution.h +++ b/backend/query/insert_on_conflict_dml_execution.h @@ -61,6 +61,8 @@ class InsertOnConflictDoUpdateRewriter absl::Status VisitResolvedColumnRef( const googlesql::ResolvedColumnRef* node) override; + absl::Status VisitResolvedSubqueryExpr( + const googlesql::ResolvedSubqueryExpr* node) override; private: // Map of column id(s) of columns referenced from the insert row (i.e. of the diff --git a/backend/query/query_engine_test.cc b/backend/query/query_engine_test.cc index 082724d9..c90a6213 100644 --- a/backend/query/query_engine_test.cc +++ b/backend/query/query_engine_test.cc @@ -1899,6 +1899,37 @@ TEST_P(QueryEngineTest, InsertOnConflictDoUpdateDml) { EXPECT_EQ(result.modified_row_count, 2); } +TEST_P(QueryEngineTest, InsertOnConflictDoUpdateSubqueryCanReferenceExcluded) { + MockRowWriter writer; + EXPECT_CALL( + writer, + Write(Property( + &Mutation::ops, + UnorderedElementsAre(AllOf( + Field(&MutationOp::type, MutationOpType::kUpdate), + Field(&MutationOp::table, "test_table"), + Field(&MutationOp::columns, + std::vector{"int64_col", "string_col"}), + Field(&MutationOp::rows, + UnorderedElementsAre( + ValueList{Int64(1), String("one-ten")}))))))) + .Times(1) + .WillOnce(Return(absl::OkStatus())); + + GOOGLESQL_ASSERT_OK_AND_ASSIGN( + QueryResult result, + query_engine().ExecuteSql( + Query{"INSERT INTO test_table (int64_col, string_col) " + "VALUES(1, 'ten') " + "ON CONFLICT(int64_col) DO UPDATE SET string_col = " + "(SELECT CONCAT(t.string_col, '-', excluded.string_col) " + "FROM test_table t WHERE t.int64_col = 1)"}, + QueryContext{schema(), reader(), &writer})); + + ASSERT_EQ(result.rows, nullptr); + EXPECT_EQ(result.modified_row_count, 1); +} + TEST_P(QueryEngineTest, InsertOnConflictDoUpdateDmlWithReturning) { std::string returning = (GetParam() == POSTGRESQL) ? "RETURNING" : "THEN RETURN WITH ACTION";