Skip to content

Commit

Permalink
BIGTOP-4227: Address SQL injection (#70)
Browse files Browse the repository at this point in the history
  • Loading branch information
timyuer authored Sep 14, 2024
1 parent 7e10c1b commit f5bc0c2
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 80 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -68,22 +68,23 @@ public Object intercept(Invocation invocation) throws Throwable {
Object parameter = invocation.getArgs()[1];
log.debug("sqlCommandType {}", sqlCommandType);

Collection<Object> objects;
if (parameter instanceof MapperMethod.ParamMap) {
MapperMethod.ParamMap<Object> paramMap = ((MapperMethod.ParamMap<Object>) parameter);
if (paramMap.get("param1") instanceof Collection) {
objects = ((Collection<Object>) paramMap.get("param1"));
if (SqlCommandType.INSERT == sqlCommandType || SqlCommandType.UPDATE == sqlCommandType) {
Collection<Object> objects;
if (parameter instanceof MapperMethod.ParamMap) {
MapperMethod.ParamMap<Object> paramMap = ((MapperMethod.ParamMap<Object>) parameter);
if (paramMap.get("param1") instanceof Collection) {
objects = ((Collection<Object>) paramMap.get("param1"));
} else {
objects = Collections.singletonList(paramMap.get("param1"));
}
} else {
objects = Collections.singletonList(paramMap.get("param1"));
objects = Collections.singletonList(parameter);
}
} else {
objects = Collections.singletonList(parameter);
}

for (Object o : objects) {
setAuditFields(o, sqlCommandType);
for (Object o : objects) {
setAuditFields(o, sqlCommandType);
}
}

return invocation.proceed();
}

Expand All @@ -92,26 +93,25 @@ private void setAuditFields(Object object, SqlCommandType sqlCommandType) throws
Timestamp timestamp = new Timestamp(System.currentTimeMillis());

List<Field> fields = ClassUtils.getFields(object.getClass());
if (SqlCommandType.INSERT == sqlCommandType || SqlCommandType.UPDATE == sqlCommandType) {
for (Field field : fields) {
boolean accessible = field.canAccess(object);
field.setAccessible(true);
if (field.isAnnotationPresent(CreateBy.class)
&& SqlCommandType.INSERT == sqlCommandType
&& userId != null) {
field.set(object, userId);
}
if (field.isAnnotationPresent(CreateTime.class) && SqlCommandType.INSERT == sqlCommandType) {
field.set(object, timestamp);
}
if (field.isAnnotationPresent(UpdateBy.class) && userId != null) {
field.set(object, userId);
}
if (field.isAnnotationPresent(UpdateTime.class)) {
field.set(object, timestamp);
}
field.setAccessible(accessible);

for (Field field : fields) {
boolean accessible = field.canAccess(object);
field.setAccessible(true);
if (field.isAnnotationPresent(CreateBy.class)
&& SqlCommandType.INSERT == sqlCommandType
&& userId != null) {
field.set(object, userId);
}
if (field.isAnnotationPresent(CreateTime.class) && SqlCommandType.INSERT == sqlCommandType) {
field.set(object, timestamp);
}
if (field.isAnnotationPresent(UpdateBy.class) && userId != null) {
field.set(object, userId);
}
if (field.isAnnotationPresent(UpdateTime.class)) {
field.set(object, timestamp);
}
field.setAccessible(accessible);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,10 @@
import java.beans.PropertyDescriptor;
import java.io.Serializable;
import java.lang.reflect.Field;
import java.text.MessageFormat;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;

/**
* Multiple data source support
Expand Down Expand Up @@ -117,7 +115,7 @@ public static <Entity> String update(TableMetaData tableMetaData, Entity entity,
}
Object value = ReflectionUtils.invokeMethod(ps.getReadMethod(), entity);
if (!ObjectUtils.isEmpty(value)) {
sql.SET("`" + getEquals(entry.getValue() + "`", entry.getKey()));
sql.SET(getEquals(entry.getValue(), entry.getKey()));
}
}

Expand Down Expand Up @@ -158,7 +156,7 @@ public static String selectById(TableMetaData tableMetaData, String databaseId,
case MYSQL: {
sql.SELECT(tableMetaData.getBaseColumns());
sql.FROM(tableMetaData.getTableName());
sql.WHERE(tableMetaData.getPkColumn() + " = '" + id + "'");
sql.WHERE(getEquals(tableMetaData.getPkColumn(), tableMetaData.getPkProperty()));
break;
}
case POSTGRESQL: {
Expand All @@ -185,10 +183,20 @@ public static String selectByIds(
SQL sql = new SQL();
switch (DBType.toType(databaseId)) {
case MYSQL: {
String idsStr = ids.stream().map(String::valueOf).collect(Collectors.joining("', '"));
sql.SELECT(tableMetaData.getBaseColumns());
sql.FROM(tableMetaData.getTableName());
sql.WHERE(tableMetaData.getPkColumn() + " in ('" + idsStr + "')");
if (ids == null || ids.isEmpty()) {
sql.WHERE("1 = 0");
break;
}

StringBuilder idStr = new StringBuilder();
for (int i = 0; i < ids.size(); i++) {
idStr.append(getTokenParam("arg0[" + i + "]")).append(",");
}
idStr.deleteCharAt(idStr.lastIndexOf(","));

sql.WHERE(tableMetaData.getPkColumn() + " IN ( " + idStr + " )");
break;
}
case POSTGRESQL: {
Expand Down Expand Up @@ -240,7 +248,7 @@ public static String deleteById(TableMetaData tableMetaData, String databaseId,
switch (DBType.toType(databaseId)) {
case MYSQL: {
sql.DELETE_FROM(tableMetaData.getTableName());
sql.WHERE(tableMetaData.getPkColumn() + " = '" + id + "'");
sql.WHERE(getEquals(tableMetaData.getPkColumn(), tableMetaData.getPkProperty()));
break;
}
case POSTGRESQL: {
Expand All @@ -261,9 +269,18 @@ public static String deleteByIds(
SQL sql = new SQL();
switch (DBType.toType(databaseId)) {
case MYSQL: {
String idsStr = ids.stream().map(String::valueOf).collect(Collectors.joining("', '"));
if (ids == null || ids.isEmpty()) {
break;
}
sql.DELETE_FROM(tableMetaData.getTableName());
sql.WHERE(tableMetaData.getPkColumn() + " in ('" + idsStr + "')");

StringBuilder idStr = new StringBuilder();
for (int i = 0; i < ids.size(); i++) {
idStr.append(getTokenParam("arg0[" + i + "]")).append(",");
}
idStr.deleteCharAt(idStr.lastIndexOf(","));

sql.WHERE(tableMetaData.getPkColumn() + " IN ( " + idStr + " )");
break;
}
case POSTGRESQL: {
Expand All @@ -282,14 +299,13 @@ public static String deleteByIds(

public static <Condition> String findByCondition(
TableMetaData tableMetaData, String databaseId, Condition condition) throws IllegalAccessException {
String tableName = tableMetaData.getTableName();
log.info("databaseId: {}", databaseId);
SQL sql = new SQL();
switch (DBType.toType(databaseId)) {
case POSTGRESQL:
tableName = "\"" + tableName + "\"";
case MYSQL: {
sql = mysqlCondition(condition, tableName);
sql = mysqlCondition(condition, tableMetaData);
break;
}
default: {
Expand All @@ -301,14 +317,15 @@ public static <Condition> String findByCondition(
}

private static String getEquals(String column, String property) {
return column + " = " + getTokenParam(property);
return "`" + column + "` = " + getTokenParam(property);
}

private static String getTokenParam(String property) {
return "#{" + property + "}";
}

private static <Condition> SQL mysqlCondition(Condition condition, String tableName) throws IllegalAccessException {
private static <Condition> SQL mysqlCondition(Condition condition, TableMetaData tableMetaData)
throws IllegalAccessException {

Class<?> loadClass;
try {
Expand All @@ -321,83 +338,80 @@ private static <Condition> SQL mysqlCondition(Condition condition, String tableN
/* Prepare SQL */
SQL sql = new SQL();
sql.SELECT("*");
sql.FROM(tableName);
sql.FROM(tableMetaData.getTableName());
for (Field field : fieldList) {
field.setAccessible(true);
String fieldName = field.getName();
log.debug("[requestField] {}, [requestValue] {}", fieldName, field.get(condition));
if (field.isAnnotationPresent(QueryCondition.class) && Objects.nonNull(field.get(condition))) {
QueryCondition annotation = field.getAnnotation(QueryCondition.class);

String queryKey = fieldName;
String property = fieldName;
if (!annotation.queryKey().isEmpty()) {
queryKey = annotation.queryKey();
property = annotation.queryKey();
}

log.info(
"[queryKey] {}, [queryType] {}, [queryValue] {}",
queryKey,
annotation.queryType().toString(),
field.get(condition));

Object value = field.get(condition);
if (value != null) {
Map<String, String> fieldColumnMap = tableMetaData.getFieldColumnMap();

if (value != null && fieldColumnMap.containsKey(property)) {
String columnName = fieldColumnMap.get(property);

log.info(
"[queryKey] {}, [queryType] {}, [queryValue] {}",
property,
annotation.queryType().toString(),
field.get(condition));
switch (annotation.queryType()) {
case EQ:
sql.WHERE(MessageFormat.format("{0} = ''{1}''", queryKey, value));
sql.WHERE(getEquals(columnName, fieldName));
break;
case NOT_EQ:
sql.WHERE(MessageFormat.format("{0} != ''{1}''", queryKey, value));
sql.WHERE(columnName + " != " + getTokenParam(fieldName));
break;
case IN:
sql.WHERE(MessageFormat.format(
"{0} IN (''{1}'')",
queryKey,
String.join("','", value.toString().split(annotation.multipleDelimiter()))));
sql.WHERE(columnName + " IN ( REPLACE( " + getTokenParam(fieldName) + ", '"
+ annotation.multipleDelimiter() + "', ',') )");
break;
case NOT_IN:
sql.WHERE(MessageFormat.format(
"{0} NOT IN (''{1}'')",
queryKey,
String.join("','", value.toString().split(annotation.multipleDelimiter()))));
sql.WHERE(columnName + " NOT IN ( REPLACE( " + getTokenParam(fieldName) + ", '"
+ annotation.multipleDelimiter() + "', ',') )");
break;
case GT:
sql.WHERE(MessageFormat.format("{0} > ''{1}''", queryKey, value));
sql.WHERE(columnName + " > " + getTokenParam(fieldName));
break;
case GTE:
sql.WHERE(MessageFormat.format("{0} >= ''{1}''", queryKey, value));
sql.WHERE(columnName + " >= " + getTokenParam(fieldName));
break;
case LT:
sql.WHERE(MessageFormat.format("{0} < ''{1}''", queryKey, value));
sql.WHERE(columnName + " < " + getTokenParam(fieldName));
break;
case LTE:
sql.WHERE(MessageFormat.format("{0} <= ''{1}''", queryKey, value));
sql.WHERE(columnName + " <= " + getTokenParam(fieldName));
break;
case BETWEEN:
String[] valueArr = field.get(condition).toString().split(annotation.pairDelimiter());
if (valueArr.length == 2) {
sql.WHERE(MessageFormat.format(
"{0} BETWEEN ''{1}'' AND ''{2}''", queryKey, valueArr[0], valueArr[1]));
}
sql.WHERE(columnName + " BETWEEN SUBSTRING_INDEX( " + getTokenParam(fieldName) + ", '"
+ annotation.pairDelimiter() + "', 1) AND SUBSTRING_INDEX( "
+ getTokenParam(fieldName) + ", '"
+ annotation.pairDelimiter() + "', 2)");
break;
case PREFIX_LIKE:
sql.WHERE(MessageFormat.format("{0} LIKE CONCAT(''{1}'', ''%'')", queryKey, value));
sql.WHERE(columnName + " LIKE CONCAT( " + getTokenParam(fieldName) + ", '%')");
break;
case SUFFIX_LIKE:
sql.WHERE(MessageFormat.format("{0} LIKE CONCAT(''%'', ''{1}'')", queryKey, value));
sql.WHERE(columnName + " LIKE CONCAT('%', " + getTokenParam(fieldName) + ")");
break;
case LIKE:
sql.WHERE(MessageFormat.format("{0} LIKE CONCAT(''%'', ''{1}'', ''%'')", queryKey, value));
sql.WHERE(columnName + " LIKE CONCAT('%', " + getTokenParam(fieldName) + ", '%')");
break;
case NOT_LIKE:
sql.WHERE(MessageFormat.format(
"{0} NOT LIKE CONCAT(''%'', ''{1}'', ''%'')", queryKey, value));
sql.WHERE(columnName + " NOT LIKE CONCAT('%', " + getTokenParam(fieldName) + ", '%')");
break;
case NULL:
sql.WHERE(queryKey + " IS NULL");
sql.WHERE(columnName + " IS NULL");
break;
case NOT_NULL:
sql.WHERE(queryKey + " IS NOT NULL");
sql.WHERE(columnName + " IS NOT NULL");
break;
default:
log.warn("Unknown query type: {}", annotation.queryType());
Expand Down

0 comments on commit f5bc0c2

Please sign in to comment.