From df021c0d0ec8cfd697bc183b47f541895f0b097c Mon Sep 17 00:00:00 2001 From: Jolan Rensen Date: Sat, 23 Mar 2024 15:19:02 +0100 Subject: [PATCH] fixing more tests. Can now remain at Kotlin 2.0 if we set -Xlambdas=class, which can be done with gradle plugin --- buildSrc/src/main/kotlin/Versions.kt | 3 +- .../SparkKotlinCompilerGradlePlugin.kt | 3 ++ kotlin-spark-api/build.gradle.kts | 1 - .../jetbrains/kotlinx/spark/api/Encoding.kt | 6 ++- .../jetbrains/kotlinx/spark/api/RddDouble.kt | 2 +- .../kotlinx/spark/api/SparkSession.kt | 5 ++- .../kotlinx/spark/api/UserDefinedFunction.kt | 4 +- .../kotlinx/spark/api/EncodingTest.kt | 21 +++++---- .../jetbrains/kotlinx/spark/api/RddTest.kt | 8 +++- .../kotlinx/spark/api/StreamingTest.kt | 20 ++++++--- .../kotlinx/spark/api/TypeInferenceTest.kt | 43 +++++++++++-------- .../jetbrains/kotlinx/spark/api/UDFTest.kt | 39 ++++++++++++----- 12 files changed, 103 insertions(+), 52 deletions(-) diff --git a/buildSrc/src/main/kotlin/Versions.kt b/buildSrc/src/main/kotlin/Versions.kt index 72b315f7..59eab276 100644 --- a/buildSrc/src/main/kotlin/Versions.kt +++ b/buildSrc/src/main/kotlin/Versions.kt @@ -2,8 +2,7 @@ object Versions : Dsl { const val project = "2.0.0-SNAPSHOT" const val kotlinSparkApiGradlePlugin = "2.0.0-SNAPSHOT" const val groupID = "org.jetbrains.kotlinx.spark" -// const val kotlin = "2.0.0-Beta5" // todo issues with NonSerializable lambdas - const val kotlin = "1.9.23" + const val kotlin = "2.0.0-Beta5" const val jvmTarget = "8" const val jupyterJvmTarget = "8" inline val spark get() = System.getProperty("spark") as String diff --git a/gradle-plugin/src/main/kotlin/org/jetbrains/kotlinx/spark/api/gradlePlugin/SparkKotlinCompilerGradlePlugin.kt b/gradle-plugin/src/main/kotlin/org/jetbrains/kotlinx/spark/api/gradlePlugin/SparkKotlinCompilerGradlePlugin.kt index 9dffe65a..23b83c41 100644 --- a/gradle-plugin/src/main/kotlin/org/jetbrains/kotlinx/spark/api/gradlePlugin/SparkKotlinCompilerGradlePlugin.kt +++ b/gradle-plugin/src/main/kotlin/org/jetbrains/kotlinx/spark/api/gradlePlugin/SparkKotlinCompilerGradlePlugin.kt @@ -20,6 +20,9 @@ class SparkKotlinCompilerGradlePlugin : KotlinCompilerPluginSupportPlugin { compilerOptions { // Make sure the parameters of data classes are visible to scala javaParameters.set(true) + + // Avoid NotSerializableException by making lambdas serializable + freeCompilerArgs.add("-Xlambdas=class") } } } diff --git a/kotlin-spark-api/build.gradle.kts b/kotlin-spark-api/build.gradle.kts index f1fc85ba..812af551 100644 --- a/kotlin-spark-api/build.gradle.kts +++ b/kotlin-spark-api/build.gradle.kts @@ -147,7 +147,6 @@ tasks.compileTestKotlin { kotlin { jvmToolchain { languageVersion = JavaLanguageVersion.of(Versions.jvmTarget) - } } diff --git a/kotlin-spark-api/src/main/kotlin/org/jetbrains/kotlinx/spark/api/Encoding.kt b/kotlin-spark-api/src/main/kotlin/org/jetbrains/kotlinx/spark/api/Encoding.kt index c2b9f972..492e79db 100644 --- a/kotlin-spark-api/src/main/kotlin/org/jetbrains/kotlinx/spark/api/Encoding.kt +++ b/kotlin-spark-api/src/main/kotlin/org/jetbrains/kotlinx/spark/api/Encoding.kt @@ -46,6 +46,7 @@ import org.apache.spark.sql.types.UDTRegistration import org.apache.spark.sql.types.UserDefinedType import org.apache.spark.unsafe.types.CalendarInterval import scala.reflect.ClassTag +import java.io.Serializable import kotlin.reflect.KClass import kotlin.reflect.KMutableProperty import kotlin.reflect.KType @@ -122,7 +123,10 @@ fun schemaFor(kType: KType): DataType = kotlinEncoderFor(kType).schema().u @Deprecated("Use schemaFor instead", ReplaceWith("schemaFor(kType)")) fun schema(kType: KType) = schemaFor(kType) -object KotlinTypeInference { +object KotlinTypeInference : Serializable { + + // https://blog.stylingandroid.com/kotlin-serializable-objects/ + private fun readResolve(): Any = KotlinTypeInference /** * @param kClass the class for which to infer the encoder. diff --git a/kotlin-spark-api/src/main/kotlin/org/jetbrains/kotlinx/spark/api/RddDouble.kt b/kotlin-spark-api/src/main/kotlin/org/jetbrains/kotlinx/spark/api/RddDouble.kt index 3ba3ab72..6bc28203 100644 --- a/kotlin-spark-api/src/main/kotlin/org/jetbrains/kotlinx/spark/api/RddDouble.kt +++ b/kotlin-spark-api/src/main/kotlin/org/jetbrains/kotlinx/spark/api/RddDouble.kt @@ -20,7 +20,7 @@ inline fun JavaRDD.toJavaDoubleRDD(): JavaDoubleRDD = /** Utility method to convert [JavaDoubleRDD] to [JavaRDD]<[Double]>. */ @Suppress("UNCHECKED_CAST") -fun JavaDoubleRDD.toDoubleRDD(): JavaRDD = +inline fun JavaDoubleRDD.toDoubleRDD(): JavaRDD = JavaDoubleRDD.toRDD(this).toJavaRDD() as JavaRDD /** Add up the elements in this RDD. */ diff --git a/kotlin-spark-api/src/main/kotlin/org/jetbrains/kotlinx/spark/api/SparkSession.kt b/kotlin-spark-api/src/main/kotlin/org/jetbrains/kotlinx/spark/api/SparkSession.kt index dde819a8..393f945f 100644 --- a/kotlin-spark-api/src/main/kotlin/org/jetbrains/kotlinx/spark/api/SparkSession.kt +++ b/kotlin-spark-api/src/main/kotlin/org/jetbrains/kotlinx/spark/api/SparkSession.kt @@ -44,6 +44,7 @@ import org.apache.spark.streaming.Durations import org.apache.spark.streaming.api.java.JavaStreamingContext import org.jetbrains.kotlinx.spark.api.SparkLogLevel.ERROR import org.jetbrains.kotlinx.spark.api.tuples.* +import scala.reflect.ClassTag import java.io.Serializable /** @@ -406,7 +407,7 @@ private fun getDefaultHadoopConf(): Configuration { * @return `Broadcast` object, a read-only variable cached on each machine */ inline fun SparkSession.broadcast(value: T): Broadcast = try { - sparkContext.broadcast(value, kotlinEncoderFor().clsTag()) + sparkContext.broadcast(value, ClassTag.apply(T::class.java)) } catch (e: ClassNotFoundException) { JavaSparkContext(sparkContext).broadcast(value) } @@ -426,7 +427,7 @@ inline fun SparkSession.broadcast(value: T): Broadcast = try { DeprecationLevel.WARNING ) inline fun SparkContext.broadcast(value: T): Broadcast = try { - broadcast(value, kotlinEncoderFor().clsTag()) + broadcast(value, ClassTag.apply(T::class.java)) } catch (e: ClassNotFoundException) { JavaSparkContext(this).broadcast(value) } \ No newline at end of file diff --git a/kotlin-spark-api/src/main/kotlin/org/jetbrains/kotlinx/spark/api/UserDefinedFunction.kt b/kotlin-spark-api/src/main/kotlin/org/jetbrains/kotlinx/spark/api/UserDefinedFunction.kt index 5c1a3071..60e8f7c8 100644 --- a/kotlin-spark-api/src/main/kotlin/org/jetbrains/kotlinx/spark/api/UserDefinedFunction.kt +++ b/kotlin-spark-api/src/main/kotlin/org/jetbrains/kotlinx/spark/api/UserDefinedFunction.kt @@ -69,9 +69,9 @@ class TypeOfUDFParameterNotSupportedException(kClass: KClass<*>, parameterName: ) @JvmName("arrayColumnAsSeq") -fun TypedColumn>.asSeq(): TypedColumn> = typed() +inline fun TypedColumn>.asSeq(): TypedColumn> = typed() @JvmName("iterableColumnAsSeq") -fun > TypedColumn.asSeq(): TypedColumn> = typed() +inline fun > TypedColumn.asSeq(): TypedColumn> = typed() @JvmName("byteArrayColumnAsSeq") fun TypedColumn.asSeq(): TypedColumn> = typed() @JvmName("charArrayColumnAsSeq") diff --git a/kotlin-spark-api/src/test/kotlin/org/jetbrains/kotlinx/spark/api/EncodingTest.kt b/kotlin-spark-api/src/test/kotlin/org/jetbrains/kotlinx/spark/api/EncodingTest.kt index 2fc9b791..295faa19 100644 --- a/kotlin-spark-api/src/test/kotlin/org/jetbrains/kotlinx/spark/api/EncodingTest.kt +++ b/kotlin-spark-api/src/test/kotlin/org/jetbrains/kotlinx/spark/api/EncodingTest.kt @@ -41,6 +41,9 @@ import java.time.Period class EncodingTest : ShouldSpec({ + @Sparkify + data class SparkifiedPair(val first: T, val second: U) + context("encoders") { withSpark(props = mapOf("spark.sql.codegen.comments" to true)) { @@ -134,8 +137,8 @@ class EncodingTest : ShouldSpec({ } should("be able to serialize Date") { - val datePair = Date.valueOf("2020-02-10") to 5 - val dataset: Dataset> = dsOf(datePair) + val datePair = SparkifiedPair(Date.valueOf("2020-02-10"), 5) + val dataset: Dataset> = dsOf(datePair) dataset.collectAsList() shouldBe listOf(datePair) } @@ -213,6 +216,8 @@ class EncodingTest : ShouldSpec({ context("Give proper names to columns of data classes") { + infix fun A.to(other: B) = SparkifiedPair(this, other) + should("Be able to serialize pairs") { val pairs = listOf( 1 to "1", @@ -653,25 +658,25 @@ class EncodingTest : ShouldSpec({ } should("handle arrays of generics") { - data class Test(val id: Long, val data: Array>) + data class Test(val id: Long, val data: Array>) - val result = listOf(Test(1, arrayOf(5.1 to 6, 6.1 to 7))) + val result = listOf(Test(1, arrayOf(SparkifiedPair(5.1, 6), SparkifiedPair(6.1, 7)))) .toDS() .map { it.id to it.data.firstOrNull { liEl -> liEl.first < 6 } } .map { it.second } .collectAsList() - expect(result).toContain.inOrder.only.values(5.1 to 6) + expect(result).toContain.inOrder.only.values(SparkifiedPair(5.1, 6)) } should("handle lists of generics") { - data class Test(val id: Long, val data: List>) + data class Test(val id: Long, val data: List>) - val result = listOf(Test(1, listOf(5.1 to 6, 6.1 to 7))) + val result = listOf(Test(1, listOf(SparkifiedPair(5.1, 6), SparkifiedPair(6.1, 7)))) .toDS() .map { it.id to it.data.firstOrNull { liEl -> liEl.first < 6 } } .map { it.second } .collectAsList() - expect(result).toContain.inOrder.only.values(5.1 to 6) + expect(result).toContain.inOrder.only.values(SparkifiedPair(5.1, 6)) } should("handle boxed arrays") { diff --git a/kotlin-spark-api/src/test/kotlin/org/jetbrains/kotlinx/spark/api/RddTest.kt b/kotlin-spark-api/src/test/kotlin/org/jetbrains/kotlinx/spark/api/RddTest.kt index 5f9b6d94..e3a9b87e 100644 --- a/kotlin-spark-api/src/test/kotlin/org/jetbrains/kotlinx/spark/api/RddTest.kt +++ b/kotlin-spark-api/src/test/kotlin/org/jetbrains/kotlinx/spark/api/RddTest.kt @@ -6,11 +6,15 @@ import io.kotest.matchers.shouldBe import org.apache.spark.api.java.JavaRDD import org.jetbrains.kotlinx.spark.api.tuples.* import scala.Tuple2 +import java.io.Serializable -class RddTest : ShouldSpec({ +class RddTest : Serializable, ShouldSpec({ context("RDD extension functions") { - withSpark(logLevel = SparkLogLevel.DEBUG) { + withSpark( + props = mapOf("spark.sql.codegen.wholeStage" to false), + logLevel = SparkLogLevel.DEBUG, + ) { context("Key/value") { should("work with spark example") { diff --git a/kotlin-spark-api/src/test/kotlin/org/jetbrains/kotlinx/spark/api/StreamingTest.kt b/kotlin-spark-api/src/test/kotlin/org/jetbrains/kotlinx/spark/api/StreamingTest.kt index a27d080b..86542aa8 100644 --- a/kotlin-spark-api/src/test/kotlin/org/jetbrains/kotlinx/spark/api/StreamingTest.kt +++ b/kotlin-spark-api/src/test/kotlin/org/jetbrains/kotlinx/spark/api/StreamingTest.kt @@ -39,6 +39,7 @@ import scala.Tuple2 import java.io.File import java.io.Serializable import java.nio.charset.StandardCharsets +import java.nio.file.Files import java.util.* import java.util.concurrent.atomic.AtomicBoolean @@ -201,10 +202,10 @@ class StreamingTest : ShouldSpec({ private val scalaCompatVersion = SCALA_COMPAT_VERSION private val sparkVersion = SPARK_VERSION -private fun createTempDir() = File.createTempFile( - System.getProperty("java.io.tmpdir"), - "spark_${scalaCompatVersion}_${sparkVersion}" -).apply { deleteOnExit() } +private fun createTempDir() = + Files.createTempDirectory("spark_${scalaCompatVersion}_${sparkVersion}") + .toFile() + .also { it.deleteOnExit() } private fun checkpointFile(checkpointDir: String, checkpointTime: Time): Path { val klass = Class.forName("org.apache.spark.streaming.Checkpoint$") @@ -215,7 +216,10 @@ private fun checkpointFile(checkpointDir: String, checkpointTime: Time): Path { return checkpointFileMethod.invoke(module, checkpointDir, checkpointTime) as Path } -private fun getCheckpointFiles(checkpointDir: String, fs: scala.Option): scala.collection.immutable.Seq { +private fun getCheckpointFiles( + checkpointDir: String, + fs: scala.Option +): scala.collection.immutable.Seq { val klass = Class.forName("org.apache.spark.streaming.Checkpoint$") val moduleField = klass.getField("MODULE$").also { it.isAccessible = true } val module = moduleField.get(null) @@ -227,7 +231,11 @@ private fun getCheckpointFiles(checkpointDir: String, fs: scala.Option(val first: T, val second: U) + + @Sparkify + data class SparkifiedTriple(val first: T, val second: U, val third: V) + context("org.jetbrains.spark.api.org.jetbrains.spark.api.schema") { - @Sparkify data class Test2(val vala2: T, val para2: Pair) - @Sparkify data class Test(val vala: T, val tripl1: Triple, T>) + @Sparkify + data class Test2(val vala2: T, val para2: SparkifiedPair) + @Sparkify + data class Test(val vala: T, val tripl1: SparkifiedTriple, T>) - val struct = Struct.fromJson(schemaFor>>().prettyJson())!! + val struct = Struct.fromJson(schemaFor>>().prettyJson())!! should("contain correct typings") { expect(struct.fields).notToEqualNull().toContain.inAnyOrder.only.entries( hasField("first", "string"), @@ -65,12 +71,15 @@ class TypeInferenceTest : ShouldSpec({ } } context("org.jetbrains.spark.api.org.jetbrains.spark.api.schema with more complex data") { - @Sparkify data class Single(val vala3: T) @Sparkify - data class Test2(val vala2: T, val para2: Pair>) - @Sparkify data class Test(val vala: T, val tripl1: Triple, T>) + data class Single(val vala3: T) + + @Sparkify + data class Test2(val vala2: T, val para2: SparkifiedPair>) + @Sparkify + data class Test(val vala: T, val tripl1: SparkifiedTriple, T>) - val struct = Struct.fromJson(schemaFor>>().prettyJson())!! + val struct = Struct.fromJson(schemaFor>>().prettyJson())!! should("contain correct typings") { expect(struct.fields).notToEqualNull().toContain.inAnyOrder.only.entries( hasField("first", "string"), @@ -99,7 +108,7 @@ class TypeInferenceTest : ShouldSpec({ } } context("org.jetbrains.spark.api.org.jetbrains.spark.api.schema without generics") { - data class Test(val a: String, val b: Int, val c: Double) + @Sparkify data class Test(val a: String, val b: Int, val c: Double) val struct = Struct.fromJson(schemaFor().prettyJson())!! should("return correct types too") { @@ -120,7 +129,7 @@ class TypeInferenceTest : ShouldSpec({ } } context("type with list of Pairs int to long") { - val struct = Struct.fromJson(schemaFor>>().prettyJson())!! + val struct = Struct.fromJson(schemaFor>>().prettyJson())!! should("return correct types too") { expect(struct) { isOfType("array") @@ -134,7 +143,7 @@ class TypeInferenceTest : ShouldSpec({ } } context("type with list of generic data class with E generic name") { - data class Test(val e: E) + @Sparkify data class Test(val e: E) val struct = Struct.fromJson(schemaFor>>().prettyJson())!! should("return correct types too") { @@ -180,7 +189,7 @@ class TypeInferenceTest : ShouldSpec({ } } context("data class with props in order lon → lat") { - data class Test(val lon: Double, val lat: Double) + @Sparkify data class Test(val lon: Double, val lat: Double) val struct = Struct.fromJson(schemaFor().prettyJson())!! should("Not change order of fields") { @@ -191,7 +200,7 @@ class TypeInferenceTest : ShouldSpec({ } } context("data class with nullable list inside") { - data class Sample(val optionList: List?) + @Sparkify data class Sample(val optionList: List?) val struct = Struct.fromJson(schemaFor().prettyJson())!! @@ -223,8 +232,8 @@ class TypeInferenceTest : ShouldSpec({ .feature("element name", { name() }) { toEqual("optionList") } .feature("field type", { dataType() }, { this - .isA() - .feature("element type", { elementType() }) { isA() } + .toBeAnInstanceOf() + .feature("element type", { elementType() }) { toBeAnInstanceOf() } .feature("element nullable", { containsNull() }) { toEqual(expected = false) } }) .feature("optionList nullable", { nullable() }) { toEqual(true) } @@ -258,5 +267,5 @@ private fun hasStruct( private fun hasField(name: String, type: String): Expect.() -> Unit = { feature { f(it::name) }.toEqual(name) - feature { f(it::type) }.isA().feature { f(it::value) }.toEqual(type) + feature { f(it::type) }.toBeAnInstanceOf().feature { f(it::value) }.toEqual(type) } diff --git a/kotlin-spark-api/src/test/kotlin/org/jetbrains/kotlinx/spark/api/UDFTest.kt b/kotlin-spark-api/src/test/kotlin/org/jetbrains/kotlinx/spark/api/UDFTest.kt index 26b79c30..8bac0408 100644 --- a/kotlin-spark-api/src/test/kotlin/org/jetbrains/kotlinx/spark/api/UDFTest.kt +++ b/kotlin-spark-api/src/test/kotlin/org/jetbrains/kotlinx/spark/api/UDFTest.kt @@ -35,6 +35,7 @@ import org.apache.spark.sql.expressions.Aggregator import org.intellij.lang.annotations.Language import org.jetbrains.kotlinx.spark.api.plugin.annotations.Sparkify import org.junit.jupiter.api.assertThrows +import scala.Product import scala.collection.Seq import java.io.Serializable import kotlin.random.Random @@ -235,7 +236,8 @@ class UDFTest : ShouldSpec({ udf.register(::stringIntDiff) @Language("SQL") - val result = spark.sql("SELECT stringIntDiff(first, second) FROM test1").to().collectAsList() + val result = + spark.sql("SELECT stringIntDiff(getFirst, getSecond) FROM test1").to().collectAsList() result shouldBe listOf(96, 96) } } @@ -304,7 +306,8 @@ class UDFTest : ShouldSpec({ ) ds should beOfType>() - "${nameConcatAge.name}(${NormalClass::name.name}, ${NormalClass::age.name})" shouldBe ds.columns().single() + "${nameConcatAge.name}(${NormalClass::name.name}, ${NormalClass::age.name})" shouldBe ds.columns() + .single() val collectAsList = ds.collectAsList() collectAsList[0] shouldBe "a-10" @@ -329,7 +332,8 @@ class UDFTest : ShouldSpec({ ) ds should beOfType>() - "${nameConcatAge.name}(${NormalClass::name.name}, ${NormalClass::age.name})" shouldBe ds.columns().single() + "${nameConcatAge.name}(${NormalClass::name.name}, ${NormalClass::age.name})" shouldBe ds.columns() + .single() val collectAsList = ds.collectAsList() collectAsList[0].getAs(0) shouldBe "a-10" @@ -354,7 +358,8 @@ class UDFTest : ShouldSpec({ ) ds should beOfType>() - "${nameConcatAge.name}(${NormalClass::name.name}, ${NormalClass::age.name})" shouldBe ds.columns().single() + "${nameConcatAge.name}(${NormalClass::name.name}, ${NormalClass::age.name})" shouldBe ds.columns() + .single() val collectAsList = ds.collectAsList() collectAsList[0].getAs(0) shouldBe "a-10" @@ -419,13 +424,14 @@ class UDFTest : ShouldSpec({ context("udf return data class") { withSpark(logLevel = SparkLogLevel.DEBUG) { + /** TODO [org.apache.spark.sql.catalyst.CatalystTypeConverters.StructConverter.toCatalystImpl] needs it to be a [scala.Product] */ should("return NormalClass") { listOf("a" to 1, "b" to 2).toDS().toDF().createOrReplaceTempView("test2") udf.register("toNormalClass") { name: String, age: Int -> NormalClass(age, name) } - spark.sql("select toNormalClass(first, second) from test2").show() + spark.sql("select toNormalClass(getFirst, getSecond) from test2").show() } should("not return NormalClass when not registered") { @@ -434,16 +440,17 @@ class UDFTest : ShouldSpec({ val toNormalClass2 = udf("toNormalClass2", ::NormalClass) shouldThrow { - spark.sql("select toNormalClass2(first, second) from test2").show() + spark.sql("select toNormalClass2(getFirst, getSecond) from test2").show() } } + /** TODO [org.apache.spark.sql.catalyst.CatalystTypeConverters.StructConverter.toCatalystImpl] needs it to be a [scala.Product] */ should("return NormalClass using accessed by delegate") { listOf(1 to "a", 2 to "b").toDS().toDF().createOrReplaceTempView("test2") val toNormalClass3 = udf("toNormalClass3", ::NormalClass) toNormalClass3.register() - spark.sql("select toNormalClass3(first, second) from test2").show() + spark.sql("select toNormalClass3(getFirst, getSecond) from test2").show() } } } @@ -643,7 +650,6 @@ class UDFTest : ShouldSpec({ } - } } @@ -1262,8 +1268,10 @@ class UDFTest : ShouldSpec({ } }) -@Sparkify data class Employee(val name: String, val salary: Long) -@Sparkify data class Average(var sum: Long, var count: Long) +@Sparkify +data class Employee(val name: String, val salary: Long) +@Sparkify +data class Average(var sum: Long, var count: Long) private object MyAverage : Aggregator() { // A zero value for this aggregation. Should satisfy the property that any b + zero = b @@ -1322,6 +1330,17 @@ data class NormalClass( val age: Int, val name: String ) +// : Product { +// override fun canEqual(that: Any?): Boolean = that is NormalClass +// +// override fun productElement(n: Int): Any = +// when (n) { +// 0 -> age +// 1 -> name +// else -> throw IndexOutOfBoundsException(n.toString()) +// } +// override fun productArity(): Int = 2 +//} private val firstByteVal = { a: ByteArray -> a.firstOrNull() } private val firstShortVal = { a: ShortArray -> a.firstOrNull() }