Browse Source

完善 OAuth2TokenServiceImplTest 单元测试

YunaiV 3 years ago
parent
commit
522d70a29b

+ 1 - 3
yudao-module-system/yudao-module-system-biz/src/main/java/cn/iocoder/yudao/module/system/dal/dataobject/oauth2/OAuth2AccessTokenDO.java

@@ -9,7 +9,6 @@ import com.baomidou.mybatisplus.annotation.TableName;
 import com.baomidou.mybatisplus.extension.handlers.JacksonTypeHandler;
 import lombok.Data;
 import lombok.EqualsAndHashCode;
-import lombok.experimental.Accessors;
 
 import java.util.Date;
 import java.util.List;
@@ -22,11 +21,10 @@ import java.util.List;
  *
  * @author 芋道源码
  */
-@TableName("system_oauth2_access_token")
+@TableName(value = "system_oauth2_access_token", autoResultMap = true)
 @KeySequence("system_oauth2_access_token_seq") // 用于 Oracle、PostgreSQL、Kingbase、DB2、H2 数据库的主键自增。如果是 MySQL 等数据库,可不写。
 @Data
 @EqualsAndHashCode(callSuper = true)
-@Accessors(chain = true)
 public class OAuth2AccessTokenDO extends TenantBaseDO {
 
     /**

+ 1 - 1
yudao-module-system/yudao-module-system-biz/src/main/java/cn/iocoder/yudao/module/system/dal/dataobject/oauth2/OAuth2RefreshTokenDO.java

@@ -18,7 +18,7 @@ import java.util.List;
  *
  * @author 芋道源码
  */
-@TableName("system_oauth2_refresh_token")
+@TableName(value = "system_oauth2_refresh_token", autoResultMap = true)
 // 由于 Oracle 的 SEQ 的名字长度有限制,所以就先用 system_oauth2_access_token_seq 吧,反正也没啥问题
 @KeySequence("system_oauth2_access_token_seq") // 用于 Oracle、PostgreSQL、Kingbase、DB2、H2 数据库的主键自增。如果是 MySQL 等数据库,可不写。
 @Data

+ 1 - 1
yudao-module-system/yudao-module-system-biz/src/main/java/cn/iocoder/yudao/module/system/dal/mysql/oauth2/OAuth2AccessTokenMapper.java

@@ -25,7 +25,7 @@ public interface OAuth2AccessTokenMapper extends BaseMapperX<OAuth2AccessTokenDO
         return selectPage(reqVO, new LambdaQueryWrapperX<OAuth2AccessTokenDO>()
                 .eqIfPresent(OAuth2AccessTokenDO::getUserId, reqVO.getUserId())
                 .eqIfPresent(OAuth2AccessTokenDO::getUserType, reqVO.getUserType())
-                .eqIfPresent(OAuth2AccessTokenDO::getClientId, reqVO.getClientId())
+                .likeIfPresent(OAuth2AccessTokenDO::getClientId, reqVO.getClientId())
                 .gt(OAuth2AccessTokenDO::getExpiresTime, new Date())
                 .orderByDesc(OAuth2AccessTokenDO::getId));
     }

+ 289 - 0
yudao-module-system/yudao-module-system-biz/src/test/java/cn/iocoder/yudao/module/system/service/oauth2/OAuth2TokenServiceImplTest.java

