Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,30 @@
*/
package org.apache.beam.sdk.io.snowflake;

import static org.apache.beam.sdk.io.TextIO.readFiles;
import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument;

import com.google.auto.value.AutoValue;
import com.opencsv.CSVParser;
import com.opencsv.CSVParserBuilder;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.Serializable;
import java.nio.channels.Channels;
import java.nio.channels.ReadableByteChannel;
import java.nio.charset.StandardCharsets;
import java.security.PrivateKey;
import java.sql.SQLException;
import java.time.LocalDateTime;
import java.time.ZoneId;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;
import java.util.zip.GZIPInputStream;
import javax.annotation.Nullable;
import javax.sql.DataSource;
import net.snowflake.client.api.datasource.SnowflakeDataSource;
Expand All @@ -43,11 +50,13 @@
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.ListCoder;
import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.io.BoundedSource;
import org.apache.beam.sdk.io.Compression;
import org.apache.beam.sdk.io.FileIO;
import org.apache.beam.sdk.io.FileSystems;
import org.apache.beam.sdk.io.TextIO;
import org.apache.beam.sdk.io.WriteFilesResult;
import org.apache.beam.sdk.io.fs.MatchResult;
import org.apache.beam.sdk.io.fs.MoveOptions;
import org.apache.beam.sdk.io.fs.ResourceId;
import org.apache.beam.sdk.io.snowflake.data.SnowflakeTableSchema;
Expand All @@ -59,6 +68,7 @@
import org.apache.beam.sdk.io.snowflake.services.SnowflakeServices;
import org.apache.beam.sdk.io.snowflake.services.SnowflakeServicesImpl;
import org.apache.beam.sdk.io.snowflake.services.SnowflakeStreamingServiceConfig;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.options.ValueProvider;
import org.apache.beam.sdk.transforms.Combine;
import org.apache.beam.sdk.transforms.Create;
Expand All @@ -67,7 +77,6 @@
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.Reify;
import org.apache.beam.sdk.transforms.Reshuffle;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.transforms.SimpleFunction;
import org.apache.beam.sdk.transforms.Values;
Expand Down Expand Up @@ -424,31 +433,26 @@ public Read<T> withQuotationMark(ValueProvider<String> quotationMark) {
public PCollection<T> expand(PBegin input) {
checkArguments();

PCollection<Void> emptyCollection = input.apply(Create.of((Void) null));
String tmpDirName = makeTmpDirName();
PCollection<T> output =
emptyCollection
.apply(
ParDo.of(
new CopyIntoStageFn(
getDataSourceProviderFn(),
getQuery(),
getTable(),
getStorageIntegrationName(),
getStagingBucketName(),
tmpDirName,
getSnowflakeServices(),
getQuotationMark())))
.apply(Reshuffle.viaRandomKey())
.apply(FileIO.matchAll())
Copy link
Copy Markdown
Contributor

@Abacn Abacn Apr 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dataflow can do it because FileIO.matchAll() has a ReShuffle present by default:

if (getOutputParallelization()) {
return res.apply(Reshuffle.viaRandomKey());

This introduces a fusion break and the downstream can be parallelized.

It sounds a Flink runner bug that significant impacted its performence. Preferrably a fix should be done in Flink runner and brings back what ReShuffle intended to do (fusion break). Have you tried to not set useDataStreamForBatch (only availlable for Flink 1.x)?

This PR essentially rewrites SnowflakeIO bounded read and would need a closer eye (I could help for generic Java but less experience on Snowflake connecor)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We had issues without useDataStreamForBatch. I'll try again on the updated version without it just to be sure. Also we have not tried Flink 2 yet because the support is quite recent.

Before the change, it seems Flink does apply reshuffle, but since it reduced the number of workers to 1, it's reshuffling... in 1 worker, not changing that because the input size (list of files) is tiny.

.apply(FileIO.readMatches())
.apply(readFiles())
.apply(ParDo.of(new MapCsvToStringArrayFn(getQuotationMark())))
.apply(ParDo.of(new MapStringArrayToUserDataFn<>(getCsvMapper())));

SnowflakeBoundedSource<T> source =
new SnowflakeBoundedSource<>(
getDataSourceProviderFn(),
getQuery(),
getTable(),
getStorageIntegrationName(),
getStagingBucketName(),
tmpDirName,
getSnowflakeServices(),
getQuotationMark(),
getCsvMapper(),
getCoder());

PCollection<T> output = input.apply(org.apache.beam.sdk.io.Read.from(source));
output.setCoder(getCoder());

emptyCollection
input
.apply(Create.of((Void) null))
.apply(Wait.on(output))
.apply(ParDo.of(new CleanTmpFilesFromGcsFn(getStagingBucketName(), tmpDirName)));
return output;
Expand Down Expand Up @@ -483,103 +487,212 @@ private String makeTmpDirName() {
);
}

private static class CopyIntoStageFn extends DoFn<Object, String> {
/**
* A {@link BoundedSource} that reads from Snowflake by running COPY INTO to stage CSV files,
* then splitting into one sub-source per file.
*/
private static class SnowflakeBoundedSource<T> extends BoundedSource<T> {
private static final Logger LOG = LoggerFactory.getLogger(SnowflakeBoundedSource.class);

private final SerializableFunction<Void, DataSource> dataSourceProviderFn;
private final ValueProvider<String> query;
private final ValueProvider<String> database;
private final ValueProvider<String> schema;
private final ValueProvider<String> table;
private final @Nullable ValueProvider<String> query;
private final @Nullable ValueProvider<String> table;
private final ValueProvider<String> storageIntegrationName;
private final ValueProvider<String> stagingBucketDir;
private final ValueProvider<String> stagingBucketName;
private final String tmpDirName;
private final SnowflakeServices snowflakeServices;
private final ValueProvider<String> quotationMark;
private final CsvMapper<T> csvMapper;
private final Coder<T> coder;

// Non-null only for child sources (one per staged file)
private final @Nullable String filePath;
private final long fileSize;

private CopyIntoStageFn(
SnowflakeBoundedSource(
SerializableFunction<Void, DataSource> dataSourceProviderFn,
ValueProvider<String> query,
ValueProvider<String> table,
@Nullable ValueProvider<String> query,
@Nullable ValueProvider<String> table,
ValueProvider<String> storageIntegrationName,
ValueProvider<String> stagingBucketDir,
ValueProvider<String> stagingBucketName,
String tmpDirName,
SnowflakeServices snowflakeServices,
ValueProvider<String> quotationMark) {
ValueProvider<String> quotationMark,
CsvMapper<T> csvMapper,
Coder<T> coder) {
this(
dataSourceProviderFn,
query,
table,
storageIntegrationName,
stagingBucketName,
tmpDirName,
snowflakeServices,
quotationMark,
csvMapper,
coder,
null,
0);
}

private SnowflakeBoundedSource(
SerializableFunction<Void, DataSource> dataSourceProviderFn,
@Nullable ValueProvider<String> query,
@Nullable ValueProvider<String> table,
ValueProvider<String> storageIntegrationName,
ValueProvider<String> stagingBucketName,
String tmpDirName,
SnowflakeServices snowflakeServices,
ValueProvider<String> quotationMark,
CsvMapper<T> csvMapper,
Coder<T> coder,
@Nullable String filePath,
long fileSize) {
this.dataSourceProviderFn = dataSourceProviderFn;
this.query = query;
this.table = table;
this.storageIntegrationName = storageIntegrationName;
this.stagingBucketName = stagingBucketName;
this.tmpDirName = tmpDirName;
this.snowflakeServices = snowflakeServices;
this.quotationMark = quotationMark;
this.stagingBucketDir = stagingBucketDir;
this.tmpDirName = tmpDirName;
DataSourceProviderFromDataSourceConfiguration
dataSourceProviderFromDataSourceConfiguration =
(DataSourceProviderFromDataSourceConfiguration) this.dataSourceProviderFn;
DataSourceConfiguration config = dataSourceProviderFromDataSourceConfiguration.getConfig();

this.database = config.getDatabase();
this.schema = config.getSchema();
this.csvMapper = csvMapper;
this.coder = coder;
this.filePath = filePath;
this.fileSize = fileSize;
}

@ProcessElement
public void processElement(ProcessContext context) throws Exception {
String databaseValue = getValueOrNull(this.database);
String schemaValue = getValueOrNull(this.schema);
String tableValue = getValueOrNull(this.table);
String queryValue = getValueOrNull(this.query);
@Override
public List<? extends BoundedSource<T>> split(
long desiredBundleSizeBytes, PipelineOptions options) throws Exception {
if (filePath != null) {
return Collections.singletonList(this);
}

String stagingBucketRunDir =
String.format(
"%s/%s/run_%s/",
stagingBucketDir.get(), tmpDirName, UUID.randomUUID().toString().subSequence(0, 8));
stagingBucketName.get(),
tmpDirName,
UUID.randomUUID().toString().subSequence(0, 8));

SnowflakeBatchServiceConfig config =
DataSourceProviderFromDataSourceConfiguration dsProvider =
(DataSourceProviderFromDataSourceConfiguration) dataSourceProviderFn;
DataSourceConfiguration config = dsProvider.getConfig();

SnowflakeBatchServiceConfig batchConfig =
new SnowflakeBatchServiceConfig(
dataSourceProviderFn,
databaseValue,
schemaValue,
tableValue,
queryValue,
getValueOrNull(config.getDatabase()),
getValueOrNull(config.getSchema()),
getValueOrNull(table),
getValueOrNull(query),
storageIntegrationName.get(),
stagingBucketRunDir,
quotationMark.get());

String output = snowflakeServices.getBatchService().read(config);
LOG.info("Running Snowflake COPY INTO stage: {}", stagingBucketRunDir);
String globPattern = snowflakeServices.getBatchService().read(batchConfig);

List<MatchResult.Metadata> files = FileSystems.match(globPattern).metadata();
LOG.info("Snowflake COPY INTO produced {} files", files.size());

return files.stream()
.map(
metadata ->
new SnowflakeBoundedSource<T>(
dataSourceProviderFn,
query,
table,
storageIntegrationName,
stagingBucketName,
tmpDirName,
snowflakeServices,
quotationMark,
csvMapper,
coder,
metadata.resourceId().toString(),
metadata.sizeBytes()))
.collect(Collectors.toList());
}

context.output(output);
@Override
public long getEstimatedSizeBytes(PipelineOptions options) {
return fileSize;
}
}

/**
* Parses {@code String} from incoming data in {@link PCollection} to have proper format for CSV
* files.
*/
public static class MapCsvToStringArrayFn extends DoFn<String, String[]> {
private ValueProvider<String> quoteChar;
@Override
public BoundedReader<T> createReader(PipelineOptions options) throws IOException {
if (filePath == null) {
throw new IOException("Cannot create reader from unsplit parent source");
}
return new SnowflakeFileReader<>(this);
}

public MapCsvToStringArrayFn(ValueProvider<String> quoteChar) {
this.quoteChar = quoteChar;
@Override
public Coder<T> getOutputCoder() {
return coder;
}

@ProcessElement
public void processElement(ProcessContext c) throws IOException {
String csvLine = c.element();
CSVParser parser = new CSVParserBuilder().withQuoteChar(quoteChar.get().charAt(0)).build();
String[] parts = parser.parseLine(csvLine);
c.output(parts);
@Override
public void validate() {
// Validation is done in SnowflakeIO.Read.checkArguments()
}
}

private static class MapStringArrayToUserDataFn<T> extends DoFn<String[], T> {
private final CsvMapper<T> csvMapper;
private static class SnowflakeFileReader<T> extends BoundedReader<T> {
private final SnowflakeBoundedSource<T> source;
private transient BufferedReader reader;
private transient CSVParser csvParser;
private T current;

public MapStringArrayToUserDataFn(CsvMapper<T> csvMapper) {
this.csvMapper = csvMapper;
}
SnowflakeFileReader(SnowflakeBoundedSource<T> source) {
this.source = source;
}

@ProcessElement
public void processElement(ProcessContext context) throws Exception {
context.output(csvMapper.mapRow(context.element()));
@Override
public boolean start() throws IOException {
ResourceId resourceId = FileSystems.matchNewResource(source.filePath, false);
ReadableByteChannel channel = FileSystems.open(resourceId);
InputStream inputStream = new GZIPInputStream(Channels.newInputStream(channel));

reader = new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8));
csvParser =
new CSVParserBuilder().withQuoteChar(source.quotationMark.get().charAt(0)).build();

return advance();
}

@Override
public boolean advance() throws IOException {
String line = reader.readLine();
if (line == null) {
return false;
}
try {
String[] parts = csvParser.parseLine(line);
current = source.csvMapper.mapRow(parts);
return true;
} catch (Exception e) {
throw new IOException("Error mapping CSV row: " + line, e);
}
}

@Override
public T getCurrent() {
return current;
}

@Override
public void close() throws IOException {
if (reader != null) {
reader.close();
}
}

@Override
public BoundedSource<T> getCurrentSource() {
return source;
}
}
}

Expand Down
Loading
Loading