Skip to content

Commit

Permalink
support dataset option
Browse files Browse the repository at this point in the history
  • Loading branch information
jinchengchenghh committed Jun 17, 2024
1 parent 02585cd commit 6b8d9cb
Show file tree
Hide file tree
Showing 12 changed files with 470 additions and 24 deletions.
7 changes: 7 additions & 0 deletions java/dataset/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
<arrow.cpp.build.dir>../../../cpp/release-build/</arrow.cpp.build.dir>
<parquet.version>1.13.1</parquet.version>
<avro.version>1.11.3</avro.version>
<protobuf.version>3.25.3</protobuf.version>
</properties>

<dependencies>
Expand All @@ -48,6 +49,12 @@
<groupId>org.immutables</groupId>
<artifactId>value-annotations</artifactId>
</dependency>
<dependency>
<groupId>com.google.protobuf</groupId>
<artifactId>protobuf-java</artifactId>
<version>${protobuf.version}</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.apache.arrow</groupId>
<artifactId>arrow-memory-netty</artifactId>
Expand Down
152 changes: 142 additions & 10 deletions java/dataset/src/main/cpp/jni_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,15 @@
#include <utility>
#include <unordered_map>

#include <google/protobuf/struct.pb.h>

#include "arrow/array.h"
#include "arrow/array/concatenate.h"
#include "arrow/c/bridge.h"
#include "arrow/c/helpers.h"
#include "arrow/dataset/api.h"
#include "arrow/dataset/file_base.h"
#include "arrow/dataset/file_csv.h"
#include "arrow/filesystem/api.h"
#include "arrow/filesystem/path_util.h"
#include "arrow/engine/substrait/util.h"
Expand Down Expand Up @@ -363,6 +366,92 @@ std::shared_ptr<arrow::Buffer> LoadArrowBufferFromByteBuffer(JNIEnv* env, jobjec
return buffer;
}

arrow::Status ParseFromBufferImpl(const arrow::Buffer& buf, const std::string& full_name,
google::protobuf::Message* message) {
google::protobuf::io::ArrayInputStream buf_stream{buf.data(),
static_cast<int>(buf.size())};

if (message->ParseFromZeroCopyStream(&buf_stream)) {
return arrow::Status::OK();
}
return arrow::Status::IOError("ParseFromZeroCopyStream failed for ", full_name);
}

template <typename Message>
arrow::Result<Message> ParseFromBuffer(const arrow::Buffer& buf) {
Message message;
ARROW_RETURN_NOT_OK(
ParseFromBufferImpl(buf, Message::descriptor()->full_name(), &message));
return message;
}

arrow::Status FromProto(const google::protobuf::Struct& struct,
std::unordered_map<std::string, std::string>& out) {
if (struct.fields().empty()) {
return arrow::Status::OK();
}
for (const auto& [name, value] : struct.fields()) {
out.emplace(name, value.string_value());
}
return arrow::Status::OK();
}

/// \brief Deserialize a Protobuf Struct message to the map
///
/// \param[in] buf a buffer containing the protobuf serialization of a Protobuf Struct
/// \param[out] out deserialize to this map.
arrow::Status DeserializeMap(const arrow::Buffer& buf,
std::unordered_map<std::string, std::string>& out) {
ARROW_ASSIGN_OR_RAISE(auto struct, ParseFromBuffer<google::protobuf::Struct>(buf));
return FromProto(struct, out);
}

inline bool ParseBool(const std::string& value) { return value == "true" ? true : false; }

/// \brief Construct FragmentScanOptions from config map
#ifdef ARROW_CSV
arrow::Result<std::shared_ptr<arrow::dataset::FragmentScanOptions>>
ToCsvFragmentScanOptions(const std::unordered_map<std::string, std::string>& configs) {
std::shared_ptr<arrow::dataset::CsvFragmentScanOptions> options =
std::make_shared<arrow::dataset::CsvFragmentScanOptions>();
for (auto const& it : configs) {
auto& key = it.first;
auto& value = it.second;
if (key == "delimiter") {
options->parse_options.delimiter = value.data()[0];
} else if (key == "quoting") {
options->parse_options.quoting = ParseBool(value);
} else if (key == "column_types") {
int64_t schema_address = std::stol(value);
ArrowSchema* cSchema = reinterpret_cast<ArrowSchema*>(schema_address);
ARROW_ASSIGN_OR_RAISE(auto schema, arrow::ImportSchema(cSchema));
auto& column_types = options->convert_options.column_types;
for (auto field : schema->fields()) {
column_types[field->name()] = field->type();
}
} else if (key == "strings_can_be_null") {
options->convert_options.strings_can_be_null = ParseBool(value);
} else {
return arrow::Status::Invalid("Config " + it.first + " is not supported.");
}
}
return options;
}
#endif

