diff --git a/common/src/main/java/org/apache/sedona/common/raster/MapAlgebra.java b/common/src/main/java/org/apache/sedona/common/raster/MapAlgebra.java index 0a24cce24d..67e8ad189f 100644 --- a/common/src/main/java/org/apache/sedona/common/raster/MapAlgebra.java +++ b/common/src/main/java/org/apache/sedona/common/raster/MapAlgebra.java @@ -18,16 +18,25 @@ */ package org.apache.sedona.common.raster; +import com.sun.media.imageioimpl.common.BogusColorSpace; import org.geotools.coverage.CoverageFactoryFinder; +import org.geotools.coverage.GridSampleDimension; import org.geotools.coverage.grid.GridCoverage2D; import org.geotools.coverage.grid.GridCoverageFactory; import javax.media.jai.RasterFactory; import java.awt.Point; +import java.awt.Transparency; +import java.awt.color.ColorSpace; +import java.awt.image.BufferedImage; +import java.awt.image.ColorModel; +import java.awt.image.ComponentColorModel; +import java.awt.image.DataBuffer; import java.awt.image.Raster; import java.awt.image.RenderedImage; import java.awt.image.WritableRaster; +import java.util.Arrays; public class MapAlgebra { @@ -118,9 +127,14 @@ private static GridCoverage2D copyRasterAndAppendBand(GridCoverage2D gridCoverag wr.setPixel(i, j, copiedPixels); } } - // Create a new GridCoverage2D with the copied image - GridCoverageFactory gridCoverageFactory = CoverageFactoryFinder.getGridCoverageFactory(null); - return gridCoverageFactory.create(gridCoverage2D.getName(), wr, gridCoverage2D.getEnvelope()); + // Add a sample dimension for newly added band + int numBand = wr.getNumBands(); + GridSampleDimension[] originalSampleDimensions = gridCoverage2D.getSampleDimensions(); + GridSampleDimension[] sampleDimensions = new GridSampleDimension[numBand]; + System.arraycopy(originalSampleDimensions, 0, sampleDimensions, 0, originalSampleDimensions.length); + sampleDimensions[numBand - 1] = new GridSampleDimension("band" + numBand); + // Construct a GridCoverage2D with the copied image. + return createCompatibleGridCoverage2D(gridCoverage2D, wr, sampleDimensions); } private static GridCoverage2D copyRasterAndReplaceBand(GridCoverage2D gridCoverage2D, int bandIndex, double[] bandValues) { @@ -141,7 +155,20 @@ private static GridCoverage2D copyRasterAndReplaceBand(GridCoverage2D gridCovera } } // Create a new GridCoverage2D with the copied image + return createCompatibleGridCoverage2D(gridCoverage2D, wr, gridCoverage2D.getSampleDimensions()); + } + + private static GridCoverage2D createCompatibleGridCoverage2D(GridCoverage2D gridCoverage2D, WritableRaster wr, GridSampleDimension[] bands) { + int rasterDataType = wr.getDataBuffer().getDataType(); + int numBand = wr.getNumBands(); + final ColorSpace cs = new BogusColorSpace(numBand); + final int[] nBits = new int[numBand]; + Arrays.fill(nBits, DataBuffer.getDataTypeSize(rasterDataType)); + ColorModel colorModel = + new ComponentColorModel(cs, nBits, false, true, Transparency.OPAQUE, rasterDataType); + final RenderedImage image = new BufferedImage(colorModel, wr, false, null); GridCoverageFactory gridCoverageFactory = CoverageFactoryFinder.getGridCoverageFactory(null); - return gridCoverageFactory.create(gridCoverage2D.getName(), wr, gridCoverage2D.getEnvelope()); + return gridCoverageFactory.create(gridCoverage2D.getName(), image, gridCoverage2D.getCoordinateReferenceSystem(), + gridCoverage2D.getGridGeometry().getGridToCRS(), bands, null, null); } } diff --git a/sql/common/src/test/scala/org/apache/sedona/sql/rasteralgebraTest.scala b/sql/common/src/test/scala/org/apache/sedona/sql/rasteralgebraTest.scala index 8ef24399e8..7728b772d0 100644 --- a/sql/common/src/test/scala/org/apache/sedona/sql/rasteralgebraTest.scala +++ b/sql/common/src/test/scala/org/apache/sedona/sql/rasteralgebraTest.scala @@ -452,5 +452,16 @@ class rasteralgebraTest extends TestBaseScala with BeforeAndAfter with GivenWhen assert(df.selectExpr("RS_Intersects(raster, ST_SetSRID(ST_Point(33.81798,-117.47993), 4326))").first().getBoolean(0)) assert(!df.selectExpr("RS_Intersects(raster, ST_SetSRID(ST_Point(33.97896,-117.27868), 4326))").first().getBoolean(0)) } + + it("Passed RS_AddBandFromArray collect generated raster") { + var df = sparkSession.read.format("binaryFile").load(resourceFolder + "raster/test1.tiff") + df = df.selectExpr("RS_FromGeoTiff(content) as raster", "RS_BandAsArray(RS_FromGeoTiff(content), 1) as band") + df = df.selectExpr("RS_AddBandFromArray(raster, band, 1) as raster", "band") + var raster = df.collect().head.getAs[GridCoverage2D](0) + assert(raster.getNumSampleDimensions == 1) + df = df.selectExpr("RS_AddBandFromArray(raster, band, 2) as raster", "band") + raster = df.collect().head.getAs[GridCoverage2D](0) + assert(raster.getNumSampleDimensions == 2) + } } }