Skip to content

Commit

Permalink
simplify writing to iceberg table
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinjqliu committed Jul 2, 2024
1 parent c14870a commit 1fef815
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 24 deletions.
25 changes: 5 additions & 20 deletions py-polars/polars/dataframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -3837,44 +3837,29 @@ def unpack_table_name(name: str) -> tuple[str | None, str | None, str]:

def write_iceberg(
self,
target: str | Path,
table: pyiceberg.table.Table,
mode: Literal["append", "overwrite"],
) -> pyiceberg.table.Table:
) -> None:
"""
Write DataFrame to an Iceberg table.
Parameters
----------
target : str | Path
The target path or identifier for the Iceberg table.
table
The pyiceberg.table.Table object representing an Iceberg table.
mode : {'append', 'overwrite'}
How to handle existing data.
- If 'append', will add new data.
- If 'overwrite', will replace table with new data.
Returns
-------
pyiceberg.Table
The Iceberg table object that was written to.
"""
from pyiceberg.catalog.sql import SqlCatalog

catalog = SqlCatalog(
"default", uri="sqlite:///:memory:", warehouse=f"file://{target}"
)
catalog.create_namespace("default")
data = self.to_arrow()
schema = data.schema
table = catalog.create_table(
"default.table",
schema=schema,
)

if mode == "append":
table.append(data)
else:
table.overwrite(data)
return table

@overload
def write_delta(
Expand Down
24 changes: 20 additions & 4 deletions py-polars/tests/unit/io/test_iceberg.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,25 @@ def test_write_iceberg(tmp_path: Path) -> None:
"ham": ["a", "b", "c", "d", "e"],
}
)
iceberg_table = df.write_iceberg(tmp_path, mode="overwrite")
iceberg_path = iceberg_table.metadata_location
new_df = pl.scan_iceberg(iceberg_path).collect()
assert len(df) == len(new_df)

from pyiceberg.catalog.sql import SqlCatalog

catalog = SqlCatalog(
"default", uri="sqlite:///:memory:", warehouse=f"file://{tmp_path}"
)
catalog.create_namespace("default")
table = catalog.create_table(
"default.table",
schema=df.to_arrow().schema,
)

df.write_iceberg(table, mode="overwrite")
new_df = pl.scan_iceberg(table).collect()
assert df.schema == new_df.schema
assert len(df) == len(new_df)
assert df.equals(new_df)

df.write_iceberg(table, mode="append")
new_df = pl.scan_iceberg(table).collect()
assert df.schema == new_df.schema
assert 2 * len(df) == len(new_df)

0 comments on commit 1fef815

Please sign in to comment.