Skip to content

Commit

Permalink
Include minimal changeset
Browse files Browse the repository at this point in the history
  • Loading branch information
RustedBones committed Sep 20, 2024
1 parent 2312dbe commit a4ad8d3
Show file tree
Hide file tree
Showing 10 changed files with 326 additions and 354 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.avro.Conversions;
import org.apache.avro.LogicalType;
import org.apache.avro.LogicalTypes;
Expand All @@ -58,6 +60,7 @@
*/
class BigQueryAvroUtils {

// org.apache.avro.LogicalType
static class DateTimeLogicalType extends LogicalType {
public DateTimeLogicalType() {
super("datetime");
Expand Down Expand Up @@ -112,6 +115,7 @@ static Schema getPrimitiveType(TableFieldSchema schema, Boolean useAvroLogicalTy
} else if (bqType.equals("NUMERIC")) {
logicalType = LogicalTypes.decimal(38, 9);
} else {
// BIGNUMERIC
logicalType = LogicalTypes.decimal(77, 38);
}
return logicalType.addToSchema(SchemaBuilder.builder().bytesType());
Expand Down Expand Up @@ -230,6 +234,41 @@ private static String formatTime(long timeMicros) {
return LocalTime.ofNanoOfDay(timeMicros * 1000).format(formatter);
}

static TableSchema trimBigQueryTableSchema(TableSchema inputSchema, Schema avroSchema) {
List<TableFieldSchema> subSchemas =
inputSchema.getFields().stream()
.flatMap(fieldSchema -> mapTableFieldSchema(fieldSchema, avroSchema))
.collect(Collectors.toList());

return new TableSchema().setFields(subSchemas);
}

private static Stream<TableFieldSchema> mapTableFieldSchema(
TableFieldSchema fieldSchema, Schema avroSchema) {
Field avroFieldSchema = avroSchema.getField(fieldSchema.getName());
if (avroFieldSchema == null) {
return Stream.empty();
} else if (avroFieldSchema.schema().getType() != Type.RECORD) {
return Stream.of(fieldSchema);
}

List<TableFieldSchema> subSchemas =
fieldSchema.getFields().stream()
.flatMap(subSchema -> mapTableFieldSchema(subSchema, avroFieldSchema.schema()))
.collect(Collectors.toList());

TableFieldSchema output =
new TableFieldSchema()
.setCategories(fieldSchema.getCategories())
.setDescription(fieldSchema.getDescription())
.setFields(subSchemas)
.setMode(fieldSchema.getMode())
.setName(fieldSchema.getName())
.setType(fieldSchema.getType());

return Stream.of(output);
}

/**
* Utility function to convert from an Avro {@link GenericRecord} to a BigQuery {@link TableRow}.
*
Expand Down Expand Up @@ -303,8 +342,6 @@ private static Object convertRequiredField(String name, Schema schema, Object v)
// INTEGER type maps to an Avro LONG type.
checkNotNull(v, "REQUIRED field %s should not be null", name);

// For historical reasons, don't validate avroLogicalType except for with NUMERIC.
// BigQuery represents NUMERIC in Avro format as BYTES with a DECIMAL logical type.
Type type = schema.getType();
LogicalType logicalType = schema.getLogicalType();
switch (type) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1272,12 +1272,8 @@ public PCollection<T> expand(PBegin input) {

Schema beamSchema = null;
if (getTypeDescriptor() != null && getToBeamRowFn() != null && getFromBeamRowFn() != null) {
TableSchema tableSchema = sourceDef.getTableSchema(bqOptions);
ValueProvider<List<String>> selectedFields = getSelectedFields();
if (selectedFields != null) {
tableSchema = BigQueryUtils.trimSchema(tableSchema, selectedFields.get());
}
beamSchema = BigQueryUtils.fromTableSchema(tableSchema);
beamSchema = sourceDef.getBeamSchema(bqOptions);
beamSchema = getFinalSchema(beamSchema, getSelectedFields());
}

final Coder<T> coder = inferCoder(p.getCoderRegistry());
Expand Down Expand Up @@ -1442,6 +1438,24 @@ void cleanup(PassThroughThenCleanup.ContextContainer c) throws Exception {
return rows;
}

private static Schema getFinalSchema(
Schema beamSchema, ValueProvider<List<String>> selectedFields) {
List<Schema.Field> flds =
beamSchema.getFields().stream()
.filter(
field -> {
if (selectedFields != null
&& selectedFields.isAccessible()
&& selectedFields.get() != null) {
return selectedFields.get().contains(field.getName());
} else {
return true;
}
})
.collect(Collectors.toList());
return Schema.builder().addFields(flds).build();
}

private PCollection<T> expandForDirectRead(
PBegin input, Coder<T> outputCoder, Schema beamSchema, BigQueryOptions bqOptions) {
ValueProvider<TableReference> tableProvider = getTableProvider();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ public <T> BigQuerySourceBase<T> toSource(

/** {@inheritDoc} */
@Override
public TableSchema getTableSchema(BigQueryOptions bqOptions) {
public Schema getBeamSchema(BigQueryOptions bqOptions) {
try {
JobStatistics stats =
BigQueryQueryHelper.dryRunQueryIfNeeded(
Expand All @@ -189,20 +189,14 @@ public TableSchema getTableSchema(BigQueryOptions bqOptions) {
flattenResults,
useLegacySql,
location);
return stats.getQuery().getSchema();
TableSchema tableSchema = stats.getQuery().getSchema();
return BigQueryUtils.fromTableSchema(tableSchema);
} catch (IOException | InterruptedException | NullPointerException e) {
throw new BigQuerySchemaRetrievalException(
"Exception while trying to retrieve schema of query", e);
}
}

/** {@inheritDoc} */
@Override
public Schema getBeamSchema(BigQueryOptions bqOptions) {
TableSchema tableSchema = getTableSchema(bqOptions);
return BigQueryUtils.fromTableSchema(tableSchema);
}

ValueProvider<String> getQuery() {
return query;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,6 @@ <T> BigQuerySourceBase<T> toSource(
SerializableFunction<TableSchema, AvroSource.DatumReaderFactory<T>> readerFactory,
boolean useAvroLogicalTypes);

/**
* Extract the {@link TableSchema} corresponding to this source.
*
* @param bqOptions BigQueryOptions
* @return table schema of the source
* @throws BigQuerySchemaRetrievalException if schema retrieval fails
*/
TableSchema getTableSchema(BigQueryOptions bqOptions);

/**
* Extract the Beam {@link Schema} corresponding to this source.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@
import com.google.cloud.bigquery.storage.v1.ReadStream;
import java.io.IOException;
import java.util.List;
import org.apache.avro.Schema;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.extensions.arrow.ArrowConversion;
import org.apache.beam.sdk.extensions.avro.schemas.utils.AvroUtils;
import org.apache.beam.sdk.io.BoundedSource;
import org.apache.beam.sdk.io.gcp.bigquery.BigQueryServices.StorageClient;
import org.apache.beam.sdk.metrics.Lineage;
Expand Down Expand Up @@ -179,17 +182,30 @@ public List<BigQueryStorageStreamSource<T>> split(
LOG.info("Read session returned {} streams", readSession.getStreamsList().size());
}

// TODO: this is inconsistent with method above, where it can be null
Preconditions.checkStateNotNull(targetTable);
TableSchema tableSchema = targetTable.getSchema();
if (selectedFieldsProvider != null) {
tableSchema = BigQueryUtils.trimSchema(tableSchema, selectedFieldsProvider.get());
Schema sessionSchema;
if (readSession.getDataFormat() == DataFormat.ARROW) {
org.apache.arrow.vector.types.pojo.Schema schema =
ArrowConversion.arrowSchemaFromInput(
readSession.getArrowSchema().getSerializedSchema().newInput());
org.apache.beam.sdk.schemas.Schema beamSchema =
ArrowConversion.ArrowSchemaTranslator.toBeamSchema(schema);
sessionSchema = AvroUtils.toAvroSchema(beamSchema);
} else if (readSession.getDataFormat() == DataFormat.AVRO) {
sessionSchema = new Schema.Parser().parse(readSession.getAvroSchema().getSchema());
} else {
throw new IllegalArgumentException(
"data is not in a supported dataFormat: " + readSession.getDataFormat());
}

Preconditions.checkStateNotNull(
targetTable); // TODO: this is inconsistent with method above, where it can be null
TableSchema trimmedSchema =
BigQueryAvroUtils.trimBigQueryTableSchema(targetTable.getSchema(), sessionSchema);
List<BigQueryStorageStreamSource<T>> sources = Lists.newArrayList();
for (ReadStream readStream : readSession.getStreamsList()) {
sources.add(
BigQueryStorageStreamSource.create(
readSession, readStream, tableSchema, parseFn, outputCoder, bqServices));
readSession, readStream, trimmedSchema, parseFn, outputCoder, bqServices));
}

return ImmutableList.copyOf(sources);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,22 +102,16 @@ public <T> BigQuerySourceBase<T> toSource(

/** {@inheritDoc} */
@Override
public TableSchema getTableSchema(BigQueryOptions bqOptions) {
public Schema getBeamSchema(BigQueryOptions bqOptions) {
try {
try (DatasetService datasetService = bqServices.getDatasetService(bqOptions)) {
TableReference tableRef = getTableReference(bqOptions);
Table table = datasetService.getTable(tableRef);
return Preconditions.checkStateNotNull(table).getSchema();
TableSchema tableSchema = Preconditions.checkStateNotNull(table).getSchema();
return BigQueryUtils.fromTableSchema(tableSchema);
}
} catch (Exception e) {
throw new BigQuerySchemaRetrievalException("Exception while trying to retrieve schema", e);
}
}

/** {@inheritDoc} */
@Override
public Schema getBeamSchema(BigQueryOptions bqOptions) {
TableSchema tableSchema = getTableSchema(bqOptions);
return BigQueryUtils.fromTableSchema(tableSchema);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -310,44 +310,45 @@ static StandardSQLTypeName toStandardSQLTypeName(FieldType fieldType) {
*
* <p>Supports both standard and legacy SQL types.
*
* @param typeName Name of the type
* @param typeName Name of the type returned by {@link TableFieldSchema#getType()}
* @param nestedFields Nested fields for the given type (eg. RECORD type)
* @return Corresponding Beam {@link FieldType}
*/
private static FieldType fromTableFieldSchemaType(
String typeName, List<TableFieldSchema> nestedFields, SchemaConversionOptions options) {
// see
// https://googleapis.dev/java/google-api-services-bigquery/latest/com/google/api/services/bigquery/model/TableFieldSchema.html#getType--
switch (typeName) {
case "STRING":
return FieldType.STRING;
case "BYTES":
return FieldType.BYTES;
case "INT64":
case "INT":
case "SMALLINT":
case "INTEGER":
case "BIGINT":
case "TINYINT":
case "BYTEINT":
case "INT64":
return FieldType.INT64;
case "FLOAT":
case "FLOAT64":
case "FLOAT": // even if not a valid BQ type, it is used in the schema
return FieldType.DOUBLE;
case "BOOL":
case "BOOLEAN":
case "BOOL":
return FieldType.BOOLEAN;
case "NUMERIC":
case "BIGNUMERIC":
return FieldType.DECIMAL;
case "TIMESTAMP":
return FieldType.DATETIME;
case "TIME":
return FieldType.logicalType(SqlTypes.TIME);
case "DATE":
return FieldType.logicalType(SqlTypes.DATE);
case "TIME":
return FieldType.logicalType(SqlTypes.TIME);
case "DATETIME":
return FieldType.logicalType(SqlTypes.DATETIME);
case "STRUCT":
case "NUMERIC":
case "BIGNUMERIC":
return FieldType.DECIMAL;
case "GEOGRAPHY":
case "JSON":
// TODO Add metadata for custom sql types ?
return FieldType.STRING;
case "RECORD":
case "STRUCT":
if (options.getInferMaps() && nestedFields.size() == 2) {
TableFieldSchema key = nestedFields.get(0);
TableFieldSchema value = nestedFields.get(1);
Expand All @@ -358,13 +359,9 @@ private static FieldType fromTableFieldSchemaType(
fromTableFieldSchemaType(value.getType(), value.getFields(), options));
}
}

Schema rowSchema = fromTableFieldSchema(nestedFields, options);
return FieldType.row(rowSchema);
case "GEOGRAPHY":
case "JSON":
// TODO Add metadata for custom sql types
return FieldType.STRING;
case "RANGE": // TODO add support for range type
default:
throw new UnsupportedOperationException(
"Converting BigQuery type " + typeName + " to Beam type is unsupported");
Expand Down Expand Up @@ -463,8 +460,8 @@ public static org.apache.avro.Schema toGenericAvroSchema(TableSchema tableSchema

/** Convert a list of BigQuery {@link TableSchema} to Avro {@link org.apache.avro.Schema}. */
public static org.apache.avro.Schema toGenericAvroSchema(
TableSchema tableSchema, Boolean stringLogicalTypes) {
return toGenericAvroSchema("root", tableSchema.getFields(), stringLogicalTypes);
TableSchema tableSchema, Boolean useAvroLogicalTypes) {
return toGenericAvroSchema("root", tableSchema.getFields(), useAvroLogicalTypes);
}

/** Convert a list of BigQuery {@link TableFieldSchema} to Avro {@link org.apache.avro.Schema}. */
Expand All @@ -475,8 +472,8 @@ public static org.apache.avro.Schema toGenericAvroSchema(

/** Convert a list of BigQuery {@link TableFieldSchema} to Avro {@link org.apache.avro.Schema}. */
public static org.apache.avro.Schema toGenericAvroSchema(
String schemaName, List<TableFieldSchema> fieldSchemas, Boolean stringLogicalTypes) {
return BigQueryAvroUtils.toGenericAvroSchema(schemaName, fieldSchemas, stringLogicalTypes);
String schemaName, List<TableFieldSchema> fieldSchemas, Boolean useAvroLogicalTypes) {
return BigQueryAvroUtils.toGenericAvroSchema(schemaName, fieldSchemas, useAvroLogicalTypes);
}

private static final BigQueryIO.TypedRead.ToBeamRowFunction<TableRow>
Expand Down Expand Up @@ -1077,21 +1074,6 @@ private static Object convertAvroNumeric(Object value) {
return tableSpec;
}

static TableSchema trimSchema(TableSchema schema, @Nullable List<String> selectedFields) {
if (selectedFields == null || selectedFields.isEmpty()) {
return schema;
}

List<TableFieldSchema> fields = schema.getFields();
List<TableFieldSchema> trimmedFields = new ArrayList<>();
for (TableFieldSchema field : fields) {
if (selectedFields.contains(field.getName())) {
trimmedFields.add(field);
}
}
return new TableSchema().setFields(trimmedFields);
}

private static @Nullable ServiceCallMetric callMetricForMethod(
@Nullable TableReference tableReference, String method) {
if (tableReference != null) {
Expand Down
Loading

0 comments on commit a4ad8d3

Please sign in to comment.