diff --git a/api/src/main/scala/ai/chronon/api/Row.scala b/api/src/main/scala/ai/chronon/api/Row.scala index 73a69d860..86d24142e 100644 --- a/api/src/main/scala/ai/chronon/api/Row.scala +++ b/api/src/main/scala/ai/chronon/api/Row.scala @@ -40,6 +40,36 @@ trait Row { } } +/** + * SchemaTraverser aids in the traversal of the given SchemaType. + * In some cases (eg avro), it is more performant to create the + * top-level schema once and then traverse it top-to-bottom, rather + * than recreating at each node. + * + * This helper trait allows the Row.to function to traverse SchemaType + * without leaking details of the SchemaType structure. + */ +trait SchemaTraverser[SchemaType] { + + def currentNode: SchemaType + + // Returns the equivalent SchemaType representation of the given field + def getField(field: StructField): SchemaTraverser[SchemaType] + + // Returns the inner type of the current collection field type. + // Throws if the current type is not a collection. + def getCollectionType: SchemaTraverser[SchemaType] + + // Returns the key type of the current map field type. + // Throws if the current type is not a map. + def getMapKeyType: SchemaTraverser[SchemaType] + + // Returns the valye type of the current map field type. + // Throws if the current type is not a map. + def getMapValueType: SchemaTraverser[SchemaType] + +} + object Row { // recursively traverse a logical struct, and convert it chronon's row type def from[CompositeType, BinaryType, ArrayType, StringType]( @@ -95,49 +125,71 @@ object Row { } // recursively traverse a chronon dataType value, and convert it to an external type - def to[StructType, BinaryType, ListType, MapType](value: Any, + def to[StructType, BinaryType, ListType, MapType, OutputSchema](value: Any, dataType: DataType, - composer: (Iterator[Any], DataType) => StructType, + composer: (Iterator[Any], DataType, Option[OutputSchema]) => StructType, binarizer: Array[Byte] => BinaryType, collector: (Iterator[Any], Int) => ListType, mapper: (util.Map[Any, Any] => MapType), - extraneousRecord: Any => Array[Any] = null): Any = { + extraneousRecord: Any => Array[Any] = null, + schemaTraverser: Option[SchemaTraverser[OutputSchema]] = None + ): Any = { if (value == null) return null - def edit(value: Any, dataType: DataType): Any = - to(value, dataType, composer, binarizer, collector, mapper, extraneousRecord) + + def getFieldSchema(f: StructField) = schemaTraverser.map(_.getField(f)) + + def edit(value: Any, dataType: DataType, subTreeTraverser: Option[SchemaTraverser[OutputSchema]]): Any = + to(value, dataType, composer, binarizer, collector, mapper, extraneousRecord, subTreeTraverser) + dataType match { case StructType(_, fields) => value match { case arr: Array[Any] => - composer(arr.iterator.zipWithIndex.map { case (value, idx) => edit(value, fields(idx).fieldType) }, - dataType) + composer( + arr.iterator.zipWithIndex.map { + case (value, idx) => edit(value, fields(idx).fieldType, getFieldSchema(fields(idx))) + }, + dataType, + schemaTraverser.map(_.currentNode) + ) case list: util.ArrayList[Any] => - composer(list - .iterator() - .asScala - .zipWithIndex - .map { case (value, idx) => edit(value, fields(idx).fieldType) }, - dataType) - case list: List[Any] => - composer(list.iterator.zipWithIndex - .map { case (value, idx) => edit(value, fields(idx).fieldType) }, - dataType) + composer( + list + .iterator() + .asScala + .zipWithIndex + .map { case (value, idx) => edit(value, fields(idx).fieldType, getFieldSchema(fields(idx))) }, + dataType, + schemaTraverser.map(_.currentNode) + ) case value: Any => assert(extraneousRecord != null, s"No handler for $value of class ${value.getClass}") - composer(extraneousRecord(value).iterator.zipWithIndex.map { - case (value, idx) => edit(value, fields(idx).fieldType) - }, - dataType) + composer( + extraneousRecord(value).iterator.zipWithIndex.map { + case (value, idx) => edit(value, fields(idx).fieldType, getFieldSchema(fields(idx))) + }, + dataType, + schemaTraverser.map(_.currentNode) + ) } case ListType(elemType) => value match { case list: util.ArrayList[Any] => - collector(list.iterator().asScala.map(edit(_, elemType)), list.size()) - case arr: Array[_] => // avro only recognizes arrayList for its ArrayType/ListType - collector(arr.iterator.map(edit(_, elemType)), arr.length) + collector( + list.iterator().asScala.map(edit(_, elemType, schemaTraverser.map(_.getCollectionType))), + list.size() + ) + case arr: Array[_] => // avro only recognizes arrayList for its ArrayType/ListType + collector( + arr.iterator.map(edit(_, elemType, schemaTraverser.map(_.getCollectionType))), + arr.length + ) case arr: mutable.WrappedArray[Any] => // handles the wrapped array type from transform function in spark sql - collector(arr.iterator.map(edit(_, elemType)), arr.length) + collector( + arr.iterator.map(edit(_, elemType, schemaTraverser.map(_.getCollectionType))), + arr.length + ) } case MapType(keyType, valueType) => value match { @@ -147,12 +199,38 @@ object Row { .entrySet() .iterator() .asScala - .foreach { entry => newMap.put(edit(entry.getKey, keyType), edit(entry.getValue, valueType)) } + .foreach { + entry => newMap.put( + edit( + entry.getKey, + keyType, + schemaTraverser.map(_.getMapKeyType) + ), + edit( + entry.getValue, + valueType, + schemaTraverser.map(_.getMapValueType) + ) + ) + } mapper(newMap) case map: collection.immutable.Map[Any, Any] => val newMap = new util.HashMap[Any, Any](map.size) map - .foreach { entry => newMap.put(edit(entry._1, keyType), edit(entry._2, valueType)) } + .foreach { + entry => newMap.put( + edit( + entry._1, + keyType, + schemaTraverser.map(_.getMapKeyType) + ), + edit( + entry._2, + valueType, + schemaTraverser.map(_.getMapValueType) + ) + ) + } mapper(newMap) } case BinaryType => binarizer(value.asInstanceOf[Array[Byte]]) diff --git a/online/src/main/scala/ai/chronon/online/AvroConversions.scala b/online/src/main/scala/ai/chronon/online/AvroConversions.scala index 8582e71ec..b65c29d3d 100644 --- a/online/src/main/scala/ai/chronon/online/AvroConversions.scala +++ b/online/src/main/scala/ai/chronon/online/AvroConversions.scala @@ -111,14 +111,14 @@ object AvroConversions { s"Cannot convert chronon type $dataType to avro type. Cast it to string please") } } - - def fromChrononRow(value: Any, dataType: DataType, extraneousRecord: Any => Array[Any] = null): Any = { + + def fromChrononRow(value: Any, dataType: DataType, topLevelSchema: Schema, extraneousRecord: Any => Array[Any] = null): Any = { // But this also has to happen at the recursive depth - data type and schema inside the compositor need to - Row.to[GenericRecord, ByteBuffer, util.ArrayList[Any], util.Map[Any, Any]]( + Row.to[GenericRecord, ByteBuffer, util.ArrayList[Any], util.Map[Any, Any], Schema]( value, dataType, - { (data: Iterator[Any], elemDataType: DataType) => - val schema = AvroConversions.fromChrononSchema(elemDataType) + { (data: Iterator[Any], elemDataType: DataType, providedSchema: Option[Schema]) => + val schema = providedSchema.getOrElse(AvroConversions.fromChrononSchema(elemDataType)) val record = new GenericData.Record(schema) data.zipWithIndex.foreach { case (value1, idx) => record.put(idx, value1) @@ -132,7 +132,8 @@ object AvroConversions { result }, { m: util.Map[Any, Any] => m }, - extraneousRecord + extraneousRecord, + Some(AvroSchemaTraverser(topLevelSchema)) ) } @@ -167,7 +168,7 @@ object AvroConversions { def encodeBytes(schema: StructType, extraneousRecord: Any => Array[Any] = null): Any => Array[Byte] = { val codec: AvroCodec = new AvroCodec(fromChrononSchema(schema).toString(true)); { data: Any => - val record = fromChrononRow(data, codec.chrononSchema, extraneousRecord).asInstanceOf[GenericData.Record] + val record = fromChrononRow(data, codec.chrononSchema, codec.schema, extraneousRecord).asInstanceOf[GenericData.Record] val bytes = codec.encodeBinary(record) bytes } @@ -176,9 +177,44 @@ object AvroConversions { def encodeJson(schema: StructType, extraneousRecord: Any => Array[Any] = null): Any => String = { val codec: AvroCodec = new AvroCodec(fromChrononSchema(schema).toString(true)); { data: Any => - val record = fromChrononRow(data, codec.chrononSchema, extraneousRecord).asInstanceOf[GenericData.Record] + val record = fromChrononRow(data, codec.chrononSchema, codec.schema, extraneousRecord).asInstanceOf[GenericData.Record] val json = codec.encodeJson(record) json } } } + +case class AvroSchemaTraverser(currentNode: Schema) extends SchemaTraverser[Schema] { + + // We only use union types for nullable fields, and always + // unbox them when writing the actual schema out. + private def unboxUnion(maybeUnion: Schema): Schema = + if (maybeUnion.getType == Schema.Type.UNION) { + maybeUnion.getTypes.get(1) + } else { + maybeUnion + } + + override def getField(field: StructField): SchemaTraverser[Schema] = copy( + unboxUnion(currentNode.getField(field.name).schema()) + ) + + override def getCollectionType: SchemaTraverser[Schema] = copy( + unboxUnion(currentNode.getElementType) + ) + + // Avro map keys are always strings. + override def getMapKeyType: SchemaTraverser[Schema] = if (currentNode.getType == Schema.Type.MAP) { + copy( + Schema.create(Schema.Type.STRING) + ) + } else { + throw new UnsupportedOperationException( + s"Current node ${currentNode.getName} is a ${currentNode.getType}, not a ${Schema.Type.MAP}" + ) + } + + override def getMapValueType: SchemaTraverser[Schema] = copy( + unboxUnion(currentNode.getValueType) + ) +} diff --git a/online/src/main/scala/ai/chronon/online/Fetcher.scala b/online/src/main/scala/ai/chronon/online/Fetcher.scala index cc4a5de7d..e86eaecbb 100644 --- a/online/src/main/scala/ai/chronon/online/Fetcher.scala +++ b/online/src/main/scala/ai/chronon/online/Fetcher.scala @@ -291,7 +291,7 @@ class Fetcher(val kvStore: KVStore, elem } } - val avroRecord = AvroConversions.fromChrononRow(data, schema).asInstanceOf[GenericRecord] + val avroRecord = AvroConversions.fromChrononRow(data, schema, codec.schema).asInstanceOf[GenericRecord] codec.encodeBinary(avroRecord) } diff --git a/online/src/main/scala/ai/chronon/online/SparkConversions.scala b/online/src/main/scala/ai/chronon/online/SparkConversions.scala index 7aa00657c..86fb91041 100644 --- a/online/src/main/scala/ai/chronon/online/SparkConversions.scala +++ b/online/src/main/scala/ai/chronon/online/SparkConversions.scala @@ -137,10 +137,10 @@ object SparkConversions { }) def toSparkRow(value: Any, dataType: api.DataType, extraneousRecord: Any => Array[Any] = null): Any = { - api.Row.to[GenericRow, Array[Byte], Array[Any], mutable.Map[Any, Any]]( + api.Row.to[GenericRow, Array[Byte], Array[Any], mutable.Map[Any, Any], StructType]( value, dataType, - { (data: Iterator[Any], _) => new GenericRow(data.toArray) }, + { (data: Iterator[Any], _, _) => new GenericRow(data.toArray) }, { bytes: Array[Byte] => bytes }, { (elems: Iterator[Any], size: Int) => val result = new Array[Any](size) diff --git a/spark/src/test/scala/ai/chronon/spark/test/InMemoryStream.scala b/spark/src/test/scala/ai/chronon/spark/test/InMemoryStream.scala index 2dbb11619..b73fa41cc 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/InMemoryStream.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/InMemoryStream.scala @@ -91,7 +91,7 @@ class InMemoryStream { input.addData(inputDf.collect.map { row: Row => val bytes = encodeRecord(avroSchema)( - AvroConversions.fromChrononRow(row, schema, GenericRowHandler.func).asInstanceOf[GenericData.Record]) + AvroConversions.fromChrononRow(row, schema, avroSchema, GenericRowHandler.func).asInstanceOf[GenericData.Record]) bytes }) input.toDF