调整对应的调用

This commit is contained in:
UnCLAS-Prommer
2025-07-30 17:07:55 +08:00
parent 3c40ceda4c
commit 6c0edd0ad7
40 changed files with 580 additions and 1236 deletions

View File

@@ -12,6 +12,7 @@ import traceback
from typing import Tuple, Any, Dict, List, Optional
from rich.traceback import install
from src.common.logger import get_logger
from src.config.api_ada_configs import TaskConfig
from src.chat.replyer.default_generator import DefaultReplyer
from src.chat.message_receive.chat_stream import ChatStream
from src.chat.utils.utils import process_llm_response
@@ -31,7 +32,7 @@ logger = get_logger("generator_api")
def get_replyer(
chat_stream: Optional[ChatStream] = None,
chat_id: Optional[str] = None,
model_configs: Optional[List[Dict[str, Any]]] = None,
model_set_with_weight: Optional[List[Tuple[TaskConfig, float]]] = None,
request_type: str = "replyer",
) -> Optional[DefaultReplyer]:
"""获取回复器对象
@@ -58,7 +59,7 @@ def get_replyer(
return replyer_manager.get_replyer(
chat_stream=chat_stream,
chat_id=chat_id,
model_configs=model_configs,
model_set_with_weight=model_set_with_weight,
request_type=request_type,
)
except Exception as e:
@@ -83,7 +84,7 @@ async def generate_reply(
enable_splitter: bool = True,
enable_chinese_typo: bool = True,
return_prompt: bool = False,
model_configs: Optional[List[Dict[str, Any]]] = None,
model_set_with_weight: Optional[List[Tuple[TaskConfig, float]]] = None,
request_type: str = "generator_api",
) -> Tuple[bool, List[Tuple[str, Any]], Optional[str]]:
"""生成回复
@@ -106,7 +107,7 @@ async def generate_reply(
"""
try:
# 获取回复器
replyer = get_replyer(chat_stream, chat_id, model_configs=model_configs, request_type=request_type)
replyer = get_replyer(chat_stream, chat_id, model_set_with_weight=model_set_with_weight, request_type=request_type)
if not replyer:
logger.error("[GeneratorAPI] 无法获取回复器")
return False, [], None
@@ -154,7 +155,7 @@ async def rewrite_reply(
chat_id: Optional[str] = None,
enable_splitter: bool = True,
enable_chinese_typo: bool = True,
model_configs: Optional[List[Dict[str, Any]]] = None,
model_set_with_weight: Optional[List[Tuple[TaskConfig, float]]] = None,
raw_reply: str = "",
reason: str = "",
reply_to: str = "",
@@ -179,7 +180,7 @@ async def rewrite_reply(
"""
try:
# 获取回复器
replyer = get_replyer(chat_stream, chat_id, model_configs=model_configs)
replyer = get_replyer(chat_stream, chat_id, model_set_with_weight=model_set_with_weight)
if not replyer:
logger.error("[GeneratorAPI] 无法获取回复器")
return False, [], None
@@ -245,17 +246,17 @@ async def process_human_text(content: str, enable_splitter: bool, enable_chinese
async def generate_response_custom(
chat_stream: Optional[ChatStream] = None,
chat_id: Optional[str] = None,
model_configs: Optional[List[Dict[str, Any]]] = None,
model_set_with_weight: Optional[List[Tuple[TaskConfig, float]]] = None,
prompt: str = "",
) -> Optional[str]:
replyer = get_replyer(chat_stream, chat_id, model_configs=model_configs)
replyer = get_replyer(chat_stream, chat_id, model_set_with_weight=model_set_with_weight)
if not replyer:
logger.error("[GeneratorAPI] 无法获取回复器")
return None
try:
logger.debug("[GeneratorAPI] 开始生成自定义回复")
response = await replyer.llm_generate_content(prompt)
response, _, _, _ = await replyer.llm_generate_content(prompt)
if response:
logger.debug("[GeneratorAPI] 自定义回复生成成功")
return response

View File

@@ -7,10 +7,11 @@
success, response, reasoning, model_name = await llm_api.generate_with_model(prompt, model_config)
"""
from typing import Tuple, Dict, Any
from typing import Tuple, Dict
from src.common.logger import get_logger
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config
from src.config.config import global_config, model_config
from src.config.api_ada_configs import TaskConfig
logger = get_logger("llm_api")
@@ -19,9 +20,7 @@ logger = get_logger("llm_api")
# =============================================================================
def get_available_models() -> Dict[str, Any]:
def get_available_models() -> Dict[str, TaskConfig]:
"""获取所有可用的模型配置
Returns:
@@ -33,14 +32,14 @@ def get_available_models() -> Dict[str, Any]:
return {}
# 自动获取所有属性并转换为字典形式
rets = {}
models = global_config.model
models = model_config.model_task_config
attrs = dir(models)
rets: Dict[str, TaskConfig] = {}
for attr in attrs:
if not attr.startswith("__"):
try:
value = getattr(models, attr)
if not callable(value): # 排除方法
if not callable(value) and isinstance(value, TaskConfig):
rets[attr] = value
except Exception as e:
logger.debug(f"[LLMAPI] 获取属性 {attr} 失败: {e}")
@@ -53,8 +52,8 @@ def get_available_models() -> Dict[str, Any]:
async def generate_with_model(
prompt: str, model_config: Dict[str, Any], request_type: str = "plugin.generate", **kwargs
) -> Tuple[bool, str]:
prompt: str, model_config: TaskConfig, request_type: str = "plugin.generate", **kwargs
) -> Tuple[bool, str, str, str]:
"""使用指定模型生成内容
Args:
@@ -67,17 +66,16 @@ async def generate_with_model(
Tuple[bool, str, str, str]: (是否成功, 生成的内容, 推理过程, 模型名称)
"""
try:
model_name = model_config.get("name")
logger.info(f"[LLMAPI] 使用模型 {model_name} 生成内容")
model_name_list = model_config.model_list
logger.info(f"[LLMAPI] 使用模型集合 {model_name_list} 生成内容")
logger.debug(f"[LLMAPI] 完整提示词: {prompt}")
llm_request = LLMRequest(model=model_config, request_type=request_type, **kwargs)
llm_request = LLMRequest(model_set=model_config, request_type=request_type, **kwargs)
# TODO: 复活这个_
response, _ = await llm_request.generate_response_async(prompt)
return True, response
response, (reasoning_content, model_name, _) = await llm_request.generate_response_async(prompt)
return True, response, reasoning_content, model_name
except Exception as e:
error_msg = f"生成内容时出错: {str(e)}"
logger.error(f"[LLMAPI] {error_msg}")
return False, error_msg
return False, error_msg, "", ""

View File

@@ -335,7 +335,7 @@ async def command_to_stream(
async def custom_to_stream(
message_type: str,
content: str,
content: str | dict,
stream_id: str,
display_message: str = "",
typing: bool = False,

View File

@@ -4,7 +4,7 @@ from typing import List, Dict, Tuple, Optional, Any
from src.plugin_system.apis.tool_api import get_llm_available_tool_definitions, get_tool_instance
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config
from src.config.config import global_config, model_config
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.utils.json_utils import process_llm_tool_calls
from src.chat.message_receive.chat_stream import get_chat_manager
@@ -52,10 +52,7 @@ class ToolExecutor:
self.chat_stream = get_chat_manager().get_stream(self.chat_id)
self.log_prefix = f"[{get_chat_manager().get_stream_name(self.chat_id) or self.chat_id}]"
self.llm_model = LLMRequest(
model=global_config.model.tool_use,
request_type="tool_executor",
)
self.llm_model = LLMRequest(model_set=model_config.model_task_config.tool_use, request_type="tool_executor")
# 缓存配置
self.enable_cache = enable_cache
@@ -137,7 +134,7 @@ class ToolExecutor:
return tool_results, used_tools, prompt
else:
return tool_results, [], ""
def _get_tool_definitions(self) -> List[Dict[str, Any]]:
all_tools = get_llm_available_tool_definitions()
user_disabled_tools = global_announcement_manager.get_disabled_chat_tools(self.chat_id)