From 1fef8159165d65869bd9d20df9a133c540e9d79f Mon Sep 17 00:00:00 2001 From: Kevin Liu Date: Tue, 2 Jul 2024 11:57:36 -0700 Subject: [PATCH] simplify writing to iceberg table --- py-polars/polars/dataframe/frame.py | 25 +++++-------------------- py-polars/tests/unit/io/test_iceberg.py | 24 ++++++++++++++++++++---- 2 files changed, 25 insertions(+), 24 deletions(-) diff --git a/py-polars/polars/dataframe/frame.py b/py-polars/polars/dataframe/frame.py index cf753c80bfa2..7746df16d07f 100644 --- a/py-polars/polars/dataframe/frame.py +++ b/py-polars/polars/dataframe/frame.py @@ -3837,44 +3837,29 @@ def unpack_table_name(name: str) -> tuple[str | None, str | None, str]: def write_iceberg( self, - target: str | Path, + table: pyiceberg.table.Table, mode: Literal["append", "overwrite"], - ) -> pyiceberg.table.Table: + ) -> None: """ Write DataFrame to an Iceberg table. Parameters ---------- - target : str | Path - The target path or identifier for the Iceberg table. + table + The pyiceberg.table.Table object representing an Iceberg table. mode : {'append', 'overwrite'} How to handle existing data. - If 'append', will add new data. - If 'overwrite', will replace table with new data. - Returns - ------- - pyiceberg.Table - The Iceberg table object that was written to. """ - from pyiceberg.catalog.sql import SqlCatalog - - catalog = SqlCatalog( - "default", uri="sqlite:///:memory:", warehouse=f"file://{target}" - ) - catalog.create_namespace("default") data = self.to_arrow() - schema = data.schema - table = catalog.create_table( - "default.table", - schema=schema, - ) + if mode == "append": table.append(data) else: table.overwrite(data) - return table @overload def write_delta( diff --git a/py-polars/tests/unit/io/test_iceberg.py b/py-polars/tests/unit/io/test_iceberg.py index 36313ae1d8d3..f82e463b1ea5 100644 --- a/py-polars/tests/unit/io/test_iceberg.py +++ b/py-polars/tests/unit/io/test_iceberg.py @@ -178,9 +178,25 @@ def test_write_iceberg(tmp_path: Path) -> None: "ham": ["a", "b", "c", "d", "e"], } ) - iceberg_table = df.write_iceberg(tmp_path, mode="overwrite") - iceberg_path = iceberg_table.metadata_location - new_df = pl.scan_iceberg(iceberg_path).collect() - assert len(df) == len(new_df) + + from pyiceberg.catalog.sql import SqlCatalog + + catalog = SqlCatalog( + "default", uri="sqlite:///:memory:", warehouse=f"file://{tmp_path}" + ) + catalog.create_namespace("default") + table = catalog.create_table( + "default.table", + schema=df.to_arrow().schema, + ) + + df.write_iceberg(table, mode="overwrite") + new_df = pl.scan_iceberg(table).collect() assert df.schema == new_df.schema + assert len(df) == len(new_df) assert df.equals(new_df) + + df.write_iceberg(table, mode="append") + new_df = pl.scan_iceberg(table).collect() + assert df.schema == new_df.schema + assert 2 * len(df) == len(new_df)