From 0eb39b3ccd7dee5d78a7d2862ea97eaf14503a42 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Mon, 9 Sep 2024 10:41:11 +0100 Subject: [PATCH] feat(python): add tooltip by default to charts --- py-polars/polars/dataframe/plotting.py | 73 +++++++++++--------------- py-polars/polars/series/plotting.py | 13 +++-- 2 files changed, 37 insertions(+), 49 deletions(-) diff --git a/py-polars/polars/dataframe/plotting.py b/py-polars/polars/dataframe/plotting.py index ed118e504656..263ce6c15f3d 100644 --- a/py-polars/polars/dataframe/plotting.py +++ b/py-polars/polars/dataframe/plotting.py @@ -6,15 +6,12 @@ import sys import altair as alt - from altair.typing import ( - ChannelColor, - ChannelOrder, - ChannelSize, - ChannelTooltip, - ChannelX, - ChannelY, - EncodeKwds, - ) + from altair.typing import ChannelColor as Color + from altair.typing import ChannelOrder as Order + from altair.typing import ChannelSize as Size + from altair.typing import ChannelX as X + from altair.typing import ChannelY as Y + from altair.typing import EncodeKwds from polars import DataFrame @@ -29,12 +26,15 @@ Encodings: TypeAlias = Dict[ str, - Union[ - ChannelX, ChannelY, ChannelColor, ChannelOrder, ChannelSize, ChannelTooltip - ], + Union[X, Y, Color, Order, Size], ] +def _add_tooltip(chart: alt.Chart) -> alt.Chart: + chart.mark = {"type": chart.mark, "tooltip": True} + return chart + + class DataFramePlot: """DataFrame.plot namespace.""" @@ -45,10 +45,9 @@ def __init__(self, df: DataFrame) -> None: def bar( self, - x: ChannelX | None = None, - y: ChannelY | None = None, - color: ChannelColor | None = None, - tooltip: ChannelTooltip | None = None, + x: X | None = None, + y: Y | None = None, + color: Color | None = None, /, **kwargs: Unpack[EncodeKwds], ) -> alt.Chart: @@ -77,8 +76,6 @@ def bar( Column with y-coordinates of bars. color Column to color bars by. - tooltip - Columns to show values of when hovering over bars with pointer. **kwargs Additional keyword arguments passed to Altair. @@ -102,17 +99,16 @@ def bar( encodings["y"] = y if color is not None: encodings["color"] = color - if tooltip is not None: - encodings["tooltip"] = tooltip - return self._chart.mark_bar().encode(**encodings, **kwargs).interactive() + return _add_tooltip( + self._chart.mark_bar().encode(**encodings, **kwargs).interactive() + ) def line( self, - x: ChannelX | None = None, - y: ChannelY | None = None, - color: ChannelColor | None = None, - order: ChannelOrder | None = None, - tooltip: ChannelTooltip | None = None, + x: X | None = None, + y: Y | None = None, + color: Color | None = None, + order: Order | None = None, /, **kwargs: Unpack[EncodeKwds], ) -> alt.Chart: @@ -142,8 +138,6 @@ def line( Column to color lines by. order Column to use for order of data points in lines. - tooltip - Columns to show values of when hovering over lines with pointer. **kwargs Additional keyword arguments passed to Altair. @@ -168,17 +162,16 @@ def line( encodings["color"] = color if order is not None: encodings["order"] = order - if tooltip is not None: - encodings["tooltip"] = tooltip - return self._chart.mark_line().encode(**encodings, **kwargs).interactive() + return _add_tooltip( + self._chart.mark_line().encode(**encodings, **kwargs).interactive() + ) def point( self, - x: ChannelX | None = None, - y: ChannelY | None = None, - color: ChannelColor | None = None, - size: ChannelSize | None = None, - tooltip: ChannelTooltip | None = None, + x: X | None = None, + y: Y | None = None, + color: Color | None = None, + size: Size | None = None, /, **kwargs: Unpack[EncodeKwds], ) -> alt.Chart: @@ -209,8 +202,6 @@ def point( Column to color points by. size Column which determines points' sizes. - tooltip - Columns to show values of when hovering over points with pointer. **kwargs Additional keyword arguments passed to Altair. @@ -234,9 +225,7 @@ def point( encodings["color"] = color if size is not None: encodings["size"] = size - if tooltip is not None: - encodings["tooltip"] = tooltip - return ( + return _add_tooltip( self._chart.mark_point() .encode( **encodings, @@ -253,4 +242,4 @@ def __getattr__(self, attr: str) -> Callable[..., alt.Chart]: if method is None: msg = "Altair has no method 'mark_{attr}'" raise AttributeError(msg) - return lambda **kwargs: method().encode(**kwargs).interactive() + return lambda **kwargs: _add_tooltip(method().encode(**kwargs).interactive()) diff --git a/py-polars/polars/series/plotting.py b/py-polars/polars/series/plotting.py index cb5c6c93a1e1..74c70933c6b9 100644 --- a/py-polars/polars/series/plotting.py +++ b/py-polars/polars/series/plotting.py @@ -2,6 +2,7 @@ from typing import TYPE_CHECKING, Callable +from polars.dataframe.plotting import _add_tooltip from polars.dependencies import altair as alt if TYPE_CHECKING: @@ -62,7 +63,7 @@ def hist( if self._series_name == "count()": msg = "Cannot use `plot.hist` when Series name is `'count()'`" raise ValueError(msg) - return ( + return _add_tooltip( alt.Chart(self._df) .mark_bar() .encode(x=alt.X(f"{self._series_name}:Q", bin=True), y="count()", **kwargs) # type: ignore[misc] @@ -104,7 +105,7 @@ def kde( if self._series_name == "density": msg = "Cannot use `plot.kde` when Series name is `'density'`" raise ValueError(msg) - return ( + return _add_tooltip( alt.Chart(self._df) .transform_density(self._series_name, as_=[self._series_name, "density"]) .mark_area() @@ -147,7 +148,7 @@ def line( if self._series_name == "index": msg = "Cannot call `plot.line` when Series name is 'index'" raise ValueError(msg) - return ( + return _add_tooltip( alt.Chart(self._df.with_row_index()) .mark_line() .encode(x="index", y=self._series_name, **kwargs) # type: ignore[misc] @@ -165,8 +166,6 @@ def __getattr__(self, attr: str) -> Callable[..., alt.Chart]: if method is None: msg = "Altair has no method 'mark_{attr}'" raise AttributeError(msg) - return ( - lambda **kwargs: method() - .encode(x="index", y=self._series_name, **kwargs) - .interactive() + return lambda **kwargs: _add_tooltip( + method().encode(x="index", y=self._series_name, **kwargs).interactive() )