diff --git a/CHANGELOG.md b/CHANGELOG.md index 713ad9be..39dda5c1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,7 +10,7 @@ nil ### Features -nil +- [650](https://github.com/Shopify/job-iteration/pull/650) Add support for batch enumeration over models with composite primary keys. ### Bug fixes diff --git a/lib/job-iteration/active_record_batch_enumerator.rb b/lib/job-iteration/active_record_batch_enumerator.rb index 81d4585a..95696184 100644 --- a/lib/job-iteration/active_record_batch_enumerator.rb +++ b/lib/job-iteration/active_record_batch_enumerator.rb @@ -1,5 +1,7 @@ # frozen_string_literal: true +require_relative "active_record_batch_enumerator/column_manager" + module JobIteration # Builds Batch Enumerator based on ActiveRecord Relation. # @see EnumeratorBuilder @@ -11,26 +13,15 @@ class ActiveRecordBatchEnumerator def initialize(relation, columns: nil, batch_size: 100, timezone: nil, cursor: nil) @batch_size = batch_size @timezone = timezone - @primary_key = "#{relation.table_name}.#{relation.primary_key}" - @columns = Array(columns&.map(&:to_s) || @primary_key) - @primary_key_index = @columns.index(@primary_key) || @columns.index(relation.primary_key) - @pluck_columns = if @primary_key_index - @columns - else - @columns.dup << @primary_key - end + @column_mgr = ColumnManager.new(relation: relation, columns: columns) @cursor = Array.wrap(cursor) @initial_cursor = @cursor - raise ArgumentError, "Must specify at least one column" if @columns.empty? - if relation.joins_values.present? && !@columns.all? { |column| column.to_s.include?(".") } - raise ArgumentError, "You need to specify fully-qualified columns if you join a table" - end if relation.arel.orders.present? || relation.arel.taken.present? raise JobIteration::ActiveRecordCursor::ConditionNotSupportedError end - @base_relation = relation.reorder(@columns.join(",")) + @base_relation = relation.reorder(@column_mgr.columns.join(",")) end def each @@ -53,7 +44,7 @@ def next_batch relation = relation.where(*conditions) end - cursor_values, ids = relation.uncached do + cursor_values, pkey_ids = relation.uncached do pluck_columns(relation) end @@ -62,25 +53,39 @@ def next_batch @cursor = @initial_cursor return end - # The primary key was plucked, but original cursor did not include it, so we should remove it - cursor.pop unless @primary_key_index - @cursor = Array.wrap(cursor) - # Yields relations by selecting the primary keys of records in the batch. - # Post.where(published: nil) results in an enumerator of relations like: - # Post.where(published: nil, ids: batch_of_ids) - @base_relation.where(@primary_key => ids) + @cursor = @column_mgr.remove_missing_pkey_values(cursor) + + filter_relation_with_primary_key(pkey_ids) end - def pluck_columns(relation) - if @pluck_columns.size == 1 # only the primary key - column_values = relation.pluck(*@pluck_columns) - return [column_values, column_values] + # Yields relations by selecting the primary keys of records in the batch. + # Post.where(published: nil) results in an enumerator of relations like: + # Post.where(published: nil, ids: batch_of_ids) + def filter_relation_with_primary_key(primary_key_values) + pkey = @column_mgr.primary_key + pkey_values = primary_key_values + + # If the primary key is only composed of a single column, simplify the + # query. This keeps us compatible with Rails prior to 7.1 where composite + # primary keys were introduced along with the syntax that allows you to + # query for multi-column values. + if pkey.size <= 1 + pkey = pkey.first + pkey_values = pkey_values.map(&:first) end - column_values = relation.pluck(*@pluck_columns) - primary_key_index = @primary_key_index || -1 - primary_key_values = column_values.map { |values| values[primary_key_index] } + @base_relation.where(pkey => pkey_values) + end + + def pluck_columns(relation) + column_values = relation.pluck(*@column_mgr.pluck_columns) + + # Pluck behaves differently when only one column is given. By using zip, + # we make the output consistent (at the cost of more object allocation). + column_values = column_values.zip if @column_mgr.pluck_columns.size == 1 + + primary_key_values = @column_mgr.pkey_values(column_values) serialize_column_values!(column_values) [column_values, primary_key_values] @@ -94,15 +99,15 @@ def cursor_value def conditions column_index = @cursor.size - 1 - column = @columns[column_index] - where_clause = if @columns.size == @cursor.size + column = @column_mgr.columns[column_index] + where_clause = if @column_mgr.columns.size == @cursor.size "#{column} > ?" else "#{column} >= ?" end while column_index > 0 column_index -= 1 - column = @columns[column_index] + column = @column_mgr.columns[column_index] where_clause = "#{column} > ? OR (#{column} = ? AND (#{where_clause}))" end ret = @cursor.reduce([where_clause]) { |params, value| params << value << value } diff --git a/lib/job-iteration/active_record_batch_enumerator/column_manager.rb b/lib/job-iteration/active_record_batch_enumerator/column_manager.rb new file mode 100644 index 00000000..d4f6ebc9 --- /dev/null +++ b/lib/job-iteration/active_record_batch_enumerator/column_manager.rb @@ -0,0 +1,138 @@ +# frozen_string_literal: true + +module JobIteration + class ActiveRecordBatchEnumerator + # Utility class for the batch enumerator that manages the columns that need + # to be plucked. It ensures primary key columns are plucked so that records + # in the batch can be queried for efficiently. + # + # @see ActiveRecordBatchEnumerator + class ColumnManager + # @param relation [ActiveRecord::Relation] - relation to manage columns for + # @param columns [Array, nil] - set of columns to select + def initialize(relation:, columns:) + @table_name = relation.table_name + @primary_key = Array(relation.primary_key) + @qualified_pkey_columns = @primary_key.map { |col| qualify_column(col) } + @columns = columns&.map(&:to_s) || @qualified_pkey_columns + + validate_columns!(relation) + initialize_pluck_columns_and_pkey_positions + end + + # @return [Array] + # The list of columns to be plucked. If no columns were specified, this + # list contains the fully qualified primary key column(s). + attr_reader :columns + + # @return [Array] + # The list of primary key columns for the relation. These columns are + # not qualified with the table name. + attr_reader :primary_key + + # @return [Array] + # The full set of columns to be plucked from the relation. This is a + # superset of `columns` and is guaranteed to contain all of the primary + # key columns on the relation. + attr_reader :pluck_columns + + # @param column_values [Array] + # List of rows where each row contains values as determined by + # `pluck_columns`. + # + # @return [Array] + # List where each item contains the primary key column values for the + # corresponding row. Values are guaranteed to be in the same order as + # the columns are listed in `primary_key`. + def pkey_values(column_values) + column_values.map do |values| + @qualified_pkey_columns.map do |pkey_column| + pkey_column_idx = @primary_key_index_map[pkey_column] + values[pkey_column_idx] + end + end + end + + # @param cursor [Array] + # A list of values for a single row, as determined by `pluck_columns`. + # + # @return [Array] + # The same values that were passed in, minus any primary key column + # values that do not appear in `columns`. + def remove_missing_pkey_values(cursor) + cursor.pop(@missing_pkey_count) + cursor + end + + private + + def qualify_column(column) + "#{@table_name}.#{column}" + end + + def validate_columns!(relation) + raise ArgumentError, "Must specify at least one column" if @columns.empty? + + if relation.joins_values.present? && !@columns.all? { |column| column.to_s.include?(".") } + raise ArgumentError, "You need to specify fully-qualified columns if you join a table" + end + end + + # This method is responsible for initializing several instance variables: + # + # * `@pluck_columns` [Array] - + # The set of columns to pluck. + # * `@missing_pkey_count` [Integer] - + # The number of primary keys that were missing from `@columns`. + # * `@primary_key_index_map` [Hash] - + # Hash mapping all primary key columns to their position in + # `@pluck_columns`. + def initialize_pluck_columns_and_pkey_positions + @pluck_columns = @columns.dup + initial_pkey_index_map = find_initial_primary_key_indices(@pluck_columns) + + missing_pkey_columns = initial_pkey_index_map.select { |_, idx| idx.nil? }.keys + missing_pkey_index_map = add_missing_pkey_columns!(missing_pkey_columns, @pluck_columns) + @missing_pkey_count = missing_pkey_index_map.size + + # Compute the location of each primary key column in `@pluck_columns`. + @primary_key_index_map = initial_pkey_index_map.merge(missing_pkey_index_map) + end + + # Figure out which primary key columns are already included in `columns` + # and track their position in the array. + # + # @param columns [Array] - list of columns + # + # @return [Hash] + # A hash containing all of the fully qualified primary key columns as + # its keys. Values are the position of each column in the `columns` + # array. A `nil` value indicates the column is not present in `columns`. + def find_initial_primary_key_indices(columns) + @primary_key.each_with_object({}) do |pkey_column, indices| + fully_qualified_pkey_column = qualify_column(pkey_column) + idx = columns.index(pkey_column) || columns.index(fully_qualified_pkey_column) + + indices[fully_qualified_pkey_column] = idx + end + end + + # Takes a set of primary key columns and adds them to `columns`. + # + # @note - mutates `columns` + # + # @param missing_columns [Array] - set of missing pkey columns + # @param columns [Array] - set of columns to pluck + # + # @return [Hash] + # A hash containing all of the values from `missing_columns` as its + # keys. Values are the position of those columns in `columns`. + def add_missing_pkey_columns!(missing_columns, columns) + missing_columns.each_with_object({}) do |pkey_column, indices| + indices[pkey_column] = columns.size + columns << pkey_column + end + end + end + end +end diff --git a/test/test_helper.rb b/test/test_helper.rb index f9ee061b..5eb84ef0 100644 --- a/test/test_helper.rb +++ b/test/test_helper.rb @@ -116,6 +116,12 @@ class TestCase def skip_until_version(version) skip("Deferred until version #{version}") if Gem::Version.new(JobIteration::VERSION) < Gem::Version.new(version) end + + def skip_until_active_record_version(version) + return if ActiveRecord.version >= Gem::Version.new(version) + + skip("Deferred until ActiveRecord version #{version}") + end end end @@ -150,6 +156,7 @@ def truncate_fixtures ActiveRecord::Base.connection.truncate(TravelRoute.table_name) ActiveRecord::Base.connection.truncate(Product.table_name) ActiveRecord::Base.connection.truncate(Comment.table_name) + ActiveRecord::Base.connection.truncate(Order.table_name) end def with_global_default_retry_backoff(backoff) diff --git a/test/unit/active_record_batch_enumerator_test.rb b/test/unit/active_record_batch_enumerator_test.rb index d16e76a0..269f5a32 100644 --- a/test/unit/active_record_batch_enumerator_test.rb +++ b/test/unit/active_record_batch_enumerator_test.rb @@ -5,6 +5,7 @@ module JobIteration class ActiveRecordBatchEnumeratorTest < IterationUnitTest SQL_TIME_FORMAT = "%Y-%m-%d %H:%M:%S.%N" + test "#each yields batches as relation with the last record's cursor position" do enum = build_enumerator product_batches = Product.order(:id).take(4).in_groups_of(2).map { |product| [product, product.last.id] } @@ -148,6 +149,92 @@ class ActiveRecordBatchEnumeratorTest < IterationUnitTest end end + test "(composite primary key) #each yields batches as relation with the last record's cursor position" do + skip_until_active_record_version("7.1") + seed_orders! + + enum = build_enumerator(relation: Order.all) + order_batches = Order.order(:id).take(4).in_groups_of(2).map { |order| [order, order.last.id] } + + enum.first(2).each_with_index do |(batch, cursor), index| + assert batch.is_a?(ActiveRecord::Relation) + assert_equal order_batches[index].first, batch + assert_equal order_batches[index].last, cursor + end + end + + test "(composite primary key) columns without a primary key column yields cursors without the unspecified value" do + skip_until_active_record_version("7.1") + seed_orders! + + enum = build_enumerator(relation: Order.all, columns: [:name, :shop_id]) + orders = Order.order(:name, :shop_id).take(2) + + cursor = [orders.last.name, orders.last.shop_id] + assert_equal([orders, cursor], enum.first) + end + + test "(composite primary key) cursor can be used to resume on multiple columns" do + skip_until_active_record_version("7.1") + seed_orders! + + enum = build_enumerator(relation: Order.all, columns: [:name, :id]) + orders = Order.order(:name, :id).take(2) + + cursor = [orders.last.name, orders.last.id_value] + assert_equal([orders, cursor], enum.first) + + enum = build_enumerator(relation: Order.all, columns: [:name, :id], cursor: cursor) + orders = Order.order(:name, :id).offset(2).take(2) + + cursor = [orders.last.name, orders.last.id_value] + assert_equal([orders, cursor], enum.first) + end + + test "(composite primary key) columns missing primary key column still queries for primary key values" do + skip_until_active_record_version("7.1") + + queries = track_queries do + enum = build_enumerator(relation: Order.all, columns: [:name]) + enum.first + end + assert_match(/\A\s?`orders`.`name`, `orders`.`shop_id`, `orders`.`id`\z/, queries.first[/SELECT (.*) FROM/, 1]) + end + + test "(composite primary key) columns with only one primary key column still queries for all primary key values" do + skip_until_active_record_version("7.1") + + queries = track_queries do + enum = build_enumerator(relation: Order.all, columns: ["orders.id", :name]) + enum.first + end + assert_match(/\A\s?`orders`.`id`, `orders`.`name`, `orders`.`shop_id`\z/, queries.first[/SELECT (.*) FROM/, 1]) + end + + test "(composite primary key) columns configured with primary key only queries primary key columns once" do + skip_until_active_record_version("7.1") + + queries = track_queries do + enum = build_enumerator(relation: Order.all, columns: [:name, :id, "orders.shop_id"]) + enum.first + end + assert_match(/\A\s?`orders`.`name`, `orders`.`id`, `orders`.`shop_id`\z/, queries.first[/SELECT (.*) FROM/, 1]) + end + + test "(composite primary key) one query performed per batch, plus an additional one for the empty cursor" do + skip_until_active_record_version("7.1") + seed_orders! + + enum = build_enumerator(relation: Order.all) + num_batches = 0 + queries = track_queries do + enum.each { num_batches += 1 } + end + + expected_num_queries = num_batches + 1 + assert_equal expected_num_queries, queries.size + end + private def build_enumerator(relation: Product.all, batch_size: 2, timezone: nil, columns: nil, cursor: nil) @@ -160,11 +247,25 @@ def build_enumerator(relation: Product.all, batch_size: 2, timezone: nil, column ) end + # Captures queries made against the database. Automatically filters out + # queries that populate model schemas, since those are just ActiveRecord + # doing its thing. def track_queries(&block) queries = [] - query_cb = ->(*, payload) { queries << payload[:sql] } + query_cb = ->(*, payload) { + return if /SHOW FULL FIELDS FROM `\w+`/.match?(payload[:sql]) + + queries << payload[:sql] + } ActiveSupport::Notifications.subscribed(query_cb, "sql.active_record", &block) queries end + + def seed_orders! + Order.create!(shop_id: 1, name: "T-shirt") + Order.create!(shop_id: 1, name: "Jeans") + Order.create!(shop_id: 2, name: "Ballcap") + Order.create!(shop_id: 3, name: "Jacket") + end end end