arrow::Result<std::shared_ptr<arrow::dataset::FragmentScanOptions>>
GetFragmentScanOptions(jint file_format_id,
const std::unordered_map<std::string, std::string>& configs) {
switch (file_format_id) {
#ifdef ARROW_CSV
case 3:
return ToCsvFragmentScanOptions(configs);
#endif
default:
return arrow::Status::Invalid("Illegal file format id: ", file_format_id);
}
}

/*
* Class: org_apache_arrow_dataset_jni_NativeMemoryPool
* Method: getDefaultMemoryPool
Expand Down Expand Up @@ -501,12 +590,13 @@ JNIEXPORT void JNICALL Java_org_apache_arrow_dataset_jni_JniWrapper_closeDataset
/*
* Class: org_apache_arrow_dataset_jni_JniWrapper
* Method: createScanner
* Signature: (J[Ljava/lang/String;Ljava/nio/ByteBuffer;Ljava/nio/ByteBuffer;JJ)J
* Signature:
* (J[Ljava/lang/String;Ljava/nio/ByteBuffer;Ljava/nio/ByteBuffer;JJ;Ljava/nio/ByteBuffer;J)J
*/
JNIEXPORT jlong JNICALL Java_org_apache_arrow_dataset_jni_JniWrapper_createScanner(
JNIEnv* env, jobject, jlong dataset_id, jobjectArray columns,
jobject substrait_projection, jobject substrait_filter,
jlong batch_size, jlong memory_pool_id) {
jobject substrait_projection, jobject substrait_filter, jlong batch_size,
jlong file_format_id, jobject options, jlong memory_pool_id) {
JNI_METHOD_START
arrow::MemoryPool* pool = reinterpret_cast<arrow::MemoryPool*>(memory_pool_id);
if (pool == nullptr) {
Expand Down Expand Up @@ -555,6 +645,14 @@ JNIEXPORT jlong JNICALL Java_org_apache_arrow_dataset_jni_JniWrapper_createScann
}
JniAssertOkOrThrow(scanner_builder->Filter(*filter_expr));
}
if (file_format_id != -1 && options != nullptr) {
std::unordered_map<std::string, std::string> option_map;
std::shared_ptr<arrow::Buffer> buffer = LoadArrowBufferFromByteBuffer(env, options);
JniAssertOkOrThrow(DeserializeMap(*buffer, option_map));
std::shared_ptr<arrow::dataset::FragmentScanOptions> scan_options =
JniGetOrThrow(GetFragmentScanOptions(file_format_id, option_map));
JniAssertOkOrThrow(scanner_builder->FragmentScanOptions(scan_options));
}
JniAssertOkOrThrow(scanner_builder->BatchSize(batch_size));

