Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ nil

### Features

nil
- [650](https://github.com/Shopify/job-iteration/pull/650) Add support for batch enumeration over models with composite primary keys.

### Bug fixes

Expand Down
67 changes: 36 additions & 31 deletions lib/job-iteration/active_record_batch_enumerator.rb
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# frozen_string_literal: true

require_relative "active_record_batch_enumerator/column_manager"

module JobIteration
# Builds Batch Enumerator based on ActiveRecord Relation.
# @see EnumeratorBuilder
Expand All @@ -11,26 +13,15 @@ class ActiveRecordBatchEnumerator
def initialize(relation, columns: nil, batch_size: 100, timezone: nil, cursor: nil)
@batch_size = batch_size
@timezone = timezone
@primary_key = "#{relation.table_name}.#{relation.primary_key}"
@columns = Array(columns&.map(&:to_s) || @primary_key)
@primary_key_index = @columns.index(@primary_key) || @columns.index(relation.primary_key)
@pluck_columns = if @primary_key_index
@columns
else
@columns.dup << @primary_key
end
@column_mgr = ColumnManager.new(relation: relation, columns: columns)
@cursor = Array.wrap(cursor)
@initial_cursor = @cursor
raise ArgumentError, "Must specify at least one column" if @columns.empty?
if relation.joins_values.present? && [email protected]? { |column| column.to_s.include?(".") }
raise ArgumentError, "You need to specify fully-qualified columns if you join a table"
end

if relation.arel.orders.present? || relation.arel.taken.present?
raise JobIteration::ActiveRecordCursor::ConditionNotSupportedError
end

@base_relation = relation.reorder(@columns.join(","))
@base_relation = relation.reorder(@column_mgr.columns.join(","))
end

def each
Expand All @@ -53,7 +44,7 @@ def next_batch
relation = relation.where(*conditions)
end

cursor_values, ids = relation.uncached do
cursor_values, pkey_ids = relation.uncached do
pluck_columns(relation)
end

Expand All @@ -62,25 +53,39 @@ def next_batch
@cursor = @initial_cursor
return
end
# The primary key was plucked, but original cursor did not include it, so we should remove it
cursor.pop unless @primary_key_index
@cursor = Array.wrap(cursor)

# Yields relations by selecting the primary keys of records in the batch.
# Post.where(published: nil) results in an enumerator of relations like:
# Post.where(published: nil, ids: batch_of_ids)
@base_relation.where(@primary_key => ids)
@cursor = @column_mgr.remove_missing_pkey_values(cursor)

filter_relation_with_primary_key(pkey_ids)
end

def pluck_columns(relation)
if @pluck_columns.size == 1 # only the primary key
column_values = relation.pluck(*@pluck_columns)
return [column_values, column_values]
# Yields relations by selecting the primary keys of records in the batch.
# Post.where(published: nil) results in an enumerator of relations like:
# Post.where(published: nil, ids: batch_of_ids)
def filter_relation_with_primary_key(primary_key_values)
pkey = @column_mgr.primary_key
pkey_values = primary_key_values

# If the primary key is only composed of a single column, simplify the
# query. This keeps us compatible with Rails prior to 7.1 where composite
# primary keys were introduced along with the syntax that allows you to
# query for multi-column values.
if pkey.size <= 1
pkey = pkey.first
pkey_values = pkey_values.map(&:first)
end

column_values = relation.pluck(*@pluck_columns)
primary_key_index = @primary_key_index || -1
primary_key_values = column_values.map { |values| values[primary_key_index] }
@base_relation.where(pkey => pkey_values)
end

def pluck_columns(relation)
column_values = relation.pluck(*@column_mgr.pluck_columns)

# Pluck behaves differently when only one column is given. By using zip,
# we make the output consistent (at the cost of more object allocation).
column_values = column_values.zip if @column_mgr.pluck_columns.size == 1

primary_key_values = @column_mgr.pkey_values(column_values)

serialize_column_values!(column_values)
[column_values, primary_key_values]
Expand All @@ -94,15 +99,15 @@ def cursor_value

def conditions
column_index = @cursor.size - 1
column = @columns[column_index]
where_clause = if @columns.size == @cursor.size
column = @column_mgr.columns[column_index]
where_clause = if @column_mgr.columns.size == @cursor.size
"#{column} > ?"
else
"#{column} >= ?"
end
while column_index > 0
column_index -= 1
column = @columns[column_index]
column = @column_mgr.columns[column_index]
where_clause = "#{column} > ? OR (#{column} = ? AND (#{where_clause}))"
end
ret = @cursor.reduce([where_clause]) { |params, value| params << value << value }
Expand Down
138 changes: 138 additions & 0 deletions lib/job-iteration/active_record_batch_enumerator/column_manager.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
# frozen_string_literal: true

module JobIteration
class ActiveRecordBatchEnumerator
# Utility class for the batch enumerator that manages the columns that need
# to be plucked. It ensures primary key columns are plucked so that records
# in the batch can be queried for efficiently.
#
# @see ActiveRecordBatchEnumerator
class ColumnManager
# @param relation [ActiveRecord::Relation] - relation to manage columns for
# @param columns [Array<String,Symbol>, nil] - set of columns to select
def initialize(relation:, columns:)
@table_name = relation.table_name
@primary_key = Array(relation.primary_key)
@qualified_pkey_columns = @primary_key.map { |col| qualify_column(col) }
@columns = columns&.map(&:to_s) || @qualified_pkey_columns

validate_columns!(relation)
initialize_pluck_columns_and_pkey_positions
end

# @return [Array<String>]
# The list of columns to be plucked. If no columns were specified, this
# list contains the fully qualified primary key column(s).
attr_reader :columns

# @return [Array<String>]
# The list of primary key columns for the relation. These columns are
# not qualified with the table name.
attr_reader :primary_key

# @return [Array<String>]
# The full set of columns to be plucked from the relation. This is a
# superset of `columns` and is guaranteed to contain all of the primary
# key columns on the relation.
attr_reader :pluck_columns

# @param column_values [Array<Array>]
# List of rows where each row contains values as determined by
# `pluck_columns`.
#
# @return [Array<Array>]
# List where each item contains the primary key column values for the
# corresponding row. Values are guaranteed to be in the same order as
# the columns are listed in `primary_key`.
def pkey_values(column_values)
column_values.map do |values|
@qualified_pkey_columns.map do |pkey_column|
pkey_column_idx = @primary_key_index_map[pkey_column]
values[pkey_column_idx]
end
end
end

# @param cursor [Array]
# A list of values for a single row, as determined by `pluck_columns`.
#
# @return [Array]
# The same values that were passed in, minus any primary key column
# values that do not appear in `columns`.
def remove_missing_pkey_values(cursor)
cursor.pop(@missing_pkey_count)
cursor
end

private

def qualify_column(column)
"#{@table_name}.#{column}"
end

def validate_columns!(relation)
raise ArgumentError, "Must specify at least one column" if @columns.empty?

if relation.joins_values.present? && [email protected]? { |column| column.to_s.include?(".") }
raise ArgumentError, "You need to specify fully-qualified columns if you join a table"
end
end

# This method is responsible for initializing several instance variables:
#
# * `@pluck_columns` [Array<String>] -
# The set of columns to pluck.
# * `@missing_pkey_count` [Integer] -
# The number of primary keys that were missing from `@columns`.
# * `@primary_key_index_map` [Hash<String:Integer>] -
# Hash mapping all primary key columns to their position in
# `@pluck_columns`.
def initialize_pluck_columns_and_pkey_positions
@pluck_columns = @columns.dup
initial_pkey_index_map = find_initial_primary_key_indices(@pluck_columns)

missing_pkey_columns = initial_pkey_index_map.select { |_, idx| idx.nil? }.keys
missing_pkey_index_map = add_missing_pkey_columns!(missing_pkey_columns, @pluck_columns)
@missing_pkey_count = missing_pkey_index_map.size

# Compute the location of each primary key column in `@pluck_columns`.
@primary_key_index_map = initial_pkey_index_map.merge(missing_pkey_index_map)
end

# Figure out which primary key columns are already included in `columns`
# and track their position in the array.
#
# @param columns [Array<String>] - list of columns
#
# @return [Hash<String:Integer,nil>]
# A hash containing all of the fully qualified primary key columns as
# its keys. Values are the position of each column in the `columns`
# array. A `nil` value indicates the column is not present in `columns`.
def find_initial_primary_key_indices(columns)
@primary_key.each_with_object({}) do |pkey_column, indices|
fully_qualified_pkey_column = qualify_column(pkey_column)
idx = columns.index(pkey_column) || columns.index(fully_qualified_pkey_column)

indices[fully_qualified_pkey_column] = idx
end
end

# Takes a set of primary key columns and adds them to `columns`.
#
# @note - mutates `columns`
#
# @param missing_columns [Array<String>] - set of missing pkey columns
# @param columns [Array<String>] - set of columns to pluck
#
# @return [Hash<String:Integer>]
# A hash containing all of the values from `missing_columns` as its
# keys. Values are the position of those columns in `columns`.
def add_missing_pkey_columns!(missing_columns, columns)
missing_columns.each_with_object({}) do |pkey_column, indices|
indices[pkey_column] = columns.size
columns << pkey_column
end
end
end
end
end
7 changes: 7 additions & 0 deletions test/test_helper.rb
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,12 @@ class TestCase
def skip_until_version(version)
skip("Deferred until version #{version}") if Gem::Version.new(JobIteration::VERSION) < Gem::Version.new(version)
end

def skip_until_active_record_version(version)
return if ActiveRecord.version >= Gem::Version.new(version)

skip("Deferred until ActiveRecord version #{version}")
end
end
end

Expand Down Expand Up @@ -150,6 +156,7 @@ def truncate_fixtures
ActiveRecord::Base.connection.truncate(TravelRoute.table_name)
ActiveRecord::Base.connection.truncate(Product.table_name)
ActiveRecord::Base.connection.truncate(Comment.table_name)
ActiveRecord::Base.connection.truncate(Order.table_name)
end

def with_global_default_retry_backoff(backoff)
Expand Down
Loading