Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create and traverse avro schemas once per task (#296) #834

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 105 additions & 27 deletions api/src/main/scala/ai/chronon/api/Row.scala
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,36 @@ trait Row {
}
}

/**
* SchemaTraverser aids in the traversal of the given SchemaType.
* In some cases (eg avro), it is more performant to create the
* top-level schema once and then traverse it top-to-bottom, rather
* than recreating at each node.
*
* This helper trait allows the Row.to function to traverse SchemaType
* without leaking details of the SchemaType structure.
*/
trait SchemaTraverser[SchemaType] {

def currentNode: SchemaType

// Returns the equivalent SchemaType representation of the given field
def getField(field: StructField): SchemaTraverser[SchemaType]

// Returns the inner type of the current collection field type.
// Throws if the current type is not a collection.
def getCollectionType: SchemaTraverser[SchemaType]

// Returns the key type of the current map field type.
// Throws if the current type is not a map.
def getMapKeyType: SchemaTraverser[SchemaType]

// Returns the valye type of the current map field type.
// Throws if the current type is not a map.
def getMapValueType: SchemaTraverser[SchemaType]

}

object Row {
// recursively traverse a logical struct, and convert it chronon's row type
def from[CompositeType, BinaryType, ArrayType, StringType](
Expand Down Expand Up @@ -95,49 +125,71 @@ object Row {
}

// recursively traverse a chronon dataType value, and convert it to an external type
def to[StructType, BinaryType, ListType, MapType](value: Any,
def to[StructType, BinaryType, ListType, MapType, OutputSchema](value: Any,
dataType: DataType,
composer: (Iterator[Any], DataType) => StructType,
composer: (Iterator[Any], DataType, Option[OutputSchema]) => StructType,
binarizer: Array[Byte] => BinaryType,
collector: (Iterator[Any], Int) => ListType,
mapper: (util.Map[Any, Any] => MapType),
extraneousRecord: Any => Array[Any] = null): Any = {
extraneousRecord: Any => Array[Any] = null,
schemaTraverser: Option[SchemaTraverser[OutputSchema]] = None
): Any = {

if (value == null) return null
def edit(value: Any, dataType: DataType): Any =
to(value, dataType, composer, binarizer, collector, mapper, extraneousRecord)

def getFieldSchema(f: StructField) = schemaTraverser.map(_.getField(f))

def edit(value: Any, dataType: DataType, subTreeTraverser: Option[SchemaTraverser[OutputSchema]]): Any =
to(value, dataType, composer, binarizer, collector, mapper, extraneousRecord, subTreeTraverser)

dataType match {
case StructType(_, fields) =>
value match {
case arr: Array[Any] =>
composer(arr.iterator.zipWithIndex.map { case (value, idx) => edit(value, fields(idx).fieldType) },
dataType)
composer(
arr.iterator.zipWithIndex.map {
case (value, idx) => edit(value, fields(idx).fieldType, getFieldSchema(fields(idx)))
},
dataType,
schemaTraverser.map(_.currentNode)
)
case list: util.ArrayList[Any] =>
composer(list
.iterator()
.asScala
.zipWithIndex
.map { case (value, idx) => edit(value, fields(idx).fieldType) },
dataType)
case list: List[Any] =>
composer(list.iterator.zipWithIndex
.map { case (value, idx) => edit(value, fields(idx).fieldType) },
dataType)
composer(
list
.iterator()
.asScala
.zipWithIndex
.map { case (value, idx) => edit(value, fields(idx).fieldType, getFieldSchema(fields(idx))) },
dataType,
schemaTraverser.map(_.currentNode)
)
case value: Any =>
assert(extraneousRecord != null, s"No handler for $value of class ${value.getClass}")
composer(extraneousRecord(value).iterator.zipWithIndex.map {
case (value, idx) => edit(value, fields(idx).fieldType)
},
dataType)
composer(
extraneousRecord(value).iterator.zipWithIndex.map {
case (value, idx) => edit(value, fields(idx).fieldType, getFieldSchema(fields(idx)))
},
dataType,
schemaTraverser.map(_.currentNode)
)
}
case ListType(elemType) =>
value match {
case list: util.ArrayList[Any] =>
collector(list.iterator().asScala.map(edit(_, elemType)), list.size())
case arr: Array[_] => // avro only recognizes arrayList for its ArrayType/ListType
collector(arr.iterator.map(edit(_, elemType)), arr.length)
collector(
list.iterator().asScala.map(edit(_, elemType, schemaTraverser.map(_.getCollectionType))),
list.size()
)
case arr: Array[_] => // avro only recognizes arrayList for its ArrayType/ListType
collector(
arr.iterator.map(edit(_, elemType, schemaTraverser.map(_.getCollectionType))),
arr.length
)
case arr: mutable.WrappedArray[Any] => // handles the wrapped array type from transform function in spark sql
collector(arr.iterator.map(edit(_, elemType)), arr.length)
collector(
arr.iterator.map(edit(_, elemType, schemaTraverser.map(_.getCollectionType))),
arr.length
)
}
case MapType(keyType, valueType) =>
value match {
Expand All @@ -147,12 +199,38 @@ object Row {
.entrySet()
.iterator()
.asScala
.foreach { entry => newMap.put(edit(entry.getKey, keyType), edit(entry.getValue, valueType)) }
.foreach {
entry => newMap.put(
edit(
entry.getKey,
keyType,
schemaTraverser.map(_.getMapKeyType)
),
edit(
entry.getValue,
valueType,
schemaTraverser.map(_.getMapValueType)
)
)
}
mapper(newMap)
case map: collection.immutable.Map[Any, Any] =>
val newMap = new util.HashMap[Any, Any](map.size)
map
.foreach { entry => newMap.put(edit(entry._1, keyType), edit(entry._2, valueType)) }
.foreach {
entry => newMap.put(
edit(
entry._1,
keyType,
schemaTraverser.map(_.getMapKeyType)
),
edit(
entry._2,
valueType,
schemaTraverser.map(_.getMapValueType)
)
)
}
mapper(newMap)
}
case BinaryType => binarizer(value.asInstanceOf[Array[Byte]])
Expand Down
52 changes: 44 additions & 8 deletions online/src/main/scala/ai/chronon/online/AvroConversions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -111,14 +111,14 @@ object AvroConversions {
s"Cannot convert chronon type $dataType to avro type. Cast it to string please")
}
}

def fromChrononRow(value: Any, dataType: DataType, extraneousRecord: Any => Array[Any] = null): Any = {
def fromChrononRow(value: Any, dataType: DataType, topLevelSchema: Schema, extraneousRecord: Any => Array[Any] = null): Any = {
// But this also has to happen at the recursive depth - data type and schema inside the compositor need to
Row.to[GenericRecord, ByteBuffer, util.ArrayList[Any], util.Map[Any, Any]](
Row.to[GenericRecord, ByteBuffer, util.ArrayList[Any], util.Map[Any, Any], Schema](
value,
dataType,
{ (data: Iterator[Any], elemDataType: DataType) =>
val schema = AvroConversions.fromChrononSchema(elemDataType)
{ (data: Iterator[Any], elemDataType: DataType, providedSchema: Option[Schema]) =>
val schema = providedSchema.getOrElse(AvroConversions.fromChrononSchema(elemDataType))
val record = new GenericData.Record(schema)
data.zipWithIndex.foreach {
case (value1, idx) => record.put(idx, value1)
Expand All @@ -132,7 +132,8 @@ object AvroConversions {
result
},
{ m: util.Map[Any, Any] => m },
extraneousRecord
extraneousRecord,
Some(AvroSchemaTraverser(topLevelSchema))
)
}

Expand Down Expand Up @@ -167,7 +168,7 @@ object AvroConversions {
def encodeBytes(schema: StructType, extraneousRecord: Any => Array[Any] = null): Any => Array[Byte] = {
val codec: AvroCodec = new AvroCodec(fromChrononSchema(schema).toString(true));
{ data: Any =>
val record = fromChrononRow(data, codec.chrononSchema, extraneousRecord).asInstanceOf[GenericData.Record]
val record = fromChrononRow(data, codec.chrononSchema, codec.schema, extraneousRecord).asInstanceOf[GenericData.Record]
val bytes = codec.encodeBinary(record)
bytes
}
Expand All @@ -176,9 +177,44 @@ object AvroConversions {
def encodeJson(schema: StructType, extraneousRecord: Any => Array[Any] = null): Any => String = {
val codec: AvroCodec = new AvroCodec(fromChrononSchema(schema).toString(true));
{ data: Any =>
val record = fromChrononRow(data, codec.chrononSchema, extraneousRecord).asInstanceOf[GenericData.Record]
val record = fromChrononRow(data, codec.chrononSchema, codec.schema, extraneousRecord).asInstanceOf[GenericData.Record]
val json = codec.encodeJson(record)
json
}
}
}

case class AvroSchemaTraverser(currentNode: Schema) extends SchemaTraverser[Schema] {

// We only use union types for nullable fields, and always
// unbox them when writing the actual schema out.
private def unboxUnion(maybeUnion: Schema): Schema =
if (maybeUnion.getType == Schema.Type.UNION) {
maybeUnion.getTypes.get(1)
} else {
maybeUnion
}

override def getField(field: StructField): SchemaTraverser[Schema] = copy(
unboxUnion(currentNode.getField(field.name).schema())
)

override def getCollectionType: SchemaTraverser[Schema] = copy(
unboxUnion(currentNode.getElementType)
)

// Avro map keys are always strings.
override def getMapKeyType: SchemaTraverser[Schema] = if (currentNode.getType == Schema.Type.MAP) {
copy(
Schema.create(Schema.Type.STRING)
)
} else {
throw new UnsupportedOperationException(
s"Current node ${currentNode.getName} is a ${currentNode.getType}, not a ${Schema.Type.MAP}"
)
}

override def getMapValueType: SchemaTraverser[Schema] = copy(
unboxUnion(currentNode.getValueType)
)
}
2 changes: 1 addition & 1 deletion online/src/main/scala/ai/chronon/online/Fetcher.scala
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ class Fetcher(val kvStore: KVStore,
elem
}
}
val avroRecord = AvroConversions.fromChrononRow(data, schema).asInstanceOf[GenericRecord]
val avroRecord = AvroConversions.fromChrononRow(data, schema, codec.schema).asInstanceOf[GenericRecord]
codec.encodeBinary(avroRecord)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,10 @@ object SparkConversions {
})

def toSparkRow(value: Any, dataType: api.DataType, extraneousRecord: Any => Array[Any] = null): Any = {
api.Row.to[GenericRow, Array[Byte], Array[Any], mutable.Map[Any, Any]](
api.Row.to[GenericRow, Array[Byte], Array[Any], mutable.Map[Any, Any], StructType](
value,
dataType,
{ (data: Iterator[Any], _) => new GenericRow(data.toArray) },
{ (data: Iterator[Any], _, _) => new GenericRow(data.toArray) },
{ bytes: Array[Byte] => bytes },
{ (elems: Iterator[Any], size: Int) =>
val result = new Array[Any](size)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ class InMemoryStream {
input.addData(inputDf.collect.map { row: Row =>
val bytes =
encodeRecord(avroSchema)(
AvroConversions.fromChrononRow(row, schema, GenericRowHandler.func).asInstanceOf[GenericData.Record])
AvroConversions.fromChrononRow(row, schema, avroSchema, GenericRowHandler.func).asInstanceOf[GenericData.Record])
bytes
})
input.toDF
Expand Down