@@ -0,0 +1,289 @@
+package cn.iocoder.yudao.module.system.service.oauth2;
+
+import cn.hutool.core.util.RandomUtil;
+import cn.iocoder.yudao.framework.common.enums.UserTypeEnum;
+import cn.iocoder.yudao.framework.common.exception.ErrorCode;
+import cn.iocoder.yudao.framework.common.pojo.PageResult;
+import cn.iocoder.yudao.framework.common.util.date.DateUtils;
+import cn.iocoder.yudao.framework.tenant.core.context.TenantContextHolder;
+import cn.iocoder.yudao.framework.test.core.ut.BaseDbAndRedisUnitTest;
+import cn.iocoder.yudao.module.system.controller.admin.oauth2.vo.token.OAuth2AccessTokenPageReqVO;
+import cn.iocoder.yudao.module.system.dal.dataobject.oauth2.OAuth2AccessTokenDO;
+import cn.iocoder.yudao.module.system.dal.dataobject.oauth2.OAuth2ClientDO;
+import cn.iocoder.yudao.module.system.dal.dataobject.oauth2.OAuth2RefreshTokenDO;
+import cn.iocoder.yudao.module.system.dal.mysql.oauth2.OAuth2AccessTokenMapper;
+import cn.iocoder.yudao.module.system.dal.mysql.oauth2.OAuth2RefreshTokenMapper;
+import cn.iocoder.yudao.module.system.dal.redis.oauth2.OAuth2AccessTokenRedisDAO;
+import org.assertj.core.util.Lists;
+import org.junit.jupiter.api.Test;
+import org.springframework.boot.test.mock.mockito.MockBean;
+import org.springframework.context.annotation.Import;
+
+import javax.annotation.Resource;
+import java.time.Duration;
+import java.util.Date;
+import java.util.List;
+
+import static cn.iocoder.yudao.framework.common.util.date.DateUtils.addTime;
+import static cn.iocoder.yudao.framework.common.util.object.ObjectUtils.cloneIgnoreId;
+import static cn.iocoder.yudao.framework.test.core.util.AssertUtils.assertPojoEquals;
+import static cn.iocoder.yudao.framework.test.core.util.AssertUtils.assertServiceException;
+import static cn.iocoder.yudao.framework.test.core.util.RandomUtils.*;
+import static org.junit.jupiter.api.Assertions.*;
+import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.Mockito.when;
+
+/**
+ * {@link OAuth2TokenServiceImpl} 的单元测试类
+ *
+ * @author 芋道源码
+ */
+@Import({OAuth2TokenServiceImpl.class, OAuth2AccessTokenRedisDAO.class})
+public class OAuth2TokenServiceImplTest extends BaseDbAndRedisUnitTest {
+
+    @Resource
+    private OAuth2TokenServiceImpl oauth2TokenService;
+
+    @Resource
+    private OAuth2AccessTokenMapper oauth2AccessTokenMapper;
+    @Resource
+    private OAuth2RefreshTokenMapper oauth2RefreshTokenMapper;
+
+    @Resource
+    private OAuth2AccessTokenRedisDAO oauth2AccessTokenRedisDAO;
+
+    @MockBean
+    private OAuth2ClientService oauth2ClientService;
+
+    @Test
+    public void testCreateAccessToken() {
+        TenantContextHolder.setTenantId(0L);
+        // 准备参数
+        Long userId = randomLongId();
+        Integer userType = RandomUtil.randomEle(UserTypeEnum.values()).getValue();
+        String clientId = randomString();
+        List<String> scopes = Lists.newArrayList("read", "write");
+        // mock 方法
+        OAuth2ClientDO clientDO = randomPojo(OAuth2ClientDO.class).setClientId(clientId)
+                .setAccessTokenValiditySeconds(30).setRefreshTokenValiditySeconds(60);
+        when(oauth2ClientService.validOAuthClientFromCache(eq(clientId))).thenReturn(clientDO);
+
+        // 调用
+        OAuth2AccessTokenDO accessTokenDO = oauth2TokenService.createAccessToken(userId, userType, clientId, scopes);
+        // 断言访问令牌
+        OAuth2AccessTokenDO dbAccessTokenDO = oauth2AccessTokenMapper.selectByAccessToken(accessTokenDO.getAccessToken());
+        assertPojoEquals(accessTokenDO, dbAccessTokenDO, "createTime", "updateTime", "deleted");
+        assertEquals(userId, accessTokenDO.getUserId());
+        assertEquals(userType, accessTokenDO.getUserType());
+        assertEquals(clientId, accessTokenDO.getClientId());
+        assertEquals(scopes, accessTokenDO.getScopes());
+        assertFalse(DateUtils.isExpired(accessTokenDO.getExpiresTime()));
+        // 断言访问令牌的缓存
+        OAuth2AccessTokenDO redisAccessTokenDO = oauth2AccessTokenRedisDAO.get(accessTokenDO.getAccessToken());
+        assertPojoEquals(accessTokenDO, redisAccessTokenDO, "createTime", "updateTime", "deleted");
+        // 断言刷新令牌
+        OAuth2RefreshTokenDO refreshTokenDO = oauth2RefreshTokenMapper.selectList().get(0);
+        assertPojoEquals(accessTokenDO, refreshTokenDO, "id", "expiresTime", "createTime", "updateTime", "deleted");
+        assertFalse(DateUtils.isExpired(refreshTokenDO.getExpiresTime()));
+    }
+
+    @Test
+    public void testRefreshAccessToken_null() {
+        // 准备参数
+        String refreshToken = randomString();
+        String clientId = randomString();
+        // mock 方法
+
+        // 调用,并断言
+        assertServiceException(() -> oauth2TokenService.refreshAccessToken(refreshToken, clientId),
+                new ErrorCode(400, "无效的刷新令牌"));
+    }
+
+    @Test
+    public void testRefreshAccessToken_clientIdError() {
+        // 准备参数
+        String refreshToken = randomString();
+        String clientId = randomString();
+        // mock 方法
+        OAuth2ClientDO clientDO = randomPojo(OAuth2ClientDO.class).setClientId(clientId);
+        when(oauth2ClientService.validOAuthClientFromCache(eq(clientId))).thenReturn(clientDO);
+        // mock 数据(访问令牌)
+        OAuth2RefreshTokenDO refreshTokenDO = randomPojo(OAuth2RefreshTokenDO.class)
+                .setRefreshToken(refreshToken).setClientId("error");
+        oauth2RefreshTokenMapper.insert(refreshTokenDO);
+
+        // 调用,并断言
+        assertServiceException(() -> oauth2TokenService.refreshAccessToken(refreshToken, clientId),
+                new ErrorCode(400, "刷新令牌的客户端编号不正确"));
+    }
+
+    @Test
+    public void testRefreshAccessToken_expired() {
+        // 准备参数
+        String refreshToken = randomString();
+        String clientId = randomString();
+        // mock 方法
+        OAuth2ClientDO clientDO = randomPojo(OAuth2ClientDO.class).setClientId(clientId);
+        when(oauth2ClientService.validOAuthClientFromCache(eq(clientId))).thenReturn(clientDO);
+        // mock 数据(访问令牌)
+        OAuth2RefreshTokenDO refreshTokenDO = randomPojo(OAuth2RefreshTokenDO.class)
+                .setRefreshToken(refreshToken).setClientId(clientId)
+                .setExpiresTime(addTime(Duration.ofDays(-1)));
+        oauth2RefreshTokenMapper.insert(refreshTokenDO);
+
+        // 调用,并断言
+        assertServiceException(() -> oauth2TokenService.refreshAccessToken(refreshToken, clientId),
+                new ErrorCode(401, "刷新令牌已过期"));
+    }
+
+    @Test
+    public void testRefreshAccessToken_success() {
+        TenantContextHolder.setTenantId(0L);
+        // 准备参数
+        String refreshToken = randomString();
+        String clientId = randomString();
+        // mock 方法
+        OAuth2ClientDO clientDO = randomPojo(OAuth2ClientDO.class).setClientId(clientId)
+                .setAccessTokenValiditySeconds(30);
+        when(oauth2ClientService.validOAuthClientFromCache(eq(clientId))).thenReturn(clientDO);
+        // mock 数据(访问令牌)
+        OAuth2RefreshTokenDO refreshTokenDO = randomPojo(OAuth2RefreshTokenDO.class)
+                .setRefreshToken(refreshToken).setClientId(clientId)
+                .setExpiresTime(addTime(Duration.ofDays(1)));
+        oauth2RefreshTokenMapper.insert(refreshTokenDO);
+        // mock 数据(访问令牌)
+        OAuth2AccessTokenDO accessTokenDO = randomPojo(OAuth2AccessTokenDO.class).setRefreshToken(refreshToken);
+        oauth2AccessTokenMapper.insert(accessTokenDO);
+        oauth2AccessTokenRedisDAO.set(accessTokenDO);
+
+        // 调用
+        OAuth2AccessTokenDO newAccessTokenDO = oauth2TokenService.refreshAccessToken(refreshToken, clientId);
+        // 断言,老的访问令牌被删除
+        assertNull(oauth2AccessTokenMapper.selectByAccessToken(accessTokenDO.getAccessToken()));
+        assertNull(oauth2AccessTokenRedisDAO.get(accessTokenDO.getAccessToken()));
+        // 断言,新的访问令牌
+        OAuth2AccessTokenDO dbAccessTokenDO = oauth2AccessTokenMapper.selectByAccessToken(newAccessTokenDO.getAccessToken());
+        assertPojoEquals(newAccessTokenDO, dbAccessTokenDO, "createTime", "updateTime", "deleted");
+        assertPojoEquals(newAccessTokenDO, refreshTokenDO, "id", "expiresTime", "createTime", "updateTime", "deleted",
+                "creator", "updater");
+        assertFalse(DateUtils.isExpired(newAccessTokenDO.getExpiresTime()));
+        // 断言,新的访问令牌的缓存
+        OAuth2AccessTokenDO redisAccessTokenDO = oauth2AccessTokenRedisDAO.get(newAccessTokenDO.getAccessToken());
+        assertPojoEquals(newAccessTokenDO, redisAccessTokenDO, "createTime", "updateTime", "deleted");
+    }
+
+    @Test
+    public void testGetAccessToken() {
+        // mock 数据(访问令牌)
+        OAuth2AccessTokenDO accessTokenDO = randomPojo(OAuth2AccessTokenDO.class)
+                .setExpiresTime(addTime(Duration.ofDays(1)));
+        oauth2AccessTokenMapper.insert(accessTokenDO);
+        // 准备参数
+        String accessToken = accessTokenDO.getAccessToken();
+
+        // 调用
+        OAuth2AccessTokenDO result = oauth2TokenService.getAccessToken(accessToken);
+        // 断言
+        assertPojoEquals(accessTokenDO, result, "createTime", "updateTime", "deleted",
+                "creator", "updater");
+        assertPojoEquals(accessTokenDO, oauth2AccessTokenRedisDAO.get(accessToken), "createTime", "updateTime", "deleted",
+                "creator", "updater");
+    }
+
+    @Test
+    public void testCheckAccessToken_null() {
+        // 调研,并断言
+        assertServiceException(() -> oauth2TokenService.checkAccessToken(randomString()),
+                new ErrorCode(401, "访问令牌不存在"));
+    }
+
+    @Test
+    public void testCheckAccessToken_expired() {
+        // mock 数据(访问令牌)
+        OAuth2AccessTokenDO accessTokenDO = randomPojo(OAuth2AccessTokenDO.class)
+                .setExpiresTime(addTime(Duration.ofDays(-1)));
+        oauth2AccessTokenMapper.insert(accessTokenDO);
+        // 准备参数
+        String accessToken = accessTokenDO.getAccessToken();
+
+        // 调研,并断言
+        assertServiceException(() -> oauth2TokenService.checkAccessToken(accessToken),
+                new ErrorCode(401, "访问令牌已过期"));
+    }
+
+    @Test
+    public void testCheckAccessToken_success() {
+        // mock 数据(访问令牌)
+        OAuth2AccessTokenDO accessTokenDO = randomPojo(OAuth2AccessTokenDO.class)
+                .setExpiresTime(addTime(Duration.ofDays(1)));
+        oauth2AccessTokenMapper.insert(accessTokenDO);
+        // 准备参数
+        String accessToken = accessTokenDO.getAccessToken();
+
+        // 调研,并断言
+        OAuth2AccessTokenDO result = oauth2TokenService.getAccessToken(accessToken);
+        // 断言
+        assertPojoEquals(accessTokenDO, result, "createTime", "updateTime", "deleted",
+                "creator", "updater");
+    }
+
+    @Test
+    public void testRemoveAccessToken_null() {
+        // 调用,并断言
+        assertNull(oauth2TokenService.removeAccessToken(randomString()));
+    }
+
+    @Test
+    public void testRemoveAccessToken_success() {
+        // mock 数据(访问令牌)
+        OAuth2AccessTokenDO accessTokenDO = randomPojo(OAuth2AccessTokenDO.class)
+                .setExpiresTime(addTime(Duration.ofDays(1)));
+        oauth2AccessTokenMapper.insert(accessTokenDO);
+        // mock 数据(刷新令牌)
+        OAuth2RefreshTokenDO refreshTokenDO = randomPojo(OAuth2RefreshTokenDO.class)
+                .setRefreshToken(accessTokenDO.getRefreshToken());
+        oauth2RefreshTokenMapper.insert(refreshTokenDO);
+        // 调用
+        OAuth2AccessTokenDO result = oauth2TokenService.removeAccessToken(accessTokenDO.getAccessToken());
+        assertPojoEquals(accessTokenDO, result, "createTime", "updateTime", "deleted",
+                "creator", "updater");
+        // 断言数据
+        assertNull(oauth2AccessTokenMapper.selectByAccessToken(accessTokenDO.getAccessToken()));
+        assertNull(oauth2RefreshTokenMapper.selectByRefreshToken(accessTokenDO.getRefreshToken()));
+        assertNull(oauth2AccessTokenRedisDAO.get(accessTokenDO.getAccessToken()));
+    }
+
+
+    @Test
+    public void testGetAccessTokenPage() {
+        // mock 数据
+        OAuth2AccessTokenDO dbAccessToken = randomPojo(OAuth2AccessTokenDO.class, o -> { // 等会查询到
+            o.setUserId(10L);
+            o.setUserType(1);
+            o.setClientId("test_client");
+            o.setExpiresTime(DateUtils.addTime(Duration.ofDays(1)));
+        });
+        oauth2AccessTokenMapper.insert(dbAccessToken);
+        // 测试 userId 不匹配
+        oauth2AccessTokenMapper.insert(cloneIgnoreId(dbAccessToken, o -> o.setUserId(20L)));
+        // 测试 userType 不匹配
+        oauth2AccessTokenMapper.insert(cloneIgnoreId(dbAccessToken, o -> o.setUserType(2)));
+        // 测试 userType 不匹配
+        oauth2AccessTokenMapper.insert(cloneIgnoreId(dbAccessToken, o -> o.setClientId("it_client")));
+        // 测试 expireTime 不匹配
+        oauth2AccessTokenMapper.insert(cloneIgnoreId(dbAccessToken, o -> o.setExpiresTime(new Date())));
+        // 准备参数
+        OAuth2AccessTokenPageReqVO reqVO = new OAuth2AccessTokenPageReqVO();
+        reqVO.setUserId(10L);
+        reqVO.setUserType(1);
+        reqVO.setClientId("test");
+
+        // 调用
+        PageResult<OAuth2AccessTokenDO> pageResult = oauth2TokenService.getAccessTokenPage(reqVO);
+        // 断言
+        assertEquals(1, pageResult.getTotal());
+        assertEquals(1, pageResult.getList().size());
+        assertPojoEquals(dbAccessToken, pageResult.getList().get(0));
+    }
+
+}

