|
| 1 | +require 'database_cleaner-active_record' |
| 2 | + |
| 3 | +module DatabaseCleaner |
| 4 | + module ActiveRecord |
| 5 | + # SeededDeletion is a strategy that deletes all records from tables except those that existed before it was instantiated |
| 6 | + # This is useful for tests that need the seeded data to be present. |
| 7 | + class SeededDeletion < Deletion |
| 8 | + def initialize(opts = {}) |
| 9 | + super |
| 10 | + cache_max_seeded_ids |
| 11 | + end |
| 12 | + |
| 13 | + def self.table_max_id_cache |
| 14 | + @table_max_id_cache ||= {} |
| 15 | + end |
| 16 | + |
| 17 | + # Memoize the maximum ID for each class table with non-zero number of rows |
| 18 | + def self.table_max_id_cache=(table_id_hash) |
| 19 | + @table_max_id_cache ||= table_id_hash |
| 20 | + end |
| 21 | + |
| 22 | + def cache_max_seeded_ids |
| 23 | + self.class.table_max_id_cache = table_max_id_hash |
| 24 | + end |
| 25 | + |
| 26 | + delegate :table_max_id_cache, to: :class |
| 27 | + |
| 28 | + private |
| 29 | + def table_max_id_hash |
| 30 | + h = {} |
| 31 | + tables_to_clean(connection).each do |table_name| |
| 32 | + begin |
| 33 | + # Check if the table has an id column using SQL |
| 34 | + column_exists_query = "SELECT EXISTS ( |
| 35 | + SELECT 1 |
| 36 | + FROM information_schema.columns |
| 37 | + WHERE table_name = '#{table_name}' |
| 38 | + AND column_name = 'id' |
| 39 | + )" |
| 40 | + |
| 41 | + has_id_column = connection.select_value(column_exists_query) |
| 42 | + |
| 43 | + if has_id_column |
| 44 | + # Get the maximum id using SQL |
| 45 | + max_id = connection.select_value("SELECT MAX(id) FROM #{connection.quote_table_name(table_name)}") |
| 46 | + |
| 47 | + # store the table and max id if there is non-zero number of rows |
| 48 | + h[table_name] = max_id.to_i if max_id.to_i > 0 |
| 49 | + end |
| 50 | + rescue => e |
| 51 | + # Skip tables that don't exist or have other issues |
| 52 | + puts "Warning: Could not cache max ID for table #{table_name}: #{e.message}" |
| 53 | + end |
| 54 | + end |
| 55 | + h |
| 56 | + end |
| 57 | + |
| 58 | + # Override the delete_table to delete records with IDs greater than the cached max ID (new rows) |
| 59 | + def delete_table(connection, table_name) |
| 60 | + if table_max_id_cache.key?(table_name) |
| 61 | + # For seedable tables, only delete records with IDs greater than the cached max ID |
| 62 | + max_id = table_max_id_cache[table_name] |
| 63 | + connection.execute("DELETE FROM #{connection.quote_table_name(table_name)} WHERE id > '#{max_id}'") if max_id.kind_of?(Numeric) && max_id > 0 |
| 64 | + else |
| 65 | + # For all other tables, use the standard deletion from the parent class |
| 66 | + super |
| 67 | + end |
| 68 | + end |
| 69 | + end |
| 70 | + end |
| 71 | +end |
0 commit comments