Merge branch 'dev' of https://github.com/Windpicker-owo/MaiBot into dev
This commit is contained in:
@@ -48,6 +48,7 @@ retry_interval = 10 # 重试间隔(秒)
|
|||||||
| `timeout` | ❌ | API请求超时时间(秒) | 30 |
|
| `timeout` | ❌ | API请求超时时间(秒) | 30 |
|
||||||
| `retry_interval` | ❌ | 重试间隔时间(秒) | 10 |
|
| `retry_interval` | ❌ | 重试间隔时间(秒) | 10 |
|
||||||
|
|
||||||
|
**请注意,对于`client_type`为`gemini`的模型,`base_url`字段无效。**
|
||||||
### 2.3 支持的服务商示例
|
### 2.3 支持的服务商示例
|
||||||
|
|
||||||
#### DeepSeek
|
#### DeepSeek
|
||||||
@@ -132,6 +133,7 @@ thinking = {type = "disabled"} # 禁用思考
|
|||||||
```
|
```
|
||||||
请注意,`extra_params` 的配置应该构成一个合法的TOML字典结构,具体内容取决于API服务商的要求。
|
请注意,`extra_params` 的配置应该构成一个合法的TOML字典结构,具体内容取决于API服务商的要求。
|
||||||
|
|
||||||
|
**请注意,对于`client_type`为`gemini`的模型,此字段无效。**
|
||||||
### 3.3 配置参数说明
|
### 3.3 配置参数说明
|
||||||
|
|
||||||
| 参数 | 必填 | 说明 |
|
| 参数 | 必填 | 说明 |
|
||||||
|
|||||||
@@ -24,7 +24,11 @@ def get_available_models() -> Dict[str, TaskConfig]:
|
|||||||
### 2. 使用模型生成内容
|
### 2. 使用模型生成内容
|
||||||
```python
|
```python
|
||||||
async def generate_with_model(
|
async def generate_with_model(
|
||||||
prompt: str, model_config: TaskConfig, request_type: str = "plugin.generate", **kwargs
|
prompt: str,
|
||||||
|
model_config: TaskConfig,
|
||||||
|
request_type: str = "plugin.generate",
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
) -> Tuple[bool, str, str, str]:
|
) -> Tuple[bool, str, str, str]:
|
||||||
```
|
```
|
||||||
使用指定模型生成内容。
|
使用指定模型生成内容。
|
||||||
@@ -33,7 +37,29 @@ async def generate_with_model(
|
|||||||
- `prompt`:提示词。
|
- `prompt`:提示词。
|
||||||
- `model_config`:模型配置对象(从 `get_available_models` 获取)。
|
- `model_config`:模型配置对象(从 `get_available_models` 获取)。
|
||||||
- `request_type`:请求类型标识,默认为 `"plugin.generate"`。
|
- `request_type`:请求类型标识,默认为 `"plugin.generate"`。
|
||||||
- `**kwargs`:其他模型特定参数,如 `temperature`、`max_tokens` 等。
|
- `temperature`:生成内容的温度设置,影响输出的随机性。
|
||||||
|
- `max_tokens`:生成内容的最大token数。
|
||||||
|
|
||||||
**Return:**
|
**Return:**
|
||||||
- `Tuple[bool, str, str, str]`:返回一个元组,包含(是否成功, 生成的内容, 推理过程, 模型名称)。
|
- `Tuple[bool, str, str, str]`:返回一个元组,包含(是否成功, 生成的内容, 推理过程, 模型名称)。
|
||||||
|
|
||||||
|
### 3. 有Tool情况下使用模型生成内容
|
||||||
|
```python
|
||||||
|
async def generate_with_model_with_tools(
|
||||||
|
prompt: str,
|
||||||
|
model_config: TaskConfig,
|
||||||
|
tool_options: List[Dict[str, Any]] | None = None,
|
||||||
|
request_type: str = "plugin.generate",
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
) -> Tuple[bool, str, str, str, List[ToolCall] | None]:
|
||||||
|
```
|
||||||
|
使用指定模型生成内容,并支持工具调用。
|
||||||
|
|
||||||
|
**Args:**
|
||||||
|
- `prompt`:提示词。
|
||||||
|
- `model_config`:模型配置对象(从 `get_available_models` 获取)。
|
||||||
|
- `tool_options`:工具选项列表,包含可用工具的配置,字典为每一个工具的定义,参见[tool-components.md](../tool-components.md#属性说明),可用`tool_api.get_llm_available_tool_definitions()`获取并选择。
|
||||||
|
- `request_type`:请求类型标识,默认为 `"plugin.generate"`。
|
||||||
|
- `temperature`:生成内容的温度设置,影响输出的随机性。
|
||||||
|
- `max_tokens`:生成内容的最大token数。
|
||||||
@@ -36,7 +36,7 @@ def get_llm_available_tool_definitions():
|
|||||||
|
|
||||||
**Returns**:
|
**Returns**:
|
||||||
- `List[Tuple[str, Dict[str, Any]]]`: 工具定义列表,每个元素为 `(工具名称, 工具定义字典)` 的元组
|
- `List[Tuple[str, Dict[str, Any]]]`: 工具定义列表,每个元素为 `(工具名称, 工具定义字典)` 的元组
|
||||||
- 其具体定义请参照[tool-components.md](../tool-components.md)中的工具定义格式。
|
- 其具体定义请参照[tool-components.md](../tool-components.md#属性说明)中的工具定义格式。
|
||||||
#### 示例:
|
#### 示例:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
|
|||||||
@@ -78,7 +78,7 @@ class MyTool(BaseTool):
|
|||||||
|
|
||||||
其构造而成的工具定义为:
|
其构造而成的工具定义为:
|
||||||
```python
|
```python
|
||||||
{"name": cls.name, "description": cls.description, "parameters": cls.parameters}
|
definition: Dict[str, Any] = {"name": cls.name, "description": cls.description, "parameters": cls.parameters}
|
||||||
```
|
```
|
||||||
|
|
||||||
### 方法说明
|
### 方法说明
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ matplotlib
|
|||||||
networkx
|
networkx
|
||||||
numpy
|
numpy
|
||||||
openai
|
openai
|
||||||
|
google-genai
|
||||||
pandas
|
pandas
|
||||||
peewee
|
peewee
|
||||||
pyarrow
|
pyarrow
|
||||||
|
|||||||
@@ -24,46 +24,6 @@ OPENIE_DIR = os.path.join(ROOT_PATH, "data", "openie")
|
|||||||
|
|
||||||
logger = get_logger("OpenIE导入")
|
logger = get_logger("OpenIE导入")
|
||||||
|
|
||||||
ENV_FILE = os.path.join(ROOT_PATH, ".env")
|
|
||||||
|
|
||||||
if os.path.exists(".env"):
|
|
||||||
load_dotenv(".env", override=True)
|
|
||||||
print("成功加载环境变量配置")
|
|
||||||
else:
|
|
||||||
print("未找到.env文件,请确保程序所需的环境变量被正确设置")
|
|
||||||
raise FileNotFoundError(".env 文件不存在,请创建并配置所需的环境变量")
|
|
||||||
|
|
||||||
env_mask = {key: os.getenv(key) for key in os.environ}
|
|
||||||
def scan_provider(env_config: dict):
|
|
||||||
provider = {}
|
|
||||||
|
|
||||||
# 利用未初始化 env 时获取的 env_mask 来对新的环境变量集去重
|
|
||||||
# 避免 GPG_KEY 这样的变量干扰检查
|
|
||||||
env_config = dict(filter(lambda item: item[0] not in env_mask, env_config.items()))
|
|
||||||
|
|
||||||
# 遍历 env_config 的所有键
|
|
||||||
for key in env_config:
|
|
||||||
# 检查键是否符合 {provider}_BASE_URL 或 {provider}_KEY 的格式
|
|
||||||
if key.endswith("_BASE_URL") or key.endswith("_KEY"):
|
|
||||||
# 提取 provider 名称
|
|
||||||
provider_name = key.split("_", 1)[0] # 从左分割一次,取第一部分
|
|
||||||
|
|
||||||
# 初始化 provider 的字典(如果尚未初始化)
|
|
||||||
if provider_name not in provider:
|
|
||||||
provider[provider_name] = {"url": None, "key": None}
|
|
||||||
|
|
||||||
# 根据键的类型填充 url 或 key
|
|
||||||
if key.endswith("_BASE_URL"):
|
|
||||||
provider[provider_name]["url"] = env_config[key]
|
|
||||||
elif key.endswith("_KEY"):
|
|
||||||
provider[provider_name]["key"] = env_config[key]
|
|
||||||
|
|
||||||
# 检查每个 provider 是否同时存在 url 和 key
|
|
||||||
for provider_name, config in provider.items():
|
|
||||||
if config["url"] is None or config["key"] is None:
|
|
||||||
logger.error(f"provider 内容:{config}\nenv_config 内容:{env_config}")
|
|
||||||
raise ValueError(f"请检查 '{provider_name}' 提供商配置是否丢失 BASE_URL 或 KEY 环境变量")
|
|
||||||
|
|
||||||
def ensure_openie_dir():
|
def ensure_openie_dir():
|
||||||
"""确保OpenIE数据目录存在"""
|
"""确保OpenIE数据目录存在"""
|
||||||
if not os.path.exists(OPENIE_DIR):
|
if not os.path.exists(OPENIE_DIR):
|
||||||
@@ -214,8 +174,6 @@ def handle_import_openie(openie_data: OpenIE, embed_manager: EmbeddingManager, k
|
|||||||
|
|
||||||
def main(): # sourcery skip: dict-comprehension
|
def main(): # sourcery skip: dict-comprehension
|
||||||
# 新增确认提示
|
# 新增确认提示
|
||||||
env_config = {key: os.getenv(key) for key in os.environ}
|
|
||||||
scan_provider(env_config)
|
|
||||||
print("=== 重要操作确认 ===")
|
print("=== 重要操作确认 ===")
|
||||||
print("OpenIE导入时会大量发送请求,可能会撞到请求速度上限,请注意选用的模型")
|
print("OpenIE导入时会大量发送请求,可能会撞到请求速度上限,请注意选用的模型")
|
||||||
print("同之前样例:在本地模型下,在70分钟内我们发送了约8万条请求,在网络允许下,速度会更快")
|
print("同之前样例:在本地模型下,在70分钟内我们发送了约8万条请求,在网络允许下,速度会更快")
|
||||||
|
|||||||
@@ -27,7 +27,6 @@ from rich.progress import (
|
|||||||
from raw_data_preprocessor import RAW_DATA_PATH, load_raw_data
|
from raw_data_preprocessor import RAW_DATA_PATH, load_raw_data
|
||||||
from src.config.config import global_config, model_config
|
from src.config.config import global_config, model_config
|
||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
from dotenv import load_dotenv
|
|
||||||
|
|
||||||
logger = get_logger("LPMM知识库-信息提取")
|
logger = get_logger("LPMM知识库-信息提取")
|
||||||
|
|
||||||
@@ -36,45 +35,6 @@ ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
|||||||
TEMP_DIR = os.path.join(ROOT_PATH, "temp")
|
TEMP_DIR = os.path.join(ROOT_PATH, "temp")
|
||||||
# IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, "data", "imported_lpmm_data")
|
# IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, "data", "imported_lpmm_data")
|
||||||
OPENIE_OUTPUT_DIR = os.path.join(ROOT_PATH, "data", "openie")
|
OPENIE_OUTPUT_DIR = os.path.join(ROOT_PATH, "data", "openie")
|
||||||
ENV_FILE = os.path.join(ROOT_PATH, ".env")
|
|
||||||
|
|
||||||
if os.path.exists(".env"):
|
|
||||||
load_dotenv(".env", override=True)
|
|
||||||
print("成功加载环境变量配置")
|
|
||||||
else:
|
|
||||||
print("未找到.env文件,请确保程序所需的环境变量被正确设置")
|
|
||||||
raise FileNotFoundError(".env 文件不存在,请创建并配置所需的环境变量")
|
|
||||||
|
|
||||||
env_mask = {key: os.getenv(key) for key in os.environ}
|
|
||||||
def scan_provider(env_config: dict):
|
|
||||||
provider = {}
|
|
||||||
|
|
||||||
# 利用未初始化 env 时获取的 env_mask 来对新的环境变量集去重
|
|
||||||
# 避免 GPG_KEY 这样的变量干扰检查
|
|
||||||
env_config = dict(filter(lambda item: item[0] not in env_mask, env_config.items()))
|
|
||||||
|
|
||||||
# 遍历 env_config 的所有键
|
|
||||||
for key in env_config:
|
|
||||||
# 检查键是否符合 {provider}_BASE_URL 或 {provider}_KEY 的格式
|
|
||||||
if key.endswith("_BASE_URL") or key.endswith("_KEY"):
|
|
||||||
# 提取 provider 名称
|
|
||||||
provider_name = key.split("_", 1)[0] # 从左分割一次,取第一部分
|
|
||||||
|
|
||||||
# 初始化 provider 的字典(如果尚未初始化)
|
|
||||||
if provider_name not in provider:
|
|
||||||
provider[provider_name] = {"url": None, "key": None}
|
|
||||||
|
|
||||||
# 根据键的类型填充 url 或 key
|
|
||||||
if key.endswith("_BASE_URL"):
|
|
||||||
provider[provider_name]["url"] = env_config[key]
|
|
||||||
elif key.endswith("_KEY"):
|
|
||||||
provider[provider_name]["key"] = env_config[key]
|
|
||||||
|
|
||||||
# 检查每个 provider 是否同时存在 url 和 key
|
|
||||||
for provider_name, config in provider.items():
|
|
||||||
if config["url"] is None or config["key"] is None:
|
|
||||||
logger.error(f"provider 内容:{config}\nenv_config 内容:{env_config}")
|
|
||||||
raise ValueError(f"请检查 '{provider_name}' 提供商配置是否丢失 BASE_URL 或 KEY 环境变量")
|
|
||||||
|
|
||||||
def ensure_dirs():
|
def ensure_dirs():
|
||||||
"""确保临时目录和输出目录存在"""
|
"""确保临时目录和输出目录存在"""
|
||||||
@@ -158,8 +118,6 @@ def main(): # sourcery skip: comprehension-to-generator, extract-method
|
|||||||
# 设置信号处理器
|
# 设置信号处理器
|
||||||
signal.signal(signal.SIGINT, signal_handler)
|
signal.signal(signal.SIGINT, signal_handler)
|
||||||
ensure_dirs() # 确保目录存在
|
ensure_dirs() # 确保目录存在
|
||||||
env_config = {key: os.getenv(key) for key in os.environ}
|
|
||||||
scan_provider(env_config)
|
|
||||||
# 新增用户确认提示
|
# 新增用户确认提示
|
||||||
print("=== 重要操作确认,请认真阅读以下内容哦 ===")
|
print("=== 重要操作确认,请认真阅读以下内容哦 ===")
|
||||||
print("实体提取操作将会花费较多api余额和时间,建议在空闲时段执行。")
|
print("实体提取操作将会花费较多api余额和时间,建议在空闲时段执行。")
|
||||||
|
|||||||
@@ -12,8 +12,6 @@ import pandas as pd
|
|||||||
# import tqdm
|
# import tqdm
|
||||||
import faiss
|
import faiss
|
||||||
|
|
||||||
# from .llm_client import LLMClient
|
|
||||||
# from .lpmmconfig import global_config
|
|
||||||
from .utils.hash import get_sha256
|
from .utils.hash import get_sha256
|
||||||
from .global_logger import logger
|
from .global_logger import logger
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
|
|||||||
@@ -1,45 +0,0 @@
|
|||||||
from openai import OpenAI
|
|
||||||
|
|
||||||
|
|
||||||
class LLMMessage:
|
|
||||||
def __init__(self, role, content):
|
|
||||||
self.role = role
|
|
||||||
self.content = content
|
|
||||||
|
|
||||||
def to_dict(self):
|
|
||||||
return {"role": self.role, "content": self.content}
|
|
||||||
|
|
||||||
|
|
||||||
class LLMClient:
|
|
||||||
"""LLM客户端,对应一个API服务商"""
|
|
||||||
|
|
||||||
def __init__(self, url, api_key):
|
|
||||||
self.client = OpenAI(
|
|
||||||
base_url=url,
|
|
||||||
api_key=api_key,
|
|
||||||
)
|
|
||||||
|
|
||||||
def send_chat_request(self, model, messages):
|
|
||||||
"""发送对话请求,等待返回结果"""
|
|
||||||
response = self.client.chat.completions.create(model=model, messages=messages, stream=False)
|
|
||||||
if hasattr(response.choices[0].message, "reasoning_content"):
|
|
||||||
# 有单独的推理内容块
|
|
||||||
reasoning_content = response.choices[0].message.reasoning_content
|
|
||||||
content = response.choices[0].message.content
|
|
||||||
else:
|
|
||||||
# 无单独的推理内容块
|
|
||||||
response = response.choices[0].message.content.split("<think>")[-1].split("</think>")
|
|
||||||
# 如果有推理内容,则分割推理内容和内容
|
|
||||||
if len(response) == 2:
|
|
||||||
reasoning_content = response[0]
|
|
||||||
content = response[1]
|
|
||||||
else:
|
|
||||||
reasoning_content = None
|
|
||||||
content = response[0]
|
|
||||||
|
|
||||||
return reasoning_content, content
|
|
||||||
|
|
||||||
def send_embedding_request(self, model, text):
|
|
||||||
"""发送嵌入请求,等待返回结果"""
|
|
||||||
text = text.replace("\n", " ")
|
|
||||||
return self.client.embeddings.create(input=[text], model=model).data[0].embedding
|
|
||||||
@@ -2,11 +2,7 @@ import time
|
|||||||
from typing import Tuple, List, Dict, Optional
|
from typing import Tuple, List, Dict, Optional
|
||||||
|
|
||||||
from .global_logger import logger
|
from .global_logger import logger
|
||||||
|
|
||||||
# from . import prompt_template
|
|
||||||
from .embedding_store import EmbeddingManager
|
from .embedding_store import EmbeddingManager
|
||||||
|
|
||||||
# from .llm_client import LLMClient
|
|
||||||
from .kg_manager import KGManager
|
from .kg_manager import KGManager
|
||||||
|
|
||||||
# from .lpmmconfig import global_config
|
# from .lpmmconfig import global_config
|
||||||
|
|||||||
@@ -36,8 +36,6 @@ def init_prompt():
|
|||||||
{chat_context_description},以下是具体的聊天内容
|
{chat_context_description},以下是具体的聊天内容
|
||||||
{chat_content_block}
|
{chat_content_block}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
{moderation_prompt}
|
{moderation_prompt}
|
||||||
|
|
||||||
现在请你根据{by_what}选择合适的action和触发action的消息:
|
现在请你根据{by_what}选择合适的action和触发action的消息:
|
||||||
|
|||||||
@@ -24,13 +24,13 @@ from src.chat.utils.chat_message_builder import (
|
|||||||
replace_user_references_sync,
|
replace_user_references_sync,
|
||||||
)
|
)
|
||||||
from src.chat.express.expression_selector import expression_selector
|
from src.chat.express.expression_selector import expression_selector
|
||||||
from src.chat.knowledge.knowledge_lib import qa_manager
|
|
||||||
from src.chat.memory_system.memory_activator import MemoryActivator
|
from src.chat.memory_system.memory_activator import MemoryActivator
|
||||||
from src.chat.memory_system.instant_memory import InstantMemory
|
from src.chat.memory_system.instant_memory import InstantMemory
|
||||||
from src.mood.mood_manager import mood_manager
|
from src.mood.mood_manager import mood_manager
|
||||||
from src.person_info.relationship_fetcher import relationship_fetcher_manager
|
from src.person_info.relationship_fetcher import relationship_fetcher_manager
|
||||||
from src.person_info.person_info import get_person_info_manager
|
from src.person_info.person_info import get_person_info_manager
|
||||||
from src.plugin_system.base.component_types import ActionInfo
|
from src.plugin_system.base.component_types import ActionInfo
|
||||||
|
from src.plugin_system.apis import llm_api
|
||||||
|
|
||||||
logger = get_logger("replyer")
|
logger = get_logger("replyer")
|
||||||
|
|
||||||
@@ -102,6 +102,22 @@ def init_prompt():
|
|||||||
"s4u_style_prompt",
|
"s4u_style_prompt",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Prompt(
|
||||||
|
"""
|
||||||
|
你是一个专门获取知识的助手。你的名字是{bot_name}。现在是{time_now}。
|
||||||
|
群里正在进行的聊天内容:
|
||||||
|
{chat_history}
|
||||||
|
|
||||||
|
现在,{sender}发送了内容:{target_message},你想要回复ta。
|
||||||
|
请仔细分析聊天内容,考虑以下几点:
|
||||||
|
1. 内容中是否包含需要查询信息的问题
|
||||||
|
2. 是否有明确的知识获取指令
|
||||||
|
|
||||||
|
If you need to use the search tool, please directly call the function "lpmm_search_knowledge". If you do not need to use any tool, simply output "No tool needed".
|
||||||
|
""",
|
||||||
|
name="lpmm_get_knowledge_prompt",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class DefaultReplyer:
|
class DefaultReplyer:
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -698,7 +714,7 @@ class DefaultReplyer:
|
|||||||
self._time_and_run_task(
|
self._time_and_run_task(
|
||||||
self.build_tool_info(chat_talking_prompt_short, reply_to, enable_tool=enable_tool), "tool_info"
|
self.build_tool_info(chat_talking_prompt_short, reply_to, enable_tool=enable_tool), "tool_info"
|
||||||
),
|
),
|
||||||
self._time_and_run_task(get_prompt_info(target, threshold=0.38), "prompt_info"),
|
self._time_and_run_task(self.get_prompt_info(chat_talking_prompt_short, reply_to), "prompt_info"),
|
||||||
)
|
)
|
||||||
|
|
||||||
# 任务名称中英文映射
|
# 任务名称中英文映射
|
||||||
@@ -1000,6 +1016,63 @@ class DefaultReplyer:
|
|||||||
logger.debug(f"replyer生成内容: {content}")
|
logger.debug(f"replyer生成内容: {content}")
|
||||||
return content, reasoning_content, model_name, tool_calls
|
return content, reasoning_content, model_name, tool_calls
|
||||||
|
|
||||||
|
async def get_prompt_info(self, message: str, reply_to: str):
|
||||||
|
related_info = ""
|
||||||
|
start_time = time.time()
|
||||||
|
from src.plugins.built_in.knowledge.lpmm_get_knowledge import SearchKnowledgeFromLPMMTool
|
||||||
|
if not reply_to:
|
||||||
|
logger.debug("没有回复对象,跳过获取知识库内容")
|
||||||
|
return ""
|
||||||
|
sender, content = self._parse_reply_target(reply_to)
|
||||||
|
if not content:
|
||||||
|
logger.debug("回复对象内容为空,跳过获取知识库内容")
|
||||||
|
return ""
|
||||||
|
logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}")
|
||||||
|
# 从LPMM知识库获取知识
|
||||||
|
try:
|
||||||
|
# 检查LPMM知识库是否启用
|
||||||
|
if not global_config.lpmm_knowledge.enable:
|
||||||
|
logger.debug("LPMM知识库未启用,跳过获取知识库内容")
|
||||||
|
return ""
|
||||||
|
time_now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
||||||
|
|
||||||
|
bot_name = global_config.bot.nickname
|
||||||
|
|
||||||
|
prompt = await global_prompt_manager.format_prompt(
|
||||||
|
"lpmm_get_knowledge_prompt",
|
||||||
|
bot_name=bot_name,
|
||||||
|
time_now=time_now,
|
||||||
|
chat_history=message,
|
||||||
|
sender=sender,
|
||||||
|
target_message=content,
|
||||||
|
)
|
||||||
|
_, _, _, _, tool_calls = await llm_api.generate_with_model_with_tools(
|
||||||
|
prompt,
|
||||||
|
model_config=model_config.model_task_config.tool_use,
|
||||||
|
tool_options=[SearchKnowledgeFromLPMMTool.get_tool_definition()],
|
||||||
|
)
|
||||||
|
if tool_calls:
|
||||||
|
result = await self.tool_executor.execute_tool_call(tool_calls[0], SearchKnowledgeFromLPMMTool())
|
||||||
|
end_time = time.time()
|
||||||
|
if not result or not result.get("content"):
|
||||||
|
logger.debug("从LPMM知识库获取知识失败,返回空知识...")
|
||||||
|
return ""
|
||||||
|
found_knowledge_from_lpmm = result.get("content", "")
|
||||||
|
logger.debug(
|
||||||
|
f"从LPMM知识库获取知识,相关信息:{found_knowledge_from_lpmm[:100]}...,信息长度: {len(found_knowledge_from_lpmm)}"
|
||||||
|
)
|
||||||
|
related_info += found_knowledge_from_lpmm
|
||||||
|
logger.debug(f"获取知识库内容耗时: {(end_time - start_time):.3f}秒")
|
||||||
|
logger.debug(f"获取知识库内容,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}")
|
||||||
|
|
||||||
|
return f"你有以下这些**知识**:\n{related_info}\n请你**记住上面的知识**,之后可能会用到。\n"
|
||||||
|
else:
|
||||||
|
logger.debug("从LPMM知识库获取知识失败,可能是从未导入过知识,返回空知识...")
|
||||||
|
return ""
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取知识库内容时发生异常: {str(e)}")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
def weighted_sample_no_replacement(items, weights, k) -> list:
|
def weighted_sample_no_replacement(items, weights, k) -> list:
|
||||||
"""
|
"""
|
||||||
@@ -1035,36 +1108,4 @@ def weighted_sample_no_replacement(items, weights, k) -> list:
|
|||||||
return selected
|
return selected
|
||||||
|
|
||||||
|
|
||||||
async def get_prompt_info(message: str, threshold: float):
|
|
||||||
related_info = ""
|
|
||||||
start_time = time.time()
|
|
||||||
|
|
||||||
logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}")
|
|
||||||
# 从LPMM知识库获取知识
|
|
||||||
try:
|
|
||||||
# 检查LPMM知识库是否启用
|
|
||||||
if qa_manager is None:
|
|
||||||
logger.debug("LPMM知识库已禁用,跳过知识获取")
|
|
||||||
return ""
|
|
||||||
|
|
||||||
found_knowledge_from_lpmm = await qa_manager.get_knowledge(message)
|
|
||||||
|
|
||||||
end_time = time.time()
|
|
||||||
if found_knowledge_from_lpmm is not None:
|
|
||||||
logger.debug(
|
|
||||||
f"从LPMM知识库获取知识,相关信息:{found_knowledge_from_lpmm[:100]}...,信息长度: {len(found_knowledge_from_lpmm)}"
|
|
||||||
)
|
|
||||||
related_info += found_knowledge_from_lpmm
|
|
||||||
logger.debug(f"获取知识库内容耗时: {(end_time - start_time):.3f}秒")
|
|
||||||
logger.debug(f"获取知识库内容,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}")
|
|
||||||
|
|
||||||
return f"你有以下这些**知识**:\n{related_info}\n请你**记住上面的知识**,之后可能会用到。\n"
|
|
||||||
else:
|
|
||||||
logger.debug("从LPMM知识库获取知识失败,可能是从未导入过知识,返回空知识...")
|
|
||||||
return ""
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"获取知识库内容时发生异常: {str(e)}")
|
|
||||||
return ""
|
|
||||||
|
|
||||||
|
|
||||||
init_prompt()
|
init_prompt()
|
||||||
|
|||||||
@@ -281,20 +281,6 @@ class Memory(BaseModel):
|
|||||||
table_name = "memory"
|
table_name = "memory"
|
||||||
|
|
||||||
|
|
||||||
class Knowledges(BaseModel):
|
|
||||||
"""
|
|
||||||
用于存储知识库条目的模型。
|
|
||||||
"""
|
|
||||||
|
|
||||||
content = TextField() # 知识内容的文本
|
|
||||||
embedding = TextField() # 知识内容的嵌入向量,存储为 JSON 字符串的浮点数列表
|
|
||||||
# 可以添加其他元数据字段,如 source, create_time 等
|
|
||||||
|
|
||||||
class Meta:
|
|
||||||
# database = db # 继承自 BaseModel
|
|
||||||
table_name = "knowledges"
|
|
||||||
|
|
||||||
|
|
||||||
class Expression(BaseModel):
|
class Expression(BaseModel):
|
||||||
"""
|
"""
|
||||||
用于存储表达风格的模型。
|
用于存储表达风格的模型。
|
||||||
@@ -382,7 +368,6 @@ def create_tables():
|
|||||||
ImageDescriptions,
|
ImageDescriptions,
|
||||||
OnlineTime,
|
OnlineTime,
|
||||||
PersonInfo,
|
PersonInfo,
|
||||||
Knowledges,
|
|
||||||
Expression,
|
Expression,
|
||||||
ThinkingLog,
|
ThinkingLog,
|
||||||
GraphNodes, # 添加图节点表
|
GraphNodes, # 添加图节点表
|
||||||
@@ -408,7 +393,6 @@ def initialize_database():
|
|||||||
ImageDescriptions,
|
ImageDescriptions,
|
||||||
OnlineTime,
|
OnlineTime,
|
||||||
PersonInfo,
|
PersonInfo,
|
||||||
Knowledges,
|
|
||||||
Expression,
|
Expression,
|
||||||
Memory,
|
Memory,
|
||||||
ThinkingLog,
|
ThinkingLog,
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ class APIProvider(ConfigBase):
|
|||||||
"""确保api_key在repr中不被显示"""
|
"""确保api_key在repr中不被显示"""
|
||||||
if not self.api_key:
|
if not self.api_key:
|
||||||
raise ValueError("API密钥不能为空,请在配置中设置有效的API密钥。")
|
raise ValueError("API密钥不能为空,请在配置中设置有效的API密钥。")
|
||||||
if not self.base_url:
|
if not self.base_url and self.client_type != "gemini":
|
||||||
raise ValueError("API基础URL不能为空,请在配置中设置有效的基础URL。")
|
raise ValueError("API基础URL不能为空,请在配置中设置有效的基础URL。")
|
||||||
if not self.name:
|
if not self.name:
|
||||||
raise ValueError("API提供商名称不能为空,请在配置中设置有效的名称。")
|
raise ValueError("API提供商名称不能为空,请在配置中设置有效的名称。")
|
||||||
|
|||||||
@@ -16,6 +16,9 @@ from google.genai.types import (
|
|||||||
GenerateContentConfig,
|
GenerateContentConfig,
|
||||||
EmbedContentResponse,
|
EmbedContentResponse,
|
||||||
EmbedContentConfig,
|
EmbedContentConfig,
|
||||||
|
SafetySetting,
|
||||||
|
HarmCategory,
|
||||||
|
HarmBlockThreshold,
|
||||||
)
|
)
|
||||||
from google.genai.errors import (
|
from google.genai.errors import (
|
||||||
ClientError,
|
ClientError,
|
||||||
@@ -41,6 +44,14 @@ from ..payload_content.tool_option import ToolOption, ToolParam, ToolCall
|
|||||||
|
|
||||||
logger = get_logger("Gemini客户端")
|
logger = get_logger("Gemini客户端")
|
||||||
|
|
||||||
|
gemini_safe_settings = [
|
||||||
|
SafetySetting(category=HarmCategory.HARM_CATEGORY_HATE_SPEECH, threshold=HarmBlockThreshold.BLOCK_NONE),
|
||||||
|
SafetySetting(category=HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, threshold=HarmBlockThreshold.BLOCK_NONE),
|
||||||
|
SafetySetting(category=HarmCategory.HARM_CATEGORY_HARASSMENT, threshold=HarmBlockThreshold.BLOCK_NONE),
|
||||||
|
SafetySetting(category=HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, threshold=HarmBlockThreshold.BLOCK_NONE),
|
||||||
|
SafetySetting(category=HarmCategory.HARM_CATEGORY_CIVIC_INTEGRITY, threshold=HarmBlockThreshold.BLOCK_NONE),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def _convert_messages(
|
def _convert_messages(
|
||||||
messages: list[Message],
|
messages: list[Message],
|
||||||
@@ -322,7 +333,7 @@ class GeminiClient(BaseClient):
|
|||||||
message_list: list[Message],
|
message_list: list[Message],
|
||||||
tool_options: list[ToolOption] | None = None,
|
tool_options: list[ToolOption] | None = None,
|
||||||
max_tokens: int = 1024,
|
max_tokens: int = 1024,
|
||||||
temperature: float = 0.7,
|
temperature: float = 0.4,
|
||||||
response_format: RespFormat | None = None,
|
response_format: RespFormat | None = None,
|
||||||
stream_response_handler: Optional[
|
stream_response_handler: Optional[
|
||||||
Callable[
|
Callable[
|
||||||
@@ -369,9 +380,12 @@ class GeminiClient(BaseClient):
|
|||||||
"thinking_config": ThinkingConfig(
|
"thinking_config": ThinkingConfig(
|
||||||
include_thoughts=True,
|
include_thoughts=True,
|
||||||
thinking_budget=(
|
thinking_budget=(
|
||||||
extra_params["thinking_budget"] if extra_params and "thinking_budget" in extra_params else None
|
extra_params["thinking_budget"]
|
||||||
|
if extra_params and "thinking_budget" in extra_params
|
||||||
|
else int(max_tokens / 2) # 默认思考预算为最大token数的一半,防止空回复
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
|
"safety_settings": gemini_safe_settings, # 防止空回复问题
|
||||||
}
|
}
|
||||||
if tools:
|
if tools:
|
||||||
generation_config_dict["tools"] = Tool(function_declarations=tools)
|
generation_config_dict["tools"] = Tool(function_declarations=tools)
|
||||||
@@ -483,6 +497,61 @@ class GeminiClient(BaseClient):
|
|||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
def get_audio_transcriptions(
|
||||||
|
self, model_info: ModelInfo, audio_base64: str, extra_params: dict[str, Any] | None = None
|
||||||
|
) -> APIResponse:
|
||||||
|
"""
|
||||||
|
获取音频转录
|
||||||
|
:param model_info: 模型信息
|
||||||
|
:param audio_base64: 音频文件的Base64编码字符串
|
||||||
|
:param extra_params: 额外参数(可选)
|
||||||
|
:return: 转录响应
|
||||||
|
"""
|
||||||
|
generation_config_dict = {
|
||||||
|
"max_output_tokens": 2048,
|
||||||
|
"response_modalities": ["TEXT"],
|
||||||
|
"thinking_config": ThinkingConfig(
|
||||||
|
include_thoughts=True,
|
||||||
|
thinking_budget=(
|
||||||
|
extra_params["thinking_budget"] if extra_params and "thinking_budget" in extra_params else 1024
|
||||||
|
),
|
||||||
|
),
|
||||||
|
"safety_settings": gemini_safe_settings,
|
||||||
|
}
|
||||||
|
generate_content_config = GenerateContentConfig(**generation_config_dict)
|
||||||
|
prompt = "Generate a transcript of the speech. The language of the transcript should **match the language of the speech**."
|
||||||
|
try:
|
||||||
|
raw_response: GenerateContentResponse = self.client.models.generate_content(
|
||||||
|
model=model_info.model_identifier,
|
||||||
|
contents=[
|
||||||
|
Content(
|
||||||
|
role="user",
|
||||||
|
parts=[
|
||||||
|
Part.from_text(text=prompt),
|
||||||
|
Part.from_bytes(data=base64.b64decode(audio_base64), mime_type="audio/wav"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
],
|
||||||
|
config=generate_content_config,
|
||||||
|
)
|
||||||
|
resp, usage_record = _default_normal_response_parser(raw_response)
|
||||||
|
except (ClientError, ServerError) as e:
|
||||||
|
# 重封装ClientError和ServerError为RespNotOkException
|
||||||
|
raise RespNotOkException(e.code) from None
|
||||||
|
except Exception as e:
|
||||||
|
raise NetworkConnectionError() from e
|
||||||
|
|
||||||
|
if usage_record:
|
||||||
|
resp.usage = UsageRecord(
|
||||||
|
model_name=model_info.name,
|
||||||
|
provider_name=model_info.api_provider,
|
||||||
|
prompt_tokens=usage_record[0],
|
||||||
|
completion_tokens=usage_record[1],
|
||||||
|
total_tokens=usage_record[2],
|
||||||
|
)
|
||||||
|
|
||||||
|
return resp
|
||||||
|
|
||||||
def get_support_image_formats(self) -> list[str]:
|
def get_support_image_formats(self) -> list[str]:
|
||||||
"""
|
"""
|
||||||
获取支持的图片格式
|
获取支持的图片格式
|
||||||
|
|||||||
@@ -48,8 +48,10 @@ class LLMRequest:
|
|||||||
self.task_name = request_type
|
self.task_name = request_type
|
||||||
self.model_for_task = model_set
|
self.model_for_task = model_set
|
||||||
self.request_type = request_type
|
self.request_type = request_type
|
||||||
self.model_usage: Dict[str, Tuple[int, int]] = {model: (0, 0) for model in self.model_for_task.model_list}
|
self.model_usage: Dict[str, Tuple[int, int, int]] = {
|
||||||
"""模型使用量记录,用于进行负载均衡,对应为(total_tokens, penalty),惩罚值是为了能在某个模型请求不给力的时候进行调整"""
|
model: (0, 0, 0) for model in self.model_for_task.model_list
|
||||||
|
}
|
||||||
|
"""模型使用量记录,用于进行负载均衡,对应为(total_tokens, penalty, usage_penalty),惩罚值是为了能在某个模型请求不给力或正在被使用的时候进行调整"""
|
||||||
|
|
||||||
self.pri_in = 0
|
self.pri_in = 0
|
||||||
self.pri_out = 0
|
self.pri_out = 0
|
||||||
@@ -181,7 +183,8 @@ class LLMRequest:
|
|||||||
endpoint="/chat/completions",
|
endpoint="/chat/completions",
|
||||||
)
|
)
|
||||||
if not content:
|
if not content:
|
||||||
raise RuntimeError("获取LLM生成内容失败")
|
logger.warning("生成的响应为空")
|
||||||
|
content = "生成的响应为空,请检查模型配置或输入内容是否正确"
|
||||||
|
|
||||||
return content, (reasoning_content, model_info.name, tool_calls)
|
return content, (reasoning_content, model_info.name, tool_calls)
|
||||||
|
|
||||||
@@ -225,12 +228,15 @@ class LLMRequest:
|
|||||||
根据总tokens和惩罚值选择的模型
|
根据总tokens和惩罚值选择的模型
|
||||||
"""
|
"""
|
||||||
least_used_model_name = min(
|
least_used_model_name = min(
|
||||||
self.model_usage, key=lambda k: self.model_usage[k][0] + self.model_usage[k][1] * 300
|
self.model_usage,
|
||||||
|
key=lambda k: self.model_usage[k][0] + self.model_usage[k][1] * 300 + self.model_usage[k][2] * 1000,
|
||||||
)
|
)
|
||||||
model_info = model_config.get_model_info(least_used_model_name)
|
model_info = model_config.get_model_info(least_used_model_name)
|
||||||
api_provider = model_config.get_provider(model_info.api_provider)
|
api_provider = model_config.get_provider(model_info.api_provider)
|
||||||
client = client_registry.get_client_class(api_provider.client_type)(copy.deepcopy(api_provider))
|
client = client_registry.get_client_class(api_provider.client_type)(copy.deepcopy(api_provider))
|
||||||
logger.debug(f"选择请求模型: {model_info.name}")
|
logger.debug(f"选择请求模型: {model_info.name}")
|
||||||
|
total_tokens, penalty, usage_penalty = self.model_usage[model_info.name]
|
||||||
|
self.model_usage[model_info.name] = (total_tokens, penalty, usage_penalty + 1) # 增加使用惩罚值防止连续使用
|
||||||
return model_info, api_provider, client
|
return model_info, api_provider, client
|
||||||
|
|
||||||
async def _execute_request(
|
async def _execute_request(
|
||||||
@@ -288,8 +294,8 @@ class LLMRequest:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"请求失败: {str(e)}")
|
logger.debug(f"请求失败: {str(e)}")
|
||||||
# 处理异常
|
# 处理异常
|
||||||
total_tokens, penalty = self.model_usage[model_info.name]
|
total_tokens, penalty, usage_penalty = self.model_usage[model_info.name]
|
||||||
self.model_usage[model_info.name] = (total_tokens, penalty + 1)
|
self.model_usage[model_info.name] = (total_tokens, penalty + 1, usage_penalty)
|
||||||
|
|
||||||
wait_interval, compressed_messages = self._default_exception_handler(
|
wait_interval, compressed_messages = self._default_exception_handler(
|
||||||
e,
|
e,
|
||||||
@@ -308,6 +314,8 @@ class LLMRequest:
|
|||||||
finally:
|
finally:
|
||||||
# 放在finally防止死循环
|
# 放在finally防止死循环
|
||||||
retry_remain -= 1
|
retry_remain -= 1
|
||||||
|
total_tokens, penalty, usage_penalty = self.model_usage[model_info.name]
|
||||||
|
self.model_usage[model_info.name] = (total_tokens, penalty, usage_penalty - 1) # 使用结束,减少使用惩罚值
|
||||||
logger.error(f"模型 '{model_info.name}' 请求失败,达到最大重试次数 {api_provider.max_retry} 次")
|
logger.error(f"模型 '{model_info.name}' 请求失败,达到最大重试次数 {api_provider.max_retry} 次")
|
||||||
raise RuntimeError("请求失败,已达到最大重试次数")
|
raise RuntimeError("请求失败,已达到最大重试次数")
|
||||||
|
|
||||||
|
|||||||
@@ -7,8 +7,9 @@
|
|||||||
success, response, reasoning, model_name = await llm_api.generate_with_model(prompt, model_config)
|
success, response, reasoning, model_name = await llm_api.generate_with_model(prompt, model_config)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Tuple, Dict
|
from typing import Tuple, Dict, List, Any, Optional
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
from src.llm_models.payload_content.tool_option import ToolCall
|
||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
from src.config.config import global_config, model_config
|
from src.config.config import global_config, model_config
|
||||||
from src.config.api_ada_configs import TaskConfig
|
from src.config.api_ada_configs import TaskConfig
|
||||||
@@ -52,7 +53,11 @@ def get_available_models() -> Dict[str, TaskConfig]:
|
|||||||
|
|
||||||
|
|
||||||
async def generate_with_model(
|
async def generate_with_model(
|
||||||
prompt: str, model_config: TaskConfig, request_type: str = "plugin.generate", **kwargs
|
prompt: str,
|
||||||
|
model_config: TaskConfig,
|
||||||
|
request_type: str = "plugin.generate",
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
) -> Tuple[bool, str, str, str]:
|
) -> Tuple[bool, str, str, str]:
|
||||||
"""使用指定模型生成内容
|
"""使用指定模型生成内容
|
||||||
|
|
||||||
@@ -60,7 +65,6 @@ async def generate_with_model(
|
|||||||
prompt: 提示词
|
prompt: 提示词
|
||||||
model_config: 模型配置(从 get_available_models 获取的模型配置)
|
model_config: 模型配置(从 get_available_models 获取的模型配置)
|
||||||
request_type: 请求类型标识
|
request_type: 请求类型标识
|
||||||
**kwargs: 其他模型特定参数,如temperature、max_tokens等
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[bool, str, str, str]: (是否成功, 生成的内容, 推理过程, 模型名称)
|
Tuple[bool, str, str, str]: (是否成功, 生成的内容, 推理过程, 模型名称)
|
||||||
@@ -70,12 +74,53 @@ async def generate_with_model(
|
|||||||
logger.info(f"[LLMAPI] 使用模型集合 {model_name_list} 生成内容")
|
logger.info(f"[LLMAPI] 使用模型集合 {model_name_list} 生成内容")
|
||||||
logger.debug(f"[LLMAPI] 完整提示词: {prompt}")
|
logger.debug(f"[LLMAPI] 完整提示词: {prompt}")
|
||||||
|
|
||||||
llm_request = LLMRequest(model_set=model_config, request_type=request_type, **kwargs)
|
llm_request = LLMRequest(model_set=model_config, request_type=request_type)
|
||||||
|
|
||||||
response, (reasoning_content, model_name, _) = await llm_request.generate_response_async(prompt)
|
response, (reasoning_content, model_name, _) = await llm_request.generate_response_async(prompt, temperature=temperature, max_tokens=max_tokens)
|
||||||
return True, response, reasoning_content, model_name
|
return True, response, reasoning_content, model_name
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = f"生成内容时出错: {str(e)}"
|
error_msg = f"生成内容时出错: {str(e)}"
|
||||||
logger.error(f"[LLMAPI] {error_msg}")
|
logger.error(f"[LLMAPI] {error_msg}")
|
||||||
return False, error_msg, "", ""
|
return False, error_msg, "", ""
|
||||||
|
|
||||||
|
async def generate_with_model_with_tools(
|
||||||
|
prompt: str,
|
||||||
|
model_config: TaskConfig,
|
||||||
|
tool_options: List[Dict[str, Any]] | None = None,
|
||||||
|
request_type: str = "plugin.generate",
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
) -> Tuple[bool, str, str, str, List[ToolCall] | None]:
|
||||||
|
"""使用指定模型和工具生成内容
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: 提示词
|
||||||
|
model_config: 模型配置(从 get_available_models 获取的模型配置)
|
||||||
|
tool_options: 工具选项列表
|
||||||
|
request_type: 请求类型标识
|
||||||
|
temperature: 温度参数
|
||||||
|
max_tokens: 最大token数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[bool, str, str, str]: (是否成功, 生成的内容, 推理过程, 模型名称)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
model_name_list = model_config.model_list
|
||||||
|
logger.info(f"[LLMAPI] 使用模型集合 {model_name_list} 生成内容")
|
||||||
|
logger.debug(f"[LLMAPI] 完整提示词: {prompt}")
|
||||||
|
|
||||||
|
llm_request = LLMRequest(model_set=model_config, request_type=request_type)
|
||||||
|
|
||||||
|
response, (reasoning_content, model_name, tool_call) = await llm_request.generate_response_async(
|
||||||
|
prompt,
|
||||||
|
tools=tool_options,
|
||||||
|
temperature=temperature,
|
||||||
|
max_tokens=max_tokens
|
||||||
|
)
|
||||||
|
return True, response, reasoning_content, model_name, tool_call
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"生成内容时出错: {str(e)}"
|
||||||
|
logger.error(f"[LLMAPI] {error_msg}")
|
||||||
|
return False, error_msg, "", "", None
|
||||||
|
|||||||
@@ -3,10 +3,11 @@ from typing import List, Type, Tuple, Union
|
|||||||
from .plugin_base import PluginBase
|
from .plugin_base import PluginBase
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.plugin_system.base.component_types import ActionInfo, CommandInfo, EventHandlerInfo
|
from src.plugin_system.base.component_types import ActionInfo, CommandInfo, EventHandlerInfo, ToolInfo
|
||||||
from .base_action import BaseAction
|
from .base_action import BaseAction
|
||||||
from .base_command import BaseCommand
|
from .base_command import BaseCommand
|
||||||
from .base_events_handler import BaseEventHandler
|
from .base_events_handler import BaseEventHandler
|
||||||
|
from .base_tool import BaseTool
|
||||||
|
|
||||||
logger = get_logger("base_plugin")
|
logger = get_logger("base_plugin")
|
||||||
|
|
||||||
@@ -31,6 +32,7 @@ class BasePlugin(PluginBase):
|
|||||||
Tuple[ActionInfo, Type[BaseAction]],
|
Tuple[ActionInfo, Type[BaseAction]],
|
||||||
Tuple[CommandInfo, Type[BaseCommand]],
|
Tuple[CommandInfo, Type[BaseCommand]],
|
||||||
Tuple[EventHandlerInfo, Type[BaseEventHandler]],
|
Tuple[EventHandlerInfo, Type[BaseEventHandler]],
|
||||||
|
Tuple[ToolInfo, Type[BaseTool]],
|
||||||
]
|
]
|
||||||
]:
|
]:
|
||||||
"""获取插件包含的组件列表
|
"""获取插件包含的组件列表
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import time
|
import time
|
||||||
from typing import List, Dict, Tuple, Optional, Any
|
from typing import List, Dict, Tuple, Optional, Any
|
||||||
from src.plugin_system.apis.tool_api import get_llm_available_tool_definitions, get_tool_instance
|
from src.plugin_system.apis.tool_api import get_llm_available_tool_definitions, get_tool_instance
|
||||||
|
from src.plugin_system.base.base_tool import BaseTool
|
||||||
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
|
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
|
||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
from src.llm_models.payload_content import ToolCall
|
from src.llm_models.payload_content import ToolCall
|
||||||
@@ -114,7 +115,7 @@ class ToolExecutor:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 执行工具调用
|
# 执行工具调用
|
||||||
tool_results, used_tools = await self._execute_tool_calls(tool_calls)
|
tool_results, used_tools = await self.execute_tool_calls(tool_calls)
|
||||||
|
|
||||||
# 缓存结果
|
# 缓存结果
|
||||||
if tool_results:
|
if tool_results:
|
||||||
@@ -133,7 +134,7 @@ class ToolExecutor:
|
|||||||
user_disabled_tools = global_announcement_manager.get_disabled_chat_tools(self.chat_id)
|
user_disabled_tools = global_announcement_manager.get_disabled_chat_tools(self.chat_id)
|
||||||
return [definition for name, definition in all_tools if name not in user_disabled_tools]
|
return [definition for name, definition in all_tools if name not in user_disabled_tools]
|
||||||
|
|
||||||
async def _execute_tool_calls(self, tool_calls: Optional[List[ToolCall]]) -> Tuple[List[Dict[str, Any]], List[str]]:
|
async def execute_tool_calls(self, tool_calls: Optional[List[ToolCall]]) -> Tuple[List[Dict[str, Any]], List[str]]:
|
||||||
"""执行工具调用
|
"""执行工具调用
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -161,7 +162,7 @@ class ToolExecutor:
|
|||||||
logger.debug(f"{self.log_prefix}执行工具: {tool_name}")
|
logger.debug(f"{self.log_prefix}执行工具: {tool_name}")
|
||||||
|
|
||||||
# 执行工具
|
# 执行工具
|
||||||
result = await self._execute_tool_call(tool_call)
|
result = await self.execute_tool_call(tool_call)
|
||||||
|
|
||||||
if result:
|
if result:
|
||||||
tool_info = {
|
tool_info = {
|
||||||
@@ -194,7 +195,7 @@ class ToolExecutor:
|
|||||||
|
|
||||||
return tool_results, used_tools
|
return tool_results, used_tools
|
||||||
|
|
||||||
async def _execute_tool_call(self, tool_call: ToolCall) -> Optional[Dict[str, Any]]:
|
async def execute_tool_call(self, tool_call: ToolCall, tool_instance: Optional[BaseTool] = None) -> Optional[Dict[str, Any]]:
|
||||||
# sourcery skip: use-assigned-variable
|
# sourcery skip: use-assigned-variable
|
||||||
"""执行单个工具调用
|
"""执行单个工具调用
|
||||||
|
|
||||||
@@ -210,7 +211,7 @@ class ToolExecutor:
|
|||||||
function_args["llm_called"] = True # 标记为LLM调用
|
function_args["llm_called"] = True # 标记为LLM调用
|
||||||
|
|
||||||
# 获取对应工具实例
|
# 获取对应工具实例
|
||||||
tool_instance = get_tool_instance(function_name)
|
tool_instance = tool_instance or get_tool_instance(function_name)
|
||||||
if not tool_instance:
|
if not tool_instance:
|
||||||
logger.warning(f"未知工具名称: {function_name}")
|
logger.warning(f"未知工具名称: {function_name}")
|
||||||
return None
|
return None
|
||||||
@@ -297,7 +298,7 @@ class ToolExecutor:
|
|||||||
if expired_keys:
|
if expired_keys:
|
||||||
logger.debug(f"{self.log_prefix}清理了{len(expired_keys)}个过期缓存")
|
logger.debug(f"{self.log_prefix}清理了{len(expired_keys)}个过期缓存")
|
||||||
|
|
||||||
async def execute_specific_tool(self, tool_name: str, tool_args: Dict) -> Optional[Dict]:
|
async def execute_specific_tool_simple(self, tool_name: str, tool_args: Dict) -> Optional[Dict]:
|
||||||
"""直接执行指定工具
|
"""直接执行指定工具
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -317,7 +318,7 @@ class ToolExecutor:
|
|||||||
|
|
||||||
logger.info(f"{self.log_prefix}直接执行工具: {tool_name}")
|
logger.info(f"{self.log_prefix}直接执行工具: {tool_name}")
|
||||||
|
|
||||||
result = await self._execute_tool_call(tool_call)
|
result = await self.execute_tool_call(tool_call)
|
||||||
|
|
||||||
if result:
|
if result:
|
||||||
tool_info = {
|
tool_info = {
|
||||||
@@ -408,7 +409,7 @@ results, used_tools, prompt = await executor.execute_from_chat_message(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 5. 直接执行特定工具
|
# 5. 直接执行特定工具
|
||||||
result = await executor.execute_specific_tool(
|
result = await executor.execute_specific_tool_simple(
|
||||||
tool_name="get_knowledge",
|
tool_name="get_knowledge",
|
||||||
tool_args={"query": "机器学习"}
|
tool_args={"query": "机器学习"}
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,131 +0,0 @@
|
|||||||
import json # Added for parsing embedding
|
|
||||||
import math # Added for cosine similarity
|
|
||||||
from typing import Any, Union, List # Added List
|
|
||||||
|
|
||||||
from src.chat.utils.utils import get_embedding
|
|
||||||
from src.common.database.database_model import Knowledges # Updated import
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
from src.plugin_system import BaseTool, ToolParamType
|
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger("get_knowledge_tool")
|
|
||||||
|
|
||||||
|
|
||||||
class SearchKnowledgeTool(BaseTool):
|
|
||||||
"""从知识库中搜索相关信息的工具"""
|
|
||||||
|
|
||||||
name = "search_knowledge"
|
|
||||||
description = "使用工具从知识库中搜索相关信息"
|
|
||||||
parameters = [
|
|
||||||
("query", ToolParamType.STRING, "搜索查询关键词", True, None),
|
|
||||||
("threshold", ToolParamType.FLOAT, "相似度阈值,0.0到1.0之间", False, None),
|
|
||||||
]
|
|
||||||
|
|
||||||
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
|
|
||||||
"""执行知识库搜索
|
|
||||||
|
|
||||||
Args:
|
|
||||||
function_args: 工具参数
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: 工具执行结果
|
|
||||||
"""
|
|
||||||
query = "" # Initialize query to ensure it's defined in except block
|
|
||||||
try:
|
|
||||||
query = function_args.get("query")
|
|
||||||
threshold = function_args.get("threshold", 0.4)
|
|
||||||
|
|
||||||
# 调用知识库搜索
|
|
||||||
embedding = await get_embedding(query, request_type="info_retrieval")
|
|
||||||
if embedding:
|
|
||||||
knowledge_info = self.get_info_from_db(embedding, limit=3, threshold=threshold)
|
|
||||||
if knowledge_info:
|
|
||||||
content = f"你知道这些知识: {knowledge_info}"
|
|
||||||
else:
|
|
||||||
content = f"你不太了解有关{query}的知识"
|
|
||||||
return {"type": "knowledge", "id": query, "content": content}
|
|
||||||
return {"type": "info", "id": query, "content": f"无法获取关于'{query}'的嵌入向量,你知识库炸了"}
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"知识库搜索工具执行失败: {str(e)}")
|
|
||||||
return {"type": "info", "id": query, "content": f"知识库搜索失败,炸了: {str(e)}"}
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _cosine_similarity(vec1: List[float], vec2: List[float]) -> float:
|
|
||||||
"""计算两个向量之间的余弦相似度"""
|
|
||||||
dot_product = sum(p * q for p, q in zip(vec1, vec2, strict=False))
|
|
||||||
magnitude1 = math.sqrt(sum(p * p for p in vec1))
|
|
||||||
magnitude2 = math.sqrt(sum(q * q for q in vec2))
|
|
||||||
if magnitude1 == 0 or magnitude2 == 0:
|
|
||||||
return 0.0
|
|
||||||
return dot_product / (magnitude1 * magnitude2)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_info_from_db(
|
|
||||||
query_embedding: list[float], limit: int = 1, threshold: float = 0.5, return_raw: bool = False
|
|
||||||
) -> Union[str, list]:
|
|
||||||
"""从数据库中获取相关信息
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query_embedding: 查询的嵌入向量
|
|
||||||
limit: 最大返回结果数
|
|
||||||
threshold: 相似度阈值
|
|
||||||
return_raw: 是否返回原始结果
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Union[str, list]: 格式化的信息字符串或原始结果列表
|
|
||||||
"""
|
|
||||||
if not query_embedding:
|
|
||||||
return [] if return_raw else ""
|
|
||||||
|
|
||||||
similar_items = []
|
|
||||||
try:
|
|
||||||
all_knowledges = Knowledges.select()
|
|
||||||
for item in all_knowledges:
|
|
||||||
try:
|
|
||||||
item_embedding_str = item.embedding
|
|
||||||
if not item_embedding_str:
|
|
||||||
logger.warning(f"Knowledge item ID {item.id} has empty embedding string.")
|
|
||||||
continue
|
|
||||||
item_embedding = json.loads(item_embedding_str)
|
|
||||||
if not isinstance(item_embedding, list) or not all(
|
|
||||||
isinstance(x, (int, float)) for x in item_embedding
|
|
||||||
):
|
|
||||||
logger.warning(f"Knowledge item ID {item.id} has invalid embedding format after JSON parsing.")
|
|
||||||
continue
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
logger.warning(f"Failed to parse embedding for knowledge item ID {item.id}")
|
|
||||||
continue
|
|
||||||
except AttributeError:
|
|
||||||
logger.warning(f"Knowledge item ID {item.id} missing 'embedding' attribute or it's not a string.")
|
|
||||||
continue
|
|
||||||
|
|
||||||
similarity = SearchKnowledgeTool._cosine_similarity(query_embedding, item_embedding)
|
|
||||||
|
|
||||||
if similarity >= threshold:
|
|
||||||
similar_items.append({"content": item.content, "similarity": similarity, "raw_item": item})
|
|
||||||
|
|
||||||
# 按相似度降序排序
|
|
||||||
similar_items.sort(key=lambda x: x["similarity"], reverse=True)
|
|
||||||
|
|
||||||
# 应用限制
|
|
||||||
results = similar_items[:limit]
|
|
||||||
logger.debug(f"知识库查询后,符合条件的结果数量: {len(results)}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"从 Peewee 数据库获取知识信息失败: {str(e)}")
|
|
||||||
return [] if return_raw else ""
|
|
||||||
|
|
||||||
if not results:
|
|
||||||
return [] if return_raw else ""
|
|
||||||
|
|
||||||
if return_raw:
|
|
||||||
# Peewee 模型实例不能直接序列化为 JSON,如果需要原始模型,调用者需要处理
|
|
||||||
# 这里返回包含内容和相似度的字典列表
|
|
||||||
return [{"content": r["content"], "similarity": r["similarity"]} for r in results]
|
|
||||||
else:
|
|
||||||
# 返回所有找到的内容,用换行分隔
|
|
||||||
return "\n".join(str(result["content"]) for result in results)
|
|
||||||
|
|
||||||
|
|
||||||
# 注册工具
|
|
||||||
# register_tool(SearchKnowledgeTool)
|
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
from typing import Dict, Any
|
from typing import Dict, Any
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
from src.config.config import global_config
|
||||||
from src.chat.knowledge.knowledge_lib import qa_manager
|
from src.chat.knowledge.knowledge_lib import qa_manager
|
||||||
from src.plugin_system import BaseTool, ToolParamType
|
from src.plugin_system import BaseTool, ToolParamType
|
||||||
|
|
||||||
@@ -16,6 +17,7 @@ class SearchKnowledgeFromLPMMTool(BaseTool):
|
|||||||
("query", ToolParamType.STRING, "搜索查询关键词", True, None),
|
("query", ToolParamType.STRING, "搜索查询关键词", True, None),
|
||||||
("threshold", ToolParamType.FLOAT, "相似度阈值,0.0到1.0之间", False, None),
|
("threshold", ToolParamType.FLOAT, "相似度阈值,0.0到1.0之间", False, None),
|
||||||
]
|
]
|
||||||
|
available_for_llm = global_config.lpmm_knowledge.enable
|
||||||
|
|
||||||
async def execute(self, function_args: Dict[str, Any]) -> Dict[str, Any]:
|
async def execute(self, function_args: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
"""执行知识库搜索
|
"""执行知识库搜索
|
||||||
|
|||||||
Reference in New Issue
Block a user