Explorar o código

数据权限的逻辑处理,暂未测试

YunaiV %!s(int64=3) %!d(string=hai) anos
pai
achega
e9385219c2

+ 113 - 75
yudao-framework/yudao-spring-boot-starter-data-permission/src/main/java/cn/iocoder/yudao/framework/datapermission/core/interceptor/DataPermissionInterceptor.java

@@ -4,18 +4,20 @@ import cn.hutool.core.collection.CollUtil;
 import cn.iocoder.yudao.framework.common.util.collection.SetUtils;
 import cn.iocoder.yudao.framework.datapermission.core.rule.DataPermissionRule;
 import cn.iocoder.yudao.framework.datapermission.core.rule.DataPermissionRuleFactory;
+import cn.iocoder.yudao.framework.mybatis.core.util.MyBatisUtils;
 import com.alibaba.ttl.TransmittableThreadLocal;
 import com.baomidou.mybatisplus.core.toolkit.CollectionUtils;
 import com.baomidou.mybatisplus.core.toolkit.PluginUtils;
-import com.baomidou.mybatisplus.core.toolkit.StringPool;
 import com.baomidou.mybatisplus.extension.parser.JsqlParserSupport;
 import com.baomidou.mybatisplus.extension.plugins.inner.InnerInterceptor;
 import lombok.RequiredArgsConstructor;
 import net.sf.jsqlparser.expression.*;
 import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
 import net.sf.jsqlparser.expression.operators.conditional.OrExpression;
-import net.sf.jsqlparser.expression.operators.relational.*;
-import net.sf.jsqlparser.schema.Column;
+import net.sf.jsqlparser.expression.operators.relational.ExistsExpression;
+import net.sf.jsqlparser.expression.operators.relational.ExpressionList;
+import net.sf.jsqlparser.expression.operators.relational.InExpression;
+import net.sf.jsqlparser.expression.operators.relational.ItemsList;
 import net.sf.jsqlparser.schema.Table;
 import net.sf.jsqlparser.statement.delete.Delete;
 import net.sf.jsqlparser.statement.select.*;
@@ -32,6 +34,15 @@ import java.sql.Connection;
 import java.util.*;
 import java.util.concurrent.ConcurrentHashMap;
 