+ 2 - 0
yudao-module-system/yudao-module-system-biz/src/test/resources/sql/clean.sql

@@ -22,3 +22,5 @@ DELETE FROM "system_tenant_package";
 DELETE FROM "system_sensitive_word";
 DELETE FROM "system_oauth2_client";
 DELETE FROM "system_oauth2_approve";
+DELETE FROM "system_oauth2_access_token";
+DELETE FROM "system_oauth2_refresh_token";

+ 36 - 0
yudao-module-system/yudao-module-system-biz/src/test/resources/sql/create_tables.sql

@@ -511,3 +511,39 @@ CREATE TABLE IF NOT EXISTS "system_oauth2_approve" (
   "deleted" bit NOT NULL DEFAULT FALSE,
   PRIMARY KEY ("id")
 ) COMMENT 'OAuth2 批准表';
+
+CREATE TABLE IF NOT EXISTS "system_oauth2_access_token" (
+   "id" bigint NOT NULL GENERATED BY DEFAULT AS IDENTITY,
+   "user_id" bigint NOT NULL,
+   "user_type" tinyint NOT NULL,
+   "access_token" varchar NOT NULL,
+   "refresh_token" varchar NOT NULL,
+   "client_id" varchar NOT NULL,
+   "scopes" varchar NOT NULL,
+   "approved" bit NOT NULL DEFAULT FALSE,
+   "expires_time" datetime NOT NULL,
+   "creator" varchar DEFAULT '',
+   "create_time" datetime NOT NULL DEFAULT CURRENT_TIMESTAMP,
+   "updater" varchar DEFAULT '',
+   "update_time" datetime NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+   "deleted" bit NOT NULL DEFAULT FALSE,
+   "tenant_id" bigint NOT NULL,
+   PRIMARY KEY ("id")
+) COMMENT 'OAuth2 访问令牌';
+
+CREATE TABLE IF NOT EXISTS "system_oauth2_refresh_token" (
+    "id" bigint NOT NULL GENERATED BY DEFAULT AS IDENTITY,
+    "user_id" bigint NOT NULL,
+    "user_type" tinyint NOT NULL,
+    "refresh_token" varchar NOT NULL,
+    "client_id" varchar NOT NULL,
+    "scopes" varchar NOT NULL,
+    "approved" bit NOT NULL DEFAULT FALSE,
+    "expires_time" datetime NOT NULL,
+    "creator" varchar DEFAULT '',
+    "create_time" datetime NOT NULL DEFAULT CURRENT_TIMESTAMP,
+    "updater" varchar DEFAULT '',
+    "update_time" datetime NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+    "deleted" bit NOT NULL DEFAULT FALSE,
+    PRIMARY KEY ("id")
+) COMMENT 'OAuth2 刷新令牌';