Skip to content

Commit

Permalink
Add function to join geometry to the nearest geometry
Browse files Browse the repository at this point in the history
  • Loading branch information
BjoernWaechter committed Nov 11, 2023
1 parent 82bf08e commit fc2dc30
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 2 deletions.
76 changes: 76 additions & 0 deletions osm_address/transform/geo_join.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
import random

from pyspark.sql import DataFrame
from pyspark.sql.functions import expr
from sedona.sql.types import GeometryType

from osm_address.transform import remove_duplicate_rows
from osm_address.utils.schema import check_for_duplicate_columns


def join_point_in_multipolygon(
df_in: DataFrame,
Expand Down Expand Up @@ -37,3 +42,74 @@ def join_point_in_multipolygon(
)

return df_result


def join_nearest_geometry(
df_in_1: DataFrame,
df_in_2: DataFrame,
column_name_1: str,
column_name_2: str,
epsg_1="epsg:4326",
epsg_2="epsg:4326",
epsg_meter_based="epsg:25832",
partition_count=1000,
distance_meter_column="distance",
max_meter=5000,
join_type="leftouter"

):
check_for_duplicate_columns(
df_1=df_in_1,
df_2=df_in_2
)

unique_id_col = f"uniqueid_{random.randint(1,999999)}"

epsg_col_1 = f"{column_name_1}_{epsg_meter_based.replace(':','_')}"
epsg_col_2 = f"{column_name_2}_{epsg_meter_based.replace(':', '_')}"

df_in_1_epsg = df_in_1.withColumn(
epsg_col_1,
expr(f"ST_Transform({column_name_1}, '{epsg_1}', '{epsg_meter_based}')")
).withColumn(
unique_id_col,
expr("monotonically_increasing_id()")
)

df_in_2_epsg = df_in_2.withColumn(
epsg_col_2,
expr(f"ST_Transform({column_name_2}, '{epsg_2}', '{epsg_meter_based}')")
)

df_max_distance = df_in_1_epsg.repartition(partition_count).join(
other=df_in_2_epsg,
on=expr(f"ST_Distance("
f"{epsg_col_1}, "
f"{epsg_col_2}) <= {max_meter}"),
how=join_type
)

df_dist_result = df_max_distance.withColumn(
distance_meter_column,
expr(
f"CASE "
f" WHEN {column_name_2}_{epsg_meter_based.replace(':','_')} IS NOT NULL THEN "
f" ST_Distance("
f"{epsg_col_1}, "
f"{epsg_col_2}) "
f" ELSE NULL "
f"END")
)

df_nearest_only = remove_duplicate_rows(
df_input=df_dist_result,
unique_col=unique_id_col,
decision_col=distance_meter_column,
decision_max_first=False
).drop(
unique_id_col,
epsg_col_1,
epsg_col_2
)

return df_nearest_only
2 changes: 1 addition & 1 deletion osm_address/udf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ def create_polygon(points, pos_prefix):
h[f"{pos_prefix}latitude"]
) for h in points]

if len(raw_points) >= 4:
if len(raw_points) >= 4 and raw_points[0] == raw_points[-1]:
return Polygon(shell=raw_points)
elif len(raw_points) == 1:
return Point(raw_points[0])
Expand Down
63 changes: 62 additions & 1 deletion tests/unit/osm_address/transform/test_geo_join.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import pytest
from pyspark.sql.types import StructField, StructType, StringType
from pyspark.sql.functions import expr
from pyspark.sql.types import StructField, StructType, StringType, Row
from sedona.sql.types import GeometryType

from osm_address.transform import (create_border_polygons,
join_point_in_multipolygon)
from osm_address.transform.geo_join import join_nearest_geometry
from osm_address.transform.get_points import get_points_from_nodes_and_ways


class TestGeoJoin:
Expand Down Expand Up @@ -158,3 +161,61 @@ def test_join_point_in_poly_missing_column(self, test_context):
)
assert "'region' in data frame df_polygon is of type" in str(excinfo)

def test_nearest_point_to_point(self, test_context):

df_hospital = get_points_from_nodes_and_ways(
osm_data=test_context.osm_data,
osm_filter="element_at(tags, 'amenity') = 'hospital'",
point_column=f"hospital_point",
id_column="hospital_id",
centroids_only=True
).drop(
"longitude",
"latitude",
"tags"
)

df_pharmacy = get_points_from_nodes_and_ways(
osm_data=test_context.osm_data,
osm_filter="element_at(tags, 'amenity') = 'pharmacy'",
additional_columns={"hospital": "element_at(tags, 'name')"},
point_column=f"pharmacy_point",
id_column="pharmacy_id",
centroids_only=True
).drop(
"longitude",
"latitude",
"tags"
)

df_join = join_nearest_geometry(
df_in_1=df_hospital,
df_in_2=df_pharmacy,
column_name_1="hospital_point",
column_name_2="pharmacy_point",
partition_count=4,
distance_meter_column="distance",
max_meter=6000,
join_type="leftouter"
).withColumn(
"distance_rounded",
expr("ROUND(distance)")
).orderBy(
"hospital_id"
).select(
"hospital_id",
"pharmacy_id",
"distance_rounded"
)

assert df_join.collect() == [
Row(hospital_id='N2050364490', pharmacy_id='N2050466659', distance_rounded=812.0),
Row(hospital_id='N4942543822', pharmacy_id='N690708548', distance_rounded=441.0),
Row(hospital_id='N522787974', pharmacy_id='N590463655', distance_rounded=1167.0),
Row(hospital_id='N5262330928', pharmacy_id='N4984739294', distance_rounded=1626.0),
Row(hospital_id='N666793601', pharmacy_id='N4984739294', distance_rounded=5434.0),
Row(hospital_id='N666793602', pharmacy_id='N4984739294', distance_rounded=4127.0),
Row(hospital_id='N666793607', pharmacy_id='N690708520', distance_rounded=5929.0),
Row(hospital_id='N666793610', pharmacy_id=None, distance_rounded=None),
Row(hospital_id='W194554955', pharmacy_id='N5723978577', distance_rounded=310.0)
]

0 comments on commit fc2dc30

Please sign in to comment.