diff --git a/bigframes/core/compile/api.py b/bigframes/core/compile/api.py index 4e833411ae..86c8fca25a 100644 --- a/bigframes/core/compile/api.py +++ b/bigframes/core/compile/api.py @@ -18,6 +18,7 @@ import google.cloud.bigquery as bigquery import bigframes.core.compile.compiler as compiler +import bigframes.core.rewrite as rewrites if TYPE_CHECKING: import bigframes.core.nodes @@ -42,6 +43,7 @@ def compile_unordered( col_id_overrides: Mapping[str, str] = {}, ) -> str: """Compile node into sql where rows are unsorted, and no ordering information is preserved.""" + # TODO: Enable limit pullup, but only if not being used to write to clustered table. return self._compiler.compile_unordered_ir(node).to_sql( col_id_overrides=col_id_overrides ) @@ -53,8 +55,10 @@ def compile_ordered( col_id_overrides: Mapping[str, str] = {}, ) -> str: """Compile node into sql where rows are sorted with ORDER BY.""" - return self._compiler.compile_ordered_ir(node).to_sql( - col_id_overrides=col_id_overrides, ordered=True + # If we are ordering the query anyways, compiling the slice as a limit is probably a good idea. + new_node, limit = rewrites.pullup_limit_from_slice(node) + return self._compiler.compile_ordered_ir(new_node).to_sql( + col_id_overrides=col_id_overrides, ordered=True, limit=limit ) def compile_raw( diff --git a/bigframes/core/compile/compiled.py b/bigframes/core/compile/compiled.py index f4afdaa97c..d02a2c444c 100644 --- a/bigframes/core/compile/compiled.py +++ b/bigframes/core/compile/compiled.py @@ -943,8 +943,9 @@ def to_sql( self, col_id_overrides: typing.Mapping[str, str] = {}, ordered: bool = False, + limit: Optional[int] = None, ) -> str: - if ordered: + if ordered or limit: # Need to bake ordering expressions into the selected column in order for our ordering clause builder to work. baked_ir = self._bake_ordering() sql = ibis_bigquery.Backend().compile( @@ -969,7 +970,11 @@ def to_sql( order_by_clause = bigframes.core.sql.ordering_clause( baked_ir._ordering.all_ordering_columns ) - sql += f"{order_by_clause}\n" + sql += f"\n{order_by_clause}" + if limit is not None: + if not isinstance(limit, int): + raise TypeError(f"Limit param: {limit} must be an int.") + sql += f"\nLIMIT {limit}" else: sql = ibis_bigquery.Backend().compile( self._to_ibis_expr( diff --git a/bigframes/core/nodes.py b/bigframes/core/nodes.py index 4aab9dc631..2e23f529e2 100644 --- a/bigframes/core/nodes.py +++ b/bigframes/core/nodes.py @@ -20,7 +20,7 @@ import functools import itertools import typing -from typing import Callable, Iterable, Optional, Sequence, Tuple +from typing import Callable, cast, Iterable, Optional, Sequence, Tuple import google.cloud.bigquery as bq @@ -30,6 +30,7 @@ import bigframes.core.identifiers as bfet_ids from bigframes.core.ordering import OrderingExpression import bigframes.core.schema as schemata +import bigframes.core.slices as slices import bigframes.core.window_spec as window import bigframes.dtypes import bigframes.operations.aggregations as agg_ops @@ -82,6 +83,11 @@ def child_nodes(self) -> typing.Sequence[BigFrameNode]: """Direct children of this node""" return tuple([]) + @property + @abc.abstractmethod + def row_count(self) -> typing.Optional[int]: + return None + @functools.cached_property def session(self): sessions = [] @@ -304,6 +310,26 @@ def variables_introduced(self) -> int: def relation_ops_created(self) -> int: return 2 + @property + def is_limit(self) -> bool: + """Returns whether this is equivalent to a ORDER BY ... LIMIT N.""" + # TODO: Handle tail case. + return ( + (not self.start) + and (self.step == 1) + and (self.stop is not None) + and (self.stop > 0) + ) + + @property + def row_count(self) -> typing.Optional[int]: + child_length = self.child.row_count + if child_length is None: + return None + return slices.slice_output_rows( + (self.start, self.stop, self.step), child_length + ) + @dataclass(frozen=True, eq=False) class JoinNode(BigFrameNode): @@ -351,6 +377,15 @@ def variables_introduced(self) -> int: def joins(self) -> bool: return True + @property + def row_count(self) -> Optional[int]: + if self.type == "cross": + if self.left_child.row_count is None or self.right_child.row_count is None: + return None + return self.left_child.row_count * self.right_child.row_count + + return None + def transform_children( self, t: Callable[[BigFrameNode], BigFrameNode] ) -> BigFrameNode: @@ -412,6 +447,16 @@ def variables_introduced(self) -> int: """Defines the number of variables generated by the current node. Used to estimate query planning complexity.""" return len(self.schema.items) + OVERHEAD_VARIABLES + @property + def row_count(self) -> Optional[int]: + sub_counts = [node.row_count for node in self.child_nodes] + total = 0 + for count in sub_counts: + if count is None: + return None + total += count + return total + def transform_children( self, t: Callable[[BigFrameNode], BigFrameNode] ) -> BigFrameNode: @@ -460,6 +505,10 @@ def variables_introduced(self) -> int: """Defines the number of variables generated by the current node. Used to estimate query planning complexity.""" return len(self.schema.items) + OVERHEAD_VARIABLES + @property + def row_count(self) -> Optional[int]: + return None + def transform_children( self, t: Callable[[BigFrameNode], BigFrameNode] ) -> BigFrameNode: @@ -484,7 +533,11 @@ def roots(self) -> typing.Set[BigFrameNode]: return {self} @property - def supports_fast_head(self) -> bool: + def fast_offsets(self) -> bool: + return False + + @property + def fast_ordered_limit(self) -> bool: return False def transform_children( @@ -492,11 +545,6 @@ def transform_children( ) -> BigFrameNode: return self - @property - def row_count(self) -> typing.Optional[int]: - """How many rows are in the data source. None means unknown.""" - return None - class ScanItem(typing.NamedTuple): id: bfet_ids.ColumnId @@ -528,7 +576,11 @@ def variables_introduced(self) -> int: return len(self.scan_list.items) + 1 @property - def supports_fast_head(self) -> bool: + def fast_offsets(self) -> bool: + return True + + @property + def fast_ordered_limit(self) -> bool: return True @property @@ -635,12 +687,27 @@ def relation_ops_created(self) -> int: return 3 @property - def supports_fast_head(self) -> bool: - # Fast head is only supported when row offsets are available. - # In the future, ORDER BY+LIMIT optimizations may allow fast head when - # clustered and/or partitioned on ordering key + def fast_offsets(self) -> bool: + # Fast head is only supported when row offsets are available or data is clustered over ordering key. return (self.source.ordering is not None) and self.source.ordering.is_sequential + @property + def fast_ordered_limit(self) -> bool: + if self.source.ordering is None: + return False + order_cols = self.source.ordering.all_ordering_columns + # monotonicity would probably be fine + if not all(col.scalar_expression.is_identity for col in order_cols): + return False + order_col_ids = tuple( + cast(ex.DerefOp, col.scalar_expression).id.name for col in order_cols + ) + cluster_col_ids = self.source.table.cluster_cols + if cluster_col_ids is None: + return False + + return order_col_ids == cluster_col_ids[: len(order_col_ids)] + @property def order_ambiguous(self) -> bool: return ( @@ -706,6 +773,10 @@ def relation_ops_created(self) -> int: def variables_introduced(self) -> int: return 1 + @property + def row_count(self) -> Optional[int]: + return self.child.row_count + def prune(self, used_cols: COLUMN_SET) -> BigFrameNode: if self.col_id not in used_cols: return self.child.prune(used_cols) @@ -726,6 +797,10 @@ def row_preserving(self) -> bool: def variables_introduced(self) -> int: return 1 + @property + def row_count(self) -> Optional[int]: + return None + def prune(self, used_cols: COLUMN_SET) -> BigFrameNode: consumed_ids = used_cols.union(self.predicate.column_references) pruned_child = self.child.prune(consumed_ids) @@ -749,6 +824,10 @@ def relation_ops_created(self) -> int: def explicitly_ordered(self) -> bool: return True + @property + def row_count(self) -> Optional[int]: + return self.child.row_count + def prune(self, used_cols: COLUMN_SET) -> BigFrameNode: ordering_cols = itertools.chain.from_iterable( map(lambda x: x.referenced_columns, self.by) @@ -772,6 +851,10 @@ def relation_ops_created(self) -> int: # Doesnt directly create any relational operations return 0 + @property + def row_count(self) -> Optional[int]: + return self.child.row_count + @dataclass(frozen=True, eq=False) class SelectionNode(UnaryNode): @@ -798,6 +881,10 @@ def variables_introduced(self) -> int: def defines_namespace(self) -> bool: return True + @property + def row_count(self) -> Optional[int]: + return self.child.row_count + def prune(self, used_cols: COLUMN_SET) -> BigFrameNode: pruned_selections = tuple( select for select in self.input_output_pairs if select[1] in used_cols @@ -842,6 +929,10 @@ def variables_introduced(self) -> int: new_vars = sum(1 for i in self.assignments if not i[0].is_identity) return new_vars + @property + def row_count(self) -> Optional[int]: + return self.child.row_count + def prune(self, used_cols: COLUMN_SET) -> BigFrameNode: pruned_assignments = tuple(i for i in self.assignments if i[1] in used_cols) if len(pruned_assignments) == 0: @@ -877,6 +968,10 @@ def variables_introduced(self) -> int: def defines_namespace(self) -> bool: return True + @property + def row_count(self) -> Optional[int]: + return 1 + @dataclass(frozen=True, eq=False) class AggregateNode(UnaryNode): @@ -926,6 +1021,12 @@ def explicitly_ordered(self) -> bool: def defines_namespace(self) -> bool: return True + @property + def row_count(self) -> Optional[int]: + if not self.by_column_ids: + return 1 + return None + def prune(self, used_cols: COLUMN_SET) -> BigFrameNode: by_ids = (ref.id for ref in self.by_column_ids) pruned_aggs = tuple(agg for agg in self.aggregations if agg[1] in used_cols) @@ -963,6 +1064,10 @@ def relation_ops_created(self) -> int: # Assume that if not reprojecting, that there is a sequence of window operations sharing the same window return 0 if self.skip_reproject_unsafe else 4 + @property + def row_count(self) -> Optional[int]: + return self.child.row_count + @functools.cached_property def added_field(self) -> Field: input_type = self.child.get_type(self.column_name.id) @@ -994,6 +1099,10 @@ def row_preserving(self) -> bool: def variables_introduced(self) -> int: return 1 + @property + def row_count(self) -> Optional[int]: + return None + # TODO: Explode should create a new column instead of overriding the existing one @dataclass(frozen=True, eq=False) @@ -1030,6 +1139,10 @@ def variables_introduced(self) -> int: def defines_namespace(self) -> bool: return True + @property + def row_count(self) -> Optional[int]: + return None + def prune(self, used_cols: COLUMN_SET) -> BigFrameNode: # Cannot prune explode op return self.transform_children( diff --git a/bigframes/core/rewrite.py b/bigframes/core/rewrite.py index d4e530fff3..9c0eb81450 100644 --- a/bigframes/core/rewrite.py +++ b/bigframes/core/rewrite.py @@ -24,7 +24,7 @@ import bigframes.core.join_def as join_defs import bigframes.core.nodes as nodes import bigframes.core.ordering as order -import bigframes.core.tree_properties as traversals +import bigframes.core.slices as slices import bigframes.operations as ops Selection = Tuple[Tuple[scalar_exprs.Expression, ids.ColumnId], ...] @@ -385,46 +385,71 @@ def common_selection_root( return None +def pullup_limit_from_slice( + root: nodes.BigFrameNode, +) -> Tuple[nodes.BigFrameNode, Optional[int]]: + """ + This is a BQ-sql specific optimization that can be helpful as ORDER BY LIMIT is more efficient than WHERE + ROW_NUMBER(). + + Only use this if writing to an unclustered table. Clustering is not compatible with ORDER BY. + """ + if isinstance(root, nodes.SliceNode): + # head case + # More cases could be handled, but this is by far the most important, as it is used by df.head(), df[:N] + if root.is_limit: + assert not root.start + assert root.step == 1 + assert root.stop is not None + limit = root.stop + new_root, prior_limit = pullup_limit_from_slice(root.child) + if (prior_limit is not None) and (prior_limit < limit): + limit = prior_limit + return new_root, limit + elif ( + isinstance(root, (nodes.SelectionNode, nodes.ProjectionNode)) + and root.row_preserving + ): + new_child, prior_limit = pullup_limit_from_slice(root.child) + if prior_limit is not None: + return root.transform_children(lambda _: new_child), prior_limit + # Most ops don't support pulling up slice, like filter, agg, join, etc. + return root, None + + def replace_slice_ops(root: nodes.BigFrameNode) -> nodes.BigFrameNode: # TODO: we want to pull up some slices into limit op if near root. if isinstance(root, nodes.SliceNode): root = root.transform_children(replace_slice_ops) - return convert_slice_to_filter(cast(nodes.SliceNode, root)) + return rewrite_slice(cast(nodes.SliceNode, root)) else: return root.transform_children(replace_slice_ops) -def get_simplified_slice(node: nodes.SliceNode): - """Attempts to simplify the slice.""" - row_count = traversals.row_count(node) - start, stop, step = node.start, node.stop, node.step +def rewrite_slice(node: nodes.SliceNode): + slice_def = (node.start, node.stop, node.step) + + # no-op (eg. df[::1]) + if slices.is_noop(slice_def, node.child.row_count): + return node.child - if start is None: - start = 0 if step > 0 else -1 - if row_count and step > 0: - if start and start < 0: - start = row_count + start - if stop and stop < 0: - stop = row_count + stop - return start, stop, step + # No filtering, just reverse (eg. df[::-1]) + if slices.is_reverse(slice_def, node.child.row_count): + return nodes.ReversedNode(node.child) + if node.child.row_count: + slice_def = slices.to_forward_offsets(slice_def, node.child.row_count) + return slice_as_filter(node.child, *slice_def) -def convert_slice_to_filter(node: nodes.SliceNode): - start, stop, step = get_simplified_slice(node) - # no-op (eg. df[::1]) +def slice_as_filter( + node: nodes.BigFrameNode, start: Optional[int], stop: Optional[int], step: int +) -> nodes.BigFrameNode: if ( - ((start == 0) or (start is None)) - and ((stop is None) or (stop == -1)) - and (step == 1) + ((start is None) or (start >= 0)) + and ((stop is None) or (stop >= 0)) + and (step > 0) ): - return node.child - # No filtering, just reverse (eg. df[::-1]) - if ((start is None) or (start == -1)) and (not stop) and (step == -1): - return nodes.ReversedNode(node.child) - # if start/stop/step are all non-negative, and do a simple predicate on forward offsets - if ((start is None) or (start >= 0)) and ((stop is None) or (stop >= 0)): - node_w_offset = add_offsets(node.child) + node_w_offset = add_offsets(node) predicate = convert_simple_slice( scalar_exprs.DerefOp(node_w_offset.col_id), start or 0, stop, step ) @@ -433,17 +458,18 @@ def convert_slice_to_filter(node: nodes.SliceNode): # fallback cases, generate both forward and backward offsets if step < 0: - forward_offsets = add_offsets(node.child) + forward_offsets = add_offsets(node) reversed_offsets = add_offsets(nodes.ReversedNode(forward_offsets)) dual_indexed = reversed_offsets else: - reversed_offsets = add_offsets(nodes.ReversedNode(node.child)) + reversed_offsets = add_offsets(nodes.ReversedNode(node)) forward_offsets = add_offsets(nodes.ReversedNode(reversed_offsets)) dual_indexed = forward_offsets + default_start = 0 if step >= 0 else -1 predicate = convert_complex_slice( scalar_exprs.DerefOp(forward_offsets.col_id), scalar_exprs.DerefOp(reversed_offsets.col_id), - start, + start if (start is not None) else default_start, stop, step, ) @@ -505,7 +531,7 @@ def convert_complex_slice( if start or ((start is not None) and step < 0): if start > 0 and step > 0: start_cond = ops.ge_op.as_expr(forward_offsets, scalar_exprs.const(start)) - elif start > 0 and step < 0: + elif start >= 0 and step < 0: start_cond = ops.le_op.as_expr(forward_offsets, scalar_exprs.const(start)) elif start < 0 and step > 0: start_cond = ops.le_op.as_expr( diff --git a/bigframes/core/slices.py b/bigframes/core/slices.py new file mode 100644 index 0000000000..97f90d3349 --- /dev/null +++ b/bigframes/core/slices.py @@ -0,0 +1,106 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Optional + + +def to_forward_offsets( + slice: tuple[Optional[int], Optional[int], Optional[int]], input_rows: int +) -> tuple[int, Optional[int], int]: + """Redefine the slice to use forward offsets for start and stop indices.""" + step = slice[2] or 1 + stop = slice[1] + start = slice[0] + + # normalize start to positive number + if start is None: + start = 0 if (step > 0) else (input_rows - 1) + elif start < 0: + start = max(0, input_rows + start) + else: + start = min(start, input_rows) + + if stop is None: + stop = None + elif stop < 0: + stop = max(0, input_rows + stop) + else: + stop = min(stop, input_rows) + + return (start, stop, step) + + +def remove_unused_parts( + slice: tuple[Optional[int], Optional[int], Optional[int]], input_rows: int +) -> tuple[Optional[int], Optional[int], Optional[int]]: + """Makes a slice component null if it doesn't impact slice semantics.""" + start, stop, step = slice + is_forward = (step is None) or (step > 0) + if start is not None: + if is_forward and ((start == 0) or (start <= -input_rows)): + start = None + elif (not is_forward) and ((start == -1) or (start >= (input_rows - 1))): + start = None + if stop is not None: + if is_forward and (stop >= input_rows): + stop = None + elif (not is_forward) and (stop <= (-input_rows - 1)): + stop = None + if step == 1: + step = None + return start, stop, step + + +def slice_output_rows( + slice: tuple[Optional[int], Optional[int], Optional[int]], input_size: int +) -> int: + """Given input_size, returns the number of rows returned after the slice operation.""" + slice = to_forward_offsets(slice, input_size) + start, stop, step = slice + + if step > 0: + if stop is None: + stop = input_size + length = max(0, (stop - start + step - 1) // step) + else: + if stop is None: + stop = -1 + length = max(0, (start - stop - step - 1) // -step) + return length + + +def is_noop( + slice_def: tuple[Optional[int], Optional[int], Optional[int]], + input_size: Optional[int], +) -> bool: + """Returns true iff the slice op is a no-op returning the input array.""" + if input_size: + start, stop, step = remove_unused_parts(slice_def, input_size) + else: + start, stop, step = slice_def + return (not start) and (stop is None) and ((step is None) or (step == 1)) + + +def is_reverse( + slice_def: tuple[Optional[int], Optional[int], Optional[int]], + input_size: Optional[int], +) -> bool: + """Returns true iff the slice op is a pure reverse op, equivalent to df[::-1]""" + if input_size: + start, stop, step = remove_unused_parts(slice_def, input_size) + else: + start, stop, step = slice_def + return (start is None) and (stop is None) and (step == -1) diff --git a/bigframes/core/tree_properties.py b/bigframes/core/tree_properties.py index 3e61b830a9..0a4339ee06 100644 --- a/bigframes/core/tree_properties.py +++ b/bigframes/core/tree_properties.py @@ -42,32 +42,35 @@ def can_fast_peek(node: nodes.BigFrameNode) -> bool: def can_fast_head(node: nodes.BigFrameNode) -> bool: """Can get head fast if can push head operator down to leafs and operators preserve rows.""" + # To do fast head operation: + # (1) the underlying data must be arranged/indexed according to the logical ordering + # (2) transformations must support pushing down LIMIT or a filter on row numbers + return has_fast_offset_address(node) or has_fast_offset_address(node) + + +def has_fast_orderby_limit(node: nodes.BigFrameNode) -> bool: + """True iff ORDER BY LIMIT can be performed without a large full table scan.""" + # TODO: In theory compatible with some Slice nodes, potentially by adding OFFSET + if isinstance(node, nodes.LeafNode): + return node.fast_ordered_limit + if isinstance(node, (nodes.ProjectionNode, nodes.SelectionNode)): + return has_fast_orderby_limit(node.child) + return False + + +def has_fast_offset_address(node: nodes.BigFrameNode) -> bool: + """True iff specific offsets can be scanned without a large full table scan.""" + # TODO: In theory can push offset lookups through slice operators by translating indices if isinstance(node, nodes.LeafNode): - return node.supports_fast_head + return node.fast_offsets if isinstance(node, (nodes.ProjectionNode, nodes.SelectionNode)): - return can_fast_head(node.child) + return has_fast_offset_address(node.child) return False def row_count(node: nodes.BigFrameNode) -> Optional[int]: """Determine row count from local metadata, return None if unknown.""" - if isinstance(node, nodes.LeafNode): - return node.row_count - if isinstance(node, nodes.AggregateNode): - if len(node.by_column_ids) == 0: - return 1 - return None - if isinstance(node, nodes.ConcatNode): - sub_counts = list(map(row_count, node.child_nodes)) - total = 0 - for count in sub_counts: - if count is None: - return None - total += count - return total - if isinstance(node, nodes.UnaryNode) and node.row_preserving: - return row_count(node.child) - return None + return node.row_count # Replace modified_cost(node) = cost(apply_cache(node)) diff --git a/bigframes/session/executor.py b/bigframes/session/executor.py index 1b58d1a993..170f0ac086 100644 --- a/bigframes/session/executor.py +++ b/bigframes/session/executor.py @@ -40,7 +40,6 @@ import bigframes.core import bigframes.core.compile -import bigframes.core.expression as ex import bigframes.core.guid import bigframes.core.identifiers import bigframes.core.nodes as nodes @@ -49,7 +48,6 @@ import bigframes.core.tree_properties as tree_properties import bigframes.features import bigframes.formatting_helpers as formatting_helpers -import bigframes.operations as ops import bigframes.session._io.bigquery as bq_io import bigframes.session.metrics import bigframes.session.planner @@ -127,7 +125,7 @@ def to_sql( col_id_overrides = dict(col_id_overrides) col_id_overrides[internal_offset_col] = offset_column node = ( - self._sub_cache_subtrees(array_value.node) + self.replace_cached_subtrees(array_value.node) if enable_cache else array_value.node ) @@ -207,6 +205,9 @@ def export_gbq( """ Export the ArrayValue to an existing BigQuery table. """ + if bigframes.options.compute.enable_multi_query_execution: + self._simplify_with_caching(array_value) + dispositions = { "fail": bigquery.WriteDisposition.WRITE_EMPTY, "replace": bigquery.WriteDisposition.WRITE_TRUNCATE, @@ -278,7 +279,7 @@ def peek( """ A 'peek' efficiently accesses a small number of rows in the dataframe. """ - plan = self._sub_cache_subtrees(array_value.node) + plan = self.replace_cached_subtrees(array_value.node) if not tree_properties.can_fast_peek(plan): warnings.warn("Peeking this value cannot be done efficiently.") @@ -313,7 +314,7 @@ def head( # No user-provided ordering, so just get any N rows, its faster! return self.peek(array_value, n_rows) - plan = self._sub_cache_subtrees(array_value.node) + plan = self.replace_cached_subtrees(array_value.node) if not tree_properties.can_fast_head(plan): # If can't get head fast, we are going to need to execute the whole query # Will want to do this in a way such that the result is reusable, but the first @@ -321,7 +322,7 @@ def head( # This currently requires clustering on offsets. self._cache_with_offsets(array_value) # Get a new optimized plan after caching - plan = self._sub_cache_subtrees(array_value.node) + plan = self.replace_cached_subtrees(array_value.node) assert tree_properties.can_fast_head(plan) head_plan = generate_head_plan(plan, n_rows) @@ -346,7 +347,7 @@ def get_row_count(self, array_value: bigframes.core.ArrayValue) -> int: if count is not None: return count else: - row_count_plan = self._sub_cache_subtrees( + row_count_plan = self.replace_cached_subtrees( generate_row_count_plan(array_value.node) ) sql = self.compiler.compile_unordered(row_count_plan) @@ -358,7 +359,7 @@ def _local_get_row_count( ) -> Optional[int]: # optimized plan has cache materializations which will have row count metadata # that is more likely to be usable than original leaf nodes. - plan = self._sub_cache_subtrees(array_value.node) + plan = self.replace_cached_subtrees(array_value.node) return tree_properties.row_count(plan) # Helpers @@ -423,13 +424,7 @@ def _wait_on_job( self.metrics.count_job_stats(query_job) return results_iterator - def _sub_cache_subtrees(self, node: nodes.BigFrameNode) -> nodes.BigFrameNode: - """ - Takes the original expression tree and applies optimizations to accelerate execution. - - At present, the only optimization is to replace subtress with cached previous materializations. - """ - # Apply any rewrites *after* applying cache, as cache is sensitive to exact tree structure + def replace_cached_subtrees(self, node: nodes.BigFrameNode) -> nodes.BigFrameNode: return tree_properties.replace_nodes(node, (dict(self._cached_executions))) def _is_trivially_executable(self, array_value: bigframes.core.ArrayValue): @@ -440,7 +435,7 @@ def _is_trivially_executable(self, array_value: bigframes.core.ArrayValue): # Once rewriting is available, will want to rewrite before # evaluating execution cost. return tree_properties.is_trivially_executable( - self._sub_cache_subtrees(array_value.node) + self.replace_cached_subtrees(array_value.node) ) def _cache_with_cluster_cols( @@ -449,7 +444,7 @@ def _cache_with_cluster_cols( """Executes the query and uses the resulting table to rewrite future executions.""" sql, schema, ordering_info = self.compiler.compile_raw( - self._sub_cache_subtrees(array_value.node) + self.replace_cached_subtrees(array_value.node) ) tmp_table = self._sql_as_cached_temp_table( sql, @@ -466,7 +461,9 @@ def _cache_with_offsets(self, array_value: bigframes.core.ArrayValue): """Executes the query and uses the resulting table to rewrite future executions.""" offset_column = bigframes.core.guid.generate_guid("bigframes_offsets") w_offsets, offset_column = array_value.promote_offsets() - sql = self.compiler.compile_unordered(self._sub_cache_subtrees(w_offsets.node)) + sql = self.compiler.compile_unordered( + self.replace_cached_subtrees(w_offsets.node) + ) tmp_table = self._sql_as_cached_temp_table( sql, @@ -502,7 +499,7 @@ def _simplify_with_caching(self, array_value: bigframes.core.ArrayValue): """Attempts to handle the complexity by caching duplicated subtrees and breaking the query into pieces.""" # Apply existing caching first for _ in range(MAX_SUBTREE_FACTORINGS): - node_with_cache = self._sub_cache_subtrees(array_value.node) + node_with_cache = self.replace_cached_subtrees(array_value.node) if node_with_cache.planning_complexity < QUERY_COMPLEXITY_LIMIT: return @@ -559,7 +556,7 @@ def _validate_result_schema( ): actual_schema = tuple(bq_schema) ibis_schema = bigframes.core.compile.test_only_ibis_inferred_schema( - self._sub_cache_subtrees(array_value.node) + self.replace_cached_subtrees(array_value.node) ) internal_schema = array_value.schema if not bigframes.features.PANDAS_VERSIONS.is_arrow_list_dtype_usable: @@ -575,20 +572,7 @@ def _validate_result_schema( def generate_head_plan(node: nodes.BigFrameNode, n: int): - offsets_id = bigframes.core.guid.generate_guid("offsets_") - plan_w_offsets = nodes.PromoteOffsetsNode( - node, bigframes.core.identifiers.ColumnId(offsets_id) - ) - predicate = ops.lt_op.as_expr(ex.deref(offsets_id), ex.const(n)) - plan_w_head = nodes.FilterNode(plan_w_offsets, predicate) - # Finally, drop the offsets column - return nodes.SelectionNode( - plan_w_head, - tuple( - (ex.deref(i), bigframes.core.identifiers.ColumnId(i)) - for i in node.schema.names - ), - ) + return nodes.SliceNode(node, start=None, stop=n) def generate_row_count_plan(node: nodes.BigFrameNode): diff --git a/tests/unit/core/test_rewrite.py b/tests/unit/core/test_rewrite.py new file mode 100644 index 0000000000..0965238fcd --- /dev/null +++ b/tests/unit/core/test_rewrite.py @@ -0,0 +1,57 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest.mock as mock + +import google.cloud.bigquery + +import bigframes.core as core +import bigframes.core.nodes as nodes +import bigframes.core.rewrite as rewrites +import bigframes.core.schema + +TABLE_REF = google.cloud.bigquery.TableReference.from_string("project.dataset.table") +SCHEMA = ( + google.cloud.bigquery.SchemaField("col_a", "INTEGER"), + google.cloud.bigquery.SchemaField("col_b", "INTEGER"), +) +TABLE = google.cloud.bigquery.Table( + table_ref=TABLE_REF, + schema=SCHEMA, +) +FAKE_SESSION = mock.create_autospec(bigframes.Session, instance=True) +type(FAKE_SESSION)._strictly_ordered = mock.PropertyMock(return_value=True) +LEAF = core.ArrayValue.from_table( + session=FAKE_SESSION, + table=TABLE, + schema=bigframes.core.schema.ArraySchema.from_bq_table(TABLE), +).node + + +def test_rewrite_noop_slice(): + slice = nodes.SliceNode(LEAF, None, None) + result = rewrites.rewrite_slice(slice) + assert result == LEAF + + +def test_rewrite_reverse_slice(): + slice = nodes.SliceNode(LEAF, None, None, -1) + result = rewrites.rewrite_slice(slice) + assert result == nodes.ReversedNode(LEAF) + + +def test_rewrite_filter_slice(): + slice = nodes.SliceNode(LEAF, None, 2) + result = rewrites.rewrite_slice(slice) + assert list(result.fields) == list(LEAF.fields) + assert isinstance(result.child, nodes.FilterNode) diff --git a/tests/unit/core/test_slices.py b/tests/unit/core/test_slices.py new file mode 100644 index 0000000000..745db45eab --- /dev/null +++ b/tests/unit/core/test_slices.py @@ -0,0 +1,61 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +import bigframes.core.slices as slices + + +@pytest.mark.parametrize( + ["slice", "input_rows", "expected"], + [ + ((1, 2, 3), 3, 1), + ((-3, 400, None), 401, 2), + ((5, 505, None), 300, 295), + ((1, 10, 4), 10, 3), + ((1, 9, 4), 10, 2), + ((-1, -10, -4), 10, 3), + ((-1, -10, 4), 10, 0), + ((99, 100, 1), 9, 0), + ], +) +def test_slice_row_count(slice, input_rows, expected): + assert expected == slices.slice_output_rows(slice, input_rows) + + +@pytest.mark.parametrize( + ["slice", "input_rows", "expected"], + [ + ((1, 2, 3), 3, (1, 2, 3)), + ((-3, 400, None), 401, (-3, 400, None)), + ((5, 505, None), 300, (5, None, None)), + ((99, 100, 1), 9, (99, None, None)), + ], +) +def test_remove_unused_parts(slice, input_rows, expected): + assert expected == slices.remove_unused_parts(slice, input_rows) + + +@pytest.mark.parametrize( + ["slice", "input_rows", "expected"], + [ + ((1, 2, 3), 3, (1, 2, 3)), + ((-3, 400, None), 401, (398, 400, 1)), + ((5, 505, None), 300, (5, 300, 1)), + ((None, None, None), 300, (0, None, 1)), + ((None, None, -1), 300, (299, None, -1)), + ], +) +def test_to_forward_offsets(slice, input_rows, expected): + assert expected == slices.to_forward_offsets(slice, input_rows)