From 4fe5411cf1f0e13e9d2ece41218b4d4fdcab5ad4 Mon Sep 17 00:00:00 2001 From: SuperMarz Date: Fri, 11 Oct 2024 15:57:32 +0800 Subject: [PATCH 1/2] repo-sync-2024-10-11T15:57:26+0800 --- .circleci/config.yml | 1 - .../common/exceptions/DataproxyErrorCode.java | 10 +- .../manager/connector/odps/OdpsConnector.java | 4 +- .../connector/odps/OdpsDataWriter.java | 132 +++++++++-- .../connector/odps/OdpsSplitArrowReader.java | 207 ++++++++++++++---- .../manager/connector/odps/OdpsUtil.java | 10 + .../impl/DataProxyServiceDirectImpl.java | 7 +- dataproxy_sdk/cc/data_proxy_file.cc | 1 + .../python/dataproxy/dp_file_adapter.py | 8 + pom.xml | 8 + 10 files changed, 327 insertions(+), 61 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index cc3be21..ae98d8f 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -78,4 +78,3 @@ workflows: .bazelrc sdk-build-and-run true WORKSPACE sdk-build-and-run true - lint - diff --git a/dataproxy-common/src/main/java/org/secretflow/dataproxy/common/exceptions/DataproxyErrorCode.java b/dataproxy-common/src/main/java/org/secretflow/dataproxy/common/exceptions/DataproxyErrorCode.java index ccb7114..b19a2b6 100644 --- a/dataproxy-common/src/main/java/org/secretflow/dataproxy/common/exceptions/DataproxyErrorCode.java +++ b/dataproxy-common/src/main/java/org/secretflow/dataproxy/common/exceptions/DataproxyErrorCode.java @@ -84,7 +84,15 @@ public enum DataproxyErrorCode { // odps 异常 ODPS_CREATE_TABLE_FAILED(ErrorLevels.ERROR, ErrorTypes.BIZ, "600", "Create ODPS table failed"), - ODPS_ERROR(ErrorLevels.ERROR, ErrorTypes.BIZ, "601", "ODPS error"), + ODPS_CREATE_PARTITION_FAILED(ErrorLevels.ERROR, ErrorTypes.BIZ, "601", "Create ODPS table failed"), + ODPS_ERROR(ErrorLevels.ERROR, ErrorTypes.BIZ, "602", "ODPS error"), + ODPS_TABLE_ALREADY_EXISTS(ErrorLevels.ERROR, ErrorTypes.BIZ, "603", "odps table already exists"), + ODPS_TABLE_NOT_EXISTS(ErrorLevels.ERROR, ErrorTypes.BIZ, "604", "odps table not exists"), + ODPS_PARTITION_ALREADY_EXISTS(ErrorLevels.ERROR, ErrorTypes.BIZ, "605", "odps partition already exists"), + ODPS_PARTITION_NOT_EXISTS(ErrorLevels.ERROR, ErrorTypes.BIZ, "606", "odps partition not exists"), + ODPS_TABLE_NOT_EMPTY(ErrorLevels.ERROR, ErrorTypes.BIZ, "607", "odps table not empty"), + ODPS_TABLE_NOT_SUPPORT_PARTITION(ErrorLevels.ERROR, ErrorTypes.BIZ, "608", "odps table not support partition"), + //============================= 第三方错误【900-999】================================== diff --git a/dataproxy-manager/src/main/java/org/secretflow/dataproxy/manager/connector/odps/OdpsConnector.java b/dataproxy-manager/src/main/java/org/secretflow/dataproxy/manager/connector/odps/OdpsConnector.java index 6e12feb..39753d6 100644 --- a/dataproxy-manager/src/main/java/org/secretflow/dataproxy/manager/connector/odps/OdpsConnector.java +++ b/dataproxy-manager/src/main/java/org/secretflow/dataproxy/manager/connector/odps/OdpsConnector.java @@ -15,7 +15,7 @@ */ package org.secretflow.dataproxy.manager.connector.odps; -import com.aliyun.odps.tunnel.TunnelException; +import com.aliyun.odps.OdpsException; import org.apache.arrow.memory.BufferAllocator; import org.secretflow.dataproxy.common.model.InferSchemaResult; import org.secretflow.dataproxy.common.model.command.DatasetReadCommand; @@ -89,7 +89,7 @@ public DataWriter buildWriter(DatasetWriteCommand writeCommand) { if (Objects.equals(DatasetFormatTypeEnum.TABLE, writeCommand.getFormatConfig().getType())) { try { return new OdpsDataWriter(config, locationConfig, writeCommand.getSchema()); - } catch (TunnelException | IOException e) { + } catch (IOException | OdpsException e) { throw new RuntimeException(e); } } diff --git a/dataproxy-manager/src/main/java/org/secretflow/dataproxy/manager/connector/odps/OdpsDataWriter.java b/dataproxy-manager/src/main/java/org/secretflow/dataproxy/manager/connector/odps/OdpsDataWriter.java index 0f10919..aeab788 100644 --- a/dataproxy-manager/src/main/java/org/secretflow/dataproxy/manager/connector/odps/OdpsDataWriter.java +++ b/dataproxy-manager/src/main/java/org/secretflow/dataproxy/manager/connector/odps/OdpsDataWriter.java @@ -18,22 +18,25 @@ import com.aliyun.odps.Column; import com.aliyun.odps.Odps; +import com.aliyun.odps.OdpsException; import com.aliyun.odps.OdpsType; import com.aliyun.odps.PartitionSpec; +import com.aliyun.odps.Table; import com.aliyun.odps.TableSchema; import com.aliyun.odps.data.Record; import com.aliyun.odps.data.RecordWriter; import com.aliyun.odps.tunnel.TableTunnel; -import com.aliyun.odps.tunnel.TunnelException; import com.aliyun.odps.type.TypeInfo; import com.aliyun.odps.type.TypeInfoFactory; import lombok.extern.slf4j.Slf4j; import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.BitVector; import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.Float4Vector; import org.apache.arrow.vector.Float8Vector; import org.apache.arrow.vector.IntVector; import org.apache.arrow.vector.SmallIntVector; +import org.apache.arrow.vector.TinyIntVector; import org.apache.arrow.vector.VarCharVector; import org.apache.arrow.vector.VectorSchemaRoot; import org.apache.arrow.vector.types.pojo.ArrowType; @@ -67,12 +70,12 @@ public class OdpsDataWriter implements DataWriter { private final boolean overwrite = true; - private boolean isTemporarilyCreatedTable = false; + private boolean isPartitioned = false; private TableTunnel.UploadSession uploadSession = null; private RecordWriter recordWriter = null; - public OdpsDataWriter(OdpsConnConfig connConfig, OdpsTableInfo tableInfo, Schema schema) throws TunnelException, IOException { + public OdpsDataWriter(OdpsConnConfig connConfig, OdpsTableInfo tableInfo, Schema schema) throws OdpsException, IOException { this.connConfig = connConfig; this.tableInfo = tableInfo; this.schema = schema; @@ -96,7 +99,8 @@ record = uploadSession.newRecord(); for (int columnIndex = 0; columnIndex < columnCount; columnIndex++) { log.debug("column: {}, type: {}", columnIndex, root.getFieldVectors().get(columnIndex).getField().getType()); - columnName = root.getVector(columnIndex).getField().getName(); + // odps column name is lower case + columnName = root.getVector(columnIndex).getField().getName().toLowerCase(); if (tableSchema.containsColumn(columnName)) { this.setRecordValue(record, tableSchema.getColumnIndex(columnName), this.getValue(root.getFieldVectors().get(columnIndex), rowIndex)); @@ -144,21 +148,25 @@ private Odps initOdpsClient(OdpsConnConfig odpsConnConfig) { return OdpsUtil.buildOdps(odpsConnConfig); } - private void initOdps() throws TunnelException, IOException { + private void initOdps() throws OdpsException, IOException { // init odps client Odps odps = initOdpsClient(this.connConfig); // Pre-processing - preProcessing(odps, connConfig.getProjectName(), tableInfo.tableName()); + preProcessing(odps, connConfig.getProjectName(), tableInfo.tableName(), this.convertToPartitionSpec(tableInfo.partitionSpec())); // init upload session TableTunnel tunnel = new TableTunnel(odps); - if (tableInfo.partitionSpec() != null && !tableInfo.partitionSpec().isEmpty() && !isTemporarilyCreatedTable) { + + if (isPartitioned) { + if (tableInfo.partitionSpec() == null || tableInfo.partitionSpec().isEmpty()) { + throw DataproxyException.of(DataproxyErrorCode.INVALID_PARTITION_SPEC, "partitionSpec is empty"); + } PartitionSpec partitionSpec = new PartitionSpec(tableInfo.partitionSpec()); uploadSession = tunnel.createUploadSession(connConfig.getProjectName(), tableInfo.tableName(), partitionSpec, overwrite); } else { uploadSession = tunnel.createUploadSession(connConfig.getProjectName(), tableInfo.tableName(), overwrite); } - recordWriter = uploadSession.openRecordWriter(0); + recordWriter = uploadSession.openRecordWriter(0, true); } /** @@ -173,6 +181,7 @@ private void initOdps() throws TunnelException, IOException { private void setRecordValue(Record record, int columnIndex, Object value) { if (value == null) { record.set(columnIndex, null); + log.warn("table name: {} record set null value. index: {}", tableInfo.tableName(), columnIndex); return; } @@ -184,8 +193,11 @@ private void setRecordValue(Record record, int columnIndex, Object value) { case STRING -> record.setString(columnIndex, String.valueOf(value)); case FLOAT -> record.set(columnIndex, Float.parseFloat(String.valueOf(value))); case DOUBLE -> record.set(columnIndex, Double.parseDouble(String.valueOf(value))); + case TINYINT -> record.set(columnIndex, Byte.parseByte(String.valueOf(value))); + case SMALLINT -> record.set(columnIndex, Short.parseShort(String.valueOf(value))); case BIGINT -> record.set(columnIndex, Long.parseLong(String.valueOf(value))); case INT -> record.set(columnIndex, Integer.parseInt(String.valueOf(value))); + case BOOLEAN -> record.setBoolean(columnIndex, (Boolean) value); default -> record.set(columnIndex, value); } } @@ -205,23 +217,32 @@ private Object getValue(FieldVector fieldVector, int index) { switch (arrowTypeID) { case Int -> { - if (fieldVector instanceof IntVector || fieldVector instanceof BigIntVector || fieldVector instanceof SmallIntVector) { + if (fieldVector instanceof IntVector || fieldVector instanceof BigIntVector || fieldVector instanceof SmallIntVector || fieldVector instanceof TinyIntVector) { return fieldVector.getObject(index); } + log.warn("Type INT is not IntVector or BigIntVector or SmallIntVector or TinyIntVector, value is: {}", fieldVector.getObject(index).toString()); } case FloatingPoint -> { if (fieldVector instanceof Float4Vector | fieldVector instanceof Float8Vector) { return fieldVector.getObject(index); } + log.warn("Type FloatingPoint is not Float4Vector or Float8Vector, value is: {}", fieldVector.getObject(index).toString()); } case Utf8 -> { if (fieldVector instanceof VarCharVector vector) { return new String(vector.get(index), StandardCharsets.UTF_8); } + log.warn("Type Utf8 is not VarCharVector, value is: {}", fieldVector.getObject(index).toString()); } case Null -> { return null; } + case Bool -> { + if (fieldVector instanceof BitVector vector) { + return vector.get(index) == 1; + } + log.warn("Type BOOL is not BitVector, value is: {}", fieldVector.getObject(index).toString()); + } default -> { log.warn("Not implemented type: {}, will use default function", arrowTypeID); return fieldVector.getObject(index); @@ -239,16 +260,35 @@ private Object getValue(FieldVector fieldVector, int index) { * @param projectName project name * @param tableName table name */ - private void preProcessing(Odps odps, String projectName, String tableName) { + private void preProcessing(Odps odps, String projectName, String tableName, PartitionSpec partitionSpec) throws OdpsException { if (!isExistsTable(odps, projectName, tableName)) { - boolean odpsTable = createOdpsTable(odps, projectName, tableName, schema); + boolean odpsTable = createOdpsTable(odps, projectName, tableName, schema, partitionSpec); if (!odpsTable) { throw DataproxyException.of(DataproxyErrorCode.ODPS_CREATE_TABLE_FAILED); } - isTemporarilyCreatedTable = true; + log.info("odps table is not exists, create table successful, project: {}, table name: {}", projectName, tableName); + } else { + log.info("odps table is exists, project: {}, table name: {}", projectName, tableName); + } + isPartitioned = odps.tables().get(projectName, tableName).isPartitioned(); + + if (isPartitioned) { + if (partitionSpec == null || partitionSpec.isEmpty()) { + throw DataproxyException.of(DataproxyErrorCode.INVALID_PARTITION_SPEC, "partitionSpec is empty"); + } + + if (!isExistsPartition(odps, projectName, tableName, partitionSpec)) { + boolean odpsPartition = createOdpsPartition(odps, projectName, tableName, partitionSpec); + if (!odpsPartition) { + throw DataproxyException.of(DataproxyErrorCode.ODPS_CREATE_PARTITION_FAILED); + } + log.info("odps partition is not exists, create partition successful, project: {}, table name: {}, PartitionSpec: {}", projectName, tableName, partitionSpec); + } else { + log.info("odps partition is exists, project: {}, table name: {}, PartitionSpec: {}", projectName, tableName, partitionSpec); + } + } - log.info("odps table is exists or create table successful, project: {}, table name: {}", projectName, tableName); } /** @@ -268,9 +308,35 @@ private boolean isExistsTable(Odps odps, String projectName, String tableName) { return false; } - private boolean createOdpsTable(Odps odps, String projectName, String tableName, Schema schema) { + private boolean isExistsPartition(Odps odps, String projectName, String tableName, PartitionSpec partitionSpec) throws OdpsException { + Table table = odps.tables().get(projectName, tableName); + + if (table == null) { + log.warn("table is null, projectName:{}, tableName:{}", projectName, tableName); + throw DataproxyException.of(DataproxyErrorCode.ODPS_TABLE_NOT_EXISTS); + } + + return table.hasPartition(partitionSpec); + } + + /** + * create odps table + * + * @param odps odps client + * @param projectName project name + * @param tableName table name + * @param schema schema + * @param partitionSpec partition spec + * @return true or false + */ + private boolean createOdpsTable(Odps odps, String projectName, String tableName, Schema schema, PartitionSpec partitionSpec) { try { - odps.tables().create(projectName, tableName, convertToTableSchema(schema), true); + TableSchema tableSchema = convertToTableSchema(schema); + if (partitionSpec != null) { + // Infer partitioning field type as string. + partitionSpec.keys().forEach(key -> tableSchema.addPartitionColumn(Column.newBuilder(key, TypeInfoFactory.STRING).build())); + } + odps.tables().create(projectName, tableName, tableSchema, "", true, null, OdpsUtil.getSqlFlag(), null); return true; } catch (Exception e) { log.error("create odps table error, projectName:{}, tableName:{}", projectName, tableName, e); @@ -278,11 +344,36 @@ private boolean createOdpsTable(Odps odps, String projectName, String tableName, return false; } + private boolean createOdpsPartition(Odps odps, String projectName, String tableName, PartitionSpec partitionSpec) { + try { + Table table = odps.tables().get(projectName, tableName); + table.createPartition(partitionSpec, true); + return true; + } catch (Exception e) { + log.error("create odps partition error, projectName:{}, tableName:{}", projectName, tableName, e); + } + return false; + } + private TableSchema convertToTableSchema(Schema schema) { List columns = schema.getFields().stream().map(this::convertToColumn).toList(); return TableSchema.builder().withColumns(columns).build(); } + /** + * convert partition spec + * + * @param partitionSpec partition spec + * @return partition spec + * @throws IllegalArgumentException if partitionSpec is invalid + */ + private PartitionSpec convertToPartitionSpec(String partitionSpec) { + if (partitionSpec == null || partitionSpec.isEmpty()) { + return null; + } + return new PartitionSpec(partitionSpec); + } + private Column convertToColumn(Field field) { return Column.newBuilder(field.getName(), convertToType(field.getType())).build(); } @@ -304,7 +395,13 @@ private TypeInfo convertToType(ArrowType type) { }; } case Int -> { - return TypeInfoFactory.INT; + return switch (((ArrowType.Int) type).getBitWidth()) { + case 8 -> TypeInfoFactory.TINYINT; + case 16 -> TypeInfoFactory.SMALLINT; + case 32 -> TypeInfoFactory.INT; + case 64 -> TypeInfoFactory.BIGINT; + default -> TypeInfoFactory.UNKNOWN; + }; } case Time -> { return TypeInfoFactory.TIMESTAMP; @@ -312,6 +409,9 @@ private TypeInfo convertToType(ArrowType type) { case Date -> { return TypeInfoFactory.DATE; } + case Bool -> { + return TypeInfoFactory.BOOLEAN; + } default -> { log.warn("Not implemented type: {}", arrowTypeID); return TypeInfoFactory.UNKNOWN; diff --git a/dataproxy-manager/src/main/java/org/secretflow/dataproxy/manager/connector/odps/OdpsSplitArrowReader.java b/dataproxy-manager/src/main/java/org/secretflow/dataproxy/manager/connector/odps/OdpsSplitArrowReader.java index 3773757..a1ee214 100644 --- a/dataproxy-manager/src/main/java/org/secretflow/dataproxy/manager/connector/odps/OdpsSplitArrowReader.java +++ b/dataproxy-manager/src/main/java/org/secretflow/dataproxy/manager/connector/odps/OdpsSplitArrowReader.java @@ -15,21 +15,30 @@ */ package org.secretflow.dataproxy.manager.connector.odps; +import com.aliyun.odps.Column; import com.aliyun.odps.Instance; import com.aliyun.odps.Odps; import com.aliyun.odps.OdpsException; -import com.aliyun.odps.TableSchema; +import com.aliyun.odps.PartitionSpec; +import com.aliyun.odps.data.ArrayRecord; import com.aliyun.odps.data.Record; -import com.aliyun.odps.data.ResultSet; import com.aliyun.odps.task.SQLTask; +import com.aliyun.odps.tunnel.InstanceTunnel; +import com.aliyun.odps.tunnel.TunnelException; +import com.aliyun.odps.tunnel.io.TunnelRecordReader; import lombok.extern.slf4j.Slf4j; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.BitVector; import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.FixedWidthVector; import org.apache.arrow.vector.Float4Vector; import org.apache.arrow.vector.Float8Vector; import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.SmallIntVector; +import org.apache.arrow.vector.TinyIntVector; import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.VariableWidthVector; import org.apache.arrow.vector.VectorSchemaRoot; import org.apache.arrow.vector.ipc.ArrowReader; import org.apache.arrow.vector.types.pojo.ArrowType; @@ -43,8 +52,17 @@ import org.secretflow.dataproxy.manager.SplitReader; import java.io.IOException; +import java.util.Arrays; +import java.util.HashSet; import java.util.List; +import java.util.Set; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.LinkedBlockingQueue; import java.util.regex.Pattern; +import java.util.stream.Collectors; /** * odps Table Split Reader @@ -61,14 +79,22 @@ public class OdpsSplitArrowReader extends ArrowReader implements SplitReader, Au private final Schema schema; - private TableSchema tableSchema; + private final int batchSize = 10000; - private final int batchSize = 1000; + private boolean partitioned = false; - private ResultSet resultSet; + private InstanceTunnel.DownloadSession downloadSession; + + private int currentIndex = 0; + + private final Set columns = new HashSet<>(); private final Pattern columnOrValuePattern = Pattern.compile("^[\\u00b7A-Za-z0-9\\u4e00-\\u9fa5\\-_,.]*$"); + private final ExecutorService executorService = Executors.newSingleThreadExecutor(); + + private final LinkedBlockingQueue recordQueue = new LinkedBlockingQueue<>(batchSize); + protected OdpsSplitArrowReader(BufferAllocator allocator, OdpsConnConfig odpsConnConfig, OdpsTableInfo tableInfo, Schema schema) { super(allocator); this.odpsConnConfig = odpsConnConfig; @@ -80,29 +106,82 @@ protected OdpsSplitArrowReader(BufferAllocator allocator, OdpsConnConfig odpsCon public boolean loadNextBatch() throws IOException { VectorSchemaRoot root = getVectorSchemaRoot(); root.clear(); + long resultCount = downloadSession.getRecordCount(); + log.info("Load next batch start, recordCount: {}", resultCount); - ValueVectorUtility.preAllocate(root, batchSize); - Record next; - - int recordCount = 0; - if (!resultSet.hasNext()) { + if (currentIndex >= resultCount) { return false; } - while (resultSet.hasNext()) { - next = resultSet.next(); - if (next != null) { - ValueVectorUtility.ensureCapacity(root, recordCount + 1); - toArrowVector(next, root, recordCount); + int recordCount = 0; + + try (TunnelRecordReader records = downloadSession.openRecordReader(currentIndex, batchSize, true)) { + + Record firstRecord = records.read(); + if (firstRecord != null) { + ValueVectorUtility.preAllocate(root, batchSize); + root.setRowCount(batchSize); + + root.getFieldVectors().forEach(fieldVector -> { + if (fieldVector instanceof FixedWidthVector baseFixedWidthVector) { + baseFixedWidthVector.allocateNew(batchSize); + } else if (fieldVector instanceof VariableWidthVector baseVariableWidthVector){ + baseVariableWidthVector.allocateNew(batchSize * 32); + } + }); + + + Future submitFuture = executorService.submit(() -> { + try { + int takeRecordCount = 0; + + for(;;) { + + Record record = recordQueue.take(); + + if (record instanceof ArrayRecord && record.getColumns().length == 0) { + log.info("recordQueue take record take Count: {}", takeRecordCount); + break; + } + + ValueVectorUtility.ensureCapacity(root, takeRecordCount + 1); + this.toArrowVector(record, root, takeRecordCount); + takeRecordCount++; + + } + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + }); + + columns.addAll(Arrays.stream(firstRecord.getColumns()).map(Column::getName).collect(Collectors.toSet())); + + recordQueue.put(firstRecord); recordCount++; - } + // 使用 #read() 方法迭代读取,将会处理历史的 record 记录的数据,异步时,将读取不到数据,可使用 #clone() 方法,性能差距不大 + for (Record record : records) { + try { + recordQueue.put(record); + recordCount++; + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } - if (recordCount == batchSize) { - root.setRowCount(recordCount); - return true; + recordQueue.put(new ArrayRecord(new Column[0])); + log.info("recordQueue put record Count: {}", recordCount); + + submitFuture.get(); + currentIndex += batchSize; + } else { + log.warn("Read first record is null, maybe it has been read."); } + + } catch (TunnelException | ExecutionException | InterruptedException e) { + throw new RuntimeException(e); } root.setRowCount(recordCount); + log.info("Load next batch success, recordCount: {}", recordCount); return true; } @@ -113,7 +192,7 @@ public long bytesRead() { @Override protected void closeReadSource() throws IOException { - + executorService.shutdownNow(); } @Override @@ -127,23 +206,21 @@ public ArrowReader startRead() { Odps odps = OdpsUtil.buildOdps(odpsConnConfig); String sql = ""; try { + partitioned = odps.tables().get(odpsConnConfig.getProjectName(), tableInfo.tableName()).isPartitioned(); + sql = this.buildSql(tableInfo.tableName(), tableInfo.fields(), tableInfo.partitionSpec()); - log.debug("SQLTask run sql: {}", sql); + Instance instance = SQLTask.run(odps, odpsConnConfig.getProjectName(), sql, OdpsUtil.getSqlFlag(), null); - Instance instance = SQLTask.run(odps, sql); + log.info("SQLTask run start, sql: {}", sql); // 等待任务完成 instance.waitForSuccess(); + log.info("SQLTask run success, sql: {}", sql); - resultSet = SQLTask.getResultSet(instance); - - tableSchema = resultSet.getTableSchema(); + downloadSession = new InstanceTunnel(odps).createDownloadSession(odps.getDefaultProject(), instance.getId(), false); } catch (OdpsException e) { log.error("SQLTask run error, sql: {}", sql, e); throw DataproxyException.of(DataproxyErrorCode.ODPS_ERROR, e.getMessage(), e); - } catch (IOException e) { - log.error("startRead error, sql: {}", sql, e); - throw new RuntimeException(e); } return this; @@ -155,6 +232,32 @@ private String buildSql(String tableName, List fields, String whereClaus throw DataproxyException.of(DataproxyErrorCode.PARAMS_UNRELIABLE, "Invalid tableName:" + tableName); } + // 普通表不再拼接条件语句 + if (!partitioned) { + whereClause = ""; + } + //TODO: 条件判断逻辑调整 + if (!whereClause.isEmpty()) { + String[] groups = whereClause.split("[,/]"); + if (groups.length > 1) { + final PartitionSpec partitionSpec = new PartitionSpec(whereClause); + + for (String key : partitionSpec.keys()) { + if (!columnOrValuePattern.matcher(key).matches()) { + throw DataproxyException.of(DataproxyErrorCode.PARAMS_UNRELIABLE, "Invalid partition key:" + key); + } + if (!columnOrValuePattern.matcher(partitionSpec.get(key)).matches()) { + throw DataproxyException.of(DataproxyErrorCode.PARAMS_UNRELIABLE, "Invalid partition value:" + partitionSpec.get(key)); + } + } + + List list = partitionSpec.keys().stream().map(k -> k + "='" + partitionSpec.get(k) + "'").toList(); + whereClause = String.join(" and ", list); + } + } + + log.info("whereClause: {}", whereClause); + return "select " + String.join(",", fields) + " from " + tableName + (whereClause.isEmpty() ? "" : " where " + whereClause) + ";"; } @@ -164,26 +267,39 @@ private void toArrowVector(Record record, VectorSchemaRoot root, int rowIndex) { for (Field field : schema.getFields()) { vector = root.getVector(field); if (vector != null) { - columnName = field.getName(); - if (tableSchema.containsColumn(columnName)) { - setValue(vector.getField().getType(), vector, rowIndex, record, columnName); - vector.setValueCount(rowIndex + 1); + // odps 获取到的字段名为小写,此处做一下兼容 + columnName = field.getName().toLowerCase(); + + if (this.hasColumn(columnName)) { + this.setValue(vector.getField().getType(), vector, rowIndex, record, columnName); } } } } + private boolean hasColumn(String columnName) { + return columns.contains(columnName); + } + private void setValue(ArrowType type, FieldVector vector, int rowIndex, Record record, String columnName) { - log.debug("columnName: {} type ID: {}, value: {}", columnName, type.getTypeID(), record.get(columnName)); - if (record.get(columnName) == null) { + Object columnValue = record.get(columnName); + log.debug("columnName: {} type ID: {}, index:{}, value: {}", columnName, type.getTypeID(), rowIndex, columnValue); + + if (columnValue == null) { + vector.setNull(rowIndex); +// log.warn("set null, columnName: {} type ID: {}, index:{}, value: {}", columnName, type.getTypeID(), rowIndex, record); return; } switch (type.getTypeID()) { case Int -> { - if (vector instanceof IntVector intVector) { - intVector.setSafe(rowIndex, Integer.parseInt(record.get(columnName).toString())); + if (vector instanceof SmallIntVector smallIntVector) { + smallIntVector.set(rowIndex, Short.parseShort(columnValue.toString())); + } else if (vector instanceof IntVector intVector) { + intVector.set(rowIndex, Integer.parseInt(columnValue.toString())); } else if (vector instanceof BigIntVector bigIntVector) { - bigIntVector.setSafe(rowIndex, Long.parseLong(record.get(columnName).toString())); + bigIntVector.set(rowIndex, Long.parseLong(columnValue.toString())); + } else if (vector instanceof TinyIntVector tinyIntVector) { + tinyIntVector.set(rowIndex, Byte.parseByte(columnValue.toString())); } else { log.warn("Unsupported type: {}", type); } @@ -198,13 +314,26 @@ private void setValue(ArrowType type, FieldVector vector, int rowIndex, Record r } case FloatingPoint -> { if (vector instanceof Float4Vector floatVector) { - floatVector.setSafe(rowIndex, Float.parseFloat(record.get(columnName).toString())); + floatVector.set(rowIndex, Float.parseFloat(columnValue.toString())); } else if (vector instanceof Float8Vector doubleVector) { - doubleVector.setSafe(rowIndex, Double.parseDouble(record.get(columnName).toString())); + doubleVector.set(rowIndex, Double.parseDouble(columnValue.toString())); } else { log.warn("Unsupported type: {}", type); } } + case Bool -> { + if (vector instanceof BitVector bitVector) { + + // switch str { + // case "1", "t", "T", "true", "TRUE", "True": + // return true, nil + // case "0", "f", "F", "false", "FALSE", "False": + // return false, nil + bitVector.set(rowIndex, record.getBoolean(columnName) ? 1 : 0); + } else { + log.warn("ArrowType ID is Bool: Unsupported type: {}", vector.getClass()); + } + } default -> throw new IllegalArgumentException("Unsupported type: " + type); } diff --git a/dataproxy-manager/src/main/java/org/secretflow/dataproxy/manager/connector/odps/OdpsUtil.java b/dataproxy-manager/src/main/java/org/secretflow/dataproxy/manager/connector/odps/OdpsUtil.java index c0d69fb..d69b713 100644 --- a/dataproxy-manager/src/main/java/org/secretflow/dataproxy/manager/connector/odps/OdpsUtil.java +++ b/dataproxy-manager/src/main/java/org/secretflow/dataproxy/manager/connector/odps/OdpsUtil.java @@ -20,6 +20,10 @@ import com.aliyun.odps.account.AliyunAccount; import org.secretflow.dataproxy.common.model.datasource.conn.OdpsConnConfig; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.Map; + /** * odps util * @@ -36,4 +40,10 @@ public static Odps buildOdps(OdpsConnConfig odpsConnConfig) { return odps; } + + public static Map getSqlFlag() { + HashMap hints = new LinkedHashMap<>(); + hints.put("odps.sql.type.system.odps2", "true"); + return hints; + } } diff --git a/dataproxy-service/src/main/java/org/secretflow/dataproxy/service/impl/DataProxyServiceDirectImpl.java b/dataproxy-service/src/main/java/org/secretflow/dataproxy/service/impl/DataProxyServiceDirectImpl.java index 2383660..b35c553 100644 --- a/dataproxy-service/src/main/java/org/secretflow/dataproxy/service/impl/DataProxyServiceDirectImpl.java +++ b/dataproxy-service/src/main/java/org/secretflow/dataproxy/service/impl/DataProxyServiceDirectImpl.java @@ -162,15 +162,18 @@ public void datasetWrite(DatasetWriteCommand writeCommand, FlightStream flightSt writeCommand.setSchema(batch.getSchema()); } + int batchSize = 0; + try (DataWriter dataWriter = connector.buildWriter(writeCommand)) { while (flightStream.next()) { dataWriter.write(batch); // 调用写回调 writeCallback.ack(batch); - log.info("[datasetWrite] 数据块存储成功"); + log.info("[datasetWrite] dataset batch write is successful"); + batchSize += batch.getRowCount(); } dataWriter.flush(); - log.info("[datasetWrite] dataset write over"); + log.info("[datasetWrite] dataset write over, total size: {}", batchSize); } } catch (DataproxyException e) { log.error("[datasetWrite] dataset write error, cmd: {}", JsonUtils.toJSONString(writeCommand), e); diff --git a/dataproxy_sdk/cc/data_proxy_file.cc b/dataproxy_sdk/cc/data_proxy_file.cc index 785f233..b09aa38 100644 --- a/dataproxy_sdk/cc/data_proxy_file.cc +++ b/dataproxy_sdk/cc/data_proxy_file.cc @@ -129,6 +129,7 @@ class DataProxyFile::Impl { ASSIGN_DP_OR_THROW(batch_size, arrow::util::ReferencedBufferSize(*batch)); if (batch_size > kMaxBatchSize) { + slice_offset = 0; slice_size = (batch_size + kMaxBatchSize - 1) / kMaxBatchSize; slice_left = batch->num_rows(); slice_len = (slice_left + slice_size - 1) / slice_size; diff --git a/dataproxy_sdk/python/dataproxy/dp_file_adapter.py b/dataproxy_sdk/python/dataproxy/dp_file_adapter.py index 5ba6f4d..6f0aaee 100644 --- a/dataproxy_sdk/python/dataproxy/dp_file_adapter.py +++ b/dataproxy_sdk/python/dataproxy/dp_file_adapter.py @@ -28,6 +28,10 @@ def close(self): def download_file( self, info: proto.DownloadInfo, file_path: str, file_format: proto.FileFormat ): + logging.info( + f"dataproxy sdk: start download_file[{file_path}], type[{file_format}]" + ) + self.data_proxy_file.download_file( info.SerializeToString(), file_path, file_format ) @@ -40,6 +44,10 @@ def download_file( def upload_file( self, info: proto.UploadInfo, file_path: str, file_format: proto.FileFormat ): + logging.info( + f"dataproxy sdk: start upload_file[{file_path}], type[{file_format}]" + ) + self.data_proxy_file.upload_file( info.SerializeToString(), file_path, file_format ) diff --git a/pom.xml b/pom.xml index 225327f..2f49697 100644 --- a/pom.xml +++ b/pom.xml @@ -40,6 +40,7 @@ 3.12.0 4.4 1.26.2 + 2.10.1 4.0.3 1.3.1 3.4.0 @@ -160,6 +161,13 @@ ${commons-compress.version} + + + org.apache.commons + commons-configuration2 + ${commons-configuration2.version} + + From 9f81ec603cddd66f4da01be9ccdd264ad25578b3 Mon Sep 17 00:00:00 2001 From: SuperMarz Date: Tue, 14 Jan 2025 11:14:44 +0800 Subject: [PATCH 2/2] fix mac build error --- .bazelrc | 13 ++++++++----- .circleci/config.yml | 2 +- MODULE.bazel | 3 ++- dataproxy_sdk/python/test/BUILD.bazel | 18 ++++++++++++++++++ dataproxy_sdk/python/test/exported_symbols.lds | 1 + dataproxy_sdk/python/test/version_script.lds | 9 +++++++++ 6 files changed, 39 insertions(+), 7 deletions(-) create mode 100644 dataproxy_sdk/python/test/exported_symbols.lds create mode 100644 dataproxy_sdk/python/test/version_script.lds diff --git a/.bazelrc b/.bazelrc index 8c195cb..47839a7 100644 --- a/.bazelrc +++ b/.bazelrc @@ -48,10 +48,13 @@ build:linux --action_env=BAZEL_LINKLIBS=-l%:libstdc++.a:-l%:libgcc.a # platform specific config # Bazel will automatic pick platform config since we have enable_platform_specific_config set -build:linux --copt=-fopenmp -build:linux --linkopt=-fopenmp -build:macos --copt="-Xpreprocessor -fopenmp" +build:macos --copt=-Xclang=-fopenmp build:macos --copt=-Wno-unused-command-line-argument build:macos --features=-supports_dynamic_linker -build:macos --macos_minimum_os=12.0 -build:macos --host_macos_minimum_os=12.0 +build:macos --macos_minimum_os=13.0 +build:macos --host_macos_minimum_os=13.0 +build:macos --action_env MACOSX_DEPLOYMENT_TARGET=13.0 + +build:linux --copt=-fopenmp +build:linux --linkopt=-fopenmp + diff --git a/.circleci/config.yml b/.circleci/config.yml index ae98d8f..02f357d 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -76,5 +76,5 @@ workflows: .circleci/continue-config.yml sdk-build-and-run true dataproxy_sdk/.* sdk-build-and-run true .bazelrc sdk-build-and-run true - WORKSPACE sdk-build-and-run true + MODULE.bazel sdk-build-and-run true - lint diff --git a/MODULE.bazel b/MODULE.bazel index 258dc67..60ace3e 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -44,13 +44,14 @@ pip = use_extension("@rules_python//python/extensions:pip.bzl", "pip") pip.parse( hub_name = "dataproxy_sdk_pip", python_version = python_version, - requirements_linux = "//dataproxy_sdk/python:requirements_lock_{}.txt".format(python_version.replace(".", "_")), + requirements_lock = "//dataproxy_sdk/python:requirements_lock_{}.txt".format(python_version.replace(".", "_")), ) for python_version in SUPPORTED_PYTHON_VERSIONS ] use_repo(pip, "dataproxy_sdk_pip") +bazel_dep(name = "apple_support", version = "1.17.1") bazel_dep(name = "arrow", version = "14.0.2") bazel_dep(name = "rules_foreign_cc", version = "0.13.0") bazel_dep(name = "spdlog", version = "1.14.1") diff --git a/dataproxy_sdk/python/test/BUILD.bazel b/dataproxy_sdk/python/test/BUILD.bazel index bf2e23b..9ac3758 100644 --- a/dataproxy_sdk/python/test/BUILD.bazel +++ b/dataproxy_sdk/python/test/BUILD.bazel @@ -18,10 +18,28 @@ load("@rules_python//python:defs.bzl", "py_library", "py_test") package(default_visibility = ["//visibility:public"]) +exports_files( + [ + "exported_symbols.lds", + "version_script.lds", + ], + visibility = ["//visibility:private"], +) + pybind_extension( name = "_dm_mock", srcs = ["_dm_mock.cc"], + linkopts = select({ + "@bazel_tools//src/conditions:darwin": [ + "-Wl,-exported_symbols_list,$(location :exported_symbols.lds)", + ], + "//conditions:default": [ + "-Wl,--version-script,$(location :version_script.lds)", + ], + }), deps = [ + ":exported_symbols.lds", + ":version_script.lds", "//dataproxy_sdk/cc:exception", "//dataproxy_sdk/test:data_mesh_mock", ], diff --git a/dataproxy_sdk/python/test/exported_symbols.lds b/dataproxy_sdk/python/test/exported_symbols.lds new file mode 100644 index 0000000..2637585 --- /dev/null +++ b/dataproxy_sdk/python/test/exported_symbols.lds @@ -0,0 +1 @@ +_PyInit_* \ No newline at end of file diff --git a/dataproxy_sdk/python/test/version_script.lds b/dataproxy_sdk/python/test/version_script.lds new file mode 100644 index 0000000..a7e3bc0 --- /dev/null +++ b/dataproxy_sdk/python/test/version_script.lds @@ -0,0 +1,9 @@ +VERS_1.0 { + # Export symbols in pybind. + global: + PyInit_*; + + # Hide everything else. + local: + *; +};