Skip to content

Commit

Permalink
fixing more tests. Can now remain at Kotlin 2.0 if we set -Xlambdas=c…
Browse files Browse the repository at this point in the history
…lass, which can be done with gradle plugin
  • Loading branch information
Jolanrensen committed Mar 23, 2024
1 parent 7069a9a commit df021c0
Show file tree
Hide file tree
Showing 12 changed files with 103 additions and 52 deletions.
3 changes: 1 addition & 2 deletions buildSrc/src/main/kotlin/Versions.kt
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@ object Versions : Dsl<Versions> {
const val project = "2.0.0-SNAPSHOT"
const val kotlinSparkApiGradlePlugin = "2.0.0-SNAPSHOT"
const val groupID = "org.jetbrains.kotlinx.spark"
// const val kotlin = "2.0.0-Beta5" // todo issues with NonSerializable lambdas
const val kotlin = "1.9.23"
const val kotlin = "2.0.0-Beta5"
const val jvmTarget = "8"
const val jupyterJvmTarget = "8"
inline val spark get() = System.getProperty("spark") as String
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ class SparkKotlinCompilerGradlePlugin : KotlinCompilerPluginSupportPlugin {
compilerOptions {
// Make sure the parameters of data classes are visible to scala
javaParameters.set(true)

// Avoid NotSerializableException by making lambdas serializable
freeCompilerArgs.add("-Xlambdas=class")
}
}
}
Expand Down
1 change: 0 additions & 1 deletion kotlin-spark-api/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,6 @@ tasks.compileTestKotlin {
kotlin {
jvmToolchain {
languageVersion = JavaLanguageVersion.of(Versions.jvmTarget)

}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ import org.apache.spark.sql.types.UDTRegistration
import org.apache.spark.sql.types.UserDefinedType
import org.apache.spark.unsafe.types.CalendarInterval
import scala.reflect.ClassTag
import java.io.Serializable
import kotlin.reflect.KClass
import kotlin.reflect.KMutableProperty
import kotlin.reflect.KType
Expand Down Expand Up @@ -122,7 +123,10 @@ fun schemaFor(kType: KType): DataType = kotlinEncoderFor<Any?>(kType).schema().u
@Deprecated("Use schemaFor instead", ReplaceWith("schemaFor(kType)"))
fun schema(kType: KType) = schemaFor(kType)

object KotlinTypeInference {
object KotlinTypeInference : Serializable {

// https://blog.stylingandroid.com/kotlin-serializable-objects/
private fun readResolve(): Any = KotlinTypeInference

/**
* @param kClass the class for which to infer the encoder.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ inline fun <reified T : Number> JavaRDD<T>.toJavaDoubleRDD(): JavaDoubleRDD =

/** Utility method to convert [JavaDoubleRDD] to [JavaRDD]<[Double]>. */
@Suppress("UNCHECKED_CAST")
fun JavaDoubleRDD.toDoubleRDD(): JavaRDD<Double> =
inline fun JavaDoubleRDD.toDoubleRDD(): JavaRDD<Double> =
JavaDoubleRDD.toRDD(this).toJavaRDD() as JavaRDD<Double>

/** Add up the elements in this RDD. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ import org.apache.spark.streaming.Durations
import org.apache.spark.streaming.api.java.JavaStreamingContext
import org.jetbrains.kotlinx.spark.api.SparkLogLevel.ERROR
import org.jetbrains.kotlinx.spark.api.tuples.*
import scala.reflect.ClassTag
import java.io.Serializable

/**
Expand Down Expand Up @@ -406,7 +407,7 @@ private fun getDefaultHadoopConf(): Configuration {
* @return `Broadcast` object, a read-only variable cached on each machine
*/
inline fun <reified T> SparkSession.broadcast(value: T): Broadcast<T> = try {
sparkContext.broadcast(value, kotlinEncoderFor<T>().clsTag())
sparkContext.broadcast(value, ClassTag.apply(T::class.java))
} catch (e: ClassNotFoundException) {
JavaSparkContext(sparkContext).broadcast(value)
}
Expand All @@ -426,7 +427,7 @@ inline fun <reified T> SparkSession.broadcast(value: T): Broadcast<T> = try {
DeprecationLevel.WARNING
)
inline fun <reified T> SparkContext.broadcast(value: T): Broadcast<T> = try {
broadcast(value, kotlinEncoderFor<T>().clsTag())
broadcast(value, ClassTag.apply(T::class.java))
} catch (e: ClassNotFoundException) {
JavaSparkContext(this).broadcast(value)
}
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,9 @@ class TypeOfUDFParameterNotSupportedException(kClass: KClass<*>, parameterName:
)

@JvmName("arrayColumnAsSeq")
fun <DsType, T> TypedColumn<DsType, Array<T>>.asSeq(): TypedColumn<DsType, Seq<T>> = typed()
inline fun <DsType, reified T> TypedColumn<DsType, Array<T>>.asSeq(): TypedColumn<DsType, Seq<T>> = typed()
@JvmName("iterableColumnAsSeq")
fun <DsType, T, I : Iterable<T>> TypedColumn<DsType, I>.asSeq(): TypedColumn<DsType, Seq<T>> = typed()
inline fun <DsType, reified T, I : Iterable<T>> TypedColumn<DsType, I>.asSeq(): TypedColumn<DsType, Seq<T>> = typed()
@JvmName("byteArrayColumnAsSeq")
fun <DsType> TypedColumn<DsType, ByteArray>.asSeq(): TypedColumn<DsType, Seq<Byte>> = typed()
@JvmName("charArrayColumnAsSeq")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ import java.time.Period

class EncodingTest : ShouldSpec({

@Sparkify
data class SparkifiedPair<T, U>(val first: T, val second: U)

context("encoders") {
withSpark(props = mapOf("spark.sql.codegen.comments" to true)) {

Expand Down Expand Up @@ -134,8 +137,8 @@ class EncodingTest : ShouldSpec({
}

should("be able to serialize Date") {
val datePair = Date.valueOf("2020-02-10") to 5
val dataset: Dataset<Pair<Date, Int>> = dsOf(datePair)
val datePair = SparkifiedPair(Date.valueOf("2020-02-10"), 5)
val dataset: Dataset<SparkifiedPair<Date, Int>> = dsOf(datePair)
dataset.collectAsList() shouldBe listOf(datePair)
}

Expand Down Expand Up @@ -213,6 +216,8 @@ class EncodingTest : ShouldSpec({

context("Give proper names to columns of data classes") {

infix fun <A, B> A.to(other: B) = SparkifiedPair(this, other)

should("Be able to serialize pairs") {
val pairs = listOf(
1 to "1",
Expand Down Expand Up @@ -653,25 +658,25 @@ class EncodingTest : ShouldSpec({
}

should("handle arrays of generics") {
data class Test<Z>(val id: Long, val data: Array<Pair<Z, Int>>)
data class Test<Z>(val id: Long, val data: Array<SparkifiedPair<Z, Int>>)

val result = listOf(Test(1, arrayOf(5.1 to 6, 6.1 to 7)))
val result = listOf(Test(1, arrayOf(SparkifiedPair(5.1, 6), SparkifiedPair(6.1, 7))))
.toDS()
.map { it.id to it.data.firstOrNull { liEl -> liEl.first < 6 } }
.map { it.second }
.collectAsList()
expect(result).toContain.inOrder.only.values(5.1 to 6)
expect(result).toContain.inOrder.only.values(SparkifiedPair(5.1, 6))
}

should("handle lists of generics") {
data class Test<Z>(val id: Long, val data: List<Pair<Z, Int>>)
data class Test<Z>(val id: Long, val data: List<SparkifiedPair<Z, Int>>)

val result = listOf(Test(1, listOf(5.1 to 6, 6.1 to 7)))
val result = listOf(Test(1, listOf(SparkifiedPair(5.1, 6), SparkifiedPair(6.1, 7))))
.toDS()
.map { it.id to it.data.firstOrNull { liEl -> liEl.first < 6 } }
.map { it.second }
.collectAsList()
expect(result).toContain.inOrder.only.values(5.1 to 6)
expect(result).toContain.inOrder.only.values(SparkifiedPair(5.1, 6))
}

should("handle boxed arrays") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,15 @@ import io.kotest.matchers.shouldBe
import org.apache.spark.api.java.JavaRDD
import org.jetbrains.kotlinx.spark.api.tuples.*
import scala.Tuple2
import java.io.Serializable

class RddTest : ShouldSpec({
class RddTest : Serializable, ShouldSpec({
context("RDD extension functions") {

withSpark(logLevel = SparkLogLevel.DEBUG) {
withSpark(
props = mapOf("spark.sql.codegen.wholeStage" to false),
logLevel = SparkLogLevel.DEBUG,
) {

context("Key/value") {
should("work with spark example") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ import scala.Tuple2
import java.io.File
import java.io.Serializable
import java.nio.charset.StandardCharsets
import java.nio.file.Files
import java.util.*
import java.util.concurrent.atomic.AtomicBoolean

Expand Down Expand Up @@ -201,10 +202,10 @@ class StreamingTest : ShouldSpec({

private val scalaCompatVersion = SCALA_COMPAT_VERSION
private val sparkVersion = SPARK_VERSION
private fun createTempDir() = File.createTempFile(
System.getProperty("java.io.tmpdir"),
"spark_${scalaCompatVersion}_${sparkVersion}"
).apply { deleteOnExit() }
private fun createTempDir() =
Files.createTempDirectory("spark_${scalaCompatVersion}_${sparkVersion}")
.toFile()
.also { it.deleteOnExit() }

private fun checkpointFile(checkpointDir: String, checkpointTime: Time): Path {
val klass = Class.forName("org.apache.spark.streaming.Checkpoint$")
Expand All @@ -215,7 +216,10 @@ private fun checkpointFile(checkpointDir: String, checkpointTime: Time): Path {
return checkpointFileMethod.invoke(module, checkpointDir, checkpointTime) as Path
}

private fun getCheckpointFiles(checkpointDir: String, fs: scala.Option<FileSystem>): scala.collection.immutable.Seq<Path> {
private fun getCheckpointFiles(
checkpointDir: String,
fs: scala.Option<FileSystem>
): scala.collection.immutable.Seq<Path> {
val klass = Class.forName("org.apache.spark.streaming.Checkpoint$")
val moduleField = klass.getField("MODULE$").also { it.isAccessible = true }
val module = moduleField.get(null)
Expand All @@ -227,7 +231,11 @@ private fun getCheckpointFiles(checkpointDir: String, fs: scala.Option<FileSyste
private fun createCorruptedCheckpoint(): String {
val checkpointDirectory = createTempDir().absolutePath
val fakeCheckpointFile = checkpointFile(checkpointDirectory, Time(1000))
FileUtils.write(File(fakeCheckpointFile.toString()), "spark_corrupt_${scalaCompatVersion}_${sparkVersion}", StandardCharsets.UTF_8)
FileUtils.write(
File(fakeCheckpointFile.toString()),
"spark_corrupt_${scalaCompatVersion}_${sparkVersion}",
StandardCharsets.UTF_8
)
assert(getCheckpointFiles(checkpointDirectory, (null as FileSystem?).toOption()).nonEmpty())
return checkpointDirectory
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,22 @@ import org.jetbrains.kotlinx.spark.api.struct.model.ElementType.ComplexElement
import org.jetbrains.kotlinx.spark.api.struct.model.ElementType.SimpleElement
import org.jetbrains.kotlinx.spark.api.struct.model.Struct
import org.jetbrains.kotlinx.spark.api.struct.model.StructField
import kotlin.reflect.typeOf


@OptIn(ExperimentalStdlibApi::class)
class TypeInferenceTest : ShouldSpec({
@Sparkify
data class SparkifiedPair<T, U>(val first: T, val second: U)

@Sparkify
data class SparkifiedTriple<T, U, V>(val first: T, val second: U, val third: V)

context("org.jetbrains.spark.api.org.jetbrains.spark.api.schema") {
@Sparkify data class Test2<T>(val vala2: T, val para2: Pair<T, String>)
@Sparkify data class Test<T>(val vala: T, val tripl1: Triple<T, Test2<Long>, T>)
@Sparkify
data class Test2<T>(val vala2: T, val para2: SparkifiedPair<T, String>)
@Sparkify
data class Test<T>(val vala: T, val tripl1: SparkifiedTriple<T, Test2<Long>, T>)

val struct = Struct.fromJson(schemaFor<Pair<String, Test<Int>>>().prettyJson())!!
val struct = Struct.fromJson(schemaFor<SparkifiedPair<String, Test<Int>>>().prettyJson())!!
should("contain correct typings") {
expect(struct.fields).notToEqualNull().toContain.inAnyOrder.only.entries(
hasField("first", "string"),
Expand All @@ -65,12 +71,15 @@ class TypeInferenceTest : ShouldSpec({
}
}
context("org.jetbrains.spark.api.org.jetbrains.spark.api.schema with more complex data") {
@Sparkify data class Single<T>(val vala3: T)
@Sparkify
data class Test2<T>(val vala2: T, val para2: Pair<T, Single<Double>>)
@Sparkify data class Test<T>(val vala: T, val tripl1: Triple<T, Test2<Long>, T>)
data class Single<T>(val vala3: T)

@Sparkify
data class Test2<T>(val vala2: T, val para2: SparkifiedPair<T, Single<Double>>)
@Sparkify
data class Test<T>(val vala: T, val tripl1: SparkifiedTriple<T, Test2<Long>, T>)

val struct = Struct.fromJson(schemaFor<Pair<String, Test<Int>>>().prettyJson())!!
val struct = Struct.fromJson(schemaFor<SparkifiedPair<String, Test<Int>>>().prettyJson())!!
should("contain correct typings") {
expect(struct.fields).notToEqualNull().toContain.inAnyOrder.only.entries(
hasField("first", "string"),
Expand Down Expand Up @@ -99,7 +108,7 @@ class TypeInferenceTest : ShouldSpec({
}
}
context("org.jetbrains.spark.api.org.jetbrains.spark.api.schema without generics") {
data class Test(val a: String, val b: Int, val c: Double)
@Sparkify data class Test(val a: String, val b: Int, val c: Double)

val struct = Struct.fromJson(schemaFor<Test>().prettyJson())!!
should("return correct types too") {
Expand All @@ -120,7 +129,7 @@ class TypeInferenceTest : ShouldSpec({
}
}
context("type with list of Pairs int to long") {
val struct = Struct.fromJson(schemaFor<List<Pair<Int, Long>>>().prettyJson())!!
val struct = Struct.fromJson(schemaFor<List<SparkifiedPair<Int, Long>>>().prettyJson())!!
should("return correct types too") {
expect(struct) {
isOfType("array")
Expand All @@ -134,7 +143,7 @@ class TypeInferenceTest : ShouldSpec({
}
}
context("type with list of generic data class with E generic name") {
data class Test<E>(val e: E)
@Sparkify data class Test<E>(val e: E)

val struct = Struct.fromJson(schemaFor<List<Test<String>>>().prettyJson())!!
should("return correct types too") {
Expand Down Expand Up @@ -180,7 +189,7 @@ class TypeInferenceTest : ShouldSpec({
}
}
context("data class with props in order lon → lat") {
data class Test(val lon: Double, val lat: Double)
@Sparkify data class Test(val lon: Double, val lat: Double)

val struct = Struct.fromJson(schemaFor<Test>().prettyJson())!!
should("Not change order of fields") {
Expand All @@ -191,7 +200,7 @@ class TypeInferenceTest : ShouldSpec({
}
}
context("data class with nullable list inside") {
data class Sample(val optionList: List<Int>?)
@Sparkify data class Sample(val optionList: List<Int>?)

val struct = Struct.fromJson(schemaFor<Sample>().prettyJson())!!

Expand Down Expand Up @@ -223,8 +232,8 @@ class TypeInferenceTest : ShouldSpec({
.feature("element name", { name() }) { toEqual("optionList") }
.feature("field type", { dataType() }, {
this
.isA<ArrayType>()
.feature("element type", { elementType() }) { isA<IntegerType>() }
.toBeAnInstanceOf<ArrayType>()
.feature("element type", { elementType() }) { toBeAnInstanceOf<IntegerType>() }
.feature("element nullable", { containsNull() }) { toEqual(expected = false) }
})
.feature("optionList nullable", { nullable() }) { toEqual(true) }
Expand Down Expand Up @@ -258,5 +267,5 @@ private fun hasStruct(

private fun hasField(name: String, type: String): Expect<StructField>.() -> Unit = {
feature { f(it::name) }.toEqual(name)
feature { f(it::type) }.isA<TypeName>().feature { f(it::value) }.toEqual(type)
feature { f(it::type) }.toBeAnInstanceOf<TypeName>().feature { f(it::value) }.toEqual(type)
}
Loading

0 comments on commit df021c0

Please sign in to comment.