Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SEDONA-646] Shapefile DataSource for DataFrame API #1553

Merged
merged 17 commits into from
Aug 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/linters/codespell.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
LOD
actualy
afterall
atmost
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,28 @@ public static Geometry transformToGivenTarget(
}
}

/**
* Get the SRID of a CRS from a WKT string
*
* @param crsWKT WKT string for CRS
* @return SRID
*/
public static int wktCRSToSRID(String crsWKT) {
try {
CoordinateReferenceSystem crs = CRS.parseWKT(crsWKT);
int srid = crsToSRID(crs);
if (srid == 0) {
Integer epsgCode = CRS.lookupEpsgCode(crs, true);
if (epsgCode != null) {
srid = epsgCode;
}
}
return srid;
} catch (FactoryException e) {
throw new IllegalArgumentException("Cannot parse CRS WKT", e);
}
}

/**
* Get the SRID of a CRS. We use the EPSG code of the CRS if available.
*
Expand Down
46 changes: 0 additions & 46 deletions docs/api/sql/Constructor.md
Original file line number Diff line number Diff line change
@@ -1,49 +1,3 @@
## Read ESRI Shapefile

Introduction: Construct a DataFrame from a Shapefile

Since: `v1.0.0`

SparkSQL example:

```scala
var spatialRDD = new SpatialRDD[Geometry]
spatialRDD.rawSpatialRDD = ShapefileReader.readToGeometryRDD(sparkSession.sparkContext, shapefileInputLocation)
var rawSpatialDf = Adapter.toDf(spatialRDD,sparkSession)
rawSpatialDf.createOrReplaceTempView("rawSpatialDf")
var spatialDf = sparkSession.sql("""
| ST_GeomFromWKT(rddshape), _c1, _c2
| FROM rawSpatialDf
""".stripMargin)
spatialDf.show()
spatialDf.printSchema()
```

!!!note
The path to the shapefile is the path to the folder that contains the .shp file, not the path to the .shp file itself. The file extensions of .shp, .shx, .dbf must be in lowercase. Assume you have a shape file called ==myShapefile==, the path should be `XXX/myShapefile`. The file structure should be like this:
```
- shapefile1
- shapefile2
- myshapefile
- myshapefile.shp
- myshapefile.shx
- myshapefile.dbf
- myshapefile...
- ...
```

!!!warning
Please make sure you use ==ST_GeomFromWKT== to create Geometry type column otherwise that column cannot be used in SedonaSQL.

If the file you are reading contains non-ASCII characters you'll need to explicitly set the Spark config before initializing the SparkSession, then you can use `ShapefileReader.readToGeometryRDD`.

Example:

```scala
spark.driver.extraJavaOptions -Dsedona.global.charset=utf8
spark.executor.extraJavaOptions -Dsedona.global.charset=utf8
```

## ST_GeomCollFromText

Introduction: Constructs a GeometryCollection from the WKT with the given SRID. If SRID is not provided then it defaults to 0. It returns `null` if the WKT is not a `GEOMETRYCOLLECTION`.
Expand Down
91 changes: 89 additions & 2 deletions docs/tutorial/sql.md
Original file line number Diff line number Diff line change
Expand Up @@ -459,9 +459,96 @@ root
|-- prop0: string (nullable = true)
```

## Load Shapefile using SpatialRDD
## Load Shapefile

Shapefile can be loaded by SpatialRDD and converted to DataFrame using Adapter. Please read [Load SpatialRDD](rdd.md#create-a-generic-spatialrdd) and [DataFrame <-> RDD](#convert-between-dataframe-and-spatialrdd).
Since v`1.7.0`, Sedona supports loading Shapefile as a DataFrame.

=== "Scala/Java"

```scala
val df = sedona.read.format("shapefile").load("/path/to/shapefile")
```

=== "Java"

```java
Dataset<Row> df = sedona.read().format("shapefile").load("/path/to/shapefile")
```

=== "Python"

```python
df = sedona.read.format("shapefile").load("/path/to/shapefile")
```

The input path can be a directory containing one or multiple shapefiles, or path to a `.shp` file.

- When the input path is a directory, all shapefiles directly under the directory will be loaded. If you want to load all shapefiles in subdirectories, please specify `.option("recursiveFileLookup", "true")`.
- When the input path is a `.shp` file, that shapefile will be loaded. Sedona will look for sibling files (`.dbf`, `.shx`, etc.) with the same main file name and load them automatically.

The name of the geometry column is `geometry` by default. You can change the name of the geometry column using the `geometry.name` option. If one of the non-spatial attributes is named "geometry", `geometry.name` must be configured to avoid conflict.

=== "Scala/Java"

```scala
val df = sedona.read.format("shapefile").option("geometry.name", "geom").load("/path/to/shapefile")
```

=== "Java"

```java
Dataset<Row> df = sedona.read().format("shapefile").option("geometry.name", "geom").load("/path/to/shapefile")
```

=== "Python"

```python
df = sedona.read.format("shapefile").option("geometry.name", "geom").load("/path/to/shapefile")
```

Each record in shapefile has a unique record number, that record number is not loaded by default. If you want to include record number in the loaded DataFrame, you can set the `key.name` option to the name of the record number column:

=== "Scala/Java"

```scala
val df = sedona.read.format("shapefile").option("key.name", "FID").load("/path/to/shapefile")
```

=== "Java"

```java
Dataset<Row> df = sedona.read().format("shapefile").option("key.name", "FID").load("/path/to/shapefile")
```

=== "Python"

```python
df = sedona.read.format("shapefile").option("key.name", "FID").load("/path/to/shapefile")
```

The character encoding of string attributes are inferred from the `.cpg` file. If you see garbled values in string fields, you can manually specify the correct charset using the `charset` option. For example:

=== "Scala/Java"

```scala
val df = sedona.read.format("shapefile").option("charset", "UTF-8").load("/path/to/shapefile")
```

=== "Java"

```java
Dataset<Row> df = sedona.read().format("shapefile").option("charset", "UTF-8").load("/path/to/shapefile")
```

=== "Python"

```python
df = sedona.read.format("shapefile").option("charset", "UTF-8").load("/path/to/shapefile")
```

### (Deprecated) Loading Shapefile using SpatialRDD

If you are using Sedona earlier than v`1.7.0`, you can load shapefiles as SpatialRDD and converted to DataFrame using Adapter. Please read [Load SpatialRDD](rdd.md#create-a-generic-spatialrdd) and [DataFrame <-> RDD](#convert-between-dataframe-and-spatialrdd).

## Load GeoParquet

Expand Down
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -709,7 +709,7 @@
<activation>
<property>
<name>spark</name>
<value>3.2.3</value>
<value>3.2</value>
</property>
</activation>
<properties>
Expand Down
85 changes: 85 additions & 0 deletions python/tests/sql/test_shapefile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import pytest
import os.path
import datetime

from tests.test_base import TestBase
from tests.tools import tests_resource


class TestShapefile(TestBase):
def test_read_simple(self):
input_location = os.path.join(tests_resource, "shapefiles/polygon")
df = self.spark.read.format("shapefile").load(input_location)
assert df.count() == 10000
rows = df.take(100)
for row in rows:
assert len(row) == 1
assert row["geometry"].geom_type in ("Polygon", "MultiPolygon")

def test_read_osm_pois(self):
input_location = os.path.join(tests_resource, "shapefiles/gis_osm_pois_free_1/gis_osm_pois_free_1.shp")
df = self.spark.read.format("shapefile").load(input_location)
assert df.count() == 12873
rows = df.take(100)
for row in rows:
assert len(row) == 5
assert row["geometry"].geom_type == "Point"
assert isinstance(row['osm_id'], str)
assert isinstance(row['fclass'], str)
assert isinstance(row['name'], str)
assert isinstance(row['code'], int)

def test_customize_geom_and_key_columns(self):
input_location = os.path.join(tests_resource, "shapefiles/gis_osm_pois_free_1")
df = self.spark.read.format("shapefile").option("geometry.name", "geom").option("key.name", "fid").load(input_location)
assert df.count() == 12873
rows = df.take(100)
for row in rows:
assert len(row) == 6
assert row["geom"].geom_type == "Point"
assert isinstance(row['fid'], int)
assert isinstance(row['osm_id'], str)
assert isinstance(row['fclass'], str)
assert isinstance(row['name'], str)
assert isinstance(row['code'], int)

def test_read_multiple_shapefiles(self):
input_location = os.path.join(tests_resource, "shapefiles/datatypes")
df = self.spark.read.format("shapefile").load(input_location)
rows = df.collect()
assert len(rows) == 9
for row in rows:
id = row['id']
assert row['aInt'] == id
if id is not None:
assert row['aUnicode'] == "测试" + str(id)
if id < 10:
assert row['aDecimal'] * 10 == id * 10 + id
assert row['aDecimal2'] is None
assert row['aDate'] == datetime.date(2020 + id, id, id)
else:
assert row['aDecimal'] is None
assert row['aDecimal2'] * 100 == id * 100 + id
assert row['aDate'] is None
else:
assert row['aUnicode'] == ''
assert row['aDecimal'] is None
assert row['aDecimal2'] is None
assert row['aDate'] is None
Original file line number Diff line number Diff line change
Expand Up @@ -137,13 +137,13 @@ public void parseFileHead(DataInputStream inputStream) throws IOException {
}

/**
* draw raw byte array of effective record
* Parse the next record in the .dbf file
*
* @param inputStream
* @return
* @throws IOException
* @param inputStream input stream of .dbf file
* @return a list of fields as their original representation in the dbf file
* @throws IOException if an I/O error occurs
*/
public String parsePrimitiveRecord(DataInputStream inputStream) throws IOException {
public List<byte[]> parse(DataInputStream inputStream) throws IOException {
if (isDone()) {
return null;
}
Expand All @@ -160,50 +160,34 @@ public String parsePrimitiveRecord(DataInputStream inputStream) throws IOExcepti
byte[] primitiveBytes = new byte[recordLength];
inputStream.readFully(primitiveBytes);
numRecordRead++; // update number of record read
return primitiveToAttributes(ByteBuffer.wrap(primitiveBytes));
return extractFieldBytes(ByteBuffer.wrap(primitiveBytes));
}

/**
* abstract attributes from primitive bytes according to field descriptors.
*
* @param inputStream
* @return
* @throws IOException
*/
public String primitiveToAttributes(DataInputStream inputStream) throws IOException {
byte[] delimiter = {'\t'};
Text attributes = new Text();
for (int i = 0; i < fieldDescriptors.size(); ++i) {
FieldDescriptor descriptor = fieldDescriptors.get(i);
/** Extract attributes from primitive bytes according to field descriptors. */
private List<byte[]> extractFieldBytes(ByteBuffer buffer) {
int numFields = fieldDescriptors.size();
List<byte[]> fieldBytesList = new ArrayList<>(numFields);
for (FieldDescriptor descriptor : fieldDescriptors) {
byte[] fldBytes = new byte[descriptor.getFieldLength()];
inputStream.readFully(fldBytes);
// System.out.println(descriptor.getFiledName() + " " + new String(fldBytes));
byte[] attr = new String(fldBytes).trim().getBytes();
if (i > 0) {
attributes.append(delimiter, 0, 1); // first attribute doesn't append '\t'
}
attributes.append(attr, 0, attr.length);
buffer.get(fldBytes, 0, fldBytes.length);
fieldBytesList.add(fldBytes);
}
String attrs = attributes.toString();
return attributes.toString();
return fieldBytesList;
}

/**
* abstract attributes from primitive bytes according to field descriptors.
*
* @param buffer
* @return
* @throws IOException
* @param fieldBytesList a list of primitive bytes
* @return string attributes delimited by '\t'
*/
public String primitiveToAttributes(ByteBuffer buffer) throws IOException {
public static String fieldBytesToString(List<byte[]> fieldBytesList) {
byte[] delimiter = {'\t'};
Text attributes = new Text();
for (int i = 0; i < fieldDescriptors.size(); ++i) {
FieldDescriptor descriptor = fieldDescriptors.get(i);
byte[] fldBytes = new byte[descriptor.getFieldLength()];
buffer.get(fldBytes, 0, fldBytes.length);
for (int i = 0; i < fieldBytesList.size(); ++i) {
byte[] fldBytes = fieldBytesList.get(i);
String charset = System.getProperty("sedona.global.charset", "default");
Boolean utf8flag = charset.equalsIgnoreCase("utf8");
boolean utf8flag = charset.equalsIgnoreCase("utf8");
byte[] attr = utf8flag ? fldBytes : fastParse(fldBytes, 0, fldBytes.length).trim().getBytes();
if (i > 0) {
attributes.append(delimiter, 0, 1); // first attribute doesn't append '\t'
Expand Down
Loading
Loading