commit
6d10f136a0
@ -40,6 +40,7 @@ public enum AiPlatformEnum implements ArrayValuable<String> {
|
|||||||
STABLE_DIFFUSION("StableDiffusion", "StableDiffusion"), // Stability AI
|
STABLE_DIFFUSION("StableDiffusion", "StableDiffusion"), // Stability AI
|
||||||
MIDJOURNEY("Midjourney", "Midjourney"), // Midjourney
|
MIDJOURNEY("Midjourney", "Midjourney"), // Midjourney
|
||||||
SUNO("Suno", "Suno"), // Suno AI
|
SUNO("Suno", "Suno"), // Suno AI
|
||||||
|
GROK("Grok","Grok"), // Grok
|
||||||
|
|
||||||
;
|
;
|
||||||
|
|
||||||
|
|||||||
@ -7,6 +7,7 @@ import cn.iocoder.yudao.module.ai.framework.ai.core.model.AiModelFactoryImpl;
|
|||||||
import cn.iocoder.yudao.module.ai.framework.ai.core.model.baichuan.BaiChuanChatModel;
|
import cn.iocoder.yudao.module.ai.framework.ai.core.model.baichuan.BaiChuanChatModel;
|
||||||
import cn.iocoder.yudao.module.ai.framework.ai.core.model.doubao.DouBaoChatModel;
|
import cn.iocoder.yudao.module.ai.framework.ai.core.model.doubao.DouBaoChatModel;
|
||||||
import cn.iocoder.yudao.module.ai.framework.ai.core.model.gemini.GeminiChatModel;
|
import cn.iocoder.yudao.module.ai.framework.ai.core.model.gemini.GeminiChatModel;
|
||||||
|
import cn.iocoder.yudao.module.ai.framework.ai.core.model.grok.GrokChatModel;
|
||||||
import cn.iocoder.yudao.module.ai.framework.ai.core.model.hunyuan.HunYuanChatModel;
|
import cn.iocoder.yudao.module.ai.framework.ai.core.model.hunyuan.HunYuanChatModel;
|
||||||
import cn.iocoder.yudao.module.ai.framework.ai.core.model.midjourney.api.MidjourneyApi;
|
import cn.iocoder.yudao.module.ai.framework.ai.core.model.midjourney.api.MidjourneyApi;
|
||||||
import cn.iocoder.yudao.module.ai.framework.ai.core.model.siliconflow.SiliconFlowApiConstants;
|
import cn.iocoder.yudao.module.ai.framework.ai.core.model.siliconflow.SiliconFlowApiConstants;
|
||||||
@ -17,6 +18,7 @@ import cn.iocoder.yudao.module.ai.framework.ai.core.webserch.AiWebSearchClient;
|
|||||||
import cn.iocoder.yudao.module.ai.framework.ai.core.webserch.bocha.AiBoChaWebSearchClient;
|
import cn.iocoder.yudao.module.ai.framework.ai.core.webserch.bocha.AiBoChaWebSearchClient;
|
||||||
import cn.iocoder.yudao.module.ai.tool.method.PersonService;
|
import cn.iocoder.yudao.module.ai.tool.method.PersonService;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.springframework.ai.chat.model.ChatModel;
|
||||||
import org.springframework.ai.deepseek.DeepSeekChatModel;
|
import org.springframework.ai.deepseek.DeepSeekChatModel;
|
||||||
import org.springframework.ai.deepseek.DeepSeekChatOptions;
|
import org.springframework.ai.deepseek.DeepSeekChatOptions;
|
||||||
import org.springframework.ai.deepseek.api.DeepSeekApi;
|
import org.springframework.ai.deepseek.api.DeepSeekApi;
|
||||||
@ -40,6 +42,7 @@ import org.springframework.context.annotation.Bean;
|
|||||||
import org.springframework.context.annotation.Configuration;
|
import org.springframework.context.annotation.Configuration;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import java.util.Optional;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 芋道 AI 自动配置
|
* 芋道 AI 自动配置
|
||||||
@ -286,4 +289,25 @@ public class AiAutoConfiguration {
|
|||||||
return List.of(ToolCallbacks.from(personService));
|
return List.of(ToolCallbacks.from(personService));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public ChatModel buildGrokChatClient(YudaoAiProperties.Grok properties) {
|
||||||
|
if (StrUtil.isEmpty(properties.getModel())) {
|
||||||
|
properties.setModel(GrokChatModel.MODEL_DEFAULT);
|
||||||
|
}
|
||||||
|
OpenAiChatModel openAiChatModel = OpenAiChatModel.builder()
|
||||||
|
.openAiApi(OpenAiApi.builder()
|
||||||
|
.baseUrl(Optional.ofNullable(properties.getBaseUrl())
|
||||||
|
.orElse(GrokChatModel.BASE_URL))
|
||||||
|
.completionsPath(GrokChatModel.COMPLETE_PATH)
|
||||||
|
.apiKey(properties.getApiKey())
|
||||||
|
.build())
|
||||||
|
.defaultOptions(OpenAiChatOptions.builder()
|
||||||
|
.model(properties.getModel())
|
||||||
|
.temperature(properties.getTemperature())
|
||||||
|
.maxTokens(properties.getMaxTokens())
|
||||||
|
.topP(properties.getTopP())
|
||||||
|
.build())
|
||||||
|
.toolCallingManager(getToolCallingManager())
|
||||||
|
.build();
|
||||||
|
return new DouBaoChatModel(openAiChatModel);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
@ -160,6 +160,20 @@ public class YudaoAiProperties {
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Data
|
||||||
|
public static class Grok {
|
||||||
|
|
||||||
|
private String enable;
|
||||||
|
private String apiKey;
|
||||||
|
private String baseUrl;
|
||||||
|
|
||||||
|
private String model;
|
||||||
|
private Double temperature;
|
||||||
|
private Integer maxTokens;
|
||||||
|
private Double topP;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
public static class WebSearch {
|
public static class WebSearch {
|
||||||
|
|
||||||
|
|||||||
@ -178,6 +178,8 @@ public class AiModelFactoryImpl implements AiModelFactory {
|
|||||||
return buildGeminiChatModel(apiKey);
|
return buildGeminiChatModel(apiKey);
|
||||||
case OLLAMA:
|
case OLLAMA:
|
||||||
return buildOllamaChatModel(url);
|
return buildOllamaChatModel(url);
|
||||||
|
case GROK:
|
||||||
|
return buildGrokChatModel(apiKey,url);
|
||||||
default:
|
default:
|
||||||
throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform));
|
throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform));
|
||||||
}
|
}
|
||||||
@ -405,6 +407,12 @@ public class AiModelFactoryImpl implements AiModelFactory {
|
|||||||
.build();
|
.build();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private ChatModel buildGrokChatModel(String apiKey,String url) {
|
||||||
|
YudaoAiProperties.Grok properties = new YudaoAiProperties.Grok()
|
||||||
|
.setBaseUrl(url)
|
||||||
|
.setApiKey(apiKey);
|
||||||
|
return new AiAutoConfiguration().buildGrokChatClient(properties);
|
||||||
|
}
|
||||||
/**
|
/**
|
||||||
* 可参考 {@link AiAutoConfiguration#douBaoChatClient(YudaoAiProperties)}
|
* 可参考 {@link AiAutoConfiguration#douBaoChatClient(YudaoAiProperties)}
|
||||||
*/
|
*/
|
||||||
|
|||||||
@ -0,0 +1,44 @@
|
|||||||
|
package cn.iocoder.yudao.module.ai.framework.ai.core.model.grok;
|
||||||
|
|
||||||
|
import lombok.RequiredArgsConstructor;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.springframework.ai.chat.model.ChatModel;
|
||||||
|
import org.springframework.ai.chat.model.ChatResponse;
|
||||||
|
import org.springframework.ai.chat.prompt.ChatOptions;
|
||||||
|
import org.springframework.ai.chat.prompt.Prompt;
|
||||||
|
import reactor.core.publisher.Flux;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Grok {@link ChatModel} 实现类
|
||||||
|
*
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
@Slf4j
|
||||||
|
@RequiredArgsConstructor
|
||||||
|
public class GrokChatModel implements ChatModel {
|
||||||
|
|
||||||
|
public static final String BASE_URL = "https://api.x.ai";
|
||||||
|
public static final String COMPLETE_PATH = "/v1/chat/completions";
|
||||||
|
public static final String MODEL_DEFAULT = "grok-4-fast-reasoning";
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 兼容 OpenAI 接口,进行复用
|
||||||
|
*/
|
||||||
|
private final ChatModel openAiChatModel;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ChatResponse call(Prompt prompt) {
|
||||||
|
return openAiChatModel.call(prompt);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Flux<ChatResponse> stream(Prompt prompt) {
|
||||||
|
return openAiChatModel.stream(prompt);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ChatOptions getDefaultOptions() {
|
||||||
|
return openAiChatModel.getDefaultOptions();
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
@ -79,6 +79,9 @@ public class AiUtils {
|
|||||||
case OLLAMA:
|
case OLLAMA:
|
||||||
return OllamaOptions.builder().model(model).temperature(temperature).numPredict(maxTokens)
|
return OllamaOptions.builder().model(model).temperature(temperature).numPredict(maxTokens)
|
||||||
.toolCallbacks(toolCallbacks).toolContext(toolContext).build();
|
.toolCallbacks(toolCallbacks).toolContext(toolContext).build();
|
||||||
|
case GROK:
|
||||||
|
return OpenAiChatOptions.builder().model(model).temperature(temperature).maxTokens(maxTokens)
|
||||||
|
.toolCallbacks(toolCallbacks).toolContext(toolContext).build();
|
||||||
default:
|
default:
|
||||||
throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform));
|
throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform));
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user