Skip to content

Commit

Permalink
Fetch upstream (#1007)
Browse files Browse the repository at this point in the history
Co-authored-by: Furqaanahmed Khan <[email protected]>
  • Loading branch information
jiayuasu and furqaankhan committed Sep 7, 2023
1 parent f5b78be commit ecc2551
Show file tree
Hide file tree
Showing 12 changed files with 196 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,12 @@
import org.geotools.coverage.GridSampleDimension;
import org.geotools.coverage.grid.GridCoordinates2D;
import org.geotools.coverage.grid.GridCoverage2D;
import org.opengis.referencing.FactoryException;

import java.awt.image.Raster;
import java.awt.image.WritableRaster;
import java.util.Arrays;
import java.util.HashMap;

public class RasterBandAccessors {

Expand Down Expand Up @@ -134,6 +138,39 @@ public static double[] getSummaryStats(GridCoverage2D raster) {
// return getSummaryStats(raster, 1, excludeNoDataValue);
// }

public static GridCoverage2D getBand(GridCoverage2D rasterGeom, int[] bandIndexes) throws FactoryException {
Double noDataValue;
double[] metadata = RasterAccessors.metadata(rasterGeom);
int width = (int) metadata[2], height = (int) metadata[3];
GridCoverage2D resultRaster = RasterConstructors.makeEmptyRaster(bandIndexes.length, width, height,
metadata[0], metadata[1], metadata[4], metadata[5], metadata[6], metadata[7], (int) metadata[8]);

// Get band data that's required
int[] bandsDistinct = Arrays.stream(bandIndexes).distinct().toArray();
HashMap<Integer, double[]> bandData = new HashMap<>();
for (int curBand: bandsDistinct) {
RasterUtils.ensureBand(rasterGeom, curBand);
bandData.put(curBand - 1, MapAlgebra.bandAsArray(rasterGeom, curBand));
}

// Get Writable Raster from the resultRaster
WritableRaster wr = resultRaster.getRenderedImage().getData().createCompatibleWritableRaster();

GridSampleDimension[] sampleDimensionsOg = rasterGeom.getSampleDimensions();
GridSampleDimension[] sampleDimensionsResult = resultRaster.getSampleDimensions();
for (int i = 0; i < bandIndexes.length; i ++) {
sampleDimensionsResult[i] = sampleDimensionsOg[bandIndexes[i] - 1];
wr.setSamples(0, 0, width, height, i, bandData.get(bandIndexes[i] - 1));
noDataValue = RasterBandAccessors.getBandNoDataValue(rasterGeom, bandIndexes[i]);
GridSampleDimension sampleDimension = sampleDimensionsResult[i];
if (noDataValue != null) {
sampleDimensionsResult[i] = RasterUtils.createSampleDimensionWithNoDataValue(sampleDimension, noDataValue);
}
}

return RasterUtils.create(wr, resultRaster.getGridGeometry(), sampleDimensionsResult);
}

public static String getBandType(GridCoverage2D raster, int band) {
RasterUtils.ensureBand(raster, band);
GridSampleDimension bandSampleDimension = raster.getSampleDimension(band - 1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.opengis.referencing.FactoryException;

import java.io.IOException;
import java.util.Arrays;

import static org.junit.Assert.*;

Expand Down Expand Up @@ -246,6 +247,42 @@ public void testCountWithRaster() throws IOException {

}

@Test
public void testGetBand() throws FactoryException {
GridCoverage2D emptyRaster = RasterConstructors.makeEmptyRaster( 4, 5, 5, 3, -215, 2, -2, 2, 2, 0);
double[] values1 = new double[] {16, 0, 24, 33, 43, 49, 64, 0, 76, 77, 79, 89, 0, 116, 118, 125, 135, 0, 157, 190, 215, 229, 241, 248, 249};
emptyRaster = MapAlgebra.addBandFromArray(emptyRaster, values1, 3, 0d);
GridCoverage2D resultRaster = RasterBandAccessors.getBand(emptyRaster, new int[]{3,3,3});
int actual = RasterAccessors.numBands(resultRaster);
int expected = 3;
assertEquals(expected, actual);

double[] actualMetadata = Arrays.stream(RasterAccessors.metadata(resultRaster), 0, 9).toArray();
double[] expectedMetadata = Arrays.stream(RasterAccessors.metadata(emptyRaster), 0, 9).toArray();
assertArrayEquals(expectedMetadata, actualMetadata, 0.1d);

double[] actualBandValues = MapAlgebra.bandAsArray(resultRaster, 3);
double[] expectedBandValues = MapAlgebra.bandAsArray(emptyRaster, 3);
assertArrayEquals(expectedBandValues, actualBandValues, 0.1d);
}

@Test
public void testGetBandWithRaster() throws IOException, FactoryException {
GridCoverage2D raster = rasterFromGeoTiff(resourceFolder + "raster_geotiff_color/FAA_UTM18N_NAD83.tif");
GridCoverage2D resultRaster = RasterBandAccessors.getBand(raster, new int[] {1,2,2,2,1});
int actual = RasterAccessors.numBands(resultRaster);
int expected = 5;
assertEquals(actual, expected);

double[] actualMetadata = Arrays.stream(RasterAccessors.metadata(resultRaster), 0, 9).toArray();
double[] expectedMetadata = Arrays.stream(RasterAccessors.metadata(raster), 0, 9).toArray();
assertArrayEquals(expectedMetadata, actualMetadata, 0.1d);

double[] actualBandValues = MapAlgebra.bandAsArray(raster, 2);
double[] expectedBandValues = MapAlgebra.bandAsArray(resultRaster, 2);
assertArrayEquals(expectedBandValues, actualBandValues, 0.1d);
}

@Test
public void testBandPixelType() throws FactoryException {
double[] values = new double[]{1.2, 1.1, 32.2, 43.2};
Expand Down
31 changes: 31 additions & 0 deletions docs/api/sql/Raster-operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,37 @@ Output: `3`

## Raster Band Accessors

### RS_Band

Introduction: Returns a new raster consisting 1 or more bands of an existing raster. It can build new rasters from
existing ones, export only selected bands from a multiband raster, or rearrange the order of bands in a raster dataset.

Format:

`RS_Band(raster: Raster, bands: ARRAY[Integer])`

Since: `v1.5.0`

Spark SQL Example:

```sql
SELECT RS_NumBands(
RS_Band(
RS_AddBandFromArray(
RS_MakeEmptyRaster(2, 5, 5, 3, -215, 2, -2, 2, 2, 0),
Array(16, 0, 24, 33, 43, 49, 64, 0, 76, 77, 79, 89, 0, 116, 118, 125, 135, 0, 157, 190, 215, 229, 241, 248, 249),
1, 0d
), Array(1,1,1)
)
)
```

Output:

```
3
```

### RS_BandNoDataValue

Introduction: Returns the no data value of the given band of the given raster. If no band is given, band 1 is assumed. The band parameter is 1-indexed. If there is no no data value associated with the given band, RS_BandNoDataValue returns null.
Expand Down
11 changes: 11 additions & 0 deletions examples/flink-sql/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,17 @@
</dependency>
</dependencies>
<repositories>
<repository>
<id>maven-central</id>
<name>Maven Central Repository</name>
<url>https://repo.maven.apache.org/maven2/</url>
<snapshots>
<enabled>false</enabled>
</snapshots>
<releases>
<enabled>true</enabled>
</releases>
</repository>
<repository>
<id>maven2-repository.dev.java.net</id>
<name>Java.net repository</name>
Expand Down
1 change: 1 addition & 0 deletions examples/spark-rdd-colocation-mining/build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ assemblyMergeStrategy in assembly := {
}

resolvers ++= Seq(
"Maven Central" at "https://repo.maven.apache.org/maven2/",
"Open Source Geospatial Foundation Repository" at "https://repo.osgeo.org/repository/release/",
"Apache Software Foundation Snapshots" at "https://repository.apache.org/content/groups/snapshots",
"Java.net repository" at "https://download.java.net/maven/2"
Expand Down
1 change: 1 addition & 0 deletions examples/spark-sql/build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ assemblyMergeStrategy in assembly := {
}

resolvers ++= Seq(
"Maven Central" at "https://repo.maven.apache.org/maven2/",
"Open Source Geospatial Foundation Repository" at "https://repo.osgeo.org/repository/release/",
"Apache Software Foundation Snapshots" at "https://repository.apache.org/content/groups/snapshots",
"Java.net repository" at "https://download.java.net/maven/2"
Expand Down
1 change: 1 addition & 0 deletions examples/spark-viz/build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ assemblyMergeStrategy in assembly := {
}

resolvers ++= Seq(
"Maven Central" at "https://repo.maven.apache.org/maven2/",
"Open Source Geospatial Foundation Repository" at "https://repo.osgeo.org/repository/release/",
"Apache Software Foundation Snapshots" at "https://repository.apache.org/content/groups/snapshots",
"Java.net repository" at "https://download.java.net/maven/2"
Expand Down
11 changes: 11 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,17 @@
</dependencies>
</dependencyManagement>
<repositories>
<repository>
<id>maven-central</id>
<name>Maven Central Repository</name>
<url>https://repo.maven.apache.org/maven2/</url>
<snapshots>
<enabled>false</enabled>
</snapshots>
<releases>
<enabled>true</enabled>
</releases>
</repository>
<repository>
<id>osgeo</id>
<name>OSGeo Release Repository</name>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ object Catalog {
function[RS_PixelAsPolygon](),
function[RS_PixelAsCentroid](),
function[RS_Count](),
function[RS_Band](),
function[RS_SummaryStats](),
function[RS_ConvexHull](),
function[RS_RasterToWorldCoordX](),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,10 @@ object InferrableType {
new InferrableType[Array[Byte]] {}
implicit val longArrayInstance: InferrableType[Array[java.lang.Long]] =
new InferrableType[Array[java.lang.Long]] {}
implicit val intArrayInstance: InferrableType[Array[Int]] =
new InferrableType[Array[Int]] {}
implicit val javaIntArrayInstance: InferrableType[Array[java.lang.Integer]] =
new InferrableType[Array[java.lang.Integer]]
implicit val doubleArrayInstance: InferrableType[Array[Double]] =
new InferrableType[Array[Double]] {}
implicit val longInstance: InferrableType[Long] =
Expand Down Expand Up @@ -193,6 +197,10 @@ object InferredTypes {
StringType
} else if (t =:= typeOf[Array[Byte]]) {
BinaryType
} else if (t =:= typeOf[Array[Int]]) {
DataTypes.createArrayType(IntegerType)
} else if (t =:= typeOf[Array[java.lang.Integer]]) {
DataTypes.createArrayType(IntegerType)
} else if (t =:= typeOf[Array[java.lang.Long]]) {
DataTypes.createArrayType(LongType)
} else if (t =:= typeOf[Array[Double]]) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,14 @@
package org.apache.spark.sql.sedona_sql.expressions.raster

import org.apache.sedona.common.raster.RasterBandAccessors
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.util.ArrayData
import org.apache.spark.sql.sedona_sql.UDT.RasterUDT
import org.apache.spark.sql.sedona_sql.expressions.InferrableFunctionConverter._
import org.apache.spark.sql.sedona_sql.expressions.raster.implicits.RasterInputExpressionEnhancer
import org.apache.spark.sql.sedona_sql.expressions.InferredExpression
import org.geotools.coverage.grid.GridCoverage2D

case class RS_BandNoDataValue(inputExpressions: Seq[Expression]) extends InferredExpression(inferrableFunction2(RasterBandAccessors.getBandNoDataValue), inferrableFunction1(RasterBandAccessors.getBandNoDataValue)) {
protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = {
Expand All @@ -45,6 +50,28 @@ case class RS_SummaryStats(inputExpressions: Seq[Expression]) extends InferredEx
}
}

case class RS_Band(inputExpressions: Seq[Expression]) extends InferredExpression(RasterBandAccessors.getBand _) {

override def evalWithoutSerialization(input: InternalRow): Any = {
val raster = inputExpressions.head.toRaster(input)
val intArray = inputExpressions(1).eval(input).asInstanceOf[ArrayData]
if (raster == null) {
null
} else {
val values = (0 until intArray.numElements()).map(i => intArray.getInt(i))
RasterBandAccessors.getBand(raster, values.toArray)
}
}

override def eval(input: InternalRow): Any = {
RasterUDT.serialize(evalWithoutSerialization(input).asInstanceOf[GridCoverage2D])
}

protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = {
copy(inputExpressions = newChildren)
}
}

case class RS_BandPixelType(inputExpressions: Seq[Expression]) extends InferredExpression(inferrableFunction2(RasterBandAccessors.getBandType), inferrableFunction1(RasterBandAccessors.getBandType)) {
protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = {
copy(inputExpressions = newChildren)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,36 @@ class rasteralgebraTest extends TestBaseScala with BeforeAndAfter with GivenWhen
assertEquals(expected, actual)
}

it("Passed RS_Band") {
val inputDf = Seq((Seq(16, 0, 24, 33, 43, 49, 64, 0, 76, 77, 79, 89, 0, 116, 118, 125, 135, 0, 157, 190, 215, 229, 241, 248, 249))).toDF("band")
val df = inputDf.selectExpr("RS_AddBandFromArray(RS_MakeEmptyRaster(2, 5, 5, 3, -215, 2, -2, 2, 2, 0), band, 1, 0d) as emptyRaster")
val resultDf = df.selectExpr("RS_Band(emptyRaster, array(1,1,1)) as raster")
val actual = resultDf.selectExpr("RS_NumBands(raster)").first().get(0)
val expected = 3
assertEquals(expected, actual)

val actualMetadata = resultDf.selectExpr("RS_Metadata(raster)").first().getSeq(0).slice(0, 9)
val expectedMetadata = df.selectExpr("RS_Metadata(emptyRaster)").first().getSeq(0).slice(0, 9)
assertEquals(expectedMetadata.toString(), actualMetadata.toString())

val actualBandValues = resultDf.selectExpr("RS_BandAsArray(raster, 1)").first().getSeq(0)
val expectedBandValues = df.selectExpr("RS_BandAsArray(emptyRaster, 1)").first().getSeq(0)
assertEquals(expectedBandValues.toString(), actualBandValues.toString())
}

it("Passed RS_Band with raster") {
val dfFile = sparkSession.read.format("binaryFile").load(resourceFolder + "raster_geotiff_color/FAA_UTM18N_NAD83.tif")
val df = dfFile.selectExpr("RS_FromGeoTiff(content) as raster")
val resultDf = df.selectExpr("RS_Band(raster, array(1,2,2,2,1)) as resultRaster")
val actual = resultDf.selectExpr("RS_NumBands(resultRaster)").first().getInt(0)
val expected = 5
assertEquals(expected, actual)

val actualMetadata = resultDf.selectExpr("RS_Metadata(resultRaster)").first().getSeq(0).slice(0, 9)
val expectedMetadata = df.selectExpr("RS_Metadata(raster)").first().getSeq(0).slice(0, 9)
assertEquals(expectedMetadata.toString(), actualMetadata.toString())
}

it("Passed RS_SetValues with empty raster") {
var inputDf = Seq((Seq(1, 1, 1, 0, 0, 0, 1, 2, 3, 3, 5, 6, 7, 0, 0, 3, 0, 0, 3, 0, 0, 0, 0, 0, 0), Seq(11, 12, 13, 14, 15, 16, 17, 18, 19))).toDF("band","newValues")
var df = inputDf.selectExpr("RS_AddBandFromArray(RS_MakeEmptyRaster(1, 5, 5, 0, 0, 1, -1, 0, 0, 0), band, 1, 0d) as raster", "newValues")
Expand Down

0 comments on commit ecc2551

Please sign in to comment.