Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat!: Add support for r2dbc driver connections #2266

Draft
wants to merge 11 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
307 changes: 221 additions & 86 deletions exposed-core/api/exposed-core.api

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
package org.jetbrains.exposed.sql

import org.jetbrains.exposed.sql.statements.Statement
import org.jetbrains.exposed.sql.statements.StatementIterator
import org.jetbrains.exposed.sql.statements.StatementType
import org.jetbrains.exposed.sql.statements.api.ResultApi
import org.jetbrains.exposed.sql.transactions.TransactionManager
import java.sql.ResultSet

/** Base class representing an SQL query that returns a [ResultSet] when executed. */
/** Base class representing an SQL query that returns a result when executed. */
abstract class AbstractQuery<T : AbstractQuery<T>>(
targets: List<Table>
) : SizedIterable<ResultRow>, Statement<ResultSet>(StatementType.SELECT, targets) {
) : SizedIterable<ResultRow>, Statement<ResultApi>(StatementType.SELECT, targets) {
protected val transaction
get() = TransactionManager.current()

Expand Down Expand Up @@ -87,7 +88,7 @@ abstract class AbstractQuery<T : AbstractQuery<T>>(

protected var count: Boolean = false

protected abstract val queryToExecute: Statement<ResultSet>
protected abstract val queryToExecute: Statement<ResultApi>

override fun iterator(): Iterator<ResultRow> {
val resultIterator = ResultIterator(transaction.exec(queryToExecute)!!)
Expand All @@ -98,33 +99,19 @@ abstract class AbstractQuery<T : AbstractQuery<T>>(
}
}

private inner class ResultIterator(val rs: ResultSet) : Iterator<ResultRow> {
private var hasNext = false
set(value) {
field = value
if (!field) {
val statement = rs.statement
rs.close()
statement?.close()
transaction.openResultSetsCount--
}
}

private val fieldsIndex = set.realFields.toSet().mapIndexed { index, expression -> expression to index }.toMap()
private inner class ResultIterator(
rs: ResultApi
) : StatementIterator<ResultApi, Expression<*>, ResultRow>(rs) {
override val fieldIndex = set.realFields.toSet().mapIndexed { index, expression ->
expression to index
}.toMap()

init {
hasNext = rs.next()
hasNext = result.next()
if (hasNext) trackResultSet(transaction)
}

override operator fun next(): ResultRow {
if (!hasNext) throw NoSuchElementException()
val result = ResultRow.create(rs, fieldsIndex)
hasNext = rs.next()
return result
}

override fun hasNext(): Boolean = hasNext
override fun createResultRow(): ResultRow = ResultRow.create(result, fieldIndex)
}

companion object {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import org.jetbrains.exposed.dao.id.EntityIDFunctionProvider
import org.jetbrains.exposed.dao.id.IdTable
import org.jetbrains.exposed.sql.statements.api.ExposedBlob
import org.jetbrains.exposed.sql.statements.api.PreparedStatementApi
import org.jetbrains.exposed.sql.statements.api.ResultApi
import org.jetbrains.exposed.sql.vendors.*
import java.io.InputStream
import java.math.BigDecimal
Expand All @@ -14,7 +15,6 @@ import java.math.RoundingMode
import java.nio.ByteBuffer
import java.sql.Blob
import java.sql.Clob
import java.sql.ResultSet
import java.sql.SQLException
import java.util.*
import kotlin.reflect.KClass
Expand Down Expand Up @@ -82,7 +82,7 @@ interface IColumnType<T> {
fun nonNullValueAsDefaultString(value: T & Any): String = nonNullValueToString(value)

/** Returns the object at the specified [index] in the [rs]. */
fun readObject(rs: ResultSet, index: Int): Any? = rs.getObject(index)
fun readObject(rs: ResultApi, index: Int): Any? = rs.getObject(index)

/** Sets the [value] at the specified [index] into the [stmt]. */
fun setParameter(stmt: PreparedStatementApi, index: Int, value: Any?) {
Expand Down Expand Up @@ -259,7 +259,7 @@ class EntityIDColumnType<T : Comparable<T>>(
idColumn.table as IdTable<T>
)

override fun readObject(rs: ResultSet, index: Int): Any? = idColumn.columnType.readObject(rs, index)
override fun readObject(rs: ResultApi, index: Int): Any? = idColumn.columnType.readObject(rs, index)

override fun equals(other: Any?): Boolean {
if (other !is EntityIDColumnType<*>) return false
Expand Down Expand Up @@ -684,10 +684,6 @@ class DecimalColumnType(
) : ColumnType<BigDecimal>() {
override fun sqlType(): String = "DECIMAL($precision, $scale)"

override fun readObject(rs: ResultSet, index: Int): Any? {
return rs.getObject(index)
}

override fun valueFromDB(value: Any): BigDecimal = when (value) {
is BigDecimal -> value
is Double -> {
Expand Down Expand Up @@ -914,7 +910,7 @@ open class TextColumnType(
}
}

override fun readObject(rs: ResultSet, index: Int): Any? {
override fun readObject(rs: ResultApi, index: Int): Any? {
val value = super.readObject(rs, index)
return if (eagerLoading && value != null) {
valueFromDB(value)
Expand Down Expand Up @@ -946,12 +942,11 @@ open class LargeTextColumnType(
open class BasicBinaryColumnType : ColumnType<ByteArray>() {
override fun sqlType(): String = currentDialect.dataTypeProvider.binaryType()

override fun readObject(rs: ResultSet, index: Int): Any? = rs.getBytes(index)

override fun valueFromDB(value: Any): ByteArray = when (value) {
is Blob -> value.binaryStream.use { it.readBytes() }
is InputStream -> value.use { it.readBytes() }
is ByteArray -> value
is String -> value.toByteArray()
else -> error("Unexpected value $value of type ${value::class.qualifiedName}")
}

Expand Down Expand Up @@ -1016,10 +1011,11 @@ class BlobColumnType(

override fun nonNullValueToString(value: ExposedBlob): String = currentDialect.dataTypeProvider.hexToDb(value.hexString())

override fun readObject(rs: ResultSet, index: Int) = when {
currentDialect is SQLServerDialect -> rs.getBytes(index)?.let(::ExposedBlob)
currentDialect is PostgreSQLDialect && useObjectIdentifier -> rs.getBlob(index)?.binaryStream?.let(::ExposedBlob)
else -> rs.getBinaryStream(index)?.let(::ExposedBlob)
override fun readObject(rs: ResultApi, index: Int) = when {
currentDialect is PostgreSQLDialect && useObjectIdentifier -> {
rs.getObject(index, java.sql.Blob::class.java)?.binaryStream?.let(::ExposedBlob)
}
else -> rs.getObject(index)
}

override fun setParameter(stmt: PreparedStatementApi, index: Int, value: Any?) {
Expand Down Expand Up @@ -1049,11 +1045,6 @@ class UUIDColumnType : ColumnType<UUID>() {

override fun nonNullValueToString(value: UUID): String = "'$value'"

override fun readObject(rs: ResultSet, index: Int): Any? = when (currentDialect) {
is MariaDBDialect -> rs.getBytes(index)
else -> super.readObject(rs, index)
}

companion object {
private val uuidRegexp =
Regex("[0-9A-F]{8}-[0-9A-F]{4}-[0-9A-F]{4}-[0-9A-F]{4}-[0-9A-F]{12}", RegexOption.IGNORE_CASE)
Expand Down Expand Up @@ -1273,7 +1264,7 @@ class ArrayColumnType<E>(
return value.joinToString(",", prefix, "]") { delegate.valueAsDefaultString(it) }
}

override fun readObject(rs: ResultSet, index: Int): Any? = rs.getArray(index)
override fun readObject(rs: ResultApi, index: Int): Any? = rs.getObject(index, java.sql.Array::class.java)

override fun setParameter(stmt: PreparedStatementApi, index: Int, value: Any?) {
when {
Expand Down
78 changes: 14 additions & 64 deletions exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/Database.kt
Original file line number Diff line number Diff line change
@@ -1,95 +1,46 @@
package org.jetbrains.exposed.sql

import org.jetbrains.annotations.TestOnly
import org.jetbrains.exposed.sql.statements.api.DatabaseApi
import org.jetbrains.exposed.sql.statements.api.ExposedConnection
import org.jetbrains.exposed.sql.statements.api.ExposedDatabaseMetadata
import org.jetbrains.exposed.sql.transactions.ThreadLocalTransactionManager
import org.jetbrains.exposed.sql.transactions.TransactionManager
import org.jetbrains.exposed.sql.vendors.*
import java.math.BigDecimal
import java.sql.Connection
import java.sql.DriverManager
import java.util.*
import java.util.concurrent.ConcurrentHashMap
import javax.sql.ConnectionPoolDataSource
import javax.sql.DataSource

/**
* Class representing the underlying database to which connections are made and on which transaction tasks are performed.
* Class representing the underlying JDBC database to which connections are made and on which transaction tasks are performed.
*/
class Database private constructor(
private val resolvedVendor: String? = null,
val config: DatabaseConfig,
val connector: () -> ExposedConnection<*>
) {
/** Whether nested transaction blocks are configured to act like top-level transactions. */
var useNestedTransactions: Boolean = config.useNestedTransactions
@Deprecated("Use DatabaseConfig to define the useNestedTransactions", level = DeprecationLevel.ERROR)
@TestOnly
set

config: DatabaseConfig,
connector: () -> ExposedConnection<*>
) : DatabaseApi(config, connector) {
override fun toString(): String =
"ExposedDatabase[${hashCode()}]($resolvedVendor${config.explicitDialect?.let { ", dialect=$it" } ?: ""})"

internal fun <T> metadata(body: ExposedDatabaseMetadata.() -> T): T {
val transaction = TransactionManager.currentOrNull()
return if (transaction == null) {
val connection = connector()
try {
connection.metadata(body)
} finally {
connection.close()
}
} else {
transaction.connection.metadata(body)
}
}
override val url: String by lazy { metadata { url } }

/** The connection URL for the database. */
val url: String by lazy { metadata { url } }

/** The name of the database based on the name of the underlying JDBC driver. */
val vendor: String by lazy {
override val vendor: String by lazy {
resolvedVendor ?: metadata { databaseDialectName }
}

/** The name of the database as a [DatabaseDialect]. */
val dialect by lazy {
config.explicitDialect ?: dialects[vendor.lowercase()]?.invoke() ?: error("No dialect registered for $name. URL=$url")
}

/** The version number of the database as a [BigDecimal]. */
val version by lazy { metadata { version } }
override val version by lazy { metadata { version } }

/** Whether the version number of the database is equal to or greater than the provided [version]. */
fun isVersionCovers(version: BigDecimal) = this.version >= version

/** Whether the database supports ALTER TABLE with an add column clause. */
val supportsAlterTableWithAddColumn by lazy(
override val supportsAlterTableWithAddColumn by lazy(
LazyThreadSafetyMode.NONE
) { metadata { supportsAlterTableWithAddColumn } }

/** Whether the database supports ALTER TABLE with a drop column clause. */
val supportsAlterTableWithDropColumn by lazy(
override val supportsAlterTableWithDropColumn by lazy(
LazyThreadSafetyMode.NONE
) { metadata { supportsAlterTableWithDropColumn } }

/** Whether the database supports getting multiple result sets from a single execute. */
val supportsMultipleResultSets by lazy(LazyThreadSafetyMode.NONE) { metadata { supportsMultipleResultSets } }

/** The database-specific class responsible for parsing and processing identifier tokens in SQL syntax. */
val identifierManager by lazy { metadata { identifierManager } }

/** The default number of results that should be fetched when queries are executed. */
var defaultFetchSize: Int? = config.defaultFetchSize
private set
override val supportsMultipleResultSets by lazy(LazyThreadSafetyMode.NONE) { metadata { supportsMultipleResultSets } }

@Deprecated("Use DatabaseConfig to define the defaultFetchSize", level = DeprecationLevel.ERROR)
@TestOnly
fun defaultFetchSize(size: Int): Database {
defaultFetchSize = size
return this
}
override val identifierManager by lazy { metadata { identifierManager } }

/** Whether [Database.connect] was invoked with a [DataSource] argument. */
internal var connectsViaDataSource = false
Expand All @@ -110,8 +61,6 @@ class Database private constructor(
internal var dataSourceReadOnly: Boolean = false

companion object {
internal val dialects = ConcurrentHashMap<String, () -> DatabaseDialect>()

private val connectionInstanceImpl: DatabaseConnectionAutoRegistration =
ServiceLoader.load(DatabaseConnectionAutoRegistration::class.java, Database::class.java.classLoader).firstOrNull()
?: error("Can't load implementation for ${DatabaseConnectionAutoRegistration::class.simpleName}")
Expand Down Expand Up @@ -273,6 +222,7 @@ class Database private constructor(
* @param setupConnection Any setup that should be applied to each new connection.
* @param databaseConfig Configuration parameters for this [Database] instance.
* @param manager The [TransactionManager] responsible for new transactions that use this [Database] instance.
* @throws IllegalStateException If a corresponding database dialect cannot be resolved from the provided [url].
*/
fun connect(
url: String,
Expand Down Expand Up @@ -319,5 +269,5 @@ class Database private constructor(
interface DatabaseConnectionAutoRegistration : (Connection) -> ExposedConnection<*>

/** Returns the name of the database obtained from its connection URL. */
val Database.name: String
val DatabaseApi.name: String
get() = url.substringBefore('?').substringAfterLast('/')
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
package org.jetbrains.exposed.sql

import org.jetbrains.exposed.sql.statements.Statement
import org.jetbrains.exposed.sql.statements.StatementIterator
import org.jetbrains.exposed.sql.statements.StatementType
import org.jetbrains.exposed.sql.statements.api.PreparedStatementApi
import org.jetbrains.exposed.sql.statements.api.ResultApi
import org.jetbrains.exposed.sql.transactions.TransactionManager
import java.sql.ResultSet

/**
* Represents the SQL query that obtains information about a statement execution plan.
Expand All @@ -16,11 +17,11 @@ open class ExplainQuery(
val analyze: Boolean,
val options: String?,
private val internalStatement: Statement<*>
) : Iterable<ExplainResultRow>, Statement<ResultSet>(StatementType.SHOW, emptyList()) {
) : Iterable<ExplainResultRow>, Statement<ResultApi>(StatementType.SHOW, emptyList()) {
private val transaction
get() = TransactionManager.current()

override fun PreparedStatementApi.executeInternal(transaction: Transaction): ResultSet = executeQuery()
override suspend fun PreparedStatementApi.executeInternal(transaction: Transaction): ResultApi = executeQuery()

override fun arguments(): Iterable<Iterable<Pair<IColumnType<*>, Any?>>> = internalStatement.arguments()

Expand All @@ -34,34 +35,18 @@ open class ExplainQuery(
return Iterable { resultIterator }.iterator()
}

private inner class ResultIterator(private val rs: ResultSet) : Iterator<ExplainResultRow> {
private val fieldIndex: Map<String, Int> = List(rs.metaData.columnCount) { i ->
rs.metaData.getColumnName(i + 1) to i
private inner class ResultIterator(
rs: ResultApi
) : StatementIterator<ResultApi, String, ExplainResultRow>(rs) {
override val fieldIndex = List(result.metadataColumnCount()) { i ->
result.metadataColumnName(i + 1) to i
}.toMap()

private var hasNext = false
set(value) {
field = value
if (!field) {
val statement = rs.statement
rs.close()
statement?.close()
transaction.openResultSetsCount--
}
}

init {
hasNext = rs.next()
hasNext = result.next()
}

override fun hasNext(): Boolean = hasNext

override operator fun next(): ExplainResultRow {
if (!hasNext) throw NoSuchElementException()
val result = ExplainResultRow.create(rs, fieldIndex)
hasNext = rs.next()
return result
}
override fun createResultRow(): ExplainResultRow = ExplainResultRow.create(result, fieldIndex)
}
}

Expand All @@ -77,8 +62,8 @@ class ExplainResultRow(
override fun toString(): String = fieldIndex.entries.joinToString { "${it.key}=${data[it.value]}" }

companion object {
/** Creates an [ExplainResultRow] storing all fields in [fieldIndex] with their values retrieved from a [ResultSet]. */
fun create(rs: ResultSet, fieldIndex: Map<String, Int>): ExplainResultRow {
/** Creates an [ExplainResultRow] storing all fields in [fieldIndex] with their values retrieved from a [ResultApi] object. */
fun create(rs: ResultApi, fieldIndex: Map<String, Int>): ExplainResultRow {
val fieldValues = arrayOfNulls<Any?>(fieldIndex.size)
fieldIndex.values.forEach { index ->
fieldValues[index] = rs.getObject(index + 1)
Expand Down
Loading
Loading