+/**
+ * 数据权限拦截器,通过 {@link DataPermissionRule} 数据权限规则,重写 SQL 的方式来实现
+ * 主要的 SQL 重写方法,可见 {@link #builderExpression(Expression, Table)} 方法
+ *
+ * 整体的代码实现上,参考 {@link com.baomidou.mybatisplus.extension.plugins.inner.TenantLineInnerInterceptor} 实现。
+ * 所以每次 MyBatis Plus 升级时,需要 Review 下其具体的实现是否有变更!
+ *
+ * @author 芋道源码
+ */
 @RequiredArgsConstructor
 public class DataPermissionInterceptor extends JsqlParserSupport implements InnerInterceptor {
 
@@ -40,7 +51,8 @@ public class DataPermissionInterceptor extends JsqlParserSupport implements Inne
     private final MappedStatementCache mappedStatementCache = new MappedStatementCache();
 
     @Override // SELECT 场景
-    public void beforeQuery(Executor executor, MappedStatement ms, Object parameter, RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql) {
+    public void beforeQuery(Executor executor, MappedStatement ms, Object parameter,
+                            RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql) {
         // 获得 Mapper 对应的数据权限的规则
         List<DataPermissionRule> rules = ruleFactory.getDataPermissionRule(ms.getId());
         if (mappedStatementCache.noRewritable(ms, rules)) { // 如果无需重写,则跳过
@@ -59,12 +71,12 @@ public class DataPermissionInterceptor extends JsqlParserSupport implements Inne
         }
     }
 
-    @Override // 只处理 UPDATE / DELETE 场景
+    @Override // 只处理 UPDATE / DELETE 场景,不处理 INSERT 场景
     public void beforePrepare(StatementHandler sh, Connection connection, Integer transactionTimeout) {
         PluginUtils.MPStatementHandler mpSh = PluginUtils.mpStatementHandler(sh);
         MappedStatement ms = mpSh.mappedStatement();
         SqlCommandType sct = ms.getSqlCommandType();
-        if (sct == SqlCommandType.UPDATE || sct == SqlCommandType.DELETE) { // 无需处理 Insert 语句
+        if (sct == SqlCommandType.UPDATE || sct == SqlCommandType.DELETE) {
             // 获得 Mapper 对应的数据权限的规则
             List<DataPermissionRule> rules = ruleFactory.getDataPermissionRule(ms.getId());
             if (mappedStatementCache.noRewritable(ms, rules)) { // 如果无需重写,则跳过
@@ -117,7 +129,8 @@ public class DataPermissionInterceptor extends JsqlParserSupport implements Inne
     @Override
     protected void processUpdate(Update update, int index, String sql, Object obj) {
         final Table table = update.getTable();
-        update.setWhere(this.andExpression(table, update.getWhere()));
+//        update.setWhere(this.andExpression(table, update.getWhere()));
+        update.setWhere(this.builderExpression(update.getWhere(), table));
     }
 
     /**
@@ -125,26 +138,27 @@ public class DataPermissionInterceptor extends JsqlParserSupport implements Inne
      */
     @Override
     protected void processDelete(Delete delete, int index, String sql, Object obj) {
-        delete.setWhere(this.andExpression(delete.getTable(), delete.getWhere()));
+//        delete.setWhere(this.andExpression(delete.getTable(), delete.getWhere()));
+        delete.setWhere(this.builderExpression(delete.getWhere(), delete.getTable()));
     }
 
-    /**
-     * delete update 语句 where 处理
-     */
-    protected BinaryExpression andExpression(Table table, Expression where) {
-        //获得where条件表达式
-        EqualsTo equalsTo = new EqualsTo();
-        equalsTo.setLeftExpression(this.getAliasColumn(table));
-        equalsTo.setRightExpression(getTenantId());
-        if (null != where) {
-            if (where instanceof OrExpression) {
-                return new AndExpression(equalsTo, new Parenthesis(where));
-            } else {
-                return new AndExpression(equalsTo, where);
-            }
-        }
-        return equalsTo;
-    }
+//    /**
+//     * delete update 语句 where 处理
+//     */
+//    protected BinaryExpression andExpression(Table table, Expression where) {
+//        //获得where条件表达式
+//        EqualsTo equalsTo = new EqualsTo();
+//        equalsTo.setLeftExpression(this.getAliasColumn(table));
+//        equalsTo.setRightExpression(getTenantId());
+//        if (null != where) {
+//            if (where instanceof OrExpression) {
+//                return new AndExpression(equalsTo, new Parenthesis(where));
+//            } else {
+//                return new AndExpression(equalsTo, where);
+//            }
+//        }
+//        return equalsTo;
+//    }
 
     /**
      * 处理 PlainSelect
@@ -155,10 +169,11 @@ public class DataPermissionInterceptor extends JsqlParserSupport implements Inne
         processWhereSubSelect(where);
         if (fromItem instanceof Table) {
             Table fromTable = (Table) fromItem;
-            if (!ignoreTable(fromTable.getName())) {
-                //#1186 github
-                plainSelect.setWhere(builderExpression(where, fromTable));
-            }
+//            if (!ignoreTable(fromTable.getName())) {
+//                //#1186 github
+//                plainSelect.setWhere(builderExpression(where, fromTable));
+//            }
+            plainSelect.setWhere(builderExpression(where, fromTable));
         } else {
             processFromItem(fromItem);
         }
@@ -311,20 +326,21 @@ public class DataPermissionInterceptor extends JsqlParserSupport implements Inne
                     processJoin(join);
                     continue;
                 }
-                // 当前表是否忽略
-                boolean needIgnore = ignoreTable(fromTable.getName());
-                // 表名压栈,忽略的表压入 null,以便后续不处理
-                tables.push(needIgnore ? null : fromTable);
+//                // 当前表是否忽略
+//                boolean needIgnore = ignoreTable(fromTable.getName());
+//                // 表名压栈,忽略的表压入 null,以便后续不处理
+//                tables.push(needIgnore ? null : fromTable);
                 // 尾缀多个 on 表达式的时候统一处理
                 if (originOnExpressions.size() > 1) {
                     Collection<Expression> onExpressions = new LinkedList<>();
                     for (Expression originOnExpression : originOnExpressions) {
                         Table currentTable = tables.poll();
-                        if (currentTable == null) {
-                            onExpressions.add(originOnExpression);
-                        } else {
-                            onExpressions.add(builderExpression(originOnExpression, currentTable));
-                        }
+//                        if (currentTable == null) {
+//                            onExpressions.add(originOnExpression);
+//                        } else {
+//                            onExpressions.add(builderExpression(originOnExpression, currentTable));
+//                        }
+                        onExpressions.add(builderExpression(originOnExpression, currentTable));
                     }
                     join.setOnExpressions(onExpressions);
                 }
@@ -341,15 +357,18 @@ public class DataPermissionInterceptor extends JsqlParserSupport implements Inne
     protected void processJoin(Join join) {
         if (join.getRightItem() instanceof Table) {
             Table fromTable = (Table) join.getRightItem();
-            if (ignoreTable(fromTable.getName())) {
-                // 过滤退出执行
-                return;
-            }
+//            if (ignoreTable(fromTable.getName())) {
+//                // 过滤退出执行
+//                return;
+//            }
             // 走到这里说明 on 表达式肯定只有一个
-            Collection<Expression> originOnExpressions = join.getOnExpressions();
-            List<Expression> onExpressions = new LinkedList<>();
-            onExpressions.add(builderExpression(originOnExpressions.iterator().next(), fromTable));
-            join.setOnExpressions(onExpressions);
+//            Collection<Expression> originOnExpressions = join.getOnExpressions();
+//            List<Expression> onExpressions = new LinkedList<>();
+//            onExpressions.add(builderExpression(originOnExpressions.iterator().next(), fromTable));
+//            join.setOnExpressions(onExpressions);
+            Expression originOnExpression = CollUtil.getFirst(join.getOnExpressions());
+            originOnExpression = builderExpression(originOnExpression, fromTable);
+            join.setOnExpressions(CollUtil.newArrayList(originOnExpression));
         }
     }
 
@@ -357,50 +376,69 @@ public class DataPermissionInterceptor extends JsqlParserSupport implements Inne
      * 处理条件
      */
     protected Expression builderExpression(Expression currentExpression, Table table) {
-        EqualsTo equalsTo = new EqualsTo();
-        equalsTo.setLeftExpression(this.getAliasColumn(table));
-        equalsTo.setRightExpression(getTenantId());
+        // 获得 Table 对应的数据权限条件
+        Expression equalsTo = buildDataPermissionExpression(table);
+        if (equalsTo == null) { // 如果没条件,则返回 currentExpression 默认
+            return currentExpression;
+        }
+
+        // 表达式为空,则直接返回 equalsTo
         if (currentExpression == null) {
             return equalsTo;
         }
+        // 如果表达式为 Or,则需要 (currentExpression) AND equalsTo
         if (currentExpression instanceof OrExpression) {
             return new AndExpression(new Parenthesis(currentExpression), equalsTo);
-        } else {
-            return new AndExpression(currentExpression, equalsTo);
         }
+        // 如果表达式为 And,则直接返回 currentExpression AND equalsTo
+        return new AndExpression(currentExpression, equalsTo);
     }
 
+//    /**
+//     * 租户字段别名设置
+//     * <p>tenantId 或 tableAlias.tenantId</p>
+//     *
+//     * @param table 表对象
+//     * @return 字段
+//     */
+//    protected Column getAliasColumn(Table table) {
+//        StringBuilder column = new StringBuilder();
+//        if (table.getAlias() != null) {
+//            column.append(table.getAlias().getName()).append(StringPool.DOT);
+//        }
+//        column.append(getTenantIdColumn());
+//        return new Column(column.toString());
+//    }
+
     /**
-     * 租户字段别名设置
-     * <p>tenantId 或 tableAlias.tenantId</p>
+     * 构建指定表的数据权限的 Expression 过滤条件
      *
-     * @param table 表对象
-     * @return 字段
+     * @param table 表
+     * @return Expression 过滤条件
      */
-    protected Column getAliasColumn(Table table) {
-        StringBuilder column = new StringBuilder();
-        if (table.getAlias() != null) {
-            column.append(table.getAlias().getName()).append(StringPool.DOT);
+    private Expression buildDataPermissionExpression(Table table) {
+        // 生成条件
+        Expression allExpression = null;
+        for (DataPermissionRule rule : ContextHolder.getRules()) {
+            // 判断表名是否匹配
+            if (!rule.getTableNames().contains(table.getName())) {
+                continue;
+            }
+            // 单条规则的条件
+            String tableName = MyBatisUtils.getTableName(table);
+            Expression oneExpress = rule.getExpression(tableName, table.getAlias());
+            // 拼接到 allExpression 中
+            allExpression = allExpression == null ? oneExpress
+                    : new AndExpression(allExpression, oneExpress);
         }
-        column.append(getTenantIdColumn());
-        return new Column(column.toString());
-    }
-
-    // TODO 芋艿:未实现
 
-    private boolean ignoreTable(String tableName) {
-        return false;
-    }
-
-    private String getTenantIdColumn() {
-        return "dept_id";
-    }
-
-    private Expression getTenantId() {
-        return new LongValue(1L);
+        // 如果条件非空,说明已经重写了
+        if (allExpression != null) {
+            ContextHolder.setRewrite(true);
+        }
+        return allExpression;
     }
 
-
     /**
      * 判断 SQL 是否重写。如果没有重写,则添加到 {@link MappedStatementCache} 中
      *

+ 19 - 0
yudao-framework/yudao-spring-boot-starter-mybatis/src/main/java/cn/iocoder/yudao/framework/mybatis/core/util/MyBatisUtils.java

@@ -7,6 +7,7 @@ import com.baomidou.mybatisplus.core.metadata.OrderItem;
 import com.baomidou.mybatisplus.extension.plugins.MybatisPlusInterceptor;
 import com.baomidou.mybatisplus.extension.plugins.inner.InnerInterceptor;
 import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
+import net.sf.jsqlparser.schema.Table;
 
 import java.util.ArrayList;
 import java.util.Collection;
@@ -18,6 +19,8 @@ import java.util.stream.Collectors;
  */
 public class MyBatisUtils {
 
+    private static final String MYSQL_ESCAPE_CHARACTER = "`";
+
     public static <T> Page<T> buildPage(PageParam pageParam) {
         return buildPage(pageParam, null);
     }
@@ -48,4 +51,20 @@ public class MyBatisUtils {
         interceptor.setInterceptors(inners);
     }
 
+    /**
+     * 获得 Table 对应的表名
+     *
+     * 兼容 MySQL 转义表名 `t_xxx`
+     *
+     * @param table 表
+     * @return 去除转移字符后的表名
+     */
+    public static String getTableName(Table table) {
+        String tableName = table.getName();
+        if (tableName.startsWith(MYSQL_ESCAPE_CHARACTER) && tableName.endsWith(MYSQL_ESCAPE_CHARACTER)) {
+            tableName = tableName.substring(1, tableName.length() - 1);
+        }
+        return tableName;
+    }
+
 }