Skip to content

Commit

Permalink
[SEDONA-327] Refactored raster UDFs to extend InferredExpression (#909)
Browse files Browse the repository at this point in the history
  • Loading branch information
Kontinuation committed Jul 20, 2023
1 parent b9b26cc commit 51c869f
Show file tree
Hide file tree
Showing 13 changed files with 98 additions and 359 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -93,15 +93,9 @@ public static GridCoverage2D makeEmptyRaster(int numBand, int widthInPixel, int
} else {
crs = CRS.decode("EPSG:" + srid);
}
// If scaleY is not defined, use scaleX
// MAX_VALUE is used to indicate that the scaleY is not defined
double actualScaleY = scaleY;
if (scaleY == Integer.MAX_VALUE) {
actualScaleY = scaleX;
}
// Create a new empty raster
WritableRaster raster = RasterFactory.createBandedRaster(DataBuffer.TYPE_DOUBLE, widthInPixel, heightInPixel, numBand, null);
MathTransform transform = new AffineTransform2D(scaleX, skewY, skewX, -actualScaleY, upperLeftX + scaleX / 2, upperLeftY - actualScaleY / 2);
MathTransform transform = new AffineTransform2D(scaleX, skewY, skewX, -scaleY, upperLeftX + scaleX / 2, upperLeftY - scaleY / 2);
GridGeometry2D gridGeometry = new GridGeometry2D(new GridEnvelope2D(0, 0, widthInPixel, heightInPixel), transform, crs);
ReferencedEnvelope referencedEnvelope = new ReferencedEnvelope(gridGeometry.getEnvelope2D());
// Create a new coverage
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@

public class RasterOutputs
{
public static byte[] asGeoTiff(GridCoverage2D raster, String compressionType, float compressionQuality) {
public static byte[] asGeoTiff(GridCoverage2D raster, String compressionType, double compressionQuality) {
ByteArrayOutputStream out = new ByteArrayOutputStream();
GridCoverageWriter writer;
try {
Expand All @@ -52,7 +52,7 @@ public static byte[] asGeoTiff(GridCoverage2D raster, String compressionType, fl
params.setCompressionType(compressionType);
// Should be a value between 0 and 1
// 0 means max compression, 1 means no compression
params.setCompressionQuality(compressionQuality);
params.setCompressionQuality((float) compressionQuality);
defaultParams.parameter(AbstractGridFormat.GEOTOOLS_WRITE_PARAMS.getName().toString()).setValue(params);
}
GeneralParameterValue[] wps = defaultParams.values().toArray(new GeneralParameterValue[0]);
Expand All @@ -67,6 +67,10 @@ public static byte[] asGeoTiff(GridCoverage2D raster, String compressionType, fl
return out.toByteArray();
}

public static byte[] asGeoTiff(GridCoverage2D raster) {
return asGeoTiff(raster, null, -1);
}

public static byte[] asArcGrid(GridCoverage2D raster, int sourceBand) {
ByteArrayOutputStream out = new ByteArrayOutputStream();
GridCoverageWriter writer;
Expand All @@ -93,4 +97,8 @@ public static byte[] asArcGrid(GridCoverage2D raster, int sourceBand) {
}
return out.toByteArray();
}

public static byte[] asArcGrid(GridCoverage2D raster) {
return asArcGrid(raster, -1);
}
}
4 changes: 4 additions & 0 deletions python-adapter/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,10 @@
<groupId>org.geotools</groupId>
<artifactId>gt-epsg-hsql</artifactId>
</dependency>
<dependency>
<groupId>org.geotools</groupId>
<artifactId>gt-coverage</artifactId>
</dependency>
<dependency>
<groupId>org.scala-lang</groupId>
<artifactId>scala-library</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ object Catalog {
function[RS_BandAsArray](),
function[RS_FromArcInfoAsciiGrid](),
function[RS_FromGeoTiff](),
function[RS_MakeEmptyRaster](java.lang.Integer.MAX_VALUE, 0.0, 0.0, 0),
function[RS_MakeEmptyRaster](),
function[RS_Envelope](),
function[RS_NumBands](),
function[RS_Metadata](),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,13 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Expression, ImplicitCastInputTypes}
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.util.ArrayData
import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT
import org.apache.spark.sql.sedona_sql.UDT.{GeometryUDT, RasterUDT}
import org.apache.spark.sql.types.{AbstractDataType, BinaryType, BooleanType, DataType, DataTypes, DoubleType, IntegerType, LongType, StringType}
import org.apache.spark.unsafe.types.UTF8String
import org.locationtech.jts.geom.Geometry
import org.apache.spark.sql.sedona_sql.expressions.implicits._
import org.apache.spark.sql.sedona_sql.expressions.raster.implicits._
import org.geotools.coverage.grid.GridCoverage2D

import scala.reflect.runtime.universe.TypeTag
import scala.reflect.runtime.universe.Type
Expand Down Expand Up @@ -76,6 +78,8 @@ sealed class InferrableType[T: TypeTag]
object InferrableType {
implicit val geometryInstance: InferrableType[Geometry] =
new InferrableType[Geometry] {}
implicit val gridCoverage2DInstance: InferrableType[GridCoverage2D] =
new InferrableType[GridCoverage2D] {}
implicit val geometryArrayInstance: InferrableType[Array[Geometry]] =
new InferrableType[Array[Geometry]] {}
implicit val javaDoubleInstance: InferrableType[java.lang.Double] =
Expand All @@ -96,6 +100,8 @@ object InferrableType {
new InferrableType[Array[Byte]] {}
implicit val longArrayInstance: InferrableType[Array[java.lang.Long]] =
new InferrableType[Array[java.lang.Long]] {}
implicit val doubleArrayInstance: InferrableType[Array[Double]] =
new InferrableType[Array[Double]] {}
}

object InferredTypes {
Expand All @@ -104,6 +110,10 @@ object InferredTypes {
expr => input => expr.toGeometry(input)
} else if (t =:= typeOf[Array[Geometry]]) {
expr => input => expr.toGeometryArray(input)
} else if (t =:= typeOf[GridCoverage2D]) {
expr => input => expr.toRaster(input)
} else if (t =:= typeOf[Array[Double]]) {
expr => input => expr.eval(input).asInstanceOf[ArrayData].toDoubleArray()
} else if (t =:= typeOf[String]) {
expr => input => expr.asString(input)
} else {
Expand All @@ -119,14 +129,22 @@ object InferredTypes {
} else {
null
}
} else if (t =:= typeOf[GridCoverage2D]) {
output => {
if (output != null) {
output.asInstanceOf[GridCoverage2D].serialize
} else {
null
}
}
} else if (t =:= typeOf[String]) {
output =>
if (output != null) {
UTF8String.fromString(output.asInstanceOf[String])
} else {
null
}
} else if (t =:= typeOf[Array[java.lang.Long]]) {
} else if (t =:= typeOf[Array[java.lang.Long]] || t =:= typeOf[Array[Double]]) {
output =>
if (output != null) {
ArrayData.toArrayData(output)
Expand Down Expand Up @@ -157,6 +175,8 @@ object InferredTypes {
GeometryUDT
} else if (t =:= typeOf[Array[Geometry]]) {
DataTypes.createArrayType(GeometryUDT)
} else if (t =:= typeOf[GridCoverage2D]) {
RasterUDT
} else if (t =:= typeOf[java.lang.Double]) {
DoubleType
} else if (t =:= typeOf[java.lang.Integer]) {
Expand All @@ -171,6 +191,8 @@ object InferredTypes {
BinaryType
} else if (t =:= typeOf[Array[java.lang.Long]]) {
DataTypes.createArrayType(LongType)
} else if (t =:= typeOf[Array[Double]]) {
DataTypes.createArrayType(DoubleType)
} else if (t =:= typeOf[Option[Boolean]]) {
BooleanType
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,14 @@
*/
package org.apache.spark.sql.sedona_sql.expressions.raster

import org.apache.sedona.common.raster.{MapAlgebra, Serde}
import org.apache.sedona.common.raster.MapAlgebra
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Expression, ImplicitCastInputTypes}
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression}
import org.apache.spark.sql.catalyst.expressions.ImplicitCastInputTypes
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
import org.apache.spark.sql.sedona_sql.UDT.RasterUDT
import org.apache.spark.sql.sedona_sql.expressions.raster.implicits.RasterInputExpressionEnhancer
import org.apache.spark.sql.sedona_sql.expressions.{SerdeAware, UserDataGeneratator}
import org.apache.spark.sql.sedona_sql.expressions.InferrableFunctionConverter._
import org.apache.spark.sql.sedona_sql.expressions.{InferredExpression, UserDataGeneratator}
import org.apache.spark.sql.types._
import org.geotools.coverage.grid.GridCoverage2D

/// Calculate Normalized Difference between two bands
case class RS_NormalizedDifference(inputExpressions: Seq[Expression])
Expand Down Expand Up @@ -807,61 +804,15 @@ case class RS_Append(inputExpressions: Seq[Expression])
}
}

case class RS_AddBandFromArray(inputExpressions: Seq[Expression]) extends Expression with CodegenFallback
with ExpectsInputTypes with SerdeAware {

override def nullable: Boolean = true

override def eval(input: InternalRow): Any = {
Option(evalWithoutSerialization(input)).map(Serde.serialize).orNull
}

override def dataType: DataType = RasterUDT

override def children: Seq[Expression] = inputExpressions

case class RS_AddBandFromArray(inputExpressions: Seq[Expression])
extends InferredExpression(inferrableFunction3(MapAlgebra.addBandFromArray)) {
protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = {
copy(inputExpressions = newChildren)
}

override def inputTypes: Seq[AbstractDataType] = Seq(RasterUDT, ArrayType, IntegerType)

override def evalWithoutSerialization(input: InternalRow): GridCoverage2D = {
val raster = inputExpressions(0).toRaster(input)
if (raster == null) {
return null
}
val band = inputExpressions(1).eval(input).asInstanceOf[ArrayData].toDoubleArray()
val bandIndex = inputExpressions(2).eval(input).asInstanceOf[Int]
MapAlgebra.addBandFromArray(raster, band, bandIndex)
}
}

case class RS_BandAsArray(inputExpressions: Seq[Expression]) extends Expression with CodegenFallback
with ExpectsInputTypes {

override def nullable: Boolean = true

override def eval(input: InternalRow): Any = {
val raster = inputExpressions(0).toRaster(input)
if (raster == null) {
return null
}
val bandIndex = inputExpressions(1).eval(input).asInstanceOf[Int]
val band = MapAlgebra.bandAsArray(raster, bandIndex)
if (band == null) {
return null
}
new GenericArrayData(band)
}

override def dataType: DataType = ArrayType(DoubleType)

override def children: Seq[Expression] = inputExpressions

case class RS_BandAsArray(inputExpressions: Seq[Expression]) extends InferredExpression(MapAlgebra.bandAsArray _) {
protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = {
copy(inputExpressions = newChildren)
}

override def inputTypes: Seq[AbstractDataType] = Seq(RasterUDT, IntegerType)
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,34 +25,15 @@ import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression}
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
import org.apache.spark.sql.sedona_sql.UDT.{GeometryUDT, RasterUDT}
import org.apache.spark.sql.sedona_sql.expressions.implicits.InputExpressionEnhancer
import org.apache.spark.sql.sedona_sql.expressions.raster.implicits.RasterInputExpressionEnhancer
import org.apache.spark.sql.types.{AbstractDataType, ArrayType, DataType, DoubleType, IntegerType}
import org.apache.spark.sql.sedona_sql.expressions.InferrableFunctionConverter._
import org.apache.spark.sql.sedona_sql.expressions.InferredExpression

case class RS_Value(inputExpressions: Seq[Expression]) extends Expression with CodegenFallback with ExpectsInputTypes {

override def nullable: Boolean = true

override def dataType: DataType = DoubleType

override def eval(input: InternalRow): Any = {
val raster = inputExpressions.head.toRaster(input)
val geom = inputExpressions(1).toGeometry(input)
val band = inputExpressions(2).eval(input).asInstanceOf[Int]
if (raster == null || geom == null) {
null
} else {
PixelFunctions.value(raster, geom, band)
}
}

override def children: Seq[Expression] = inputExpressions

case class RS_Value(inputExpressions: Seq[Expression]) extends InferredExpression(PixelFunctions.value _) {
protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = {
copy(inputExpressions = newChildren)
}

override def inputTypes: Seq[AbstractDataType] = Seq(RasterUDT, GeometryUDT, IntegerType)
}

case class RS_Values(inputExpressions: Seq[Expression]) extends Expression with CodegenFallback with ExpectsInputTypes {
Expand Down Expand Up @@ -82,4 +63,4 @@ case class RS_Values(inputExpressions: Seq[Expression]) extends Expression with
}

override def inputTypes: Seq[AbstractDataType] = Seq(RasterUDT, ArrayType(GeometryUDT), IntegerType)
}
}
Loading

0 comments on commit 51c869f

Please sign in to comment.