Skip to content

Commit 86b7e72

Browse files
authored
chore: support timestamp subtractions (#1346)
* chore: support timestamp subtractions * Fix format * use tree rewrites to dispatch timestamp_diff operator * add TODO for more node updates * polish the code and fix typos * fix comment * add rewrites to compile_raw and compile_peek_sql
1 parent b9bdca8 commit 86b7e72

File tree

13 files changed

+208
-8
lines changed

13 files changed

+208
-8
lines changed

bigframes/core/compile/compiler.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def compile_sql(
5858
# TODO: get rid of output_ids arg
5959
assert len(output_ids) == len(list(node.fields))
6060
node = set_output_names(node, output_ids)
61+
node = nodes.top_down(node, rewrites.rewrite_timedelta_ops)
6162
if ordered:
6263
node, limit = rewrites.pullup_limit_from_slice(node)
6364
node = nodes.bottom_up(node, rewrites.rewrite_slice)
@@ -81,6 +82,7 @@ def compile_sql(
8182
def compile_peek_sql(self, node: nodes.BigFrameNode, n_rows: int) -> str:
8283
ids = [id.sql for id in node.ids]
8384
node = nodes.bottom_up(node, rewrites.rewrite_slice)
85+
node = nodes.top_down(node, rewrites.rewrite_timedelta_ops)
8486
node, _ = rewrites.pull_up_order(
8587
node, order_root=False, ordered_joins=self.strict
8688
)
@@ -93,13 +95,15 @@ def compile_raw(
9395
str, typing.Sequence[google.cloud.bigquery.SchemaField], bf_ordering.RowOrdering
9496
]:
9597
node = nodes.bottom_up(node, rewrites.rewrite_slice)
98+
node = nodes.top_down(node, rewrites.rewrite_timedelta_ops)
9699
node, ordering = rewrites.pull_up_order(node, ordered_joins=self.strict)
97100
ir = self.compile_node(node)
98101
sql = ir.to_sql()
99102
return sql, node.schema.to_bigquery(), ordering
100103

101104
def _preprocess(self, node: nodes.BigFrameNode):
102105
node = nodes.bottom_up(node, rewrites.rewrite_slice)
106+
node = nodes.top_down(node, rewrites.rewrite_timedelta_ops)
103107
node, _ = rewrites.pull_up_order(
104108
node, order_root=False, ordered_joins=self.strict
105109
)

bigframes/core/compile/ibis_types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@
7979
BIGFRAMES_TO_IBIS: Dict[bigframes.dtypes.Dtype, ibis_dtypes.DataType] = {
8080
pandas: ibis for ibis, pandas in BIDIRECTIONAL_MAPPINGS
8181
}
82-
BIGFRAMES_TO_IBIS.update({bigframes.dtypes.TIMEDETLA_DTYPE: ibis_dtypes.int64})
82+
BIGFRAMES_TO_IBIS.update({bigframes.dtypes.TIMEDELTA_DTYPE: ibis_dtypes.int64})
8383
IBIS_TO_BIGFRAMES: Dict[ibis_dtypes.DataType, bigframes.dtypes.Dtype] = {
8484
ibis: pandas for ibis, pandas in BIDIRECTIONAL_MAPPINGS
8585
}

bigframes/core/compile/scalar_op_compiler.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -737,6 +737,11 @@ def unix_millis_op_impl(x: ibis_types.TimestampValue):
737737
return unix_millis(x)
738738

739739

740+
@scalar_op_compiler.register_binary_op(ops.timestamp_diff_op)
741+
def timestamp_diff_op_impl(x: ibis_types.TimestampValue, y: ibis_types.TimestampValue):
742+
return x.delta(y, "microsecond")
743+
744+
740745
@scalar_op_compiler.register_unary_op(ops.FloorDtOp, pass_op=True)
741746
def floor_dt_op_impl(x: ibis_types.Value, op: ops.FloorDtOp):
742747
supported_freqs = ["Y", "Q", "M", "W", "D", "h", "min", "s", "ms", "us", "ns"]

bigframes/core/rewrite/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,15 @@
1515
from bigframes.core.rewrite.identifiers import remap_variables
1616
from bigframes.core.rewrite.implicit_align import try_row_join
1717
from bigframes.core.rewrite.legacy_align import legacy_join_as_projection
18+
from bigframes.core.rewrite.operators import rewrite_timedelta_ops
1819
from bigframes.core.rewrite.order import pull_up_order
1920
from bigframes.core.rewrite.slices import pullup_limit_from_slice, rewrite_slice
2021

2122
__all__ = [
2223
"legacy_join_as_projection",
2324
"try_row_join",
2425
"rewrite_slice",
26+
"rewrite_timedelta_ops",
2527
"pullup_limit_from_slice",
2628
"remap_variables",
2729
"pull_up_order",

bigframes/core/rewrite/operators.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import dataclasses
16+
import functools
17+
import typing
18+
19+
from bigframes import dtypes
20+
from bigframes import operations as ops
21+
from bigframes.core import expression as ex
22+
from bigframes.core import nodes, schema
23+
24+
25+
@dataclasses.dataclass
26+
class _TypedExpr:
27+
expr: ex.Expression
28+
dtype: dtypes.Dtype
29+
30+
31+
def rewrite_timedelta_ops(root: nodes.BigFrameNode) -> nodes.BigFrameNode:
32+
"""
33+
Rewrites expressions to properly handle timedelta values, because this type does not exist
34+
in the SQL world.
35+
"""
36+
if isinstance(root, nodes.ProjectionNode):
37+
updated_assignments = tuple(
38+
(_rewrite_expressions(expr, root.schema).expr, column_id)
39+
for expr, column_id in root.assignments
40+
)
41+
root = nodes.ProjectionNode(root.child, updated_assignments)
42+
43+
# TODO(b/394354614): FilterByNode and OrderNode also contain expressions. Need to update them too.
44+
return root
45+
46+
47+
@functools.cache
48+
def _rewrite_expressions(expr: ex.Expression, schema: schema.ArraySchema) -> _TypedExpr:
49+
if isinstance(expr, ex.DerefOp):
50+
return _TypedExpr(expr, schema.get_type(expr.id.sql))
51+
52+
if isinstance(expr, ex.ScalarConstantExpression):
53+
return _TypedExpr(expr, expr.dtype)
54+
55+
if isinstance(expr, ex.OpExpression):
56+
updated_inputs = tuple(
57+
map(lambda x: _rewrite_expressions(x, schema), expr.inputs)
58+
)
59+
return _rewrite_op_expr(expr, updated_inputs)
60+
61+
raise AssertionError(f"Unexpected expression type: {type(expr)}")
62+
63+
64+
def _rewrite_op_expr(
65+
expr: ex.OpExpression, inputs: typing.Tuple[_TypedExpr, ...]
66+
) -> _TypedExpr:
67+
if isinstance(expr.op, ops.SubOp):
68+
return _rewrite_sub_op(inputs[0], inputs[1])
69+
70+
input_types = tuple(map(lambda x: x.dtype, inputs))
71+
return _TypedExpr(expr, expr.op.output_type(*input_types))
72+
73+
74+
def _rewrite_sub_op(left: _TypedExpr, right: _TypedExpr) -> _TypedExpr:
75+
result_op: ops.BinaryOp = ops.sub_op
76+
if dtypes.is_datetime_like(left.dtype) and dtypes.is_datetime_like(right.dtype):
77+
result_op = ops.timestamp_diff_op
78+
79+
return _TypedExpr(
80+
result_op.as_expr(left.expr, right.expr),
81+
result_op.output_type(left.dtype, right.dtype),
82+
)

bigframes/dtypes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@
5656
TIME_DTYPE = pd.ArrowDtype(pa.time64("us"))
5757
DATETIME_DTYPE = pd.ArrowDtype(pa.timestamp("us"))
5858
TIMESTAMP_DTYPE = pd.ArrowDtype(pa.timestamp("us", tz="UTC"))
59-
TIMEDETLA_DTYPE = pd.ArrowDtype(pa.duration("us"))
59+
TIMEDELTA_DTYPE = pd.ArrowDtype(pa.duration("us"))
6060
NUMERIC_DTYPE = pd.ArrowDtype(pa.decimal128(38, 9))
6161
BIGNUMERIC_DTYPE = pd.ArrowDtype(pa.decimal256(76, 38))
6262
# No arrow equivalent

bigframes/operations/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
date_op,
5050
StrftimeOp,
5151
time_op,
52+
timestamp_diff_op,
5253
ToDatetimeOp,
5354
ToTimestampOp,
5455
UnixMicros,
@@ -125,6 +126,7 @@
125126
sinh_op,
126127
sqrt_op,
127128
sub_op,
129+
SubOp,
128130
tan_op,
129131
tanh_op,
130132
unsafe_pow_op,
@@ -246,6 +248,7 @@
246248
# Datetime ops
247249
"date_op",
248250
"time_op",
251+
"timestamp_diff_op",
249252
"ToDatetimeOp",
250253
"ToTimestampOp",
251254
"StrftimeOp",
@@ -283,6 +286,7 @@
283286
"sinh_op",
284287
"sqrt_op",
285288
"sub_op",
289+
"SubOp",
286290
"tan_op",
287291
"tanh_op",
288292
"unsafe_pow_op",

bigframes/operations/datetime_ops.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,3 +107,22 @@ def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionT
107107
if input_types[0] is not dtypes.TIMESTAMP_DTYPE:
108108
raise TypeError("expected timestamp input")
109109
return dtypes.INT_DTYPE
110+
111+
112+
@dataclasses.dataclass(frozen=True)
113+
class TimestampDiff(base_ops.BinaryOp):
114+
name: typing.ClassVar[str] = "timestamp_diff"
115+
116+
def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
117+
if input_types[0] is not input_types[1]:
118+
raise TypeError(
119+
f"two inputs have different types. left: {input_types[0]}, right: {input_types[1]}"
120+
)
121+
122+
if not dtypes.is_datetime_like(input_types[0]):
123+
raise TypeError("expected timestamp input")
124+
125+
return dtypes.TIMEDELTA_DTYPE
126+
127+
128+
timestamp_diff_op = TimestampDiff()

bigframes/operations/numeric_ops.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,10 @@ def output_type(self, *input_types):
141141
):
142142
# Numeric subtraction
143143
return dtypes.coerce_to_common(left_type, right_type)
144-
# TODO: Add temporal addition once delta types supported
144+
145+
if dtypes.is_datetime_like(left_type) and dtypes.is_datetime_like(right_type):
146+
return dtypes.TIMEDELTA_DTYPE
147+
145148
raise TypeError(f"Cannot subtract dtypes {left_type} and {right_type}")
146149

147150

bigframes/operations/timedelta_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,4 +28,4 @@ class ToTimedeltaOp(base_ops.UnaryOp):
2828
def output_type(self, *input_types):
2929
if input_types[0] is not dtypes.INT_DTYPE:
3030
raise TypeError("expected integer input")
31-
return dtypes.TIMEDETLA_DTYPE
31+
return dtypes.TIMEDELTA_DTYPE

0 commit comments

Comments
 (0)