Skip to content

Commit

Permalink
repo-sync-2024-10-23T11:09:21+0800
Browse files Browse the repository at this point in the history
  • Loading branch information
YanZhuangz committed Oct 23, 2024
1 parent 3b8bcc3 commit 609533b
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 78 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,12 @@
import org.secretflow.dataproxy.common.exceptions.DataproxyException;
import org.secretflow.dataproxy.common.model.datasource.conn.OdpsConnConfig;
import org.secretflow.dataproxy.common.model.datasource.location.OdpsTableInfo;
import org.secretflow.dataproxy.common.utils.JsonUtils;
import org.secretflow.dataproxy.manager.DataWriter;

import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;

/**
Expand All @@ -72,6 +74,7 @@ public class OdpsDataWriter implements DataWriter {

private boolean isPartitioned = false;

private TableSchema odpsTableSchema = null;
private TableTunnel.UploadSession uploadSession = null;
private RecordWriter recordWriter = null;

Expand Down Expand Up @@ -152,15 +155,21 @@ private void initOdps() throws OdpsException, IOException {
// init odps client
Odps odps = initOdpsClient(this.connConfig);
// Pre-processing
preProcessing(odps, connConfig.getProjectName(), tableInfo.tableName(), this.convertToPartitionSpec(tableInfo.partitionSpec()));
PartitionSpec convertPartitionSpec = this.convertToPartitionSpec(tableInfo.partitionSpec());
preProcessing(odps, connConfig.getProjectName(), tableInfo.tableName(), convertPartitionSpec);
// init upload session
TableTunnel tunnel = new TableTunnel(odps);

if (isPartitioned) {
if (tableInfo.partitionSpec() == null || tableInfo.partitionSpec().isEmpty()) {
throw DataproxyException.of(DataproxyErrorCode.INVALID_PARTITION_SPEC, "partitionSpec is empty");
}
PartitionSpec partitionSpec = new PartitionSpec(tableInfo.partitionSpec());
assert this.odpsTableSchema != null;
List<Column> partitionColumns = this.odpsTableSchema.getPartitionColumns();
PartitionSpec partitionSpec = new PartitionSpec();
for (Column partitionColumn : partitionColumns) {
partitionSpec.set(partitionColumn.getName(), convertPartitionSpec.get(partitionColumn.getName()));
}
uploadSession = tunnel.createUploadSession(connConfig.getProjectName(), tableInfo.tableName(), partitionSpec, overwrite);
} else {
uploadSession = tunnel.createUploadSession(connConfig.getProjectName(), tableInfo.tableName(), overwrite);
Expand Down Expand Up @@ -271,7 +280,10 @@ private void preProcessing(Odps odps, String projectName, String tableName, Part
} else {
log.info("odps table is exists, project: {}, table name: {}", projectName, tableName);
}
isPartitioned = odps.tables().get(projectName, tableName).isPartitioned();

Table table = odps.tables().get(projectName, tableName);
isPartitioned = table.isPartitioned();
this.setOdpsTableSchemaIfAbsent(table.getSchema());

if (isPartitioned) {
if (partitionSpec == null || partitionSpec.isEmpty()) {
Expand Down Expand Up @@ -334,8 +346,29 @@ private boolean createOdpsTable(Odps odps, String projectName, String tableName,
TableSchema tableSchema = convertToTableSchema(schema);
if (partitionSpec != null) {
// Infer partitioning field type as string.
partitionSpec.keys().forEach(key -> tableSchema.addPartitionColumn(Column.newBuilder(key, TypeInfoFactory.STRING).build()));
List<Column> tableSchemaColumns = tableSchema.getColumns();
List<Integer> partitionColumnIndexes = new ArrayList<>();
ArrayList<Column> partitionColumns = new ArrayList<>();

for (String key : partitionSpec.keys()) {
if (tableSchema.containsColumn(key)) {
log.info("tableSchemaColumns contains partition column: {}", key);
partitionColumnIndexes.add(tableSchema.getColumnIndex(key));
partitionColumns.add(tableSchema.getColumn(key));
} else {
log.info("tableSchemaColumns not contains partition column: {}", key);
partitionColumns.add(Column.newBuilder(key, TypeInfoFactory.STRING).build());
}
}

for (int i = 0; i < partitionColumnIndexes.size(); i++) {
tableSchemaColumns.remove(partitionColumnIndexes.get(i) - i);
}
log.info("tableSchemaColumns: {}, partitionColumnIndexes: {}", JsonUtils.toString(tableSchemaColumns), JsonUtils.toString(partitionColumnIndexes));
tableSchema.setColumns(tableSchemaColumns);
tableSchema.setPartitionColumns(partitionColumns);
}
log.info("create odps table schema: {}", JsonUtils.toString(tableSchema));
odps.tables().create(projectName, tableName, tableSchema, "", true, null, OdpsUtil.getSqlFlag(), null);
return true;
} catch (Exception e) {
Expand All @@ -355,6 +388,12 @@ private boolean createOdpsPartition(Odps odps, String projectName, String tableN
return false;
}

private void setOdpsTableSchemaIfAbsent(TableSchema tableSchema) {
if (odpsTableSchema == null) {
this.odpsTableSchema = tableSchema;
}
}

private TableSchema convertToTableSchema(Schema schema) {
List<Column> columns = schema.getFields().stream().map(this::convertToColumn).toList();
return TableSchema.builder().withColumns(columns).build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,10 @@
import org.secretflow.dataproxy.common.model.datasource.location.JdbcLocationConfig;
import org.secretflow.dataproxy.common.utils.JsonUtils;
import org.secretflow.dataproxy.manager.DataWriter;
import org.secretflow.dataproxy.manager.connector.rdbms.adaptor.JdbcParameterBinder;

import java.io.IOException;
import java.sql.*;
import java.util.Arrays;
import java.sql.Connection;
import java.sql.SQLException;
import java.util.List;

/**
Expand Down Expand Up @@ -90,7 +89,9 @@ protected void initialize(Schema schema) {
log.info("[JdbcDataWriter] preSql execute start, sql: {}", JsonUtils.toJSONString(preSqlList));

try (Connection conn = this.jdbcAssistant.getDatabaseConn(dataSource)) {
executePreWorkSqls(conn, preSqlList);
// do nothing
// Avoid SQL injection issues
// About to Delete
} catch (SQLException e) {
throw DataproxyException.of(DataproxyErrorCode.JDBC_CREATE_TABLE_FAILED, e.getMessage(), e);
}
Expand All @@ -103,61 +104,7 @@ protected void initialize(Schema schema) {

@Override
public void write(VectorSchemaRoot root) throws IOException {
ensureInitialized(root.getSchema());

// 每次直接发送,不积攒
final int rowCount = root.getRowCount();
int recordCount = 0;

try (Connection conn = this.jdbcAssistant.getDatabaseConn(dataSource)) {
boolean finished = false;

if (this.jdbcAssistant.supportBatchInsert()) {
try (PreparedStatement preparedStatement = conn.prepareStatement(this.stmt)) {
if (rowCount != 0) {
final JdbcParameterBinder binder = JdbcParameterBinder.builder(preparedStatement, root).bindAll().build();
while (binder.next()) {
preparedStatement.addBatch();
}
int[] recordCounts = preparedStatement.executeBatch();
recordCount = Arrays.stream(recordCounts).sum();
}
finished = true;
} catch (Exception e) {
log.warn("[JdbcDataWriter] prepare batch write error, then dp will try to generate integral insert sql, stmt:{}", this.stmt, e);
}
}

// 不支持prepare模式,需要构造完整insert语句
//insert into `default`.`test_table`(`int32`,`float64`,`string`) values(?,?,?)
if (!finished) {
String insertSql = null;
List<JDBCType> jdbcTypes = root.getFieldVectors().stream()
.map(vector -> this.jdbcAssistant.arrowTypeToJdbcType(vector.getField()))
.toList();

try (Statement statement = conn.createStatement()) {
// 数据逐行写入
for (int row = 0; row < root.getRowCount(); row++) {
String[] values = new String[root.getFieldVectors().size()];
for (int col = 0; col < root.getFieldVectors().size(); col++) {
values[col] = this.jdbcAssistant.serialize(jdbcTypes.get(col), root.getVector(col).getObject(row));
}

insertSql = String.format(this.stmt.replace("?", "%s"), (Object[]) values);
statement.execute(insertSql);
}
} catch (Exception e) {
log.error("[JdbcDataWriter] integral insert sql error, sql:{}", insertSql, e);
throw e;
}
}

log.info("[JdbcDataWriter] jdbc batch write success, record count:{}, table:{}", recordCount, this.composeTableName);
} catch (Exception e) {
log.error("[JdbcDataWriter] jdbc batch write failed, table:{}", this.composeTableName);
throw DataproxyException.of(DataproxyErrorCode.JDBC_INSERT_INTO_TABLE_FAILED, e);
}
throw DataproxyException.of(DataproxyErrorCode.JDBC_INSERT_INTO_TABLE_FAILED, "jdbc not support write");
}

@Override
Expand All @@ -179,15 +126,4 @@ public void close() throws Exception {
} catch (Exception ignored) {
}
}

void executePreWorkSqls(Connection conn, List<String> preWorkSqls) throws SQLException {
for (String sql : preWorkSqls) {
try (Statement statement = conn.createStatement()) {
statement.execute(sql);
} catch (SQLException e) {
log.error("[SinkJdbcHandler] 数据转移前预先执行SQL失败:{}", sql);
throw e;
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -296,8 +296,17 @@ public void getStreamReadData(CallContext context, Ticket ticket, ServerStreamLi
log.info("[getStreamReadData] parse command from ticket success, command:{}", JsonUtils.toJSONString(command));
try (ArrowReader arrowReader = dataProxyService.generateArrowReader(rootAllocator, (DatasetReadCommand) command.getCommandInfo())) {
listener.start(arrowReader.getVectorSchemaRoot());
while (arrowReader.loadNextBatch()) {
listener.putNext();

while (true) {
if (context.isCancelled()) {
log.warn("[getStreamReadData] get stream cancelled");
break;
}
if (arrowReader.loadNextBatch()) {
listener.putNext();
} else {
break;
}
}
listener.completed();
log.info("[getStreamReadData] get stream completed");
Expand Down
2 changes: 1 addition & 1 deletion dataproxy_sdk/python/dataproxy/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@
# limitations under the License.


__version__ = "0.2.0.dev$$DATE$$"
__version__ = "0.2.0b0"
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@
<dependency>
<groupId>commons-io</groupId>
<artifactId>commons-io</artifactId>
<version>2.11.0</version>
<version>2.14.0</version>
</dependency>
<dependency>
<groupId>com.opencsv</groupId>
Expand Down

0 comments on commit 609533b

Please sign in to comment.