diff --git a/docs/model_configuration_guide.md b/docs/model_configuration_guide.md index 6bbe05aff..d5afbd296 100644 --- a/docs/model_configuration_guide.md +++ b/docs/model_configuration_guide.md @@ -48,6 +48,7 @@ retry_interval = 10 # 重试间隔(秒) | `timeout` | ❌ | API请求超时时间(秒) | 30 | | `retry_interval` | ❌ | 重试间隔时间(秒) | 10 | +**请注意,对于`client_type`为`gemini`的模型,`base_url`字段无效。** ### 2.3 支持的服务商示例 #### DeepSeek @@ -132,6 +133,7 @@ thinking = {type = "disabled"} # 禁用思考 ``` 请注意,`extra_params` 的配置应该构成一个合法的TOML字典结构,具体内容取决于API服务商的要求。 +**请注意,对于`client_type`为`gemini`的模型,此字段无效。** ### 3.3 配置参数说明 | 参数 | 必填 | 说明 | diff --git a/docs/plugins/api/llm-api.md b/docs/plugins/api/llm-api.md index 9a266933b..d35ea68b6 100644 --- a/docs/plugins/api/llm-api.md +++ b/docs/plugins/api/llm-api.md @@ -24,7 +24,11 @@ def get_available_models() -> Dict[str, TaskConfig]: ### 2. 使用模型生成内容 ```python async def generate_with_model( - prompt: str, model_config: TaskConfig, request_type: str = "plugin.generate", **kwargs + prompt: str, + model_config: TaskConfig, + request_type: str = "plugin.generate", + temperature: Optional[float] = None, + max_tokens: Optional[int] = None, ) -> Tuple[bool, str, str, str]: ``` 使用指定模型生成内容。 @@ -33,7 +37,29 @@ async def generate_with_model( - `prompt`:提示词。 - `model_config`:模型配置对象(从 `get_available_models` 获取)。 - `request_type`:请求类型标识,默认为 `"plugin.generate"`。 -- `**kwargs`:其他模型特定参数,如 `temperature`、`max_tokens` 等。 +- `temperature`:生成内容的温度设置,影响输出的随机性。 +- `max_tokens`:生成内容的最大token数。 **Return:** -- `Tuple[bool, str, str, str]`:返回一个元组,包含(是否成功, 生成的内容, 推理过程, 模型名称)。 \ No newline at end of file +- `Tuple[bool, str, str, str]`:返回一个元组,包含(是否成功, 生成的内容, 推理过程, 模型名称)。 + +### 3. 有Tool情况下使用模型生成内容 +```python +async def generate_with_model_with_tools( + prompt: str, + model_config: TaskConfig, + tool_options: List[Dict[str, Any]] | None = None, + request_type: str = "plugin.generate", + temperature: Optional[float] = None, + max_tokens: Optional[int] = None, +) -> Tuple[bool, str, str, str, List[ToolCall] | None]: +``` +使用指定模型生成内容,并支持工具调用。 + +**Args:** +- `prompt`:提示词。 +- `model_config`:模型配置对象(从 `get_available_models` 获取)。 +- `tool_options`:工具选项列表,包含可用工具的配置,字典为每一个工具的定义,参见[tool-components.md](../tool-components.md#属性说明),可用`tool_api.get_llm_available_tool_definitions()`获取并选择。 +- `request_type`:请求类型标识,默认为 `"plugin.generate"`。 +- `temperature`:生成内容的温度设置,影响输出的随机性。 +- `max_tokens`:生成内容的最大token数。 \ No newline at end of file diff --git a/docs/plugins/api/tool-api.md b/docs/plugins/api/tool-api.md index d86734fcd..bd6e7d2ef 100644 --- a/docs/plugins/api/tool-api.md +++ b/docs/plugins/api/tool-api.md @@ -36,7 +36,7 @@ def get_llm_available_tool_definitions(): **Returns**: - `List[Tuple[str, Dict[str, Any]]]`: 工具定义列表,每个元素为 `(工具名称, 工具定义字典)` 的元组 - - 其具体定义请参照[tool-components.md](../tool-components.md)中的工具定义格式。 + - 其具体定义请参照[tool-components.md](../tool-components.md#属性说明)中的工具定义格式。 #### 示例: ```python diff --git a/docs/plugins/tool-components.md b/docs/plugins/tool-components.md index 059656aa4..b9dc35704 100644 --- a/docs/plugins/tool-components.md +++ b/docs/plugins/tool-components.md @@ -78,7 +78,7 @@ class MyTool(BaseTool): 其构造而成的工具定义为: ```python -{"name": cls.name, "description": cls.description, "parameters": cls.parameters} +definition: Dict[str, Any] = {"name": cls.name, "description": cls.description, "parameters": cls.parameters} ``` ### 方法说明 diff --git a/requirements.txt b/requirements.txt index a09637a91..999bd5fd2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,6 +15,7 @@ matplotlib networkx numpy openai +google-genai pandas peewee pyarrow diff --git a/scripts/import_openie.py b/scripts/import_openie.py index 63a4d9852..1177650d4 100644 --- a/scripts/import_openie.py +++ b/scripts/import_openie.py @@ -24,46 +24,6 @@ OPENIE_DIR = os.path.join(ROOT_PATH, "data", "openie") logger = get_logger("OpenIE导入") -ENV_FILE = os.path.join(ROOT_PATH, ".env") - -if os.path.exists(".env"): - load_dotenv(".env", override=True) - print("成功加载环境变量配置") -else: - print("未找到.env文件,请确保程序所需的环境变量被正确设置") - raise FileNotFoundError(".env 文件不存在,请创建并配置所需的环境变量") - -env_mask = {key: os.getenv(key) for key in os.environ} -def scan_provider(env_config: dict): - provider = {} - - # 利用未初始化 env 时获取的 env_mask 来对新的环境变量集去重 - # 避免 GPG_KEY 这样的变量干扰检查 - env_config = dict(filter(lambda item: item[0] not in env_mask, env_config.items())) - - # 遍历 env_config 的所有键 - for key in env_config: - # 检查键是否符合 {provider}_BASE_URL 或 {provider}_KEY 的格式 - if key.endswith("_BASE_URL") or key.endswith("_KEY"): - # 提取 provider 名称 - provider_name = key.split("_", 1)[0] # 从左分割一次,取第一部分 - - # 初始化 provider 的字典(如果尚未初始化) - if provider_name not in provider: - provider[provider_name] = {"url": None, "key": None} - - # 根据键的类型填充 url 或 key - if key.endswith("_BASE_URL"): - provider[provider_name]["url"] = env_config[key] - elif key.endswith("_KEY"): - provider[provider_name]["key"] = env_config[key] - - # 检查每个 provider 是否同时存在 url 和 key - for provider_name, config in provider.items(): - if config["url"] is None or config["key"] is None: - logger.error(f"provider 内容:{config}\nenv_config 内容:{env_config}") - raise ValueError(f"请检查 '{provider_name}' 提供商配置是否丢失 BASE_URL 或 KEY 环境变量") - def ensure_openie_dir(): """确保OpenIE数据目录存在""" if not os.path.exists(OPENIE_DIR): @@ -214,8 +174,6 @@ def handle_import_openie(openie_data: OpenIE, embed_manager: EmbeddingManager, k def main(): # sourcery skip: dict-comprehension # 新增确认提示 - env_config = {key: os.getenv(key) for key in os.environ} - scan_provider(env_config) print("=== 重要操作确认 ===") print("OpenIE导入时会大量发送请求,可能会撞到请求速度上限,请注意选用的模型") print("同之前样例:在本地模型下,在70分钟内我们发送了约8万条请求,在网络允许下,速度会更快") diff --git a/scripts/info_extraction.py b/scripts/info_extraction.py index cb545a44d..47ad55a8b 100644 --- a/scripts/info_extraction.py +++ b/scripts/info_extraction.py @@ -27,7 +27,6 @@ from rich.progress import ( from raw_data_preprocessor import RAW_DATA_PATH, load_raw_data from src.config.config import global_config, model_config from src.llm_models.utils_model import LLMRequest -from dotenv import load_dotenv logger = get_logger("LPMM知识库-信息提取") @@ -36,45 +35,6 @@ ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) TEMP_DIR = os.path.join(ROOT_PATH, "temp") # IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, "data", "imported_lpmm_data") OPENIE_OUTPUT_DIR = os.path.join(ROOT_PATH, "data", "openie") -ENV_FILE = os.path.join(ROOT_PATH, ".env") - -if os.path.exists(".env"): - load_dotenv(".env", override=True) - print("成功加载环境变量配置") -else: - print("未找到.env文件,请确保程序所需的环境变量被正确设置") - raise FileNotFoundError(".env 文件不存在,请创建并配置所需的环境变量") - -env_mask = {key: os.getenv(key) for key in os.environ} -def scan_provider(env_config: dict): - provider = {} - - # 利用未初始化 env 时获取的 env_mask 来对新的环境变量集去重 - # 避免 GPG_KEY 这样的变量干扰检查 - env_config = dict(filter(lambda item: item[0] not in env_mask, env_config.items())) - - # 遍历 env_config 的所有键 - for key in env_config: - # 检查键是否符合 {provider}_BASE_URL 或 {provider}_KEY 的格式 - if key.endswith("_BASE_URL") or key.endswith("_KEY"): - # 提取 provider 名称 - provider_name = key.split("_", 1)[0] # 从左分割一次,取第一部分 - - # 初始化 provider 的字典(如果尚未初始化) - if provider_name not in provider: - provider[provider_name] = {"url": None, "key": None} - - # 根据键的类型填充 url 或 key - if key.endswith("_BASE_URL"): - provider[provider_name]["url"] = env_config[key] - elif key.endswith("_KEY"): - provider[provider_name]["key"] = env_config[key] - - # 检查每个 provider 是否同时存在 url 和 key - for provider_name, config in provider.items(): - if config["url"] is None or config["key"] is None: - logger.error(f"provider 内容:{config}\nenv_config 内容:{env_config}") - raise ValueError(f"请检查 '{provider_name}' 提供商配置是否丢失 BASE_URL 或 KEY 环境变量") def ensure_dirs(): """确保临时目录和输出目录存在""" @@ -158,8 +118,6 @@ def main(): # sourcery skip: comprehension-to-generator, extract-method # 设置信号处理器 signal.signal(signal.SIGINT, signal_handler) ensure_dirs() # 确保目录存在 - env_config = {key: os.getenv(key) for key in os.environ} - scan_provider(env_config) # 新增用户确认提示 print("=== 重要操作确认,请认真阅读以下内容哦 ===") print("实体提取操作将会花费较多api余额和时间,建议在空闲时段执行。") diff --git a/src/chat/knowledge/embedding_store.py b/src/chat/knowledge/embedding_store.py index 447ef8e7e..d0f6e7744 100644 --- a/src/chat/knowledge/embedding_store.py +++ b/src/chat/knowledge/embedding_store.py @@ -12,8 +12,6 @@ import pandas as pd # import tqdm import faiss -# from .llm_client import LLMClient -# from .lpmmconfig import global_config from .utils.hash import get_sha256 from .global_logger import logger from rich.traceback import install diff --git a/src/chat/knowledge/llm_client.py b/src/chat/knowledge/llm_client.py deleted file mode 100644 index 52d0dca06..000000000 --- a/src/chat/knowledge/llm_client.py +++ /dev/null @@ -1,45 +0,0 @@ -from openai import OpenAI - - -class LLMMessage: - def __init__(self, role, content): - self.role = role - self.content = content - - def to_dict(self): - return {"role": self.role, "content": self.content} - - -class LLMClient: - """LLM客户端,对应一个API服务商""" - - def __init__(self, url, api_key): - self.client = OpenAI( - base_url=url, - api_key=api_key, - ) - - def send_chat_request(self, model, messages): - """发送对话请求,等待返回结果""" - response = self.client.chat.completions.create(model=model, messages=messages, stream=False) - if hasattr(response.choices[0].message, "reasoning_content"): - # 有单独的推理内容块 - reasoning_content = response.choices[0].message.reasoning_content - content = response.choices[0].message.content - else: - # 无单独的推理内容块 - response = response.choices[0].message.content.split("")[-1].split("") - # 如果有推理内容,则分割推理内容和内容 - if len(response) == 2: - reasoning_content = response[0] - content = response[1] - else: - reasoning_content = None - content = response[0] - - return reasoning_content, content - - def send_embedding_request(self, model, text): - """发送嵌入请求,等待返回结果""" - text = text.replace("\n", " ") - return self.client.embeddings.create(input=[text], model=model).data[0].embedding diff --git a/src/chat/knowledge/qa_manager.py b/src/chat/knowledge/qa_manager.py index 1a47767cb..5354447af 100644 --- a/src/chat/knowledge/qa_manager.py +++ b/src/chat/knowledge/qa_manager.py @@ -2,11 +2,7 @@ import time from typing import Tuple, List, Dict, Optional from .global_logger import logger - -# from . import prompt_template from .embedding_store import EmbeddingManager - -# from .llm_client import LLMClient from .kg_manager import KGManager # from .lpmmconfig import global_config diff --git a/src/chat/planner_actions/planner.py b/src/chat/planner_actions/planner.py index 04e17ad6e..85dd5e637 100644 --- a/src/chat/planner_actions/planner.py +++ b/src/chat/planner_actions/planner.py @@ -36,8 +36,6 @@ def init_prompt(): {chat_context_description},以下是具体的聊天内容 {chat_content_block} - - {moderation_prompt} 现在请你根据{by_what}选择合适的action和触发action的消息: diff --git a/src/chat/replyer/default_generator.py b/src/chat/replyer/default_generator.py index 3c8a54922..c2b6e1cb9 100644 --- a/src/chat/replyer/default_generator.py +++ b/src/chat/replyer/default_generator.py @@ -24,13 +24,13 @@ from src.chat.utils.chat_message_builder import ( replace_user_references_sync, ) from src.chat.express.expression_selector import expression_selector -from src.chat.knowledge.knowledge_lib import qa_manager from src.chat.memory_system.memory_activator import MemoryActivator from src.chat.memory_system.instant_memory import InstantMemory from src.mood.mood_manager import mood_manager from src.person_info.relationship_fetcher import relationship_fetcher_manager from src.person_info.person_info import get_person_info_manager from src.plugin_system.base.component_types import ActionInfo +from src.plugin_system.apis import llm_api logger = get_logger("replyer") @@ -102,6 +102,22 @@ def init_prompt(): "s4u_style_prompt", ) + Prompt( + """ +你是一个专门获取知识的助手。你的名字是{bot_name}。现在是{time_now}。 +群里正在进行的聊天内容: +{chat_history} + +现在,{sender}发送了内容:{target_message},你想要回复ta。 +请仔细分析聊天内容,考虑以下几点: +1. 内容中是否包含需要查询信息的问题 +2. 是否有明确的知识获取指令 + +If you need to use the search tool, please directly call the function "lpmm_search_knowledge". If you do not need to use any tool, simply output "No tool needed". +""", + name="lpmm_get_knowledge_prompt", + ) + class DefaultReplyer: def __init__( @@ -698,7 +714,7 @@ class DefaultReplyer: self._time_and_run_task( self.build_tool_info(chat_talking_prompt_short, reply_to, enable_tool=enable_tool), "tool_info" ), - self._time_and_run_task(get_prompt_info(target, threshold=0.38), "prompt_info"), + self._time_and_run_task(self.get_prompt_info(chat_talking_prompt_short, reply_to), "prompt_info"), ) # 任务名称中英文映射 @@ -1000,6 +1016,63 @@ class DefaultReplyer: logger.debug(f"replyer生成内容: {content}") return content, reasoning_content, model_name, tool_calls + async def get_prompt_info(self, message: str, reply_to: str): + related_info = "" + start_time = time.time() + from src.plugins.built_in.knowledge.lpmm_get_knowledge import SearchKnowledgeFromLPMMTool + if not reply_to: + logger.debug("没有回复对象,跳过获取知识库内容") + return "" + sender, content = self._parse_reply_target(reply_to) + if not content: + logger.debug("回复对象内容为空,跳过获取知识库内容") + return "" + logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}") + # 从LPMM知识库获取知识 + try: + # 检查LPMM知识库是否启用 + if not global_config.lpmm_knowledge.enable: + logger.debug("LPMM知识库未启用,跳过获取知识库内容") + return "" + time_now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + + bot_name = global_config.bot.nickname + + prompt = await global_prompt_manager.format_prompt( + "lpmm_get_knowledge_prompt", + bot_name=bot_name, + time_now=time_now, + chat_history=message, + sender=sender, + target_message=content, + ) + _, _, _, _, tool_calls = await llm_api.generate_with_model_with_tools( + prompt, + model_config=model_config.model_task_config.tool_use, + tool_options=[SearchKnowledgeFromLPMMTool.get_tool_definition()], + ) + if tool_calls: + result = await self.tool_executor.execute_tool_call(tool_calls[0], SearchKnowledgeFromLPMMTool()) + end_time = time.time() + if not result or not result.get("content"): + logger.debug("从LPMM知识库获取知识失败,返回空知识...") + return "" + found_knowledge_from_lpmm = result.get("content", "") + logger.debug( + f"从LPMM知识库获取知识,相关信息:{found_knowledge_from_lpmm[:100]}...,信息长度: {len(found_knowledge_from_lpmm)}" + ) + related_info += found_knowledge_from_lpmm + logger.debug(f"获取知识库内容耗时: {(end_time - start_time):.3f}秒") + logger.debug(f"获取知识库内容,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}") + + return f"你有以下这些**知识**:\n{related_info}\n请你**记住上面的知识**,之后可能会用到。\n" + else: + logger.debug("从LPMM知识库获取知识失败,可能是从未导入过知识,返回空知识...") + return "" + except Exception as e: + logger.error(f"获取知识库内容时发生异常: {str(e)}") + return "" + def weighted_sample_no_replacement(items, weights, k) -> list: """ @@ -1035,36 +1108,4 @@ def weighted_sample_no_replacement(items, weights, k) -> list: return selected -async def get_prompt_info(message: str, threshold: float): - related_info = "" - start_time = time.time() - - logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}") - # 从LPMM知识库获取知识 - try: - # 检查LPMM知识库是否启用 - if qa_manager is None: - logger.debug("LPMM知识库已禁用,跳过知识获取") - return "" - - found_knowledge_from_lpmm = await qa_manager.get_knowledge(message) - - end_time = time.time() - if found_knowledge_from_lpmm is not None: - logger.debug( - f"从LPMM知识库获取知识,相关信息:{found_knowledge_from_lpmm[:100]}...,信息长度: {len(found_knowledge_from_lpmm)}" - ) - related_info += found_knowledge_from_lpmm - logger.debug(f"获取知识库内容耗时: {(end_time - start_time):.3f}秒") - logger.debug(f"获取知识库内容,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}") - - return f"你有以下这些**知识**:\n{related_info}\n请你**记住上面的知识**,之后可能会用到。\n" - else: - logger.debug("从LPMM知识库获取知识失败,可能是从未导入过知识,返回空知识...") - return "" - except Exception as e: - logger.error(f"获取知识库内容时发生异常: {str(e)}") - return "" - - init_prompt() diff --git a/src/common/database/database_model.py b/src/common/database/database_model.py index 1d0b8a397..d2b3acce7 100644 --- a/src/common/database/database_model.py +++ b/src/common/database/database_model.py @@ -281,20 +281,6 @@ class Memory(BaseModel): table_name = "memory" -class Knowledges(BaseModel): - """ - 用于存储知识库条目的模型。 - """ - - content = TextField() # 知识内容的文本 - embedding = TextField() # 知识内容的嵌入向量,存储为 JSON 字符串的浮点数列表 - # 可以添加其他元数据字段,如 source, create_time 等 - - class Meta: - # database = db # 继承自 BaseModel - table_name = "knowledges" - - class Expression(BaseModel): """ 用于存储表达风格的模型。 @@ -382,7 +368,6 @@ def create_tables(): ImageDescriptions, OnlineTime, PersonInfo, - Knowledges, Expression, ThinkingLog, GraphNodes, # 添加图节点表 @@ -408,7 +393,6 @@ def initialize_database(): ImageDescriptions, OnlineTime, PersonInfo, - Knowledges, Expression, Memory, ThinkingLog, diff --git a/src/config/api_ada_configs.py b/src/config/api_ada_configs.py index 5f3398e0e..9692aced3 100644 --- a/src/config/api_ada_configs.py +++ b/src/config/api_ada_configs.py @@ -35,7 +35,7 @@ class APIProvider(ConfigBase): """确保api_key在repr中不被显示""" if not self.api_key: raise ValueError("API密钥不能为空,请在配置中设置有效的API密钥。") - if not self.base_url: + if not self.base_url and self.client_type != "gemini": raise ValueError("API基础URL不能为空,请在配置中设置有效的基础URL。") if not self.name: raise ValueError("API提供商名称不能为空,请在配置中设置有效的名称。") diff --git a/src/llm_models/model_client/gemini_client.py b/src/llm_models/model_client/gemini_client.py index d00ae8b55..a74b466f1 100644 --- a/src/llm_models/model_client/gemini_client.py +++ b/src/llm_models/model_client/gemini_client.py @@ -16,6 +16,9 @@ from google.genai.types import ( GenerateContentConfig, EmbedContentResponse, EmbedContentConfig, + SafetySetting, + HarmCategory, + HarmBlockThreshold, ) from google.genai.errors import ( ClientError, @@ -41,6 +44,14 @@ from ..payload_content.tool_option import ToolOption, ToolParam, ToolCall logger = get_logger("Gemini客户端") +gemini_safe_settings = [ + SafetySetting(category=HarmCategory.HARM_CATEGORY_HATE_SPEECH, threshold=HarmBlockThreshold.BLOCK_NONE), + SafetySetting(category=HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, threshold=HarmBlockThreshold.BLOCK_NONE), + SafetySetting(category=HarmCategory.HARM_CATEGORY_HARASSMENT, threshold=HarmBlockThreshold.BLOCK_NONE), + SafetySetting(category=HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, threshold=HarmBlockThreshold.BLOCK_NONE), + SafetySetting(category=HarmCategory.HARM_CATEGORY_CIVIC_INTEGRITY, threshold=HarmBlockThreshold.BLOCK_NONE), +] + def _convert_messages( messages: list[Message], @@ -322,7 +333,7 @@ class GeminiClient(BaseClient): message_list: list[Message], tool_options: list[ToolOption] | None = None, max_tokens: int = 1024, - temperature: float = 0.7, + temperature: float = 0.4, response_format: RespFormat | None = None, stream_response_handler: Optional[ Callable[ @@ -369,9 +380,12 @@ class GeminiClient(BaseClient): "thinking_config": ThinkingConfig( include_thoughts=True, thinking_budget=( - extra_params["thinking_budget"] if extra_params and "thinking_budget" in extra_params else None + extra_params["thinking_budget"] + if extra_params and "thinking_budget" in extra_params + else int(max_tokens / 2) # 默认思考预算为最大token数的一半,防止空回复 ), ), + "safety_settings": gemini_safe_settings, # 防止空回复问题 } if tools: generation_config_dict["tools"] = Tool(function_declarations=tools) @@ -483,6 +497,61 @@ class GeminiClient(BaseClient): return response + def get_audio_transcriptions( + self, model_info: ModelInfo, audio_base64: str, extra_params: dict[str, Any] | None = None + ) -> APIResponse: + """ + 获取音频转录 + :param model_info: 模型信息 + :param audio_base64: 音频文件的Base64编码字符串 + :param extra_params: 额外参数(可选) + :return: 转录响应 + """ + generation_config_dict = { + "max_output_tokens": 2048, + "response_modalities": ["TEXT"], + "thinking_config": ThinkingConfig( + include_thoughts=True, + thinking_budget=( + extra_params["thinking_budget"] if extra_params and "thinking_budget" in extra_params else 1024 + ), + ), + "safety_settings": gemini_safe_settings, + } + generate_content_config = GenerateContentConfig(**generation_config_dict) + prompt = "Generate a transcript of the speech. The language of the transcript should **match the language of the speech**." + try: + raw_response: GenerateContentResponse = self.client.models.generate_content( + model=model_info.model_identifier, + contents=[ + Content( + role="user", + parts=[ + Part.from_text(text=prompt), + Part.from_bytes(data=base64.b64decode(audio_base64), mime_type="audio/wav"), + ], + ) + ], + config=generate_content_config, + ) + resp, usage_record = _default_normal_response_parser(raw_response) + except (ClientError, ServerError) as e: + # 重封装ClientError和ServerError为RespNotOkException + raise RespNotOkException(e.code) from None + except Exception as e: + raise NetworkConnectionError() from e + + if usage_record: + resp.usage = UsageRecord( + model_name=model_info.name, + provider_name=model_info.api_provider, + prompt_tokens=usage_record[0], + completion_tokens=usage_record[1], + total_tokens=usage_record[2], + ) + + return resp + def get_support_image_formats(self) -> list[str]: """ 获取支持的图片格式 diff --git a/src/llm_models/utils_model.py b/src/llm_models/utils_model.py index d2a960f1d..48ef0c082 100644 --- a/src/llm_models/utils_model.py +++ b/src/llm_models/utils_model.py @@ -48,8 +48,10 @@ class LLMRequest: self.task_name = request_type self.model_for_task = model_set self.request_type = request_type - self.model_usage: Dict[str, Tuple[int, int]] = {model: (0, 0) for model in self.model_for_task.model_list} - """模型使用量记录,用于进行负载均衡,对应为(total_tokens, penalty),惩罚值是为了能在某个模型请求不给力的时候进行调整""" + self.model_usage: Dict[str, Tuple[int, int, int]] = { + model: (0, 0, 0) for model in self.model_for_task.model_list + } + """模型使用量记录,用于进行负载均衡,对应为(total_tokens, penalty, usage_penalty),惩罚值是为了能在某个模型请求不给力或正在被使用的时候进行调整""" self.pri_in = 0 self.pri_out = 0 @@ -181,7 +183,8 @@ class LLMRequest: endpoint="/chat/completions", ) if not content: - raise RuntimeError("获取LLM生成内容失败") + logger.warning("生成的响应为空") + content = "生成的响应为空,请检查模型配置或输入内容是否正确" return content, (reasoning_content, model_info.name, tool_calls) @@ -225,12 +228,15 @@ class LLMRequest: 根据总tokens和惩罚值选择的模型 """ least_used_model_name = min( - self.model_usage, key=lambda k: self.model_usage[k][0] + self.model_usage[k][1] * 300 + self.model_usage, + key=lambda k: self.model_usage[k][0] + self.model_usage[k][1] * 300 + self.model_usage[k][2] * 1000, ) model_info = model_config.get_model_info(least_used_model_name) api_provider = model_config.get_provider(model_info.api_provider) client = client_registry.get_client_class(api_provider.client_type)(copy.deepcopy(api_provider)) logger.debug(f"选择请求模型: {model_info.name}") + total_tokens, penalty, usage_penalty = self.model_usage[model_info.name] + self.model_usage[model_info.name] = (total_tokens, penalty, usage_penalty + 1) # 增加使用惩罚值防止连续使用 return model_info, api_provider, client async def _execute_request( @@ -288,8 +294,8 @@ class LLMRequest: except Exception as e: logger.debug(f"请求失败: {str(e)}") # 处理异常 - total_tokens, penalty = self.model_usage[model_info.name] - self.model_usage[model_info.name] = (total_tokens, penalty + 1) + total_tokens, penalty, usage_penalty = self.model_usage[model_info.name] + self.model_usage[model_info.name] = (total_tokens, penalty + 1, usage_penalty) wait_interval, compressed_messages = self._default_exception_handler( e, @@ -308,6 +314,8 @@ class LLMRequest: finally: # 放在finally防止死循环 retry_remain -= 1 + total_tokens, penalty, usage_penalty = self.model_usage[model_info.name] + self.model_usage[model_info.name] = (total_tokens, penalty, usage_penalty - 1) # 使用结束,减少使用惩罚值 logger.error(f"模型 '{model_info.name}' 请求失败,达到最大重试次数 {api_provider.max_retry} 次") raise RuntimeError("请求失败,已达到最大重试次数") diff --git a/src/plugin_system/apis/llm_api.py b/src/plugin_system/apis/llm_api.py index eaf48556b..9d37a8e34 100644 --- a/src/plugin_system/apis/llm_api.py +++ b/src/plugin_system/apis/llm_api.py @@ -7,8 +7,9 @@ success, response, reasoning, model_name = await llm_api.generate_with_model(prompt, model_config) """ -from typing import Tuple, Dict +from typing import Tuple, Dict, List, Any, Optional from src.common.logger import get_logger +from src.llm_models.payload_content.tool_option import ToolCall from src.llm_models.utils_model import LLMRequest from src.config.config import global_config, model_config from src.config.api_ada_configs import TaskConfig @@ -52,7 +53,11 @@ def get_available_models() -> Dict[str, TaskConfig]: async def generate_with_model( - prompt: str, model_config: TaskConfig, request_type: str = "plugin.generate", **kwargs + prompt: str, + model_config: TaskConfig, + request_type: str = "plugin.generate", + temperature: Optional[float] = None, + max_tokens: Optional[int] = None, ) -> Tuple[bool, str, str, str]: """使用指定模型生成内容 @@ -60,7 +65,6 @@ async def generate_with_model( prompt: 提示词 model_config: 模型配置(从 get_available_models 获取的模型配置) request_type: 请求类型标识 - **kwargs: 其他模型特定参数,如temperature、max_tokens等 Returns: Tuple[bool, str, str, str]: (是否成功, 生成的内容, 推理过程, 模型名称) @@ -70,12 +74,53 @@ async def generate_with_model( logger.info(f"[LLMAPI] 使用模型集合 {model_name_list} 生成内容") logger.debug(f"[LLMAPI] 完整提示词: {prompt}") - llm_request = LLMRequest(model_set=model_config, request_type=request_type, **kwargs) + llm_request = LLMRequest(model_set=model_config, request_type=request_type) - response, (reasoning_content, model_name, _) = await llm_request.generate_response_async(prompt) + response, (reasoning_content, model_name, _) = await llm_request.generate_response_async(prompt, temperature=temperature, max_tokens=max_tokens) return True, response, reasoning_content, model_name except Exception as e: error_msg = f"生成内容时出错: {str(e)}" logger.error(f"[LLMAPI] {error_msg}") return False, error_msg, "", "" + +async def generate_with_model_with_tools( + prompt: str, + model_config: TaskConfig, + tool_options: List[Dict[str, Any]] | None = None, + request_type: str = "plugin.generate", + temperature: Optional[float] = None, + max_tokens: Optional[int] = None, +) -> Tuple[bool, str, str, str, List[ToolCall] | None]: + """使用指定模型和工具生成内容 + + Args: + prompt: 提示词 + model_config: 模型配置(从 get_available_models 获取的模型配置) + tool_options: 工具选项列表 + request_type: 请求类型标识 + temperature: 温度参数 + max_tokens: 最大token数 + + Returns: + Tuple[bool, str, str, str]: (是否成功, 生成的内容, 推理过程, 模型名称) + """ + try: + model_name_list = model_config.model_list + logger.info(f"[LLMAPI] 使用模型集合 {model_name_list} 生成内容") + logger.debug(f"[LLMAPI] 完整提示词: {prompt}") + + llm_request = LLMRequest(model_set=model_config, request_type=request_type) + + response, (reasoning_content, model_name, tool_call) = await llm_request.generate_response_async( + prompt, + tools=tool_options, + temperature=temperature, + max_tokens=max_tokens + ) + return True, response, reasoning_content, model_name, tool_call + + except Exception as e: + error_msg = f"生成内容时出错: {str(e)}" + logger.error(f"[LLMAPI] {error_msg}") + return False, error_msg, "", "", None diff --git a/src/plugin_system/base/base_plugin.py b/src/plugin_system/base/base_plugin.py index 3cf82390e..ea28c5143 100644 --- a/src/plugin_system/base/base_plugin.py +++ b/src/plugin_system/base/base_plugin.py @@ -3,10 +3,11 @@ from typing import List, Type, Tuple, Union from .plugin_base import PluginBase from src.common.logger import get_logger -from src.plugin_system.base.component_types import ActionInfo, CommandInfo, EventHandlerInfo +from src.plugin_system.base.component_types import ActionInfo, CommandInfo, EventHandlerInfo, ToolInfo from .base_action import BaseAction from .base_command import BaseCommand from .base_events_handler import BaseEventHandler +from .base_tool import BaseTool logger = get_logger("base_plugin") @@ -31,6 +32,7 @@ class BasePlugin(PluginBase): Tuple[ActionInfo, Type[BaseAction]], Tuple[CommandInfo, Type[BaseCommand]], Tuple[EventHandlerInfo, Type[BaseEventHandler]], + Tuple[ToolInfo, Type[BaseTool]], ] ]: """获取插件包含的组件列表 diff --git a/src/plugin_system/core/tool_use.py b/src/plugin_system/core/tool_use.py index d1b3ba158..9a37bc1d8 100644 --- a/src/plugin_system/core/tool_use.py +++ b/src/plugin_system/core/tool_use.py @@ -1,6 +1,7 @@ import time from typing import List, Dict, Tuple, Optional, Any from src.plugin_system.apis.tool_api import get_llm_available_tool_definitions, get_tool_instance +from src.plugin_system.base.base_tool import BaseTool from src.plugin_system.core.global_announcement_manager import global_announcement_manager from src.llm_models.utils_model import LLMRequest from src.llm_models.payload_content import ToolCall @@ -114,7 +115,7 @@ class ToolExecutor: ) # 执行工具调用 - tool_results, used_tools = await self._execute_tool_calls(tool_calls) + tool_results, used_tools = await self.execute_tool_calls(tool_calls) # 缓存结果 if tool_results: @@ -133,7 +134,7 @@ class ToolExecutor: user_disabled_tools = global_announcement_manager.get_disabled_chat_tools(self.chat_id) return [definition for name, definition in all_tools if name not in user_disabled_tools] - async def _execute_tool_calls(self, tool_calls: Optional[List[ToolCall]]) -> Tuple[List[Dict[str, Any]], List[str]]: + async def execute_tool_calls(self, tool_calls: Optional[List[ToolCall]]) -> Tuple[List[Dict[str, Any]], List[str]]: """执行工具调用 Args: @@ -161,7 +162,7 @@ class ToolExecutor: logger.debug(f"{self.log_prefix}执行工具: {tool_name}") # 执行工具 - result = await self._execute_tool_call(tool_call) + result = await self.execute_tool_call(tool_call) if result: tool_info = { @@ -194,7 +195,7 @@ class ToolExecutor: return tool_results, used_tools - async def _execute_tool_call(self, tool_call: ToolCall) -> Optional[Dict[str, Any]]: + async def execute_tool_call(self, tool_call: ToolCall, tool_instance: Optional[BaseTool] = None) -> Optional[Dict[str, Any]]: # sourcery skip: use-assigned-variable """执行单个工具调用 @@ -210,7 +211,7 @@ class ToolExecutor: function_args["llm_called"] = True # 标记为LLM调用 # 获取对应工具实例 - tool_instance = get_tool_instance(function_name) + tool_instance = tool_instance or get_tool_instance(function_name) if not tool_instance: logger.warning(f"未知工具名称: {function_name}") return None @@ -297,7 +298,7 @@ class ToolExecutor: if expired_keys: logger.debug(f"{self.log_prefix}清理了{len(expired_keys)}个过期缓存") - async def execute_specific_tool(self, tool_name: str, tool_args: Dict) -> Optional[Dict]: + async def execute_specific_tool_simple(self, tool_name: str, tool_args: Dict) -> Optional[Dict]: """直接执行指定工具 Args: @@ -317,7 +318,7 @@ class ToolExecutor: logger.info(f"{self.log_prefix}直接执行工具: {tool_name}") - result = await self._execute_tool_call(tool_call) + result = await self.execute_tool_call(tool_call) if result: tool_info = { @@ -408,7 +409,7 @@ results, used_tools, prompt = await executor.execute_from_chat_message( ) # 5. 直接执行特定工具 -result = await executor.execute_specific_tool( +result = await executor.execute_specific_tool_simple( tool_name="get_knowledge", tool_args={"query": "机器学习"} ) diff --git a/src/plugins/built_in/knowledge/get_knowledge.py b/src/plugins/built_in/knowledge/get_knowledge.py deleted file mode 100644 index ce90cb680..000000000 --- a/src/plugins/built_in/knowledge/get_knowledge.py +++ /dev/null @@ -1,131 +0,0 @@ -import json # Added for parsing embedding -import math # Added for cosine similarity -from typing import Any, Union, List # Added List - -from src.chat.utils.utils import get_embedding -from src.common.database.database_model import Knowledges # Updated import -from src.common.logger import get_logger -from src.plugin_system import BaseTool, ToolParamType - - -logger = get_logger("get_knowledge_tool") - - -class SearchKnowledgeTool(BaseTool): - """从知识库中搜索相关信息的工具""" - - name = "search_knowledge" - description = "使用工具从知识库中搜索相关信息" - parameters = [ - ("query", ToolParamType.STRING, "搜索查询关键词", True, None), - ("threshold", ToolParamType.FLOAT, "相似度阈值,0.0到1.0之间", False, None), - ] - - async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]: - """执行知识库搜索 - - Args: - function_args: 工具参数 - - Returns: - dict: 工具执行结果 - """ - query = "" # Initialize query to ensure it's defined in except block - try: - query = function_args.get("query") - threshold = function_args.get("threshold", 0.4) - - # 调用知识库搜索 - embedding = await get_embedding(query, request_type="info_retrieval") - if embedding: - knowledge_info = self.get_info_from_db(embedding, limit=3, threshold=threshold) - if knowledge_info: - content = f"你知道这些知识: {knowledge_info}" - else: - content = f"你不太了解有关{query}的知识" - return {"type": "knowledge", "id": query, "content": content} - return {"type": "info", "id": query, "content": f"无法获取关于'{query}'的嵌入向量,你知识库炸了"} - except Exception as e: - logger.error(f"知识库搜索工具执行失败: {str(e)}") - return {"type": "info", "id": query, "content": f"知识库搜索失败,炸了: {str(e)}"} - - @staticmethod - def _cosine_similarity(vec1: List[float], vec2: List[float]) -> float: - """计算两个向量之间的余弦相似度""" - dot_product = sum(p * q for p, q in zip(vec1, vec2, strict=False)) - magnitude1 = math.sqrt(sum(p * p for p in vec1)) - magnitude2 = math.sqrt(sum(q * q for q in vec2)) - if magnitude1 == 0 or magnitude2 == 0: - return 0.0 - return dot_product / (magnitude1 * magnitude2) - - @staticmethod - def get_info_from_db( - query_embedding: list[float], limit: int = 1, threshold: float = 0.5, return_raw: bool = False - ) -> Union[str, list]: - """从数据库中获取相关信息 - - Args: - query_embedding: 查询的嵌入向量 - limit: 最大返回结果数 - threshold: 相似度阈值 - return_raw: 是否返回原始结果 - - Returns: - Union[str, list]: 格式化的信息字符串或原始结果列表 - """ - if not query_embedding: - return [] if return_raw else "" - - similar_items = [] - try: - all_knowledges = Knowledges.select() - for item in all_knowledges: - try: - item_embedding_str = item.embedding - if not item_embedding_str: - logger.warning(f"Knowledge item ID {item.id} has empty embedding string.") - continue - item_embedding = json.loads(item_embedding_str) - if not isinstance(item_embedding, list) or not all( - isinstance(x, (int, float)) for x in item_embedding - ): - logger.warning(f"Knowledge item ID {item.id} has invalid embedding format after JSON parsing.") - continue - except json.JSONDecodeError: - logger.warning(f"Failed to parse embedding for knowledge item ID {item.id}") - continue - except AttributeError: - logger.warning(f"Knowledge item ID {item.id} missing 'embedding' attribute or it's not a string.") - continue - - similarity = SearchKnowledgeTool._cosine_similarity(query_embedding, item_embedding) - - if similarity >= threshold: - similar_items.append({"content": item.content, "similarity": similarity, "raw_item": item}) - - # 按相似度降序排序 - similar_items.sort(key=lambda x: x["similarity"], reverse=True) - - # 应用限制 - results = similar_items[:limit] - logger.debug(f"知识库查询后,符合条件的结果数量: {len(results)}") - - except Exception as e: - logger.error(f"从 Peewee 数据库获取知识信息失败: {str(e)}") - return [] if return_raw else "" - - if not results: - return [] if return_raw else "" - - if return_raw: - # Peewee 模型实例不能直接序列化为 JSON,如果需要原始模型,调用者需要处理 - # 这里返回包含内容和相似度的字典列表 - return [{"content": r["content"], "similarity": r["similarity"]} for r in results] - else: - # 返回所有找到的内容,用换行分隔 - return "\n".join(str(result["content"]) for result in results) - - -# 注册工具 -# register_tool(SearchKnowledgeTool) diff --git a/src/plugins/built_in/knowledge/lpmm_get_knowledge.py b/src/plugins/built_in/knowledge/lpmm_get_knowledge.py index da20c348b..fd3d811b2 100644 --- a/src/plugins/built_in/knowledge/lpmm_get_knowledge.py +++ b/src/plugins/built_in/knowledge/lpmm_get_knowledge.py @@ -1,6 +1,7 @@ from typing import Dict, Any from src.common.logger import get_logger +from src.config.config import global_config from src.chat.knowledge.knowledge_lib import qa_manager from src.plugin_system import BaseTool, ToolParamType @@ -16,6 +17,7 @@ class SearchKnowledgeFromLPMMTool(BaseTool): ("query", ToolParamType.STRING, "搜索查询关键词", True, None), ("threshold", ToolParamType.FLOAT, "相似度阈值,0.0到1.0之间", False, None), ] + available_for_llm = global_config.lpmm_knowledge.enable async def execute(self, function_args: Dict[str, Any]) -> Dict[str, Any]: """执行知识库搜索