|  | 
|  | 1 | +package com.marklogic.spring.batch.item.reader; | 
|  | 2 | + | 
|  | 3 | +import org.slf4j.Logger; | 
|  | 4 | +import org.slf4j.LoggerFactory; | 
|  | 5 | +import org.springframework.batch.item.ExecutionContext; | 
|  | 6 | +import org.springframework.batch.item.ItemStreamException; | 
|  | 7 | +import org.springframework.batch.item.database.JdbcCursorItemReader; | 
|  | 8 | +import org.springframework.batch.item.support.AbstractItemStreamItemReader; | 
|  | 9 | +import org.springframework.dao.DataAccessException; | 
|  | 10 | +import org.springframework.jdbc.core.ColumnMapRowMapper; | 
|  | 11 | +import org.springframework.jdbc.core.ConnectionCallback; | 
|  | 12 | +import org.springframework.jdbc.core.JdbcTemplate; | 
|  | 13 | + | 
|  | 14 | +import javax.sql.DataSource; | 
|  | 15 | +import java.sql.Connection; | 
|  | 16 | +import java.sql.ResultSet; | 
|  | 17 | +import java.sql.SQLException; | 
|  | 18 | +import java.util.*; | 
|  | 19 | + | 
|  | 20 | +/** | 
|  | 21 | + * | 
|  | 22 | + * Spring Batch Reader that first queries for all the table names from the given DataSource, and then reads rows from | 
|  | 23 | + * every table. | 
|  | 24 | + * | 
|  | 25 | + * The excludeTableNames property can be used to exclude certain table names from processing. | 
|  | 26 | + * | 
|  | 27 | + * The tableQueries property can be used to specify a custom SELECT query for a particular table name. By default, | 
|  | 28 | + * "SELECT * FROM (table name)" is used. | 
|  | 29 | + */ | 
|  | 30 | +public class AllTablesItemReader extends AbstractItemStreamItemReader<Map<String, Object>> { | 
|  | 31 | + | 
|  | 32 | +    public final static String DEFAULT_TABLE_NAME_KEY = "_tableName"; | 
|  | 33 | + | 
|  | 34 | +    protected final Logger logger = LoggerFactory.getLogger(getClass()); | 
|  | 35 | + | 
|  | 36 | +    private DataSource dataSource; | 
|  | 37 | +    private List<String> tableNames; | 
|  | 38 | +    private Map<String, JdbcCursorItemReader> tableReaders; | 
|  | 39 | +    private int tableNameIndex = 0; | 
|  | 40 | +    private String tableNameKey = DEFAULT_TABLE_NAME_KEY; | 
|  | 41 | +    private String databaseVendor = ""; | 
|  | 42 | + | 
|  | 43 | +    // For ignoring certain table names | 
|  | 44 | +    private Set<String> excludeTableNames; | 
|  | 45 | + | 
|  | 46 | +    // For using a custom SQL query for a given table name | 
|  | 47 | +    private Map<String, String> tableQueries; | 
|  | 48 | + | 
|  | 49 | +    public AllTablesItemReader(DataSource dataSource) { | 
|  | 50 | +        this.dataSource = dataSource; | 
|  | 51 | +    } | 
|  | 52 | + | 
|  | 53 | +    public AllTablesItemReader(DataSource dataSource, String databaseVendor) { | 
|  | 54 | +        this.databaseVendor = databaseVendor; | 
|  | 55 | +        this.dataSource = dataSource; | 
|  | 56 | +    } | 
|  | 57 | + | 
|  | 58 | +    /** | 
|  | 59 | +     * Use the DataSource to get a list of all the tables. Then, create a JdbcCursorItemReader for every table with | 
|  | 60 | +     * a SQL query of "SELECT * FROM (table)". Put each of those into a List. Set a RowColumnRowMapper on each | 
|  | 61 | +     * JdbcCursorItemReader so that every row is read as a ColumnMap. | 
|  | 62 | +     */ | 
|  | 63 | +    @Override | 
|  | 64 | +    public void open(ExecutionContext executionContext) throws ItemStreamException { | 
|  | 65 | +        tableNames = getTableNames(); | 
|  | 66 | +        tableReaders = new HashMap<>(); | 
|  | 67 | +        for (String tableName : tableNames) { | 
|  | 68 | +            tableReaders.put(tableName, buildTableReader(tableName, executionContext)); | 
|  | 69 | +        } | 
|  | 70 | +    } | 
|  | 71 | + | 
|  | 72 | +    /** | 
|  | 73 | +     * Reads a row from the active JdbcCursorItemReader. If that returns null, move on to the next JdbcCursorItemReader. | 
|  | 74 | +     * If there are no more readers, then this method is all done. | 
|  | 75 | +     */ | 
|  | 76 | +    @Override | 
|  | 77 | +    public Map<String, Object> read() throws Exception { | 
|  | 78 | +        final String currentTableName = tableNames.get(tableNameIndex); | 
|  | 79 | +        JdbcCursorItemReader<Map<String, Object>> reader = tableReaders.get(currentTableName); | 
|  | 80 | +        Map<String, Object> result = reader.read(); | 
|  | 81 | +        if (result != null) { | 
|  | 82 | +            result.put("_tableName", currentTableName); | 
|  | 83 | +            return result; | 
|  | 84 | +        } | 
|  | 85 | + | 
|  | 86 | +        if (logger.isInfoEnabled()) { | 
|  | 87 | +            logger.info("Finished reading rows for query: " + reader.getSql()); | 
|  | 88 | +        } | 
|  | 89 | +        reader.close(); | 
|  | 90 | + | 
|  | 91 | +        // Bump up index - if we're at the end of the list, we're all done | 
|  | 92 | +        tableNameIndex++; | 
|  | 93 | +        if (tableNameIndex >= tableNames.size()) { | 
|  | 94 | +            return null; | 
|  | 95 | +        } | 
|  | 96 | + | 
|  | 97 | +        return read(); | 
|  | 98 | +    } | 
|  | 99 | + | 
|  | 100 | +    /** | 
|  | 101 | +     * Register a custom SQL query for selecting data from the given table name. By default, a query of | 
|  | 102 | +     * "SELECT * FROM (table name)" is used. | 
|  | 103 | +     * | 
|  | 104 | +     * @param tableName | 
|  | 105 | +     * @param sql | 
|  | 106 | +     */ | 
|  | 107 | +    public void addTableQuery(String tableName, String sql) { | 
|  | 108 | +        if (tableQueries == null) { | 
|  | 109 | +            tableQueries = new HashMap<>(); | 
|  | 110 | +        } | 
|  | 111 | +        tableQueries.put(tableName, sql); | 
|  | 112 | +    } | 
|  | 113 | + | 
|  | 114 | +    /** | 
|  | 115 | +     * @return a list of the table names, retrieved via the connection metadata object. The excludeTableNames | 
|  | 116 | +     * property can be used to ignore certain table names. | 
|  | 117 | +     */ | 
|  | 118 | +    protected List<String> getTableNames() { | 
|  | 119 | +        return new JdbcTemplate(dataSource).execute(new ConnectionCallback<List<String>>() { | 
|  | 120 | +            @Override | 
|  | 121 | +            public List<String> doInConnection(Connection con) throws SQLException, DataAccessException { | 
|  | 122 | +                ResultSet rs = con.getMetaData().getTables(null, null, "%", new String[]{"TABLE"}); | 
|  | 123 | +                List<String> list = new ArrayList<>(); | 
|  | 124 | +                while (rs.next()) { | 
|  | 125 | +                    String name = rs.getString("TABLE_NAME"); | 
|  | 126 | +                    if (excludeTableNames == null || !excludeTableNames.contains(name)) { | 
|  | 127 | +                        list.add(name); | 
|  | 128 | +                    } | 
|  | 129 | +                } | 
|  | 130 | +                return list; | 
|  | 131 | +            } | 
|  | 132 | +        }); | 
|  | 133 | +    } | 
|  | 134 | + | 
|  | 135 | +    /** | 
|  | 136 | +     * @param tableName | 
|  | 137 | +     * @param executionContext | 
|  | 138 | +     * @return a JdbcCursorItemReader for the given table name. Override this method to alert the SQL statement that's | 
|  | 139 | +     * used for a particular table. | 
|  | 140 | +     */ | 
|  | 141 | +    protected JdbcCursorItemReader<Map<String, Object>> buildTableReader(String tableName, ExecutionContext executionContext) { | 
|  | 142 | +        JdbcCursorItemReader<Map<String, Object>> reader = new JdbcCursorItemReader<>(); | 
|  | 143 | +        reader.setDataSource(dataSource); | 
|  | 144 | +        reader.setRowMapper(new ColumnMapRowMapper()); | 
|  | 145 | +        reader.setSql(getSqlQueryForTable(tableName)); | 
|  | 146 | +        reader.open(executionContext); | 
|  | 147 | +        return reader; | 
|  | 148 | +    } | 
|  | 149 | + | 
|  | 150 | +    /** | 
|  | 151 | +     * Uses the tableQueries property to see if there's a custom SQL query for the given table name. | 
|  | 152 | +     * | 
|  | 153 | +     * @param tableName | 
|  | 154 | +     * @return | 
|  | 155 | +     */ | 
|  | 156 | +    protected String getSqlQueryForTable(String tableName) { | 
|  | 157 | +        String sql = null; | 
|  | 158 | +        if (tableQueries != null) { | 
|  | 159 | +            sql = tableQueries.get(tableName); | 
|  | 160 | +        } | 
|  | 161 | +        if (tableName.contains(" ") && "MICROSOFT".equals(databaseVendor.toUpperCase())) { | 
|  | 162 | +            tableName = "[" + tableName + "]"; | 
|  | 163 | +        } | 
|  | 164 | +        return sql != null ? sql : "SELECT * FROM " + tableName; | 
|  | 165 | +    } | 
|  | 166 | + | 
|  | 167 | +    public void setExcludeTableNames(Set<String> excludeTableNames) { | 
|  | 168 | +        this.excludeTableNames = excludeTableNames; | 
|  | 169 | +    } | 
|  | 170 | + | 
|  | 171 | +    public void setTableQueries(Map<String, String> tableQueries) { | 
|  | 172 | +        this.tableQueries = tableQueries; | 
|  | 173 | +    } | 
|  | 174 | + | 
|  | 175 | +    public void setTableNameKey(String tableNameKey) { | 
|  | 176 | +        this.tableNameKey = tableNameKey; | 
|  | 177 | +    } | 
|  | 178 | + | 
|  | 179 | +    public void setDatabaseVendor(String databaseVendor) { | 
|  | 180 | +        this.databaseVendor = databaseVendor; | 
|  | 181 | +    } | 
|  | 182 | +} | 
0 commit comments