Skip to content

Commit

Permalink
added java bean class fallback support
Browse files Browse the repository at this point in the history
  • Loading branch information
Jolanrensen committed Mar 30, 2024
1 parent 68e830a commit 48db819
Showing 1 changed file with 88 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,14 @@

package org.jetbrains.kotlinx.spark.api

import org.apache.commons.lang3.reflect.TypeUtils.*
import org.apache.spark.sql.Encoder
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.DefinedByConstructorParams
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.EncoderField
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.JavaBeanEncoder
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.ProductEncoder
import org.apache.spark.sql.catalyst.encoders.OuterScopes
import org.apache.spark.sql.types.DataType
Expand All @@ -49,19 +51,23 @@ import org.jetbrains.kotlinx.spark.api.plugin.annotations.ColumnName
import org.jetbrains.kotlinx.spark.api.plugin.annotations.Sparkify
import scala.reflect.ClassTag
import java.io.Serializable
import java.util.*
import javax.annotation.Nonnull
import kotlin.reflect.KClass
import kotlin.reflect.KMutableProperty
import kotlin.reflect.KProperty1
import kotlin.reflect.KType
import kotlin.reflect.KTypeProjection
import kotlin.reflect.full.createType
import kotlin.reflect.full.declaredMemberFunctions
import kotlin.reflect.full.declaredMemberProperties
import kotlin.reflect.full.hasAnnotation
import kotlin.reflect.full.isSubclassOf
import kotlin.reflect.full.isSubtypeOf
import kotlin.reflect.full.primaryConstructor
import kotlin.reflect.full.staticFunctions
import kotlin.reflect.full.withNullability
import kotlin.reflect.jvm.javaGetter
import kotlin.reflect.jvm.javaMethod
import kotlin.reflect.jvm.jvmName
import kotlin.reflect.typeOf
Expand Down Expand Up @@ -163,6 +169,7 @@ object KotlinTypeInference : Serializable {
*
* @return an [AgnosticEncoder] for the given [kType].
*/
@Suppress("UNCHECKED_CAST")
fun <T> encoderFor(kType: KType): AgnosticEncoder<T> =
encoderFor(
currentType = kType,
Expand Down Expand Up @@ -562,10 +569,6 @@ object KotlinTypeInference : Serializable {
}

kClass.isData -> {
// TODO provide warnings for non-Sparkify annotated classes
// TODO especially Pair and Triple, promote people to use Tuple2 and Tuple3 or use "getFirst" etc. as column name

if (currentType in seenTypeSet) throw IllegalStateException("Circular reference detected for type $currentType")
val constructor = kClass.primaryConstructor!!
val kParameters = constructor.parameters
// todo filter for transient?
Expand All @@ -586,7 +589,7 @@ object KotlinTypeInference : Serializable {
)

val paramName = param.name
val readMethodName = prop.getter.javaMethod!!.name
val readMethodName = prop.javaGetter!!.name
val writeMethodName = (prop as? KMutableProperty<*>)?.setter?.javaMethod?.name

EncoderField(
Expand Down Expand Up @@ -636,13 +639,87 @@ object KotlinTypeInference : Serializable {
}

// java bean class
// currentType.classifier is KClass<*> -> {
// TODO()
//
// JavaBeanEncoder()
// }
else -> {
if (currentType in seenTypeSet)
throw IllegalStateException("Circular reference detected for type $currentType")

val properties = getJavaBeanReadableProperties(kClass)
val fields = properties.map {
val encoder = encoderFor(
currentType = it.type,
seenTypeSet = seenTypeSet + currentType,
typeVariables = typeVariables,
)

EncoderField(
/* name = */ it.propName,
/* enc = */ encoder,
/* nullable = */ encoder.nullable() && !it.hasNonnull,
/* metadata = */ Metadata.empty(),
/* readMethod = */ it.getterName.toOption(),
/* writeMethod = */ it.setterName.toOption(),
)
}

JavaBeanEncoder<Any>(
ClassTag.apply(jClass),
fields.asScalaSeq(),
)
}

// else -> throw IllegalArgumentException("No encoder found for type $currentType")
}
}

else -> throw IllegalArgumentException("No encoder found for type $currentType")
private data class JavaReadableProperty(
val propName: String,
val getterName: String,
val setterName: String?,
val type: KType,
val hasNonnull: Boolean,
)

private fun getJavaBeanReadableProperties(klass: KClass<*>): List<JavaReadableProperty> {
val functions = klass.declaredMemberFunctions.filter {
it.name.startsWith("get") || it.name.startsWith("is") || it.name.startsWith("set")
}

val properties = functions.mapNotNull { getter ->
if (getter.name.startsWith("set")) return@mapNotNull null

val propName = getter.name
.removePrefix("get")
.removePrefix("is")
.replaceFirstChar { it.lowercase() }
val setter = functions.find {
it.name == "set${propName.replaceFirstChar { it.uppercase() }}"
}

JavaReadableProperty(
propName = propName,
getterName = getter.name,
setterName = setter?.name,
type = getter.returnType,
hasNonnull = getter.hasAnnotation<Nonnull>(),
)
}

// Aside from java get/set functions, attempt to get kotlin properties as well, for non data classes
val kotlinProps = klass.declaredMemberProperties
.filter { it.getter.javaMethod != null } // filter kotlin-facing props
.map {
val hasSetter = (it as? KMutableProperty<*>)?.setter != null
val nameSuffix = it.name.removePrefix("is").replaceFirstChar { it.uppercase() }

JavaReadableProperty(
propName = it.name,
getterName = if (it.name.startsWith("is")) it.name else "get$nameSuffix",
setterName = if (hasSetter) "set$nameSuffix" else null,
type = it.returnType,
hasNonnull = it.hasAnnotation<Nonnull>(),
)
}

return properties + kotlinProps
}
}

0 comments on commit 48db819

Please sign in to comment.