Skip to content

Commit

Permalink
[SEDONA-326] Improve raster algebra functions: RS_Array and RS_Multip…
Browse files Browse the repository at this point in the history
…lyFactor (#907)
  • Loading branch information
Kontinuation committed Jul 18, 2023
1 parent d4d0110 commit d1f8b0e
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,6 @@ public static Geometry polygonFromEnvelope(double minX, double minY, double maxX
}

public static Geometry geomFromGeoHash(String geoHash, Integer precision) {
System.out.println(geoHash);
System.out.println(precision);
try {
return GeoHashDecoder.decode(geoHash, precision);
} catch (GeoHashDecoder.InvalidGeoHashException e) {
Expand Down
2 changes: 1 addition & 1 deletion docs/api/sql/Raster-loader.md
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ Output:

Introduction: Create an array that is filled by the given value

Format: `RS_Array(length:Int, value: Decimal)`
Format: `RS_Array(length:Int, value: Double)`

Since: `v1.1.0`

Expand Down
4 changes: 3 additions & 1 deletion docs/api/sql/Raster-operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,7 @@ val multiplyDF = spark.sql("select RS_Multiply(band1, band2) as multiplyBands fr

Introduction: Multiply a factor to a spectral band in a geotiff image

Format: `RS_MultiplyFactor (Band1: Array[Double], Factor: Int)`
Format: `RS_MultiplyFactor (Band1: Array[Double], Factor: Double)`

Since: `v1.1.0`

Expand All @@ -528,6 +528,8 @@ val multiplyFactorDF = spark.sql("select RS_MultiplyFactor(band1, 2) as multiply

```

This function only accepts integer as factor before `v1.5.0`.

### RS_Normalize

Introduction: Normalize the value in the array to [0, 255]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ package org.apache.spark.sql.sedona_sql.expressions.raster
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.expressions.{Expression, UnsafeArrayData}
import org.apache.spark.sql.catalyst.expressions.ImplicitCastInputTypes
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
import org.apache.spark.sql.sedona_sql.expressions.UserDataGeneratator
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -126,14 +127,14 @@ case class RS_GetBand(inputExpressions: Seq[Expression])
}

case class RS_Array(inputExpressions: Seq[Expression])
extends Expression with CodegenFallback with UserDataGeneratator {
extends Expression with ImplicitCastInputTypes with CodegenFallback with UserDataGeneratator {
override def nullable: Boolean = false

override def eval(inputRow: InternalRow): Any = {
// This is an expression which takes one input expressions
assert(inputExpressions.length == 2)
val len =inputExpressions(0).eval(inputRow).asInstanceOf[Int]
val num = inputExpressions(1).eval(inputRow).asInstanceOf[Decimal].toDouble
val num = inputExpressions(1).eval(inputRow).asInstanceOf[Double]
val result = createarray(len, num)
new GenericArrayData(result)
}
Expand All @@ -148,6 +149,8 @@ case class RS_Array(inputExpressions: Seq[Expression])
result
}

override def inputTypes: Seq[AbstractDataType] = Seq(IntegerType, DoubleType)

override def dataType: DataType = ArrayType(DoubleType)

override def children: Seq[Expression] = inputExpressions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.apache.sedona.common.raster.{MapAlgebra, Serde}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression}
import org.apache.spark.sql.catalyst.expressions.ImplicitCastInputTypes
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
import org.apache.spark.sql.sedona_sql.UDT.RasterUDT
import org.apache.spark.sql.sedona_sql.expressions.raster.implicits.RasterInputExpressionEnhancer
Expand Down Expand Up @@ -352,30 +353,31 @@ case class RS_Count(inputExpressions: Seq[Expression])

// Multiply a factor to all values of a band
case class RS_MultiplyFactor(inputExpressions: Seq[Expression])
extends Expression with CodegenFallback with UserDataGeneratator {
extends Expression with ImplicitCastInputTypes with CodegenFallback with UserDataGeneratator {
assert(inputExpressions.length == 2)

override def nullable: Boolean = false

override def eval(inputRow: InternalRow): Any = {
val band = inputExpressions(0).eval(inputRow).asInstanceOf[ArrayData].toDoubleArray()
val target = inputExpressions(1).eval(inputRow).asInstanceOf[Int]
new GenericArrayData(multiply(band, target))
val factor = inputExpressions(1).eval(inputRow).asInstanceOf[Double]
new GenericArrayData(multiply(band, factor))

}

private def multiply(band: Array[Double], target: Int):Array[Double] = {
private def multiply(band: Array[Double], factor: Double):Array[Double] = {

var result = new Array[Double](band.length)
for(i<-0 until band.length) {

result(i) = band(i)*target
result(i) = band(i) * factor

}
result

}

override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType(DoubleType), DoubleType)

override def dataType: DataType = ArrayType(DoubleType)

override def children: Seq[Expression] = inputExpressions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,12 @@ class rasteralgebraTest extends TestBaseScala with BeforeAndAfter with GivenWhen
assert(inputDf.first().getAs[mutable.WrappedArray[Double]](0) == expectedDF.first().getAs[mutable.WrappedArray[Double]](0))
}

it("Passed RS_MultiplyFactor with double factor") {
val inputDf = Seq((Seq(200.0, 400.0, 600.0))).toDF("Band")
val expectedDF = Seq((Seq(20.0, 40.0, 60.0))).toDF("multiply")
val actualDF = inputDf.selectExpr("RS_MultiplyFactor(Band, 0.1) as multiply")
assert(actualDF.first().getAs[mutable.WrappedArray[Double]](0) == expectedDF.first().getAs[mutable.WrappedArray[Double]](0))
}
}

describe("Should pass basic statistical tests") {
Expand Down Expand Up @@ -202,6 +208,13 @@ class rasteralgebraTest extends TestBaseScala with BeforeAndAfter with GivenWhen
df = df.selectExpr("RS_Normalize(Band) as normalizedBand")
assert(df.first().getAs[mutable.WrappedArray[Double]](0)(1) == 255)
}

it("should pass RS_Array") {
val df = sparkSession.sql("SELECT RS_Array(6, 1e-6) as band")
val result = df.first().getAs[mutable.WrappedArray[Double]](0)
assert(result.length == 6)
assert(result sameElements Array.fill[Double](6)(1e-6))
}
}

describe("Should pass all transformation tests") {
Expand Down

0 comments on commit d1f8b0e

Please sign in to comment.