diff --git a/py-polars/polars/dataframe/frame.py b/py-polars/polars/dataframe/frame.py index c0cd3507c1a9..35fa975f5830 100644 --- a/py-polars/polars/dataframe/frame.py +++ b/py-polars/polars/dataframe/frame.py @@ -121,6 +121,7 @@ import deltalake import jax import numpy.typing as npt + import pyiceberg import torch from great_tables import GT from hvplot.plotting.core import hvPlotTabularPolars @@ -3834,6 +3835,36 @@ def unpack_table_name(name: str) -> tuple[str | None, str | None, str]: msg = f"unrecognised connection type {connection!r}" raise TypeError(msg) + def write_iceberg(self, target: str | Path) -> pyiceberg.table.Table: + """ + Write DataFrame to an Iceberg table. + + Parameters + ---------- + target : str | Path + The target path or identifier for the Iceberg table. + + Returns + ------- + pyiceberg.Table + The Iceberg table object that was written to. + """ + from pyiceberg.catalog.sql import SqlCatalog + + catalog = SqlCatalog( + "default", uri="sqlite:///:memory:", warehouse=f"file://{target}" + ) + catalog.create_namespace("default") + data = self.to_arrow() + schema = data.schema + table = catalog.create_table( + "default.table", + schema=schema, + ) + + table.overwrite(data) + return table + @overload def write_delta( self, diff --git a/py-polars/tests/unit/io/test_iceberg.py b/py-polars/tests/unit/io/test_iceberg.py index 0d80816552f5..7b47e28876ea 100644 --- a/py-polars/tests/unit/io/test_iceberg.py +++ b/py-polars/tests/unit/io/test_iceberg.py @@ -163,3 +163,24 @@ def test_parse_lteq(self) -> None: expr = _to_ast("(pa.compute.field('ts') <= '2023-08-08')") assert _convert_predicate(expr) == LessThanOrEqual("ts", "2023-08-08") + + +@pytest.mark.slow() +@pytest.mark.write_disk() +@pytest.mark.filterwarnings( + "ignore:No preferred file implementation for scheme*:UserWarning" +) +def test_write_iceberg(tmp_path: Path) -> None: + df = pl.DataFrame( + { + "foo": [1, 2, 3, 4, 5], + "bar": [6, 7, 8, 9, 10], + "ham": ["a", "b", "c", "d", "e"], + } + ) + iceberg_table = df.write_iceberg(tmp_path) + iceberg_path = iceberg_table.metadata_location + new_df = pl.scan_iceberg(iceberg_path).collect() + assert len(df) == len(new_df) + assert df.schema == new_df.schema + assert df.equals(new_df)