|
@@ -1,145 +1,162 @@
|
|
|
package cn.iocoder.yudao.framework.ai.core.model.xinghuo;
|
|
|
|
|
|
-import cn.hutool.core.bean.BeanUtil;
|
|
|
-import cn.hutool.core.exceptions.ExceptionUtil;
|
|
|
-import cn.iocoder.yudao.framework.ai.core.exception.ChatException;
|
|
|
-import cn.iocoder.yudao.framework.ai.core.model.xinghuo.api.XingHuoApi;
|
|
|
-import cn.iocoder.yudao.framework.ai.core.model.xinghuo.api.XingHuoChatCompletion;
|
|
|
-import cn.iocoder.yudao.framework.ai.core.model.xinghuo.api.XingHuoChatCompletionRequest;
|
|
|
+import cn.hutool.core.lang.Assert;
|
|
|
import lombok.extern.slf4j.Slf4j;
|
|
|
+import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
|
|
|
import org.springframework.ai.chat.model.ChatModel;
|
|
|
import org.springframework.ai.chat.model.ChatResponse;
|
|
|
import org.springframework.ai.chat.model.Generation;
|
|
|
-import org.springframework.ai.chat.model.StreamingChatModel;
|
|
|
import org.springframework.ai.chat.prompt.ChatOptions;
|
|
|
import org.springframework.ai.chat.prompt.Prompt;
|
|
|
+import org.springframework.ai.model.ModelOptionsUtils;
|
|
|
+import org.springframework.ai.openai.OpenAiChatOptions;
|
|
|
+import org.springframework.ai.openai.api.OpenAiApi;
|
|
|
+import org.springframework.ai.openai.metadata.OpenAiChatResponseMetadata;
|
|
|
+import org.springframework.ai.retry.RetryUtils;
|
|
|
import org.springframework.http.ResponseEntity;
|
|
|
-import org.springframework.retry.RetryCallback;
|
|
|
-import org.springframework.retry.RetryContext;
|
|
|
-import org.springframework.retry.RetryListener;
|
|
|
import org.springframework.retry.support.RetryTemplate;
|
|
|
import reactor.core.publisher.Flux;
|
|
|
|
|
|
-import java.time.Duration;
|
|
|
+import java.util.HashMap;
|
|
|
import java.util.List;
|
|
|
-import java.util.stream.Collectors;
|
|
|
+import java.util.Map;
|
|
|
+
|
|
|
+import static cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatOptions.MODEL_DEFAULT;
|
|
|
|
|
|
-// TODO @fan:参考 yiyan 的修改建议,调整下 xinghuo 的实现;可以等 yiyan 修改完建议,然后我 review 完,再改这个哈;
|
|
|
/**
|
|
|
- * 讯飞星火 client
|
|
|
- * <p>
|
|
|
- * author: fansili
|
|
|
- * time: 2024/3/11 10:19
|
|
|
+ * 讯飞星火 {@link ChatModel} 实现类
|
|
|
+ *
|
|
|
+ * @author fansili
|
|
|
*/
|
|
|
@Slf4j
|
|
|
-public class XingHuoChatClient implements ChatModel, StreamingChatModel {
|
|
|
-
|
|
|
- private XingHuoApi xingHuoApi;
|
|
|
-
|
|
|
- private XingHuoOptions xingHuoOptions;
|
|
|
-
|
|
|
- public final RetryTemplate retryTemplate = RetryTemplate.builder()
|
|
|
- // 最大重试次数 10
|
|
|
- .maxAttempts(3)
|
|
|
- .retryOn(ChatException.class)
|
|
|
- // 最大重试5次,第一次间隔3000ms,第二次3000ms * 2,第三次3000ms * 3,以此类推,最大间隔3 * 60000ms
|
|
|
- .exponentialBackoff(Duration.ofMillis(3000), 2, Duration.ofMillis(3 * 60000))
|
|
|
- .withListener(new RetryListener() {
|
|
|
- @Override
|
|
|
- public <T extends Object, E extends Throwable> void onError(RetryContext context,
|
|
|
- RetryCallback<T, E> callback, Throwable throwable) {
|
|
|
- System.err.println("正在重试... " + ExceptionUtil.getMessage(throwable));
|
|
|
- log.warn("重试异常:" + context.getRetryCount(), throwable);
|
|
|
- }
|
|
|
+public class XingHuoChatClient implements ChatModel {
|
|
|
+
|
|
|
+ private static final String BASE_URL = "https://spark-api-open.xf-yun.com";
|
|
|
|
|
|
- ;
|
|
|
- })
|
|
|
- .build();
|
|
|
+ private final XingHuoChatOptions defaultOptions;
|
|
|
+ private final RetryTemplate retryTemplate;
|
|
|
|
|
|
- public XingHuoChatClient(XingHuoApi xingHuoApi) {
|
|
|
- this.xingHuoApi = xingHuoApi;
|
|
|
+ /**
|
|
|
+ * 星火兼容 OpenAI 的 HTTP 接口,所以复用它的实现,简化接入成本
|
|
|
+ *
|
|
|
+ * 不过要注意,星火没有完全兼容,所以不能使用 {@link org.springframework.ai.openai.OpenAiChatModel} 调用,但是实现会参考它
|
|
|
+ */
|
|
|
+ private final OpenAiApi openAiApi;
|
|
|
+
|
|
|
+ public XingHuoChatClient(String apiKey, String secretKey) {
|
|
|
+ this(apiKey, secretKey,
|
|
|
+ XingHuoChatOptions.builder().model(MODEL_DEFAULT).temperature(0.7F).build());
|
|
|
}
|
|
|
|
|
|
- public XingHuoChatClient(XingHuoApi xingHuoApi, XingHuoOptions xingHuoOptions) {
|
|
|
- this.xingHuoApi = xingHuoApi;
|
|
|
- this.xingHuoOptions = xingHuoOptions;
|
|
|
+ public XingHuoChatClient(String apiKey, String secretKey, XingHuoChatOptions options) {
|
|
|
+ this(apiKey, secretKey, options, RetryUtils.DEFAULT_RETRY_TEMPLATE);
|
|
|
+ }
|
|
|
+
|
|
|
+ public XingHuoChatClient(String apiKey, String secretKey, XingHuoChatOptions options, RetryTemplate retryTemplate) {
|
|
|
+ Assert.notEmpty(apiKey, "apiKey 不能为空");
|
|
|
+ Assert.notEmpty(secretKey, "secretKey 不能为空");
|
|
|
+ Assert.notNull(options, "options 不能为空");
|
|
|
+ Assert.notNull(retryTemplate, "retryTemplate 不能为空");
|
|
|
+ this.openAiApi = new OpenAiApi(BASE_URL, apiKey + ":" + secretKey);
|
|
|
+ this.defaultOptions = options;
|
|
|
+ this.retryTemplate = retryTemplate;
|
|
|
}
|
|
|
|
|
|
@Override
|
|
|
public ChatResponse call(Prompt prompt) {
|
|
|
+ OpenAiApi.ChatCompletionRequest request = createRequest(prompt, false);
|
|
|
return this.retryTemplate.execute(ctx -> {
|
|
|
- // ctx 会有重试的信息
|
|
|
- // 获取 chatOptions 属性
|
|
|
- XingHuoOptions chatOptions = this.getChatOptions(prompt);
|
|
|
- // 创建 request 请求,stream模式需要供应商支持
|
|
|
- XingHuoChatCompletionRequest request = this.createRequest(prompt, chatOptions);
|
|
|
- // 调用 callWithFunctionSupport 发送请求
|
|
|
- ResponseEntity<XingHuoChatCompletion> response = xingHuoApi.chatCompletionEntity(request, chatOptions.getChatModel());
|
|
|
- // 获取结果封装 ChatResponse
|
|
|
- return new ChatResponse(List.of(new Generation(response.getBody().getPayload().getChoices().getText().get(0).getContent())));
|
|
|
+ // 1.1 发起调用
|
|
|
+ ResponseEntity<OpenAiApi.ChatCompletion> completionEntity = openAiApi.chatCompletionEntity(request);
|
|
|
+ // 1.2 校验结果
|
|
|
+ OpenAiApi.ChatCompletion chatCompletion = completionEntity.getBody();
|
|
|
+ if (chatCompletion == null) {
|
|
|
+ log.warn("No chat completion returned for prompt: {}", prompt);
|
|
|
+ return new ChatResponse(List.of());
|
|
|
+ }
|
|
|
+ List<OpenAiApi.ChatCompletion.Choice> choices = chatCompletion.choices();
|
|
|
+ if (choices == null) {
|
|
|
+ log.warn("No choices returned for prompt: {}", prompt);
|
|
|
+ return new ChatResponse(List.of());
|
|
|
+ }
|
|
|
+
|
|
|
+ // 2. 转换 ChatResponse 返回
|
|
|
+ List<Generation> generations = choices.stream().map(choice -> {
|
|
|
+ Generation generation = new Generation(choice.message().content(), toMap(chatCompletion.id(), choice));
|
|
|
+ if (choice.finishReason() != null) {
|
|
|
+ generation.withGenerationMetadata(ChatGenerationMetadata.from(choice.finishReason().name(), null));
|
|
|
+ }
|
|
|
+ return generation;
|
|
|
+ }).toList();
|
|
|
+ return new ChatResponse(generations,
|
|
|
+ OpenAiChatResponseMetadata.from(completionEntity.getBody()));
|
|
|
});
|
|
|
}
|
|
|
|
|
|
- @Override
|
|
|
- public ChatOptions getDefaultOptions() {
|
|
|
- // TODO 芋艿:需要跟进下
|
|
|
- throw new UnsupportedOperationException();
|
|
|
+ private Map<String, Object> toMap(String id, OpenAiApi.ChatCompletion.Choice choice) {
|
|
|
+ Map<String, Object> map = new HashMap<>();
|
|
|
+ OpenAiApi.ChatCompletionMessage message = choice.message();
|
|
|
+ if (message.role() != null) {
|
|
|
+ map.put("role", message.role().name());
|
|
|
+ }
|
|
|
+ if (choice.finishReason() != null) {
|
|
|
+ map.put("finishReason", choice.finishReason().name());
|
|
|
+ }
|
|
|
+ map.put("id", id);
|
|
|
+ return map;
|
|
|
}
|
|
|
|
|
|
@Override
|
|
|
public Flux<ChatResponse> stream(Prompt prompt) {
|
|
|
- // 获取 chatOptions 属性
|
|
|
- XingHuoOptions chatOptions = this.getChatOptions(prompt);
|
|
|
- // 创建 request 请求,stream模式需要供应商支持
|
|
|
- XingHuoChatCompletionRequest request = this.createRequest(prompt, chatOptions);
|
|
|
- // 发送请求
|
|
|
- Flux<XingHuoChatCompletion> response = this.xingHuoApi.chatCompletionStream(request, chatOptions.getChatModel());
|
|
|
- return response.map(res -> {
|
|
|
- String content = res.getPayload().getChoices().getText().stream()
|
|
|
- .map(item -> item.getContent()).collect(Collectors.joining());
|
|
|
- return new ChatResponse(List.of(new Generation(content)));
|
|
|
+ OpenAiApi.ChatCompletionRequest request = createRequest(prompt, true);
|
|
|
+ return this.retryTemplate.execute(ctx -> {
|
|
|
+ // 1. 发起调用
|
|
|
+ Flux<OpenAiApi.ChatCompletionChunk> response = this.openAiApi.chatCompletionStream(request);
|
|
|
+ return response.map(chatCompletion -> {
|
|
|
+ String id = chatCompletion.id();
|
|
|
+ // 2. 转换 ChatResponse 返回
|
|
|
+ List<Generation> generations = chatCompletion.choices().stream().map(choice -> {
|
|
|
+ String finish = (choice.finishReason() != null ? choice.finishReason().name() : "");
|
|
|
+ Generation generation = new Generation(choice.delta().content(),
|
|
|
+ Map.of("id", id, "role", choice.delta().role().name(), "finishReason", finish));
|
|
|
+ if (choice.finishReason() != null) {
|
|
|
+ generation = generation.withGenerationMetadata(
|
|
|
+ ChatGenerationMetadata.from(choice.finishReason().name(), null));
|
|
|
+ }
|
|
|
+ return generation;
|
|
|
+ }).toList();
|
|
|
+ return new ChatResponse(generations);
|
|
|
+ });
|
|
|
});
|
|
|
}
|
|
|
|
|
|
- private XingHuoOptions getChatOptions(Prompt prompt) {
|
|
|
- // 两个都为null 则没有配置文件
|
|
|
- if (xingHuoOptions == null && prompt.getOptions() == null) {
|
|
|
- throw new ChatException("ChatOptions 未配置参数!");
|
|
|
- }
|
|
|
- // 优先使用 Prompt 里面的 ChatOptions
|
|
|
- ChatOptions options = xingHuoOptions;
|
|
|
+ OpenAiApi.ChatCompletionRequest createRequest(Prompt prompt, boolean stream) {
|
|
|
+ // 1. 构建 ChatCompletionMessage 对象
|
|
|
+ List<OpenAiApi.ChatCompletionMessage> chatCompletionMessages = prompt.getInstructions().stream().map(m ->
|
|
|
+ new OpenAiApi.ChatCompletionMessage(m.getContent(), OpenAiApi.ChatCompletionMessage.Role.valueOf(m.getMessageType().name()))).toList();
|
|
|
+ OpenAiApi.ChatCompletionRequest request = new OpenAiApi.ChatCompletionRequest(chatCompletionMessages, stream);
|
|
|
+
|
|
|
+ // 2.1 补充 prompt 内置的 options
|
|
|
if (prompt.getOptions() != null) {
|
|
|
- options = (ChatOptions) prompt.getOptions();
|
|
|
+ if (prompt.getOptions() instanceof ChatOptions runtimeOptions) {
|
|
|
+ OpenAiChatOptions updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(runtimeOptions,
|
|
|
+ ChatOptions.class, OpenAiChatOptions.class);
|
|
|
+ request = ModelOptionsUtils.merge(updatedRuntimeOptions, request, OpenAiApi.ChatCompletionRequest.class);
|
|
|
+ } else {
|
|
|
+ throw new IllegalArgumentException("Prompt options are not of type ChatOptions: "
|
|
|
+ + prompt.getOptions().getClass().getSimpleName());
|
|
|
+ }
|
|
|
}
|
|
|
- // Prompt 里面是一个 ChatOptions,用户可以随意传入,这里做一下判断
|
|
|
- if (!(options instanceof XingHuoOptions)) {
|
|
|
- throw new ChatException("Prompt 传入的不是 XingHuoOptions!");
|
|
|
+ // 2.2 补充默认 options
|
|
|
+ if (this.defaultOptions != null) {
|
|
|
+ request = ModelOptionsUtils.merge(request, this.defaultOptions, OpenAiApi.ChatCompletionRequest.class);
|
|
|
}
|
|
|
- return (XingHuoOptions) options;
|
|
|
+ return request;
|
|
|
}
|
|
|
|
|
|
- private XingHuoChatCompletionRequest createRequest(Prompt prompt, XingHuoOptions xingHuoOptions) {
|
|
|
- // 创建 header
|
|
|
- XingHuoChatCompletionRequest.Header header = new XingHuoChatCompletionRequest.Header().setApp_id(xingHuoApi.getAppId());
|
|
|
- // 创建 params
|
|
|
- XingHuoChatCompletionRequest.Parameter.Chat chatParameter = new XingHuoChatCompletionRequest.Parameter.Chat();
|
|
|
- BeanUtil.copyProperties(xingHuoOptions, chatParameter);
|
|
|
- chatParameter.setDomain(xingHuoOptions.getChatModel().getModel());
|
|
|
- XingHuoChatCompletionRequest.Parameter parameter = new XingHuoChatCompletionRequest.Parameter().setChat(chatParameter);
|
|
|
- // 创建 payload text 信息
|
|
|
- List<XingHuoChatCompletionRequest.Payload.Message.Text> texts = prompt.getInstructions().stream().map(message -> {
|
|
|
- XingHuoChatCompletionRequest.Payload.Message.Text text = new XingHuoChatCompletionRequest.Payload.Message.Text();
|
|
|
- text.setContent(message.getContent());
|
|
|
- text.setRole(message.getMessageType().getValue());
|
|
|
- return text;
|
|
|
- }).collect(Collectors.toList());
|
|
|
- // 创建 payload
|
|
|
- XingHuoChatCompletionRequest.Payload payload = new XingHuoChatCompletionRequest.Payload()
|
|
|
- .setMessage(new XingHuoChatCompletionRequest.Payload.Message().setText(texts));
|
|
|
- // 创建 request
|
|
|
- return new XingHuoChatCompletionRequest()
|
|
|
- .setHeader(header)
|
|
|
- .setParameter(parameter)
|
|
|
- .setPayload(payload);
|
|
|
+ @Override
|
|
|
+ public ChatOptions getDefaultOptions() {
|
|
|
+ return XingHuoChatOptions.fromOptions(defaultOptions);
|
|
|
}
|
|
|
+
|
|
|
}
|