diff --git a/buildSrc/src/main/kotlin/Versions.kt b/buildSrc/src/main/kotlin/Versions.kt index 59eab276..ed3ce444 100644 --- a/buildSrc/src/main/kotlin/Versions.kt +++ b/buildSrc/src/main/kotlin/Versions.kt @@ -2,7 +2,7 @@ object Versions : Dsl { const val project = "2.0.0-SNAPSHOT" const val kotlinSparkApiGradlePlugin = "2.0.0-SNAPSHOT" const val groupID = "org.jetbrains.kotlinx.spark" - const val kotlin = "2.0.0-Beta5" + const val kotlin = "2.0.0-RC1" const val jvmTarget = "8" const val jupyterJvmTarget = "8" inline val spark get() = System.getProperty("spark") as String diff --git a/compiler-plugin/src/main/kotlin/org/jetbrains/kotlinx/spark/api/compilerPlugin/fir/DataClassSparkifySuperTypeGenerator.kt b/compiler-plugin/src/main/kotlin/org/jetbrains/kotlinx/spark/api/compilerPlugin/fir/DataClassSparkifySuperTypeGenerator.kt index 339b3269..3a13c458 100644 --- a/compiler-plugin/src/main/kotlin/org/jetbrains/kotlinx/spark/api/compilerPlugin/fir/DataClassSparkifySuperTypeGenerator.kt +++ b/compiler-plugin/src/main/kotlin/org/jetbrains/kotlinx/spark/api/compilerPlugin/fir/DataClassSparkifySuperTypeGenerator.kt @@ -33,10 +33,10 @@ class DataClassSparkifySuperTypeGenerator( } } - context(TypeResolveServiceContainer) override fun computeAdditionalSupertypes( classLikeDeclaration: FirClassLikeDeclaration, - resolvedSupertypes: List + resolvedSupertypes: List, + typeResolver: TypeResolveService, ): List = listOf( buildResolvedTypeRef { val scalaProduct = productFqNames.first().let { @@ -48,7 +48,6 @@ class DataClassSparkifySuperTypeGenerator( isNullable = false, ) } - ) override fun needTransformSupertypes(declaration: FirClassLikeDeclaration): Boolean = diff --git a/examples/build.gradle.kts b/examples/build.gradle.kts index 4a0a2179..8683926a 100644 --- a/examples/build.gradle.kts +++ b/examples/build.gradle.kts @@ -3,7 +3,13 @@ import org.jetbrains.kotlin.gradle.tasks.KotlinCompile plugins { // Needs to be installed in the local maven repository or have the bootstrap jar on the classpath id("org.jetbrains.kotlinx.spark.api") + java kotlin("jvm") + kotlin("plugin.noarg") version Versions.kotlin +} + +noArg { + annotation("org.jetbrains.kotlinx.spark.examples.NoArg") } kotlinSparkApi { diff --git a/gradle/bootstraps/compiler-plugin.jar b/gradle/bootstraps/compiler-plugin.jar index 2ea26a00..af6fb778 100644 Binary files a/gradle/bootstraps/compiler-plugin.jar and b/gradle/bootstraps/compiler-plugin.jar differ diff --git a/gradle/bootstraps/gradle-plugin.jar b/gradle/bootstraps/gradle-plugin.jar index 7d1d0358..fe0cde25 100644 Binary files a/gradle/bootstraps/gradle-plugin.jar and b/gradle/bootstraps/gradle-plugin.jar differ diff --git a/jupyter/src/main/kotlin/org/jetbrains/kotlinx/spark/api/jupyter/Integration.kt b/jupyter/src/main/kotlin/org/jetbrains/kotlinx/spark/api/jupyter/Integration.kt index 448751ae..6e2585dc 100644 --- a/jupyter/src/main/kotlin/org/jetbrains/kotlinx/spark/api/jupyter/Integration.kt +++ b/jupyter/src/main/kotlin/org/jetbrains/kotlinx/spark/api/jupyter/Integration.kt @@ -19,21 +19,43 @@ */ package org.jetbrains.kotlinx.spark.api.jupyter -import kotlinx.serialization.Serializable -import kotlinx.serialization.json.* +import org.apache.spark.api.java.JavaDoubleRDD +import org.apache.spark.api.java.JavaPairRDD +import org.apache.spark.api.java.JavaRDD import org.apache.spark.api.java.JavaRDDLike import org.apache.spark.rdd.RDD import org.apache.spark.sql.Dataset import org.intellij.lang.annotations.Language -import org.jetbrains.kotlinx.jupyter.api.* +import org.jetbrains.kotlinx.jupyter.api.Code +import org.jetbrains.kotlinx.jupyter.api.FieldValue +import org.jetbrains.kotlinx.jupyter.api.KotlinKernelHost +import org.jetbrains.kotlinx.jupyter.api.MimeTypedResult +import org.jetbrains.kotlinx.jupyter.api.Notebook +import org.jetbrains.kotlinx.jupyter.api.VariableDeclaration +import org.jetbrains.kotlinx.jupyter.api.createRendererByCompileTimeType +import org.jetbrains.kotlinx.jupyter.api.declare import org.jetbrains.kotlinx.jupyter.api.libraries.JupyterIntegration +import org.jetbrains.kotlinx.jupyter.api.textResult +import org.jetbrains.kotlinx.spark.api.SparkSession import org.jetbrains.kotlinx.spark.api.jupyter.Properties.Companion.displayLimitName import org.jetbrains.kotlinx.spark.api.jupyter.Properties.Companion.displayTruncateName import org.jetbrains.kotlinx.spark.api.jupyter.Properties.Companion.scalaName import org.jetbrains.kotlinx.spark.api.jupyter.Properties.Companion.sparkName import org.jetbrains.kotlinx.spark.api.jupyter.Properties.Companion.sparkPropertiesName import org.jetbrains.kotlinx.spark.api.jupyter.Properties.Companion.versionName -import kotlin.reflect.KProperty1 +import org.jetbrains.kotlinx.spark.api.kotlinEncoderFor +import org.jetbrains.kotlinx.spark.api.plugin.annotations.ColumnName +import org.jetbrains.kotlinx.spark.api.plugin.annotations.Sparkify +import scala.Tuple2 +import kotlin.reflect.KClass +import kotlin.reflect.KMutableProperty +import kotlin.reflect.full.createType +import kotlin.reflect.full.findAnnotation +import kotlin.reflect.full.isSubtypeOf +import kotlin.reflect.full.memberFunctions +import kotlin.reflect.full.memberProperties +import kotlin.reflect.full.primaryConstructor +import kotlin.reflect.full.valueParameters import kotlin.reflect.typeOf @@ -46,9 +68,6 @@ abstract class Integration(private val notebook: Notebook, private val options: protected val sparkVersion = /*$"\""+spark+"\""$*/ /*-*/ "" protected val version = /*$"\""+version+"\""$*/ /*-*/ "" - protected val displayLimitOld = "DISPLAY_LIMIT" - protected val displayTruncateOld = "DISPLAY_TRUNCATE" - protected val properties: Properties get() = notebook .variablesState[sparkPropertiesName]!! @@ -101,6 +120,7 @@ abstract class Integration(private val notebook: Notebook, private val options: ) open val imports: Array = arrayOf( + "org.jetbrains.kotlinx.spark.api.plugin.annotations.*", "org.jetbrains.kotlinx.spark.api.*", "org.jetbrains.kotlinx.spark.api.tuples.*", *(1..22).map { "scala.Tuple$it" }.toTypedArray(), @@ -116,6 +136,9 @@ abstract class Integration(private val notebook: Notebook, private val options: "org.apache.spark.streaming.*", ) + // Needs to be set by integration + var spark: SparkSession? = null + override fun Builder.onLoaded() { dependencies(*dependencies) import(*imports) @@ -135,27 +158,6 @@ abstract class Integration(private val notebook: Notebook, private val options: ) ) - @Language("kts") - val _0 = execute( - """ - @Deprecated("Use ${displayLimitName}=${properties.displayLimit} in %use magic or ${sparkPropertiesName}.${displayLimitName} = ${properties.displayLimit} instead", ReplaceWith("${sparkPropertiesName}.${displayLimitName}")) - var $displayLimitOld: Int - get() = ${sparkPropertiesName}.${displayLimitName} - set(value) { - println("$displayLimitOld is deprecated: Use ${sparkPropertiesName}.${displayLimitName} instead") - ${sparkPropertiesName}.${displayLimitName} = value - } - - @Deprecated("Use ${displayTruncateName}=${properties.displayTruncate} in %use magic or ${sparkPropertiesName}.${displayTruncateName} = ${properties.displayTruncate} instead", ReplaceWith("${sparkPropertiesName}.${displayTruncateName}")) - var $displayTruncateOld: Int - get() = ${sparkPropertiesName}.${displayTruncateName} - set(value) { - println("$displayTruncateOld is deprecated: Use ${sparkPropertiesName}.${displayTruncateName} instead") - ${sparkPropertiesName}.${displayTruncateName} = value - } - """.trimIndent() - ) - onLoaded() } @@ -180,27 +182,119 @@ abstract class Integration(private val notebook: Notebook, private val options: onShutdown() } + onClassAnnotation { + for (klass in it) { + if (klass.isData) { + execute(generateSparkifyClass(klass)) + } + } + } // Render Dataset render> { - with(properties) { - HTML(it.toHtml(limit = displayLimit, truncate = displayTruncate)) - } + renderDataset(it) } - render> { - with(properties) { - HTML(it.toJavaRDD().toHtml(limit = displayLimit, truncate = displayTruncate)) + // using compile time KType, convert this JavaRDDLike to Dataset and render it + notebook.renderersProcessor.registerWithoutOptimizing( + createRendererByCompileTimeType> { + if (spark == null) return@createRendererByCompileTimeType it.value.toString() + + val rdd = (it.value as JavaRDDLike<*, *>).rdd() + val type = when { + it.type.isSubtypeOf(typeOf()) -> + typeOf() + + it.type.isSubtypeOf(typeOf>()) -> + Tuple2::class.createType( + listOf( + it.type.arguments.first(), + it.type.arguments.last(), + ) + ) + + it.type.isSubtypeOf(typeOf>()) -> + it.type.arguments.first().type!! + + else -> it.type.arguments.first().type!! + } + val ds = spark!!.createDataset(rdd, kotlinEncoderFor(type)) + renderDataset(ds) } - } + ) + + // using compile time KType, convert this RDD to Dataset and render it + notebook.renderersProcessor.registerWithoutOptimizing( + createRendererByCompileTimeType> { + if (spark == null) return@createRendererByCompileTimeType it.value.toString() - render> { - with(properties) { - HTML(it.toHtml(limit = displayLimit, truncate = displayTruncate)) + val rdd = it.value as RDD<*> + val type = it.type.arguments.first().type!! + val ds = spark!!.createDataset(rdd, kotlinEncoderFor(type)) + renderDataset(ds) } + ) + + onLoadedAlsoDo() + } + private fun renderDataset(it: Dataset<*>): MimeTypedResult = + with(properties) { + val showFunction = Dataset::class + .memberFunctions + .firstOrNull { it.name == "showString" && it.valueParameters.size == 3 } + + textResult( + if (showFunction != null) { + showFunction.call(it, displayLimit, displayTruncate, false) as String + } else { + // if the function cannot be called, make sure it will call println instead + it.show(displayLimit, displayTruncate) + "" + } + ) } - onLoadedAlsoDo() + + // TODO wip + private fun generateSparkifyClass(klass: KClass<*>): Code { +// val name = "`${klass.simpleName!!}${'$'}Generated`" + val name = klass.simpleName + val constructorArgs = klass.primaryConstructor!!.parameters + val visibility = klass.visibility?.name?.lowercase() ?: "" + val memberProperties = klass.memberProperties + + val properties = constructorArgs.associateWith { + memberProperties.first { it.name == it.name } + } + + val constructorParamsCode = properties.entries.joinToString("\n") { (param, prop) -> + // TODO check override + if (param.isOptional) TODO() + val modifier = if (prop is KMutableProperty<*>) "var" else "val" + val paramVisiblity = prop.visibility?.name?.lowercase() ?: "" + val columnName = param.findAnnotation()?.name ?: param.name!! + + "| @get:kotlin.jvm.JvmName(\"$columnName\") $paramVisiblity $modifier ${param.name}: ${param.type}," + } + + val productElementWhenParamsCode = properties.entries.joinToString("\n") { (param, _) -> + "| ${param.index} -> this.${param.name}" + } + + @Language("kotlin") + val code = """ + |$visibility data class $name( + $constructorParamsCode + |): scala.Product, java.io.Serializable { + | override fun canEqual(that: Any?): Boolean = that is $name + | override fun productArity(): Int = ${constructorArgs.size} + | override fun productElement(n: Int): Any = when (n) { + $productElementWhenParamsCode + | else -> throw IndexOutOfBoundsException() + | } + |} + """.trimMargin() + return code } } diff --git a/jupyter/src/main/kotlin/org/jetbrains/kotlinx/spark/api/jupyter/SparkIntegration.kt b/jupyter/src/main/kotlin/org/jetbrains/kotlinx/spark/api/jupyter/SparkIntegration.kt index cc308116..0c4eb096 100644 --- a/jupyter/src/main/kotlin/org/jetbrains/kotlinx/spark/api/jupyter/SparkIntegration.kt +++ b/jupyter/src/main/kotlin/org/jetbrains/kotlinx/spark/api/jupyter/SparkIntegration.kt @@ -25,6 +25,7 @@ package org.jetbrains.kotlinx.spark.api.jupyter import org.intellij.lang.annotations.Language import org.jetbrains.kotlinx.jupyter.api.KotlinKernelHost import org.jetbrains.kotlinx.jupyter.api.Notebook +import org.jetbrains.kotlinx.spark.api.SparkSession import org.jetbrains.kotlinx.spark.api.jupyter.Properties.Companion.appNameName import org.jetbrains.kotlinx.spark.api.jupyter.Properties.Companion.sparkMasterName @@ -86,7 +87,7 @@ class SparkIntegration(notebook: Notebook, options: MutableMap) """ inline fun dfOf(vararg arg: T): Dataset = spark.dfOf(*arg)""".trimIndent(), """ - inline fun emptyDataset(): Dataset = spark.emptyDataset(encoder())""".trimIndent(), + inline fun emptyDataset(): Dataset = spark.emptyDataset(kotlinEncoderFor())""".trimIndent(), """ inline fun dfOf(colNames: Array, vararg arg: T): Dataset = spark.dfOf(colNames, *arg)""".trimIndent(), """ @@ -108,6 +109,8 @@ class SparkIntegration(notebook: Notebook, options: MutableMap) """ inline fun > UserDefinedFunction.register(name: String): NAMED_UDF = spark.udf().register(name = name, udf = this)""".trimIndent(), ).map(::execute) + + spark = execute("spark").value as SparkSession } override fun KotlinKernelHost.onShutdown() { diff --git a/jupyter/src/test/kotlin/org/jetbrains/kotlinx/spark/api/jupyter/JupyterTests.kt b/jupyter/src/test/kotlin/org/jetbrains/kotlinx/spark/api/jupyter/JupyterTests.kt index b82512b7..3ffd37be 100644 --- a/jupyter/src/test/kotlin/org/jetbrains/kotlinx/spark/api/jupyter/JupyterTests.kt +++ b/jupyter/src/test/kotlin/org/jetbrains/kotlinx/spark/api/jupyter/JupyterTests.kt @@ -32,15 +32,12 @@ import org.apache.spark.api.java.JavaSparkContext import org.apache.spark.streaming.api.java.JavaStreamingContext import org.intellij.lang.annotations.Language import org.jetbrains.kotlinx.jupyter.EvalRequestData -import org.jetbrains.kotlinx.jupyter.MutableNotebook import org.jetbrains.kotlinx.jupyter.ReplForJupyter -import org.jetbrains.kotlinx.jupyter.ReplForJupyterImpl import org.jetbrains.kotlinx.jupyter.api.Code import org.jetbrains.kotlinx.jupyter.api.MimeTypedResult -import org.jetbrains.kotlinx.jupyter.libraries.EmptyResolutionInfoProvider +import org.jetbrains.kotlinx.jupyter.api.MimeTypes import org.jetbrains.kotlinx.jupyter.repl.EvalResultEx import org.jetbrains.kotlinx.jupyter.repl.creating.createRepl -import org.jetbrains.kotlinx.jupyter.testkit.JupyterReplTestCase import org.jetbrains.kotlinx.jupyter.testkit.ReplProvider import org.jetbrains.kotlinx.jupyter.util.PatternNameAcceptanceRule import org.jetbrains.kotlinx.spark.api.SparkSession @@ -83,10 +80,11 @@ class JupyterTests : ShouldSpec({ context("Jupyter") { withRepl { + exec("%trackExecution") should("Allow functions on local data classes") { @Language("kts") - val klass = exec("""data class Test(val a: Int, val b: String)""") + val klass = exec("""@Sparkify data class Test(val a: Int, val b: String)""") @Language("kts") val ds = exec("""val ds = dsOf(Test(1, "hi"), Test(2, "something"))""") @@ -112,7 +110,7 @@ class JupyterTests : ShouldSpec({ should("render Datasets") { @Language("kts") - val html = execHtml( + val html = execForDisplayText( """ val ds = listOf(1, 2, 3).toDS() ds @@ -128,7 +126,7 @@ class JupyterTests : ShouldSpec({ should("render JavaRDDs") { @Language("kts") - val html = execHtml( + val html = execForDisplayText( """ val rdd: JavaRDD> = listOf( listOf(1, 2, 3), @@ -145,7 +143,7 @@ class JupyterTests : ShouldSpec({ should("render JavaRDDs with Arrays") { @Language("kts") - val html = execHtml( + val html = execForDisplayText( """ val rdd: JavaRDD = rddOf( intArrayOf(1, 2, 3), @@ -165,7 +163,7 @@ class JupyterTests : ShouldSpec({ @Language("kts") val klass = exec( """ - data class Test( + @Sparkify data class Test( val longFirstName: String, val second: LongArray, val somethingSpecial: Map, @@ -174,7 +172,7 @@ class JupyterTests : ShouldSpec({ ) @Language("kts") - val html = execHtml( + val html = execForDisplayText( """ val rdd = listOf( @@ -185,29 +183,40 @@ class JupyterTests : ShouldSpec({ rdd """.trimIndent() ) - html shouldContain "Test(longFirstName=aaaaaaaa..." + html shouldContain """ + +-------------+---------------+--------------------+ + |longFirstName| second| somethingSpecial| + +-------------+---------------+--------------------+ + | aaaaaaaaa|[1, 100000, 24]|{1 -> one, 2 -> two}| + | aaaaaaaaa|[1, 100000, 24]|{1 -> one, 2 -> two}| + +-------------+---------------+--------------------+""".trimIndent() } should("render JavaPairRDDs") { @Language("kts") - val html = execHtml( + val html = execForDisplayText( """ val rdd: JavaPairRDD = rddOf( - c(1, 2).toTuple(), - c(3, 4).toTuple(), + t(1, 2), + t(3, 4), ).toJavaPairRDD() rdd """.trimIndent() ) println(html) - html shouldContain "1, 2" - html shouldContain "3, 4" + html shouldContain """ + +---+---+ + | _1| _2| + +---+---+ + | 1| 2| + | 3| 4| + +---+---+""".trimIndent() } should("render JavaDoubleRDD") { @Language("kts") - val html = execHtml( + val html = execForDisplayText( """ val rdd: JavaDoubleRDD = rddOf(1.0, 2.0, 3.0, 4.0,).toJavaDoubleRDD() rdd @@ -223,7 +232,7 @@ class JupyterTests : ShouldSpec({ should("render Scala RDD") { @Language("kts") - val html = execHtml( + val html = execForDisplayText( """ val rdd: RDD> = rddOf( listOf(1, 2, 3), @@ -244,9 +253,9 @@ class JupyterTests : ShouldSpec({ val oldTruncation = exec("""sparkProperties.displayTruncate""") as Int @Language("kts") - val html = execHtml( + val html = execForDisplayText( """ - data class Test(val a: String) + @Sparkify data class Test(val a: String) sparkProperties.displayTruncate = 3 dsOf(Test("aaaaaaaaaa")) """.trimIndent() @@ -255,8 +264,8 @@ class JupyterTests : ShouldSpec({ @Language("kts") val restoreTruncation = exec("""sparkProperties.displayTruncate = $oldTruncation""") - html shouldContain "aaa" - html shouldNotContain "aaaaaaaaaa" + html shouldContain "aaa" + html shouldNotContain "aaaaaaaaaa" } should("limit dataset rows using properties") { @@ -265,9 +274,9 @@ class JupyterTests : ShouldSpec({ val oldLimit = exec("""sparkProperties.displayLimit""") as Int @Language("kts") - val html = execHtml( + val html = execForDisplayText( """ - data class Test(val a: String) + @Sparkify data class Test(val a: String) sparkProperties.displayLimit = 3 dsOf(Test("a"), Test("b"), Test("c"), Test("d"), Test("e")) """.trimIndent() @@ -276,11 +285,11 @@ class JupyterTests : ShouldSpec({ @Language("kts") val restoreLimit = exec("""sparkProperties.displayLimit = $oldLimit""") - html shouldContain "a" - html shouldContain "b" - html shouldContain "c" - html shouldNotContain "d" - html shouldNotContain "e" + html shouldContain "a|" + html shouldContain "b|" + html shouldContain "c|" + html shouldNotContain "d|" + html shouldNotContain "e|" } should("truncate rdd cells using properties") { @@ -289,7 +298,7 @@ class JupyterTests : ShouldSpec({ val oldTruncation = exec("""sparkProperties.displayTruncate""") as Int @Language("kts") - val html = execHtml( + val html = execForDisplayText( """ sparkProperties.displayTruncate = 3 rddOf("aaaaaaaaaa") @@ -299,8 +308,8 @@ class JupyterTests : ShouldSpec({ @Language("kts") val restoreTruncation = exec("""sparkProperties.displayTruncate = $oldTruncation""") - html shouldContain "aaa" - html shouldNotContain "aaaaaaaaaa" + html shouldContain "aaa" + html shouldNotContain "aaaaaaaaaa" } should("limit rdd rows using properties") { @@ -309,7 +318,7 @@ class JupyterTests : ShouldSpec({ val oldLimit = exec("""sparkProperties.displayLimit""") as Int @Language("kts") - val html = execHtml( + val html = execForDisplayText( """ sparkProperties.displayLimit = 3 rddOf("a", "b", "c", "d", "e") @@ -319,11 +328,11 @@ class JupyterTests : ShouldSpec({ @Language("kts") val restoreLimit = exec("""sparkProperties.displayLimit = $oldLimit""") - html shouldContain "a" - html shouldContain "b" - html shouldContain "c" - html shouldNotContain "d" - html shouldNotContain "e" + html shouldContain " a|" + html shouldContain " b|" + html shouldContain " c|" + html shouldNotContain " d|" + html shouldNotContain " e|" } @Language("kts") @@ -391,7 +400,7 @@ class JupyterStreamingTests : ShouldSpec({ } } - xshould("stream") { + should("stream") { @Language("kts") val value = exec( @@ -458,4 +467,11 @@ private fun ReplForJupyter.execHtml(code: Code): String { return html } +private fun ReplForJupyter.execForDisplayText(code: Code): String { + val res = exec(code) + val text = res[MimeTypes.PLAIN_TEXT] + text.shouldNotBeNull() + return text +} + class Counter(@Volatile var value: Int) : Serializable diff --git a/settings.gradle.kts b/settings.gradle.kts index 8ad32812..98776e06 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -35,7 +35,7 @@ rootProject.name = "kotlin-spark-api-parent_$versions" include("scala-helpers") include("scala-tuples-in-kotlin") include("kotlin-spark-api") -//include("jupyter") +include("jupyter") include("examples") include("compiler-plugin") include("gradle-plugin") @@ -46,7 +46,7 @@ project(":scala-tuples-in-kotlin").name = "scala-tuples-in-kotlin_$scalaCompat" // spark+scala dependent project(":kotlin-spark-api").name = "kotlin-spark-api_$versions" -//project(":jupyter").name = "jupyter_$versions" +project(":jupyter").name = "jupyter_$versions" project(":examples").name = "examples_$versions" buildCache {