Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
275 changes: 252 additions & 23 deletions src/main/cpp/src/timezones.cu
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,12 @@
#include <cudf/table/table.hpp>
#include <cudf/transform.hpp>
#include <cudf/types.hpp>
#include <cudf/utilities/bit.hpp>

#include <rmm/cuda_stream_view.hpp>
#include <rmm/exec_policy.hpp>

#include <cuda/std/array>
#include <cuda/std/functional>
#include <thrust/binary_search.h>

Expand Down Expand Up @@ -242,49 +244,276 @@ std::unique_ptr<column> convert_to_utc_with_multiple_timezones(
// =================== ORC timezones begin ===================
// ORC timezone uses java.util.TimeZone rules, which is different from java.time.ZoneId rules.

// ---- Calendar helpers for DST computation on GPU ----
// Ported from java.util.SimpleTimeZone.getOffset logic.

constexpr int32_t MS_PER_SECOND = 1000;
constexpr int32_t MS_PER_MINUTE = 60 * MS_PER_SECOND;
constexpr int32_t MS_PER_HOUR = 60 * MS_PER_MINUTE;
constexpr int64_t MS_PER_DAY = 24LL * MS_PER_HOUR;

// Cumulative days before each month (non-leap). Index 0 = Jan.
// __device__ storage is required because days_in_month / days_before_month index
// this with non-compile-time values, which addresses elements at runtime.
__device__ constexpr cuda::std::array<int32_t, 12> DAYS_BEFORE_MONTH = {
0, 31, 59, 90, 120, 151, 181, 212, 243, 273, 304, 334};

__device__ static bool is_leap_year(int32_t year)
{
return (year % 4 == 0) && ((year % 100 != 0) || (year % 400 == 0));
}

__device__ static int32_t days_in_month(int32_t month, int32_t year)
{
int32_t d =
(month < 11) ? (DAYS_BEFORE_MONTH[month + 1] - DAYS_BEFORE_MONTH[month]) : 31; // December
if (month == 1 && is_leap_year(year)) { d++; }
return d;
}

__device__ static int32_t days_before_month(int32_t month, int32_t year)
{
int32_t d = DAYS_BEFORE_MONTH[month];
if (month > 1 && is_leap_year(year)) { d++; }
return d;
}

// Days from epoch (1970-01-01) to Jan 1 of the given year.
__device__ static int64_t days_from_epoch_to_year(int32_t year)
{
int32_t y = year - 1;
// Gregorian calendar days from epoch
return 365LL * (year - 1970) + (y / 4 - 492) - (y / 100 - 19) + (y / 400 - 4);
}

// DST rule mode constants (same as SimpleTimeZone)
enum dst_rule_mode : int32_t {
DOM_MODE = 0,
DOW_IN_MONTH_MODE = 1,
DOW_GE_DOM_MODE = 2,
DOW_LE_DOM_MODE = 3
};

// Time mode constants
enum dst_time_mode : int32_t { WALL_TIME = 0, STANDARD_TIME = 1, UTC_TIME = 2 };

/**
* @brief Compute the day-of-month when a DST rule triggers for the given year and month.
*
* Implements the same logic as SimpleTimeZone's rule decoding:
* - DOM_MODE: exact day of month
* - DOW_IN_MONTH_MODE: nth occurrence of dayOfWeek (negative = from end)
* - DOW_GE_DOM_MODE: first dayOfWeek on or after the given day
* - DOW_LE_DOM_MODE: last dayOfWeek on or before the given day
*/
__device__ static int32_t compute_rule_day(
int32_t rule_mode, int32_t rule_day, int32_t rule_dow, int32_t year, int32_t month)
{
int32_t month_len = days_in_month(month, year);

// Compute day-of-week of the 1st of the month
int64_t first_of_month_epoch_days =
days_from_epoch_to_year(year) + days_before_month(month, year);
int64_t dow_raw = ((first_of_month_epoch_days + 4) % 7);
if (dow_raw < 0) dow_raw += 7;
int32_t first_dow = static_cast<int32_t>(dow_raw) + 1; // 1=Sun..7=Sat

switch (rule_mode) {
case DOM_MODE: return rule_day;

case DOW_IN_MONTH_MODE: {
if (rule_day > 0) {
// nth occurrence: 1st=first week, etc.
int32_t diff = rule_dow - first_dow;
if (diff < 0) diff += 7;
return 1 + diff + (rule_day - 1) * 7;
} else {
// negative: from end of month. -1 = last occurrence.
int32_t last_dow = ((first_dow - 1 + (month_len - 1)) % 7) + 1;
int32_t diff = last_dow - rule_dow;
if (diff < 0) diff += 7;
return month_len - diff + (rule_day + 1) * 7;
}
}

case DOW_GE_DOM_MODE: {
// First rule_dow on or after rule_day
int64_t target_epoch = first_of_month_epoch_days + (rule_day - 1);
int64_t target_raw = ((target_epoch + 4) % 7);
if (target_raw < 0) target_raw += 7;
int32_t target_dow = static_cast<int32_t>(target_raw) + 1;
int32_t diff = rule_dow - target_dow;
if (diff < 0) diff += 7;
return rule_day + diff;
}

case DOW_LE_DOM_MODE: {
// Last rule_dow on or before rule_day
int64_t target_epoch = first_of_month_epoch_days + (rule_day - 1);
int64_t target_raw = ((target_epoch + 4) % 7);
if (target_raw < 0) target_raw += 7;
int32_t target_dow = static_cast<int32_t>(target_raw) + 1;
int32_t diff = target_dow - rule_dow;
if (diff < 0) diff += 7;
return rule_day - diff;
}

default: return rule_day;
}
}

/**
* @brief Compute the UTC millis of a DST transition for a given year.
*
* @param year The calendar year.
* @param rule_month 0-based month.
* @param rule_day Day parameter of the rule.
* @param rule_dow Day-of-week parameter.
* @param rule_time Time-of-day in ms.
* @param rule_time_mode WALL_TIME / STANDARD_TIME / UTC_TIME.
* @param rule_mode DOM / DOW_IN_MONTH / DOW_GE_DOM / DOW_LE_DOM.
* @param raw_offset_ms The timezone raw offset in ms.
* @param dst_savings_ms The DST savings in ms (needed for WALL_TIME conversion).
* @param is_start_rule True for DST start (to determine WALL_TIME adjustment).
*/
__device__ static int64_t compute_transition_utc_ms(int32_t year,
int32_t rule_month,
int32_t rule_day,
int32_t rule_dow,
int32_t rule_time,
int32_t rule_time_mode,
int32_t rule_mode,
int32_t raw_offset_ms,
int32_t dst_savings_ms,
bool is_start_rule)
{
int32_t actual_day = compute_rule_day(rule_mode, rule_day, rule_dow, year, rule_month);
Comment on lines +336 to +390

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 compute_rule_day can return a day outside [1, month_len]

DOW_GE_DOM_MODE returns rule_day + diff (0–6 added), so a rule_day of 27 in a 28-day month with diff = 3 yields 30 — already out of range. Similarly, DOW_IN_MONTH_MODE with a large positive occurrence can overshoot. Java's SimpleTimeZone validates its own rules before storing them, so inputs from the serialized DST rule table should be safe, but compute_transition_utc_ms ultimately uses the result as a day-of-month offset into the month without any clamp, so a corrupted/unexpected rule would silently produce a wrong UTC timestamp (rather than a detectable error).

int64_t epoch_days =
days_from_epoch_to_year(year) + days_before_month(rule_month, year) + (actual_day - 1);
int64_t utc_ms = epoch_days * MS_PER_DAY + rule_time;

// Convert from the specified time mode to UTC
switch (rule_time_mode) {
case WALL_TIME:
utc_ms -= raw_offset_ms;
// Wall time during DST-end means DST is still active, subtract savings.
// Wall time during DST-start means DST is not yet active.
if (!is_start_rule) { utc_ms -= dst_savings_ms; }
break;
case STANDARD_TIME: utc_ms -= raw_offset_ms; break;
case UTC_TIME: break;
}

return utc_ms;
}

// Lightweight year extraction from epoch millis (no month/day computation).
__device__ static int32_t millis_to_year(int64_t epoch_ms)
{
int64_t day_count =
(epoch_ms >= 0) ? epoch_ms / MS_PER_DAY : (epoch_ms - MS_PER_DAY + 1) / MS_PER_DAY;
int64_t days_since_1 = day_count + 719468;
int32_t era =
static_cast<int32_t>((days_since_1 >= 0 ? days_since_1 : days_since_1 - 146096) / 146097);
int32_t doe = static_cast<int32_t>(days_since_1 - static_cast<int64_t>(era) * 146097);
int32_t yoe = (doe - doe / 1460 + doe / 36524 - doe / 146096) / 365;
int32_t year = yoe + era * 400;
int32_t doy = doe - (365 * yoe + yoe / 4 - yoe / 100 + yoe / 400);
int32_t mp = (5 * doy + 2) / 153;
int32_t month = mp < 10 ? mp + 2 : mp - 10;
if (month <= 1) { year++; }
return year;
}

/**
* @brief Get the transition index for the given time `time_ms` using binary search.
* Find the first transition that is greater or equal to `time_ms`,
* and return the corresponding offset.
* @brief Compute the total UTC offset (raw + DST) for a UTC timestamp using DST rules.
*
* @param begin the beginning of the transition array.
* @param end the end of the transition array.
* @param time_ms the input time in milliseconds to find the transition index for.
* @param offset_begin the beginning of the offset array.
* @param offset_end the end of the offset array.
* This is the GPU equivalent of java.util.SimpleTimeZone.getOffset(long).
* It computes the DST start and end transitions for the year containing the
Comment on lines 428 to +432

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 millis_to_year uses only raw_offset_ms; may derive wrong year near DST-active year boundaries

millis_to_year(utc_ms + raw_offset_ms) approximates the local year using the standard (non-DST) offset. When DST is active the true local time is utc_ms + raw_offset_ms + dst_savings. For a Southern-Hemisphere timezone (e.g., UTC+11, +1h DST) a UTC timestamp at Dec 31 ~22:xx could map to Jan 1 via the raw offset but only to Dec 31 if the extra DST hour is missing. The Northern-Hemisphere path is fine because DST periods never straddle Jan 1, but for Southern-Hemisphere transitions that span the year boundary the OR-based check (utc_ms >= dst_start || utc_ms < dst_end) can give a correct result even with the wrong year, so in practice this may not surface — but it is an undocumented fragility that should be either explained in a comment or resolved by iterating DST windows for year and year ± 1.

* timestamp, then checks if the timestamp falls within the DST window.
*
* @return the offset.
* Handles both Northern Hemisphere (start < end) and Southern Hemisphere
* (start > end, i.e., DST spans year boundary).
*/
__device__ static int32_t compute_dst_offset(int64_t utc_ms,
int32_t raw_offset_ms,
spark_rapids_jni::dst_rule const& rule)
{
if (!rule.has_dst) { return raw_offset_ms; }

int32_t year = millis_to_year(utc_ms + raw_offset_ms);

// Compute DST-on and DST-off transitions in UTC for this year
int64_t dst_start = compute_transition_utc_ms(year,
rule.start_month,
rule.start_day,
rule.start_dow,
rule.start_time,
rule.start_time_mode,
rule.start_mode,
raw_offset_ms,
rule.dst_savings,
true);

int64_t dst_end = compute_transition_utc_ms(year,
rule.end_month,
rule.end_day,
rule.end_dow,
rule.end_time,
rule.end_time_mode,
rule.end_mode,
raw_offset_ms,
rule.dst_savings,
false);

bool in_dst;
if (dst_start < dst_end) {
// Northern Hemisphere: DST is [start, end)
in_dst = (utc_ms >= dst_start && utc_ms < dst_end);
} else {
// Southern Hemisphere: DST is [start, year_end) ∪ [year_start, end)
in_dst = (utc_ms >= dst_start || utc_ms < dst_end);
}

return in_dst ? raw_offset_ms + rule.dst_savings : raw_offset_ms;
}

/**
* @brief Get the offset for a UTC time using the transition table + DST rule fallback.
*
* For timestamps within the transition table range, uses binary search.
* For timestamps beyond the table, uses DST rule computation.
* For timestamps before the first recorded transition, falls back to the
* historical initial offset to match java.util.TimeZone behavior.
*/
__device__ static int32_t get_transition_index(int64_t const* begin,
int64_t const* end,
int64_t time_ms,
int32_t const* offset_begin,
int32_t const* offset_end,
int32_t raw_offset)
int32_t initial_offset,
int32_t raw_offset,
spark_rapids_jni::dst_rule const& rule)
{
if (begin == end) {
// fixed offset timezone, no transitions
return raw_offset;
// No transition table. Use DST rule if available, else fixed offset.
return compute_dst_offset(time_ms, raw_offset, rule);
}

// upper_bound returns the first element strictly greater than time_ms, so
// *iter > time_ms and the index we want is iter - 1.
auto const iter = thrust::upper_bound(thrust::seq, begin, end, time_ms);
if (iter == end) {
// after the transition table, returns the raw offset
return raw_offset;
// Beyond the transition table -- use DST rule for future dates
return compute_dst_offset(time_ms, raw_offset, rule);
}

int32_t index = static_cast<int32_t>(cuda::std::distance(begin, iter));
if (*iter == time_ms) {
// find exact match, return the offset at that index
return offset_begin[index];
}

if (index == 0) {
// prior to the transition table, returns the raw offset
return raw_offset;
// Before the first recorded transition, java.util.TimeZone uses the
// historical offset in effect before that transition, not the future rule.
return initial_offset;
}

// return the offset at the previous index
return offset_begin[index - 1];
}

Expand Down