|
@@ -7,8 +7,13 @@ import cn.iocoder.yudao.module.ai.controller.admin.vo.AiChatReqVO;
|
|
import cn.iocoder.yudao.module.ai.enums.AiModelEnum;
|
|
import cn.iocoder.yudao.module.ai.enums.AiModelEnum;
|
|
import io.swagger.v3.oas.annotations.Operation;
|
|
import io.swagger.v3.oas.annotations.Operation;
|
|
import io.swagger.v3.oas.annotations.tags.Tag;
|
|
import io.swagger.v3.oas.annotations.tags.Tag;
|
|
|
|
+import jakarta.servlet.http.HttpServletResponse;
|
|
|
|
+import lombok.extern.slf4j.Slf4j;
|
|
import org.springframework.ai.chat.ChatClient;
|
|
import org.springframework.ai.chat.ChatClient;
|
|
|
|
+import org.springframework.ai.chat.ChatResponse;
|
|
|
|
+import org.springframework.ai.chat.prompt.Prompt;
|
|
import org.springframework.ai.openai.OpenAiChatClient;
|
|
import org.springframework.ai.openai.OpenAiChatClient;
|
|
|
|
+import org.springframework.ai.openai.api.OpenAiApi;
|
|
import org.springframework.beans.factory.annotation.Autowired;
|
|
import org.springframework.beans.factory.annotation.Autowired;
|
|
import org.springframework.context.ApplicationContext;
|
|
import org.springframework.context.ApplicationContext;
|
|
import org.springframework.validation.annotation.Validated;
|
|
import org.springframework.validation.annotation.Validated;
|
|
@@ -16,6 +21,10 @@ import org.springframework.web.bind.annotation.PostMapping;
|
|
import org.springframework.web.bind.annotation.RequestBody;
|
|
import org.springframework.web.bind.annotation.RequestBody;
|
|
import org.springframework.web.bind.annotation.RequestMapping;
|
|
import org.springframework.web.bind.annotation.RequestMapping;
|
|
import org.springframework.web.bind.annotation.RestController;
|
|
import org.springframework.web.bind.annotation.RestController;
|
|
|
|
+import reactor.core.publisher.Flux;
|
|
|
|
+
|
|
|
|
+import java.util.Scanner;
|
|
|
|
+import java.util.function.Consumer;
|
|
|
|
|
|
/**
|
|
/**
|
|
* AI模块
|
|
* AI模块
|
|
@@ -26,6 +35,7 @@ import org.springframework.web.bind.annotation.RestController;
|
|
@Tag(name = "AI模块")
|
|
@Tag(name = "AI模块")
|
|
@RestController
|
|
@RestController
|
|
@RequestMapping("/ai-api")
|
|
@RequestMapping("/ai-api")
|
|
|
|
+@Slf4j
|
|
public class ChatController {
|
|
public class ChatController {
|
|
|
|
|
|
@Autowired
|
|
@Autowired
|
|
@@ -44,6 +54,39 @@ public class ChatController {
|
|
return CommonResult.success(res);
|
|
return CommonResult.success(res);
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+ @PostMapping("/chatStream")
|
|
|
|
+ @Operation(summary = "对话聊天chatStream", description = "简单的ai聊天")
|
|
|
|
+ public CommonResult chatStream(HttpServletResponse response, @RequestBody @Validated AiChatReqVO reqVO) throws InterruptedException {
|
|
|
|
+ OpenAiChatClient chatClient = applicationContext.getBean(OpenAiChatClient.class);
|
|
|
|
+ Flux<ChatResponse> chatResponse = chatClient.stream(new Prompt(reqVO.getInputText()));
|
|
|
|
+ chatResponse.subscribe(new Consumer<ChatResponse>() {
|
|
|
|
+ @Override
|
|
|
|
+ public void accept(ChatResponse chatResponse) {
|
|
|
|
+ System.err.println(chatResponse.getResults().get(0).getOutput().getContent());
|
|
|
|
+ }
|
|
|
|
+ });
|
|
|
|
+ return CommonResult.success("1");
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ public static void main(String[] args) {
|
|
|
|
+ OpenAiChatClient openAiChatClient = new OpenAiChatClient(new OpenAiApi("openkey"));
|
|
|
|
+ Flux<ChatResponse> responseFlux = openAiChatClient.stream(new Prompt("最好的编程语言!"));
|
|
|
|
+ long now = System.currentTimeMillis();
|
|
|
|
+ responseFlux.subscribe(new Consumer<ChatResponse>() {
|
|
|
|
+ @Override
|
|
|
|
+ public void accept(ChatResponse chatResponse) {
|
|
|
|
+ if (chatResponse.getResults().get(0).getOutput() == null) {
|
|
|
|
+ return;
|
|
|
|
+ }
|
|
|
|
+ System.err.println(chatResponse.getResults().get(0).getOutput().getContent());
|
|
|
|
+ }
|
|
|
|
+ });
|
|
|
|
+
|
|
|
|
+ // 阻止退出
|
|
|
|
+ Scanner scanner = new Scanner(System.in);
|
|
|
|
+ scanner.nextLine();
|
|
|
|
+ }
|
|
|
|
+
|
|
/**
|
|
/**
|
|
* 根据 ai模型 获取对于的 模型实现类
|
|
* 根据 ai模型 获取对于的 模型实现类
|
|
*
|
|
*
|