diff --git a/mybatis-plus-extension/src/main/java/com/baomidou/mybatisplus/extension/plugins/inner/IllegalSQLInnerInterceptor.java b/mybatis-plus-extension/src/main/java/com/baomidou/mybatisplus/extension/plugins/inner/IllegalSQLInnerInterceptor.java index 1f7207982..4cabbdb0a 100644 --- a/mybatis-plus-extension/src/main/java/com/baomidou/mybatisplus/extension/plugins/inner/IllegalSQLInnerInterceptor.java +++ b/mybatis-plus-extension/src/main/java/com/baomidou/mybatisplus/extension/plugins/inner/IllegalSQLInnerInterceptor.java @@ -27,6 +27,7 @@ import net.sf.jsqlparser.expression.BinaryExpression; import net.sf.jsqlparser.expression.Expression; import net.sf.jsqlparser.expression.Function; +import net.sf.jsqlparser.expression.Parenthesis; import net.sf.jsqlparser.expression.operators.arithmetic.Subtraction; import net.sf.jsqlparser.expression.operators.conditional.OrExpression; import net.sf.jsqlparser.expression.operators.relational.InExpression; @@ -154,6 +155,10 @@ protected void processDelete(Delete delete, int index, String sql, Object obj) { * @param expression ignore */ private void validExpression(Expression expression) { + while (expression instanceof Parenthesis) { + Parenthesis parenthesis = (Parenthesis) expression; + expression = parenthesis.getExpression(); + } //where条件使用了 or 关键字 if (expression instanceof OrExpression) { OrExpression orExpression = (OrExpression) expression; @@ -289,8 +294,10 @@ else if (leftExpression instanceof BinaryExpression) { } //获得右边表达式,并分解 - Expression rightExpression = ((BinaryExpression) expression).getRightExpression(); - validExpression(rightExpression); + if (joinTable != null) { + Expression rightExpression = ((BinaryExpression) expression).getRightExpression(); + validExpression(rightExpression); + } } } diff --git a/mybatis-plus-extension/src/test/java/com/baomidou/mybatisplus/extension/plugins/inner/IllegalSQLInnerInterceptorTest.java b/mybatis-plus-extension/src/test/java/com/baomidou/mybatisplus/extension/plugins/inner/IllegalSQLInnerInterceptorTest.java index e11471e50..62f7c6e08 100644 --- a/mybatis-plus-extension/src/test/java/com/baomidou/mybatisplus/extension/plugins/inner/IllegalSQLInnerInterceptorTest.java +++ b/mybatis-plus-extension/src/test/java/com/baomidou/mybatisplus/extension/plugins/inner/IllegalSQLInnerInterceptorTest.java @@ -20,7 +20,7 @@ class IllegalSQLInnerInterceptorTest { private final IllegalSQLInnerInterceptor interceptor = new IllegalSQLInnerInterceptor(); // - // 待研究为啥H2读不到索引信息 + // 待研究为啥H2读不到索引信息 // private static Connection connection; // // @BeforeAll @@ -48,12 +48,17 @@ void test() { Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("delete from t_user set age = 18", null)); Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("select * from t_user where age != 1", null)); Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("select * from t_user where age = 1 or name = 'test'", null)); + Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("select * from t_user where (age = 1 or name = 'test')", null)); // Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("select * from `t_demo` where a = 1 and b = 2", connection)); + Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("update t_user set age = 1 where age = 1 or name = 'test'", null)); + Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("update t_user set age = 1 where (age = 1 or name = 'test')", null)); + Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("delete t_user where age = 1 or name = 'test'", null)); + Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("delete t_user where (age = 1 or name = 'test')", null)); } @Test @Disabled - void testForMysql(){ + void testForMysql() { /* * CREATE TABLE `t_demo` ( `a` int DEFAULT NULL, @@ -97,6 +102,30 @@ void testForMysql(){ Assertions.assertDoesNotThrow(() -> interceptor.parserSingle("SELECT * FROM `t_demo` a LEFT JOIN `test` b ON a.a = b.a WHERE a.a = 1", dataSource.getConnection())); Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("SELECT * FROM `t_demo` a LEFT JOIN `test` b ON a.a = b.a WHERE a.b = 1", dataSource.getConnection())); + Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("select * from `t_demo` where (c = 3 OR b = 2)", dataSource.getConnection())); + Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("select * from `t_demo` where c = 3 OR b = 2", dataSource.getConnection())); + Assertions.assertDoesNotThrow(() -> interceptor.parserSingle("select * from `t_demo` where a = 3 AND (c = 3 OR b = 2)", dataSource.getConnection())); + + Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("select * from `t_demo` where (a = 3 AND c = 3 OR b = 2)", dataSource.getConnection())); + + Assertions.assertDoesNotThrow(() -> interceptor.parserSingle("select * from `t_demo` where a in (1,3,2)", dataSource.getConnection())); + + Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("select * from `t_demo` where a in (1,3,2) or b = 2", dataSource.getConnection())); + Assertions.assertDoesNotThrow(() -> interceptor.parserSingle("select * from `t_demo` where a in (1,3,2) AND b = 2", dataSource.getConnection())); + + Assertions.assertDoesNotThrow(() -> interceptor.parserSingle("select * from `t_demo` where (a = 3 AND c = 3 AND b = 2)", dataSource.getConnection())); + Assertions.assertDoesNotThrow(() -> interceptor.parserSingle("select * from `t_demo` a INNER JOIN test b ON a.a = b.a where a.a = 3 AND (b.c = 3 OR b.b = 2)", dataSource.getConnection())); + + Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("select * from `t_demo` where a != (SELECT b FROM test limit 1) ", dataSource.getConnection())); + //TODO 低版本这里的抛异常了.看着应该不用抛出 + Assertions.assertDoesNotThrow(() -> interceptor.parserSingle("select * from `t_demo` where a = (SELECT b FROM test limit 1) ", dataSource.getConnection())); + Assertions.assertDoesNotThrow(() -> interceptor.parserSingle("select * from `t_demo` where a >= (SELECT b FROM test limit 1) ", dataSource.getConnection())); + Assertions.assertDoesNotThrow(() -> interceptor.parserSingle("select * from `t_demo` where a <= (SELECT b FROM test limit 1) ", dataSource.getConnection())); + + Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("select * from `t_demo` where b = (SELECT b FROM test limit 1) ", dataSource.getConnection())); + Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("select * from `t_demo` where b >= (SELECT b FROM test limit 1) ", dataSource.getConnection())); + Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("select * from `t_demo` where b <= (SELECT b FROM test limit 1) ", dataSource.getConnection())); + } }