From a3fd2bc0341b1d42471a54828e11a2e96aa795bd Mon Sep 17 00:00:00 2001 From: Shenyang Cai Date: Fri, 21 Feb 2025 22:56:31 +0000 Subject: [PATCH 1/2] feat: (Preview) Support aggregations over timedeltas --- bigframes/core/compile/aggregate_compiler.py | 14 ++++-- bigframes/core/rewrite/timedeltas.py | 31 +++++++++++-- bigframes/operations/aggregations.py | 25 ++++++++-- .../small/operations/test_timedeltas.py | 46 +++++++++++++++++++ 4 files changed, 104 insertions(+), 12 deletions(-) diff --git a/bigframes/core/compile/aggregate_compiler.py b/bigframes/core/compile/aggregate_compiler.py index 4ec0b270ed..87ab941551 100644 --- a/bigframes/core/compile/aggregate_compiler.py +++ b/bigframes/core/compile/aggregate_compiler.py @@ -231,7 +231,11 @@ def _( column: ibis_types.NumericColumn, window=None, ) -> ibis_types.NumericValue: - return _apply_window_if_present(column.quantile(op.q), window) + result = column.quantile(op.q) + if op.floor_result: + result = result.floor() # type:ignore + + return _apply_window_if_present(result, window) @compile_unary_agg.register @@ -242,7 +246,8 @@ def _( window=None, # order_by: typing.Sequence[ibis_types.Value] = [], ) -> ibis_types.NumericValue: - return _apply_window_if_present(column.mean(), window) + result = column.mean().floor() if op.floor_result else column.mean() + return _apply_window_if_present(result, window) @compile_unary_agg.register @@ -306,10 +311,11 @@ def _( @numeric_op def _( op: agg_ops.StdOp, - x: ibis_types.Column, + x: ibis_types.NumericColumn, window=None, ) -> ibis_types.Value: - return _apply_window_if_present(cast(ibis_types.NumericColumn, x).std(), window) + result = x.std().floor() if op.floor_result else x.std() + return _apply_window_if_present(result, window) @compile_unary_agg.register diff --git a/bigframes/core/rewrite/timedeltas.py b/bigframes/core/rewrite/timedeltas.py index bde1a4431c..79153baeca 100644 --- a/bigframes/core/rewrite/timedeltas.py +++ b/bigframes/core/rewrite/timedeltas.py @@ -70,6 +70,19 @@ def rewrite_timedelta_expressions(root: nodes.BigFrameNode) -> nodes.BigFrameNod root.skip_reproject_unsafe, ) + if isinstance(root, nodes.AggregateNode): + updated_aggregations = tuple( + (_rewrite_aggregation(agg, root.child.schema), col_id) + for agg, col_id in root.aggregations + ) + return nodes.AggregateNode( + root.child, + updated_aggregations, + root.by_column_ids, + root.order_by, + root.dropna, + ) + return root @@ -196,17 +209,29 @@ def _rewrite_aggregation( ) -> ex.Aggregation: if not isinstance(aggregation, ex.UnaryAggregation): return aggregation - if not isinstance(aggregation.op, aggs.DiffOp): - return aggregation if isinstance(aggregation.arg, ex.DerefOp): input_type = schema.get_type(aggregation.arg.id.sql) else: input_type = aggregation.arg.dtype - if dtypes.is_datetime_like(input_type): + if isinstance(aggregation.op, aggs.DiffOp) and dtypes.is_datetime_like(input_type): return ex.UnaryAggregation( aggs.TimeSeriesDiffOp(aggregation.op.periods), aggregation.arg ) + if isinstance(aggregation.op, aggs.StdOp) and input_type is dtypes.TIMEDELTA_DTYPE: + return ex.UnaryAggregation(aggs.StdOp(floor_result=True), aggregation.arg) + + if isinstance(aggregation.op, aggs.MeanOp) and input_type is dtypes.TIMEDELTA_DTYPE: + return ex.UnaryAggregation(aggs.MeanOp(floor_result=True), aggregation.arg) + + if ( + isinstance(aggregation.op, aggs.QuantileOp) + and input_type is dtypes.TIMEDELTA_DTYPE + ): + return ex.UnaryAggregation( + aggs.QuantileOp(q=aggregation.op.q, floor_result=True), aggregation.arg + ) + return aggregation diff --git a/bigframes/operations/aggregations.py b/bigframes/operations/aggregations.py index e9d102b42d..b5c28913cb 100644 --- a/bigframes/operations/aggregations.py +++ b/bigframes/operations/aggregations.py @@ -142,13 +142,16 @@ class SumOp(UnaryAggregateOp): name: ClassVar[str] = "sum" def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType: - if not dtypes.is_numeric(input_types[0]): - raise TypeError(f"Type {input_types[0]} is not numeric") - if pd.api.types.is_bool_dtype(input_types[0]): - return dtypes.INT_DTYPE - else: + if input_types[0] is dtypes.TIMEDELTA_DTYPE: + return dtypes.TIMEDELTA_DTYPE + + if dtypes.is_numeric(input_types[0]): + if pd.api.types.is_bool_dtype(input_types[0]): + return dtypes.INT_DTYPE return input_types[0] + raise TypeError(f"Type {input_types[0]} is not numeric or timedelta") + @dataclasses.dataclass(frozen=True) class MedianOp(UnaryAggregateOp): @@ -171,6 +174,7 @@ def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionT @dataclasses.dataclass(frozen=True) class QuantileOp(UnaryAggregateOp): q: float + floor_result: bool = False @property def name(self): @@ -181,6 +185,8 @@ def order_independent(self) -> bool: return True def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType: + if input_types[0] is dtypes.TIMEDELTA_DTYPE: + return dtypes.TIMEDELTA_DTYPE return signatures.UNARY_REAL_NUMERIC.output_type(input_types[0]) @@ -224,7 +230,11 @@ def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionT class MeanOp(UnaryAggregateOp): name: ClassVar[str] = "mean" + floor_result: bool = False + def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType: + if input_types[0] is dtypes.TIMEDELTA_DTYPE: + return dtypes.TIMEDELTA_DTYPE return signatures.UNARY_REAL_NUMERIC.output_type(input_types[0]) @@ -262,7 +272,12 @@ def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionT class StdOp(UnaryAggregateOp): name: ClassVar[str] = "std" + floor_result: bool = False + def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType: + if input_types[0] is dtypes.TIMEDELTA_DTYPE: + return dtypes.TIMEDELTA_DTYPE + return signatures.FixedOutputType( dtypes.is_numeric, dtypes.FLOAT_DTYPE, "numeric" ).output_type(input_types[0]) diff --git a/tests/system/small/operations/test_timedeltas.py b/tests/system/small/operations/test_timedeltas.py index 356000b3f6..723481b1d1 100644 --- a/tests/system/small/operations/test_timedeltas.py +++ b/tests/system/small/operations/test_timedeltas.py @@ -465,3 +465,49 @@ def test_timedelta_ordering(session): pandas.testing.assert_series_equal( actual_result, expected_result, check_index_type=False ) + + +def test_timedelta_cumsum(temporal_dfs): + bf_df, pd_df = temporal_dfs + + actual_result = bf_df["timedelta_col_1"].cumsum().to_pandas() + + expected_result = pd_df["timedelta_col_1"].cumsum() + _assert_series_equal(actual_result, expected_result) + + +@pytest.mark.parametrize( + "agg_func", + [ + pytest.param(lambda x: x.min(), id="min"), + pytest.param(lambda x: x.max(), id="max"), + pytest.param(lambda x: x.sum(), id="sum"), + pytest.param(lambda x: x.mean(), id="mean"), + pytest.param(lambda x: x.median(), id="median"), + pytest.param(lambda x: x.quantile(0.5), id="quantile"), + pytest.param(lambda x: x.std(), id="std"), + ], +) +def test_timedelta_agg__timedelta_result(temporal_dfs, agg_func): + bf_df, pd_df = temporal_dfs + + actual_result = agg_func(bf_df["timedelta_col_1"]) + + expected_result = agg_func(pd_df["timedelta_col_1"]).floor("us") + assert actual_result == expected_result + + +@pytest.mark.parametrize( + "agg_func", + [ + pytest.param(lambda x: x.count(), id="count"), + pytest.param(lambda x: x.nunique(), id="nunique"), + ], +) +def test_timedelta_agg__int_result(temporal_dfs, agg_func): + bf_df, pd_df = temporal_dfs + + actual_result = agg_func(bf_df["timedelta_col_1"]) + + expected_result = agg_func(pd_df["timedelta_col_1"]) + assert actual_result == expected_result From 0d7fdb1a4657d8347267611a0754bdedf9168221 Mon Sep 17 00:00:00 2001 From: Shenyang Cai Date: Mon, 24 Feb 2025 19:44:11 +0000 Subject: [PATCH 2/2] rename variable --- bigframes/core/compile/aggregate_compiler.py | 6 +++--- bigframes/core/rewrite/timedeltas.py | 11 ++++++++--- bigframes/operations/aggregations.py | 6 +++--- 3 files changed, 14 insertions(+), 9 deletions(-) diff --git a/bigframes/core/compile/aggregate_compiler.py b/bigframes/core/compile/aggregate_compiler.py index 87ab941551..a17b69815c 100644 --- a/bigframes/core/compile/aggregate_compiler.py +++ b/bigframes/core/compile/aggregate_compiler.py @@ -232,7 +232,7 @@ def _( window=None, ) -> ibis_types.NumericValue: result = column.quantile(op.q) - if op.floor_result: + if op.should_floor_result: result = result.floor() # type:ignore return _apply_window_if_present(result, window) @@ -246,7 +246,7 @@ def _( window=None, # order_by: typing.Sequence[ibis_types.Value] = [], ) -> ibis_types.NumericValue: - result = column.mean().floor() if op.floor_result else column.mean() + result = column.mean().floor() if op.should_floor_result else column.mean() return _apply_window_if_present(result, window) @@ -314,7 +314,7 @@ def _( x: ibis_types.NumericColumn, window=None, ) -> ibis_types.Value: - result = x.std().floor() if op.floor_result else x.std() + result = x.std().floor() if op.should_floor_result else x.std() return _apply_window_if_present(result, window) diff --git a/bigframes/core/rewrite/timedeltas.py b/bigframes/core/rewrite/timedeltas.py index 79153baeca..e21e0b6bf2 100644 --- a/bigframes/core/rewrite/timedeltas.py +++ b/bigframes/core/rewrite/timedeltas.py @@ -221,17 +221,22 @@ def _rewrite_aggregation( ) if isinstance(aggregation.op, aggs.StdOp) and input_type is dtypes.TIMEDELTA_DTYPE: - return ex.UnaryAggregation(aggs.StdOp(floor_result=True), aggregation.arg) + return ex.UnaryAggregation( + aggs.StdOp(should_floor_result=True), aggregation.arg + ) if isinstance(aggregation.op, aggs.MeanOp) and input_type is dtypes.TIMEDELTA_DTYPE: - return ex.UnaryAggregation(aggs.MeanOp(floor_result=True), aggregation.arg) + return ex.UnaryAggregation( + aggs.MeanOp(should_floor_result=True), aggregation.arg + ) if ( isinstance(aggregation.op, aggs.QuantileOp) and input_type is dtypes.TIMEDELTA_DTYPE ): return ex.UnaryAggregation( - aggs.QuantileOp(q=aggregation.op.q, floor_result=True), aggregation.arg + aggs.QuantileOp(q=aggregation.op.q, should_floor_result=True), + aggregation.arg, ) return aggregation diff --git a/bigframes/operations/aggregations.py b/bigframes/operations/aggregations.py index b5c28913cb..bf6016bb2e 100644 --- a/bigframes/operations/aggregations.py +++ b/bigframes/operations/aggregations.py @@ -174,7 +174,7 @@ def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionT @dataclasses.dataclass(frozen=True) class QuantileOp(UnaryAggregateOp): q: float - floor_result: bool = False + should_floor_result: bool = False @property def name(self): @@ -230,7 +230,7 @@ def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionT class MeanOp(UnaryAggregateOp): name: ClassVar[str] = "mean" - floor_result: bool = False + should_floor_result: bool = False def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType: if input_types[0] is dtypes.TIMEDELTA_DTYPE: @@ -272,7 +272,7 @@ def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionT class StdOp(UnaryAggregateOp): name: ClassVar[str] = "std" - floor_result: bool = False + should_floor_result: bool = False def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType: if input_types[0] is dtypes.TIMEDELTA_DTYPE: