From 6b8d9cbd1ea0f0e46086cad3dd4fb5b61a8680bf Mon Sep 17 00:00:00 2001 From: Chengcheng Jin Date: Mon, 17 Jun 2024 15:50:15 +0000 Subject: [PATCH] support dataset option --- java/dataset/pom.xml | 7 + java/dataset/src/main/cpp/jni_wrapper.cc | 152 ++++++++++++++++-- .../file/FileSystemDatasetFactory.java | 30 +++- .../apache/arrow/dataset/file/JniWrapper.java | 14 +- .../apache/arrow/dataset/jni/JniWrapper.java | 5 +- .../arrow/dataset/jni/NativeDataset.java | 14 +- .../dataset/scanner/FragmentScanOptions.java | 53 ++++++ .../arrow/dataset/scanner/ScanOptions.java | 21 +++ .../scanner/csv/CsvConvertOptions.java | 51 ++++++ .../scanner/csv/CsvFragmentScanOptions.java | 97 +++++++++++ .../substrait/TestAceroSubstraitConsumer.java | 46 ++++++ .../src/test/resources/data/student.csv | 4 + 12 files changed, 470 insertions(+), 24 deletions(-) create mode 100644 java/dataset/src/main/java/org/apache/arrow/dataset/scanner/FragmentScanOptions.java create mode 100644 java/dataset/src/main/java/org/apache/arrow/dataset/scanner/csv/CsvConvertOptions.java create mode 100644 java/dataset/src/main/java/org/apache/arrow/dataset/scanner/csv/CsvFragmentScanOptions.java create mode 100644 java/dataset/src/test/resources/data/student.csv diff --git a/java/dataset/pom.xml b/java/dataset/pom.xml index 3dea16204a4db..269bf8f0e0b02 100644 --- a/java/dataset/pom.xml +++ b/java/dataset/pom.xml @@ -25,6 +25,7 @@ ../../../cpp/release-build/ 1.13.1 1.11.3 + 3.25.3 @@ -48,6 +49,12 @@ org.immutables value-annotations + + com.google.protobuf + protobuf-java + ${protobuf.version} + provided + org.apache.arrow arrow-memory-netty diff --git a/java/dataset/src/main/cpp/jni_wrapper.cc b/java/dataset/src/main/cpp/jni_wrapper.cc index 79efbeb74fc54..399f2188dc7c3 100644 --- a/java/dataset/src/main/cpp/jni_wrapper.cc +++ b/java/dataset/src/main/cpp/jni_wrapper.cc @@ -19,12 +19,15 @@ #include #include +#include + #include "arrow/array.h" #include "arrow/array/concatenate.h" #include "arrow/c/bridge.h" #include "arrow/c/helpers.h" #include "arrow/dataset/api.h" #include "arrow/dataset/file_base.h" +#include "arrow/dataset/file_csv.h" #include "arrow/filesystem/api.h" #include "arrow/filesystem/path_util.h" #include "arrow/engine/substrait/util.h" @@ -363,6 +366,92 @@ std::shared_ptr LoadArrowBufferFromByteBuffer(JNIEnv* env, jobjec return buffer; } +arrow::Status ParseFromBufferImpl(const arrow::Buffer& buf, const std::string& full_name, + google::protobuf::Message* message) { + google::protobuf::io::ArrayInputStream buf_stream{buf.data(), + static_cast(buf.size())}; + + if (message->ParseFromZeroCopyStream(&buf_stream)) { + return arrow::Status::OK(); + } + return arrow::Status::IOError("ParseFromZeroCopyStream failed for ", full_name); +} + +template +arrow::Result ParseFromBuffer(const arrow::Buffer& buf) { + Message message; + ARROW_RETURN_NOT_OK( + ParseFromBufferImpl(buf, Message::descriptor()->full_name(), &message)); + return message; +} + +arrow::Status FromProto(const google::protobuf::Struct& struct, + std::unordered_map& out) { + if (struct.fields().empty()) { + return arrow::Status::OK(); + } + for (const auto& [name, value] : struct.fields()) { + out.emplace(name, value.string_value()); + } + return arrow::Status::OK(); +} + +/// \brief Deserialize a Protobuf Struct message to the map +/// +/// \param[in] buf a buffer containing the protobuf serialization of a Protobuf Struct +/// \param[out] out deserialize to this map. +arrow::Status DeserializeMap(const arrow::Buffer& buf, + std::unordered_map& out) { + ARROW_ASSIGN_OR_RAISE(auto struct, ParseFromBuffer(buf)); + return FromProto(struct, out); +} + +inline bool ParseBool(const std::string& value) { return value == "true" ? true : false; } + +/// \brief Construct FragmentScanOptions from config map +#ifdef ARROW_CSV +arrow::Result> +ToCsvFragmentScanOptions(const std::unordered_map& configs) { + std::shared_ptr options = + std::make_shared(); + for (auto const& it : configs) { + auto& key = it.first; + auto& value = it.second; + if (key == "delimiter") { + options->parse_options.delimiter = value.data()[0]; + } else if (key == "quoting") { + options->parse_options.quoting = ParseBool(value); + } else if (key == "column_types") { + int64_t schema_address = std::stol(value); + ArrowSchema* cSchema = reinterpret_cast(schema_address); + ARROW_ASSIGN_OR_RAISE(auto schema, arrow::ImportSchema(cSchema)); + auto& column_types = options->convert_options.column_types; + for (auto field : schema->fields()) { + column_types[field->name()] = field->type(); + } + } else if (key == "strings_can_be_null") { + options->convert_options.strings_can_be_null = ParseBool(value); + } else { + return arrow::Status::Invalid("Config " + it.first + " is not supported."); + } + } + return options; +} +#endif + +arrow::Result> +GetFragmentScanOptions(jint file_format_id, + const std::unordered_map& configs) { + switch (file_format_id) { +#ifdef ARROW_CSV + case 3: + return ToCsvFragmentScanOptions(configs); +#endif + default: + return arrow::Status::Invalid("Illegal file format id: ", file_format_id); + } +} + /* * Class: org_apache_arrow_dataset_jni_NativeMemoryPool * Method: getDefaultMemoryPool @@ -501,12 +590,13 @@ JNIEXPORT void JNICALL Java_org_apache_arrow_dataset_jni_JniWrapper_closeDataset /* * Class: org_apache_arrow_dataset_jni_JniWrapper * Method: createScanner - * Signature: (J[Ljava/lang/String;Ljava/nio/ByteBuffer;Ljava/nio/ByteBuffer;JJ)J + * Signature: + * (J[Ljava/lang/String;Ljava/nio/ByteBuffer;Ljava/nio/ByteBuffer;JJ;Ljava/nio/ByteBuffer;J)J */ JNIEXPORT jlong JNICALL Java_org_apache_arrow_dataset_jni_JniWrapper_createScanner( JNIEnv* env, jobject, jlong dataset_id, jobjectArray columns, - jobject substrait_projection, jobject substrait_filter, - jlong batch_size, jlong memory_pool_id) { + jobject substrait_projection, jobject substrait_filter, jlong batch_size, + jlong file_format_id, jobject options, jlong memory_pool_id) { JNI_METHOD_START arrow::MemoryPool* pool = reinterpret_cast(memory_pool_id); if (pool == nullptr) { @@ -555,6 +645,14 @@ JNIEXPORT jlong JNICALL Java_org_apache_arrow_dataset_jni_JniWrapper_createScann } JniAssertOkOrThrow(scanner_builder->Filter(*filter_expr)); } + if (file_format_id != -1 && options != nullptr) { + std::unordered_map option_map; + std::shared_ptr buffer = LoadArrowBufferFromByteBuffer(env, options); + JniAssertOkOrThrow(DeserializeMap(*buffer, option_map)); + std::shared_ptr scan_options = + JniGetOrThrow(GetFragmentScanOptions(file_format_id, option_map)); + JniAssertOkOrThrow(scanner_builder->FragmentScanOptions(scan_options)); + } JniAssertOkOrThrow(scanner_builder->BatchSize(batch_size)); auto scanner = JniGetOrThrow(scanner_builder->Finish()); @@ -668,14 +766,31 @@ JNIEXPORT void JNICALL Java_org_apache_arrow_dataset_jni_JniWrapper_ensureS3Fina /* * Class: org_apache_arrow_dataset_file_JniWrapper * Method: makeFileSystemDatasetFactory - * Signature: (Ljava/lang/String;II)J + * Signature: (Ljava/lang/String;II;Ljava/lang/String;Ljava/nio/ByteBuffer)J */ JNIEXPORT jlong JNICALL -Java_org_apache_arrow_dataset_file_JniWrapper_makeFileSystemDatasetFactory__Ljava_lang_String_2I( - JNIEnv* env, jobject, jstring uri, jint file_format_id) { +Java_org_apache_arrow_dataset_file_JniWrapper_makeFileSystemDatasetFactory( + JNIEnv* env, jobject, jstring uri, jint file_format_id, jobject options) { JNI_METHOD_START std::shared_ptr file_format = JniGetOrThrow(GetFileFormat(file_format_id)); + if (options != nullptr) { + std::unordered_map option_map; + std::shared_ptr buffer = LoadArrowBufferFromByteBuffer(env, options); + JniAssertOkOrThrow(DeserializeMap(*buffer, option_map)); + std::shared_ptr scan_options = + JniGetOrThrow(GetFragmentScanOptions(file_format_id, option_map)); + file_format->default_fragment_scan_options = scan_options; +#ifdef ARROW_CSV + if (file_format_id == 3) { + std::shared_ptr csv_file_format = + std::dynamic_pointer_cast(file_format); + csv_file_format->parse_options = + std::dynamic_pointer_cast(scan_options) + ->parse_options; + } +#endif + } arrow::dataset::FileSystemFactoryOptions options; std::shared_ptr d = JniGetOrThrow(arrow::dataset::FileSystemDatasetFactory::Make( @@ -686,16 +801,33 @@ Java_org_apache_arrow_dataset_file_JniWrapper_makeFileSystemDatasetFactory__Ljav /* * Class: org_apache_arrow_dataset_file_JniWrapper - * Method: makeFileSystemDatasetFactory - * Signature: ([Ljava/lang/String;II)J + * Method: makeFileSystemDatasetFactoryWithFiles + * Signature: ([Ljava/lang/String;II;Ljava/nio/ByteBuffer)J */ JNIEXPORT jlong JNICALL -Java_org_apache_arrow_dataset_file_JniWrapper_makeFileSystemDatasetFactory___3Ljava_lang_String_2I( - JNIEnv* env, jobject, jobjectArray uris, jint file_format_id) { +Java_org_apache_arrow_dataset_file_JniWrapper_makeFileSystemDatasetFactoryWithFiles( + JNIEnv* env, jobject, jobjectArray uris, jint file_format_id, jobject options) { JNI_METHOD_START std::shared_ptr file_format = JniGetOrThrow(GetFileFormat(file_format_id)); + if (options != nullptr) { + std::unordered_map option_map; + std::shared_ptr buffer = LoadArrowBufferFromByteBuffer(env, options); + JniAssertOkOrThrow(DeserializeMap(*buffer, option_map)); + std::shared_ptr scan_options = + JniGetOrThrow(GetFragmentScanOptions(file_format_id, option_map)); + file_format->default_fragment_scan_options = scan_options; +#ifdef ARROW_CSV + if (file_format_id == 3) { + std::shared_ptr csv_file_format = + std::dynamic_pointer_cast(file_format); + csv_file_format->parse_options = + std::dynamic_pointer_cast(scan_options) + ->parse_options; + } +#endif + } arrow::dataset::FileSystemFactoryOptions options; std::vector uri_vec = ToStringVector(env, uris); diff --git a/java/dataset/src/main/java/org/apache/arrow/dataset/file/FileSystemDatasetFactory.java b/java/dataset/src/main/java/org/apache/arrow/dataset/file/FileSystemDatasetFactory.java index aa315690592ee..a0b6fb168eca9 100644 --- a/java/dataset/src/main/java/org/apache/arrow/dataset/file/FileSystemDatasetFactory.java +++ b/java/dataset/src/main/java/org/apache/arrow/dataset/file/FileSystemDatasetFactory.java @@ -17,8 +17,11 @@ package org.apache.arrow.dataset.file; +import java.util.Optional; + import org.apache.arrow.dataset.jni.NativeDatasetFactory; import org.apache.arrow.dataset.jni.NativeMemoryPool; +import org.apache.arrow.dataset.scanner.FragmentScanOptions; import org.apache.arrow.memory.BufferAllocator; /** @@ -27,21 +30,34 @@ public class FileSystemDatasetFactory extends NativeDatasetFactory { public FileSystemDatasetFactory(BufferAllocator allocator, NativeMemoryPool memoryPool, FileFormat format, - String uri) { - super(allocator, memoryPool, createNative(format, uri)); + String uri, Optional fragmentScanOptions) { + super(allocator, memoryPool, createNative(format, uri, fragmentScanOptions)); + } + + public FileSystemDatasetFactory(BufferAllocator allocator, NativeMemoryPool memoryPool, FileFormat format, + String uri) { + super(allocator, memoryPool, createNative(format, uri, Optional.empty())); + } + + public FileSystemDatasetFactory(BufferAllocator allocator, NativeMemoryPool memoryPool, FileFormat format, + String[] uris, Optional fragmentScanOptions) { + super(allocator, memoryPool, createNative(format, uris, fragmentScanOptions)); } public FileSystemDatasetFactory(BufferAllocator allocator, NativeMemoryPool memoryPool, FileFormat format, String[] uris) { - super(allocator, memoryPool, createNative(format, uris)); + super(allocator, memoryPool, createNative(format, uris, Optional.empty())); } - private static long createNative(FileFormat format, String uri) { - return JniWrapper.get().makeFileSystemDatasetFactory(uri, format.id()); + private static long createNative(FileFormat format, String uri, Optional fragmentScanOptions) { + return JniWrapper.get().makeFileSystemDatasetFactory(uri, format.id(), + fragmentScanOptions.map(FragmentScanOptions::serialize).orElse(null)); } - private static long createNative(FileFormat format, String[] uris) { - return JniWrapper.get().makeFileSystemDatasetFactory(uris, format.id()); + private static long createNative(FileFormat format, String[] uris, + Optional fragmentScanOptions) { + return JniWrapper.get().makeFileSystemDatasetFactoryWithFiles(uris, format.id(), + fragmentScanOptions.map(FragmentScanOptions::serialize).orElse(null)); } } diff --git a/java/dataset/src/main/java/org/apache/arrow/dataset/file/JniWrapper.java b/java/dataset/src/main/java/org/apache/arrow/dataset/file/JniWrapper.java index c3a1a4e58a140..3448ee2f24f82 100644 --- a/java/dataset/src/main/java/org/apache/arrow/dataset/file/JniWrapper.java +++ b/java/dataset/src/main/java/org/apache/arrow/dataset/file/JniWrapper.java @@ -17,6 +17,8 @@ package org.apache.arrow.dataset.file; +import java.nio.ByteBuffer; + import org.apache.arrow.dataset.jni.JniLoader; /** @@ -39,22 +41,26 @@ private JniWrapper() { * intermediate shared_ptr of the factory instance. * * @param uri file uri to read, either a file or a directory - * @param fileFormat file format ID + * @param fileFormat file format ID. + * @param serializedFragmentScanOptions serialized FragmentScanOptions. * @return the native pointer of the arrow::dataset::FileSystemDatasetFactory instance. * @see FileFormat */ - public native long makeFileSystemDatasetFactory(String uri, int fileFormat); + public native long makeFileSystemDatasetFactory(String uri, int fileFormat, + ByteBuffer serializedFragmentScanOptions); /** * Create FileSystemDatasetFactory and return its native pointer. The pointer is pointing to a * intermediate shared_ptr of the factory instance. * * @param uris List of file uris to read, each path pointing to an individual file - * @param fileFormat file format ID + * @param fileFormat file format ID. + * @param serializedFragmentScanOptions serialized FragmentScanOptions. * @return the native pointer of the arrow::dataset::FileSystemDatasetFactory instance. * @see FileFormat */ - public native long makeFileSystemDatasetFactory(String[] uris, int fileFormat); + public native long makeFileSystemDatasetFactoryWithFiles(String[] uris, int fileFormat, + ByteBuffer serializedFragmentScanOptions); /** * Write the content in a {@link org.apache.arrow.c.ArrowArrayStream} into files. This internally diff --git a/java/dataset/src/main/java/org/apache/arrow/dataset/jni/JniWrapper.java b/java/dataset/src/main/java/org/apache/arrow/dataset/jni/JniWrapper.java index 637a3e8f22a9a..0d53d6fd83790 100644 --- a/java/dataset/src/main/java/org/apache/arrow/dataset/jni/JniWrapper.java +++ b/java/dataset/src/main/java/org/apache/arrow/dataset/jni/JniWrapper.java @@ -76,11 +76,14 @@ private JniWrapper() { * @param substraitProjection substrait extended expression to evaluate for project new columns * @param substraitFilter substrait extended expression to evaluate for apply filter * @param batchSize batch size of scanned record batches. + * @param fileFormat file format ID. + * @param serializedFragmentScanOptions serialized FragmentScanOptions. * @param memoryPool identifier of memory pool used in the native scanner. * @return the native pointer of the arrow::dataset::Scanner instance. */ public native long createScanner(long datasetId, String[] columns, ByteBuffer substraitProjection, - ByteBuffer substraitFilter, long batchSize, long memoryPool); + ByteBuffer substraitFilter, long batchSize, long fileFormat, + ByteBuffer serializedFragmentScanOptions, long memoryPool); /** * Get a serialized schema from native instance of a Scanner. diff --git a/java/dataset/src/main/java/org/apache/arrow/dataset/jni/NativeDataset.java b/java/dataset/src/main/java/org/apache/arrow/dataset/jni/NativeDataset.java index d9abad9971c4e..3a96fe768761c 100644 --- a/java/dataset/src/main/java/org/apache/arrow/dataset/jni/NativeDataset.java +++ b/java/dataset/src/main/java/org/apache/arrow/dataset/jni/NativeDataset.java @@ -17,6 +17,9 @@ package org.apache.arrow.dataset.jni; +import java.nio.ByteBuffer; + +import org.apache.arrow.dataset.scanner.FragmentScanOptions; import org.apache.arrow.dataset.scanner.ScanOptions; import org.apache.arrow.dataset.source.Dataset; @@ -40,11 +43,18 @@ public synchronized NativeScanner newScan(ScanOptions options) { if (closed) { throw new NativeInstanceReleasedException(); } - + int fileFormat = -1; + ByteBuffer serialized = null; + if (options.getFragmentScanOptions().isPresent()) { + FragmentScanOptions fragmentScanOptions = options.getFragmentScanOptions().get(); + fileFormat = fragmentScanOptions.fileFormatId(); + serialized = fragmentScanOptions.serialize(); + } long scannerId = JniWrapper.get().createScanner(datasetId, options.getColumns().orElse(null), options.getSubstraitProjection().orElse(null), options.getSubstraitFilter().orElse(null), - options.getBatchSize(), context.getMemoryPool().getNativeInstanceId()); + options.getBatchSize(), fileFormat, serialized, + context.getMemoryPool().getNativeInstanceId()); return new NativeScanner(context, scannerId); } diff --git a/java/dataset/src/main/java/org/apache/arrow/dataset/scanner/FragmentScanOptions.java b/java/dataset/src/main/java/org/apache/arrow/dataset/scanner/FragmentScanOptions.java new file mode 100644 index 0000000000000..bbfbb84a3d1ef --- /dev/null +++ b/java/dataset/src/main/java/org/apache/arrow/dataset/scanner/FragmentScanOptions.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.dataset.scanner; + +import java.nio.ByteBuffer; +import java.util.Map; + +import com.google.protobuf.MapEntry; +import com.google.protobuf.Struct; +import com.google.protobuf.Value; + +public interface FragmentScanOptions { + String typeName(); + + int fileFormatId(); + + ByteBuffer serialize(); + + /** + * Serialize the map to Protobuf Struct. + * + * @param config config map + * @return buffer to jni call argument, should be DirectByteBuffer + */ + default ByteBuffer serializeMap(Map config) { + if (config.isEmpty()) { + return null; + } + + Struct.Builder builder = Struct.newBuilder(); + config.forEach((k, v) -> builder.putFields(k, Value.newBuilder().setStringValue(v).build())); + Struct struct = builder.build(); + + ByteBuffer buf = ByteBuffer.allocateDirect(struct.getSerializedSize()); + buf.put(struct.toByteArray()); + return buf; + } +} diff --git a/java/dataset/src/main/java/org/apache/arrow/dataset/scanner/ScanOptions.java b/java/dataset/src/main/java/org/apache/arrow/dataset/scanner/ScanOptions.java index 995d05ac3b314..6072da5aa1b71 100644 --- a/java/dataset/src/main/java/org/apache/arrow/dataset/scanner/ScanOptions.java +++ b/java/dataset/src/main/java/org/apache/arrow/dataset/scanner/ScanOptions.java @@ -31,6 +31,8 @@ public class ScanOptions { private final Optional substraitProjection; private final Optional substraitFilter; + private final Optional fragmentScanOptions; + /** * Constructor. * @param columns Projected columns. Empty for scanning all columns. @@ -61,6 +63,7 @@ public ScanOptions(long batchSize, Optional columns) { this.columns = columns; this.substraitProjection = Optional.empty(); this.substraitFilter = Optional.empty(); + this.fragmentScanOptions = Optional.empty(); } public ScanOptions(long batchSize) { @@ -83,6 +86,10 @@ public Optional getSubstraitFilter() { return substraitFilter; } + public Optional getFragmentScanOptions() { + return fragmentScanOptions; + } + /** * Builder for Options used during scanning. */ @@ -91,6 +98,7 @@ public static class Builder { private Optional columns; private ByteBuffer substraitProjection; private ByteBuffer substraitFilter; + private FragmentScanOptions fragmentScanOptions; /** * Constructor. @@ -136,6 +144,18 @@ public Builder substraitFilter(ByteBuffer substraitFilter) { return this; } + /** + * Set the FragmentScanOptions. + * + * @param fragmentScanOptions fragment scan options + * @return the ScanOptions configured. + */ + public Builder fragmentScanOptions(FragmentScanOptions fragmentScanOptions) { + Preconditions.checkNotNull(fragmentScanOptions); + this.fragmentScanOptions = fragmentScanOptions; + return this; + } + public ScanOptions build() { return new ScanOptions(this); } @@ -146,5 +166,6 @@ private ScanOptions(Builder builder) { columns = builder.columns; substraitProjection = Optional.ofNullable(builder.substraitProjection); substraitFilter = Optional.ofNullable(builder.substraitFilter); + fragmentScanOptions = Optional.ofNullable(builder.fragmentScanOptions); } } diff --git a/java/dataset/src/main/java/org/apache/arrow/dataset/scanner/csv/CsvConvertOptions.java b/java/dataset/src/main/java/org/apache/arrow/dataset/scanner/csv/CsvConvertOptions.java new file mode 100644 index 0000000000000..08e35ede2adb5 --- /dev/null +++ b/java/dataset/src/main/java/org/apache/arrow/dataset/scanner/csv/CsvConvertOptions.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.dataset.scanner.csv; + +import java.util.Map; +import java.util.Optional; + +import org.apache.arrow.c.ArrowSchema; + +public class CsvConvertOptions { + + private final Map configs; + + private Optional cSchema = Optional.empty(); + + public CsvConvertOptions(Map configs) { + this.configs = configs; + } + + public Optional getArrowSchema() { + return cSchema; + } + + public Map getConfigs() { + return configs; + } + + public void set(String key, String value) { + configs.put(key, value); + } + + public void setArrowSchema(ArrowSchema cSchema) { + this.cSchema = Optional.of(cSchema); + } + +} diff --git a/java/dataset/src/main/java/org/apache/arrow/dataset/scanner/csv/CsvFragmentScanOptions.java b/java/dataset/src/main/java/org/apache/arrow/dataset/scanner/csv/CsvFragmentScanOptions.java new file mode 100644 index 0000000000000..88973f0a04a41 --- /dev/null +++ b/java/dataset/src/main/java/org/apache/arrow/dataset/scanner/csv/CsvFragmentScanOptions.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.dataset.scanner.csv; + +import java.io.Serializable; +import java.nio.ByteBuffer; +import java.util.Locale; +import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import org.apache.arrow.dataset.file.FileFormat; +import org.apache.arrow.dataset.scanner.FragmentScanOptions; + +public class CsvFragmentScanOptions implements Serializable, FragmentScanOptions { + private final CsvConvertOptions convertOptions; + private final Map readOptions; + private final Map parseOptions; + + + /** + * csv scan options, map to CPP struct CsvFragmentScanOptions. + * + * @param convertOptions same struct in CPP + * @param readOptions same struct in CPP + * @param parseOptions same struct in CPP + */ + public CsvFragmentScanOptions(CsvConvertOptions convertOptions, + Map readOptions, + Map parseOptions) { + this.convertOptions = convertOptions; + this.readOptions = readOptions; + this.parseOptions = parseOptions; + } + + public String typeName() { + return FileFormat.CSV.name().toLowerCase(Locale.ROOT); + } + + /** + * File format id. + * + * @return id + */ + public int fileFormatId() { + return FileFormat.CSV.id(); + } + + /** + * Serialize this class to ByteBuffer and then called by jni call. + * + * @return DirectByteBuffer + */ + public ByteBuffer serialize() { + Map options = Stream.concat(Stream.concat(readOptions.entrySet().stream(), + parseOptions.entrySet().stream()), + convertOptions.getConfigs().entrySet().stream()).collect( + Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + + if (convertOptions.getArrowSchema().isPresent()) { + options.put("column_types", Long.toString(convertOptions.getArrowSchema().get().memoryAddress())); + } + return serializeMap(options); + } + + public static CsvFragmentScanOptions deserialize(String serialized) { + throw new UnsupportedOperationException("Not implemented now"); + } + + public CsvConvertOptions getConvertOptions() { + return convertOptions; + } + + public Map getReadOptions() { + return readOptions; + } + + public Map getParseOptions() { + return parseOptions; + } + +} diff --git a/java/dataset/src/test/java/org/apache/arrow/dataset/substrait/TestAceroSubstraitConsumer.java b/java/dataset/src/test/java/org/apache/arrow/dataset/substrait/TestAceroSubstraitConsumer.java index 0fba72892cdc6..e7903b7a4eda7 100644 --- a/java/dataset/src/test/java/org/apache/arrow/dataset/substrait/TestAceroSubstraitConsumer.java +++ b/java/dataset/src/test/java/org/apache/arrow/dataset/substrait/TestAceroSubstraitConsumer.java @@ -31,6 +31,9 @@ import java.util.Map; import java.util.Optional; +import org.apache.arrow.c.ArrowSchema; +import org.apache.arrow.c.CDataDictionaryProvider; +import org.apache.arrow.c.Data; import org.apache.arrow.dataset.ParquetWriteSupport; import org.apache.arrow.dataset.TestDataset; import org.apache.arrow.dataset.file.FileFormat; @@ -38,8 +41,11 @@ import org.apache.arrow.dataset.jni.NativeMemoryPool; import org.apache.arrow.dataset.scanner.ScanOptions; import org.apache.arrow.dataset.scanner.Scanner; +import org.apache.arrow.dataset.scanner.csv.CsvConvertOptions; +import org.apache.arrow.dataset.scanner.csv.CsvFragmentScanOptions; import org.apache.arrow.dataset.source.Dataset; import org.apache.arrow.dataset.source.DatasetFactory; +import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.vector.ipc.ArrowReader; import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.Field; @@ -49,6 +55,8 @@ import org.junit.Test; import org.junit.rules.TemporaryFolder; +import com.google.common.collect.ImmutableMap; + public class TestAceroSubstraitConsumer extends TestDataset { @ClassRule @@ -457,4 +465,42 @@ private static ByteBuffer getByteBuffer(String base64EncodedSubstrait) { substraitExpression.put(decodedSubstrait); return substraitExpression; } + + @Test + public void testCsvConvertOptions() throws Exception { + final Schema schema = new Schema(Arrays.asList( + Field.nullable("Id", new ArrowType.Int(32, true)), + Field.nullable("Name", new ArrowType.Utf8()), + Field.nullable("Language", new ArrowType.Utf8()) + ), null); + String path = "file://" + getClass().getResource("/").getPath() + "/data/student.csv"; + BufferAllocator allocator = rootAllocator(); + try (ArrowSchema cSchema = ArrowSchema.allocateNew(allocator); + CDataDictionaryProvider provider = new CDataDictionaryProvider()) { + Data.exportSchema(allocator, schema, provider, cSchema); + CsvConvertOptions convertOptions = new CsvConvertOptions(ImmutableMap.of("delimiter", ";")); + convertOptions.setArrowSchema(cSchema); + CsvFragmentScanOptions fragmentScanOptions = new CsvFragmentScanOptions( + convertOptions, ImmutableMap.of(), ImmutableMap.of()); + ScanOptions options = new ScanOptions.Builder(/*batchSize*/ 32768) + .columns(Optional.empty()) + .fragmentScanOptions(fragmentScanOptions) + .build(); + try ( + DatasetFactory datasetFactory = new FileSystemDatasetFactory(allocator, NativeMemoryPool.getDefault(), + FileFormat.CSV, path); + Dataset dataset = datasetFactory.finish(); + Scanner scanner = dataset.newScan(options); + ArrowReader reader = scanner.scanBatches() + ) { + assertEquals(schema.getFields(), reader.getVectorSchemaRoot().getSchema().getFields()); + int rowCount = 0; + while (reader.loadNextBatch()) { + assertEquals("[1, 2, 3]", reader.getVectorSchemaRoot().getVector("Id").toString()); + rowCount += reader.getVectorSchemaRoot().getRowCount(); + } + assertEquals(3, rowCount); + } + } + } } diff --git a/java/dataset/src/test/resources/data/student.csv b/java/dataset/src/test/resources/data/student.csv new file mode 100644 index 0000000000000..3291946092156 --- /dev/null +++ b/java/dataset/src/test/resources/data/student.csv @@ -0,0 +1,4 @@ +Id;Name;Language +1;Juno;Java +2;Peter;Python +3;Celin;C++