auto scanner = JniGetOrThrow(scanner_builder->Finish());
Expand Down Expand Up @@ -668,14 +766,31 @@ JNIEXPORT void JNICALL Java_org_apache_arrow_dataset_jni_JniWrapper_ensureS3Fina
/*
* Class: org_apache_arrow_dataset_file_JniWrapper
* Method: makeFileSystemDatasetFactory
* Signature: (Ljava/lang/String;II)J
* Signature: (Ljava/lang/String;II;Ljava/lang/String;Ljava/nio/ByteBuffer)J
*/
JNIEXPORT jlong JNICALL
Java_org_apache_arrow_dataset_file_JniWrapper_makeFileSystemDatasetFactory__Ljava_lang_String_2I(
JNIEnv* env, jobject, jstring uri, jint file_format_id) {
Java_org_apache_arrow_dataset_file_JniWrapper_makeFileSystemDatasetFactory(
JNIEnv* env, jobject, jstring uri, jint file_format_id, jobject options) {
JNI_METHOD_START
std::shared_ptr<arrow::dataset::FileFormat> file_format =
JniGetOrThrow(GetFileFormat(file_format_id));
if (options != nullptr) {
std::unordered_map<std::string, std::string> option_map;
std::shared_ptr<arrow::Buffer> buffer = LoadArrowBufferFromByteBuffer(env, options);
JniAssertOkOrThrow(DeserializeMap(*buffer, option_map));
std::shared_ptr<arrow::dataset::FragmentScanOptions> scan_options =
JniGetOrThrow(GetFragmentScanOptions(file_format_id, option_map));
file_format->default_fragment_scan_options = scan_options;
#ifdef ARROW_CSV
if (file_format_id == 3) {
std::shared_ptr<arrow::dataset::CsvFileFormat> csv_file_format =
std::dynamic_pointer_cast<arrow::dataset::CsvFileFormat>(file_format);
csv_file_format->parse_options =
std::dynamic_pointer_cast<arrow::dataset::CsvFragmentScanOptions>(scan_options)
->parse_options;
}
#endif
}
arrow::dataset::FileSystemFactoryOptions options;
std::shared_ptr<arrow::dataset::DatasetFactory> d =
JniGetOrThrow(arrow::dataset::FileSystemDatasetFactory::Make(
Expand All @@ -686,16 +801,33 @@ Java_org_apache_arrow_dataset_file_JniWrapper_makeFileSystemDatasetFactory__Ljav

/*
* Class: org_apache_arrow_dataset_file_JniWrapper
* Method: makeFileSystemDatasetFactory
* Signature: ([Ljava/lang/String;II)J
* Method: makeFileSystemDatasetFactoryWithFiles
* Signature: ([Ljava/lang/String;II;Ljava/nio/ByteBuffer)J
*/
JNIEXPORT jlong JNICALL
Java_org_apache_arrow_dataset_file_JniWrapper_makeFileSystemDatasetFactory___3Ljava_lang_String_2I(
JNIEnv* env, jobject, jobjectArray uris, jint file_format_id) {
Java_org_apache_arrow_dataset_file_JniWrapper_makeFileSystemDatasetFactoryWithFiles(
JNIEnv* env, jobject, jobjectArray uris, jint file_format_id, jobject options) {
JNI_METHOD_START

std::shared_ptr<arrow::dataset::FileFormat> file_format =
JniGetOrThrow(GetFileFormat(file_format_id));
if (options != nullptr) {
std::unordered_map<std::string, std::string> option_map;
std::shared_ptr<arrow::Buffer> buffer = LoadArrowBufferFromByteBuffer(env, options);
JniAssertOkOrThrow(DeserializeMap(*buffer, option_map));
std::shared_ptr<arrow::dataset::FragmentScanOptions> scan_options =
JniGetOrThrow(GetFragmentScanOptions(file_format_id, option_map));
file_format->default_fragment_scan_options = scan_options;
#ifdef ARROW_CSV
if (file_format_id == 3) {
std::shared_ptr<arrow::dataset::CsvFileFormat> csv_file_format =
std::dynamic_pointer_cast<arrow::dataset::CsvFileFormat>(file_format);
csv_file_format->parse_options =
std::dynamic_pointer_cast<arrow::dataset::CsvFragmentScanOptions>(scan_options)
->parse_options;
}
#endif
}
arrow::dataset::FileSystemFactoryOptions options;

std::vector<std::string> uri_vec = ToStringVector(env, uris);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@

package org.apache.arrow.dataset.file;

import java.util.Optional;

import org.apache.arrow.dataset.jni.NativeDatasetFactory;
import org.apache.arrow.dataset.jni.NativeMemoryPool;
import org.apache.arrow.dataset.scanner.FragmentScanOptions;
import org.apache.arrow.memory.BufferAllocator;

/**
Expand All @@ -27,21 +30,34 @@
public class FileSystemDatasetFactory extends NativeDatasetFactory {

public FileSystemDatasetFactory(BufferAllocator allocator, NativeMemoryPool memoryPool, FileFormat format,
String uri) {
super(allocator, memoryPool, createNative(format, uri));
String uri, Optional<FragmentScanOptions> fragmentScanOptions) {
super(allocator, memoryPool, createNative(format, uri, fragmentScanOptions));
}

public FileSystemDatasetFactory(BufferAllocator allocator, NativeMemoryPool memoryPool, FileFormat format,
String uri) {
super(allocator, memoryPool, createNative(format, uri, Optional.empty()));
}

public FileSystemDatasetFactory(BufferAllocator allocator, NativeMemoryPool memoryPool, FileFormat format,
String[] uris, Optional<FragmentScanOptions> fragmentScanOptions) {
super(allocator, memoryPool, createNative(format, uris, fragmentScanOptions));
}

public FileSystemDatasetFactory(BufferAllocator allocator, NativeMemoryPool memoryPool, FileFormat format,
String[] uris) {
super(allocator, memoryPool, createNative(format, uris));
super(allocator, memoryPool, createNative(format, uris, Optional.empty()));
}

private static long createNative(FileFormat format, String uri) {
return JniWrapper.get().makeFileSystemDatasetFactory(uri, format.id());
private static long createNative(FileFormat format, String uri, Optional<FragmentScanOptions> fragmentScanOptions) {
return JniWrapper.get().makeFileSystemDatasetFactory(uri, format.id(),
fragmentScanOptions.map(FragmentScanOptions::serialize).orElse(null));
}

private static long createNative(FileFormat format, String[] uris) {
return JniWrapper.get().makeFileSystemDatasetFactory(uris, format.id());
private static long createNative(FileFormat format, String[] uris,
Optional<FragmentScanOptions> fragmentScanOptions) {
return JniWrapper.get().makeFileSystemDatasetFactoryWithFiles(uris, format.id(),
fragmentScanOptions.map(FragmentScanOptions::serialize).orElse(null));
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.arrow.dataset.file;

import java.nio.ByteBuffer;

import org.apache.arrow.dataset.jni.JniLoader;

/**
Expand All @@ -39,22 +41,26 @@ private JniWrapper() {
* intermediate shared_ptr of the factory instance.
*
* @param uri file uri to read, either a file or a directory
* @param fileFormat file format ID
* @param fileFormat file format ID.
* @param serializedFragmentScanOptions serialized FragmentScanOptions.
* @return the native pointer of the arrow::dataset::FileSystemDatasetFactory instance.
* @see FileFormat
*/
public native long makeFileSystemDatasetFactory(String uri, int fileFormat);
public native long makeFileSystemDatasetFactory(String uri, int fileFormat,
ByteBuffer serializedFragmentScanOptions);

/**
* Create FileSystemDatasetFactory and return its native pointer. The pointer is pointing to a
* intermediate shared_ptr of the factory instance.
*
* @param uris List of file uris to read, each path pointing to an individual file
* @param fileFormat file format ID
* @param fileFormat file format ID.
* @param serializedFragmentScanOptions serialized FragmentScanOptions.
* @return the native pointer of the arrow::dataset::FileSystemDatasetFactory instance.
* @see FileFormat
*/
public native long makeFileSystemDatasetFactory(String[] uris, int fileFormat);
public native long makeFileSystemDatasetFactoryWithFiles(String[] uris, int fileFormat,
ByteBuffer serializedFragmentScanOptions);

/**
* Write the content in a {@link org.apache.arrow.c.ArrowArrayStream} into files. This internally
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,14 @@ private JniWrapper() {
* @param substraitProjection substrait extended expression to evaluate for project new columns
* @param substraitFilter substrait extended expression to evaluate for apply filter
* @param batchSize batch size of scanned record batches.
* @param fileFormat file format ID.
* @param serializedFragmentScanOptions serialized FragmentScanOptions.
* @param memoryPool identifier of memory pool used in the native scanner.
* @return the native pointer of the arrow::dataset::Scanner instance.
*/
public native long createScanner(long datasetId, String[] columns, ByteBuffer substraitProjection,
ByteBuffer substraitFilter, long batchSize, long memoryPool);
ByteBuffer substraitFilter, long batchSize, long fileFormat,
ByteBuffer serializedFragmentScanOptions, long memoryPool);

/**
* Get a serialized schema from native instance of a Scanner.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@

package org.apache.arrow.dataset.jni;

import java.nio.ByteBuffer;

import org.apache.arrow.dataset.scanner.FragmentScanOptions;
import org.apache.arrow.dataset.scanner.ScanOptions;
import org.apache.arrow.dataset.source.Dataset;

Expand All @@ -40,11 +43,18 @@ public synchronized NativeScanner newScan(ScanOptions options) {
if (closed) {
throw new NativeInstanceReleasedException();
}

int fileFormat = -1;
ByteBuffer serialized = null;
if (options.getFragmentScanOptions().isPresent()) {
FragmentScanOptions fragmentScanOptions = options.getFragmentScanOptions().get();
fileFormat = fragmentScanOptions.fileFormatId();
serialized = fragmentScanOptions.serialize();
}
long scannerId = JniWrapper.get().createScanner(datasetId, options.getColumns().orElse(null),
options.getSubstraitProjection().orElse(null),
options.getSubstraitFilter().orElse(null),
options.getBatchSize(), context.getMemoryPool().getNativeInstanceId());
options.getBatchSize(), fileFormat, serialized,
context.getMemoryPool().getNativeInstanceId());

return new NativeScanner(context, scannerId);
}
Expand Down
Loading

0 comments on commit 6b8d9cb

Please sign in to comment.