Skip to content

Commit

Permalink
[SEDONA-319] Make result of RS_BandAsArray serializable (#899)
Browse files Browse the repository at this point in the history
  • Loading branch information
Kontinuation committed Jul 12, 2023
1 parent c764200 commit 3a1e8c6
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,25 @@
*/
package org.apache.sedona.common.raster;

import com.sun.media.imageioimpl.common.BogusColorSpace;
import org.geotools.coverage.CoverageFactoryFinder;
import org.geotools.coverage.GridSampleDimension;
import org.geotools.coverage.grid.GridCoverage2D;
import org.geotools.coverage.grid.GridCoverageFactory;

import javax.media.jai.RasterFactory;

import java.awt.Point;
import java.awt.Transparency;
import java.awt.color.ColorSpace;
import java.awt.image.BufferedImage;
import java.awt.image.ColorModel;
import java.awt.image.ComponentColorModel;
import java.awt.image.DataBuffer;
import java.awt.image.Raster;
import java.awt.image.RenderedImage;
import java.awt.image.WritableRaster;
import java.util.Arrays;

public class MapAlgebra
{
Expand Down Expand Up @@ -118,9 +127,14 @@ private static GridCoverage2D copyRasterAndAppendBand(GridCoverage2D gridCoverag
wr.setPixel(i, j, copiedPixels);
}
}
// Create a new GridCoverage2D with the copied image
GridCoverageFactory gridCoverageFactory = CoverageFactoryFinder.getGridCoverageFactory(null);
return gridCoverageFactory.create(gridCoverage2D.getName(), wr, gridCoverage2D.getEnvelope());
// Add a sample dimension for newly added band
int numBand = wr.getNumBands();
GridSampleDimension[] originalSampleDimensions = gridCoverage2D.getSampleDimensions();
GridSampleDimension[] sampleDimensions = new GridSampleDimension[numBand];
System.arraycopy(originalSampleDimensions, 0, sampleDimensions, 0, originalSampleDimensions.length);
sampleDimensions[numBand - 1] = new GridSampleDimension("band" + numBand);
// Construct a GridCoverage2D with the copied image.
return createCompatibleGridCoverage2D(gridCoverage2D, wr, sampleDimensions);
}

private static GridCoverage2D copyRasterAndReplaceBand(GridCoverage2D gridCoverage2D, int bandIndex, double[] bandValues) {
Expand All @@ -141,7 +155,20 @@ private static GridCoverage2D copyRasterAndReplaceBand(GridCoverage2D gridCovera
}
}
// Create a new GridCoverage2D with the copied image
return createCompatibleGridCoverage2D(gridCoverage2D, wr, gridCoverage2D.getSampleDimensions());
}

private static GridCoverage2D createCompatibleGridCoverage2D(GridCoverage2D gridCoverage2D, WritableRaster wr, GridSampleDimension[] bands) {
int rasterDataType = wr.getDataBuffer().getDataType();
int numBand = wr.getNumBands();
final ColorSpace cs = new BogusColorSpace(numBand);
final int[] nBits = new int[numBand];
Arrays.fill(nBits, DataBuffer.getDataTypeSize(rasterDataType));
ColorModel colorModel =
new ComponentColorModel(cs, nBits, false, true, Transparency.OPAQUE, rasterDataType);
final RenderedImage image = new BufferedImage(colorModel, wr, false, null);
GridCoverageFactory gridCoverageFactory = CoverageFactoryFinder.getGridCoverageFactory(null);
return gridCoverageFactory.create(gridCoverage2D.getName(), wr, gridCoverage2D.getEnvelope());
return gridCoverageFactory.create(gridCoverage2D.getName(), image, gridCoverage2D.getCoordinateReferenceSystem(),
gridCoverage2D.getGridGeometry().getGridToCRS(), bands, null, null);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -452,5 +452,16 @@ class rasteralgebraTest extends TestBaseScala with BeforeAndAfter with GivenWhen
assert(df.selectExpr("RS_Intersects(raster, ST_SetSRID(ST_Point(33.81798,-117.47993), 4326))").first().getBoolean(0))
assert(!df.selectExpr("RS_Intersects(raster, ST_SetSRID(ST_Point(33.97896,-117.27868), 4326))").first().getBoolean(0))
}

it("Passed RS_AddBandFromArray collect generated raster") {
var df = sparkSession.read.format("binaryFile").load(resourceFolder + "raster/test1.tiff")
df = df.selectExpr("RS_FromGeoTiff(content) as raster", "RS_BandAsArray(RS_FromGeoTiff(content), 1) as band")
df = df.selectExpr("RS_AddBandFromArray(raster, band, 1) as raster", "band")
var raster = df.collect().head.getAs[GridCoverage2D](0)
assert(raster.getNumSampleDimensions == 1)
df = df.selectExpr("RS_AddBandFromArray(raster, band, 2) as raster", "band")
raster = df.collect().head.getAs[GridCoverage2D](0)
assert(raster.getNumSampleDimensions == 2)
}
}
}

0 comments on commit 3a1e8c6

Please sign in to comment.