正确使用lpmm构建prompt

This commit is contained in:
UnCLAS-Prommer
2025-08-03 19:52:31 +08:00
parent 9a63a8030e
commit 1e5db5d7e1
12 changed files with 141 additions and 249 deletions

View File

@@ -12,8 +12,6 @@ import pandas as pd
# import tqdm # import tqdm
import faiss import faiss
# from .llm_client import LLMClient
# from .lpmmconfig import global_config
from .utils.hash import get_sha256 from .utils.hash import get_sha256
from .global_logger import logger from .global_logger import logger
from rich.traceback import install from rich.traceback import install

View File

@@ -1,45 +0,0 @@
from openai import OpenAI
class LLMMessage:
def __init__(self, role, content):
self.role = role
self.content = content
def to_dict(self):
return {"role": self.role, "content": self.content}
class LLMClient:
"""LLM客户端对应一个API服务商"""
def __init__(self, url, api_key):
self.client = OpenAI(
base_url=url,
api_key=api_key,
)
def send_chat_request(self, model, messages):
"""发送对话请求,等待返回结果"""
response = self.client.chat.completions.create(model=model, messages=messages, stream=False)
if hasattr(response.choices[0].message, "reasoning_content"):
# 有单独的推理内容块
reasoning_content = response.choices[0].message.reasoning_content
content = response.choices[0].message.content
else:
# 无单独的推理内容块
response = response.choices[0].message.content.split("<think>")[-1].split("</think>")
# 如果有推理内容,则分割推理内容和内容
if len(response) == 2:
reasoning_content = response[0]
content = response[1]
else:
reasoning_content = None
content = response[0]
return reasoning_content, content
def send_embedding_request(self, model, text):
"""发送嵌入请求,等待返回结果"""
text = text.replace("\n", " ")
return self.client.embeddings.create(input=[text], model=model).data[0].embedding

View File

@@ -2,11 +2,7 @@ import time
from typing import Tuple, List, Dict, Optional from typing import Tuple, List, Dict, Optional
from .global_logger import logger from .global_logger import logger
# from . import prompt_template
from .embedding_store import EmbeddingManager from .embedding_store import EmbeddingManager
# from .llm_client import LLMClient
from .kg_manager import KGManager from .kg_manager import KGManager
# from .lpmmconfig import global_config # from .lpmmconfig import global_config

View File

@@ -36,8 +36,6 @@ def init_prompt():
{chat_context_description},以下是具体的聊天内容 {chat_context_description},以下是具体的聊天内容
{chat_content_block} {chat_content_block}
{moderation_prompt} {moderation_prompt}
现在请你根据{by_what}选择合适的action和触发action的消息: 现在请你根据{by_what}选择合适的action和触发action的消息:

View File

@@ -24,13 +24,13 @@ from src.chat.utils.chat_message_builder import (
replace_user_references_sync, replace_user_references_sync,
) )
from src.chat.express.expression_selector import expression_selector from src.chat.express.expression_selector import expression_selector
from src.chat.knowledge.knowledge_lib import qa_manager
from src.chat.memory_system.memory_activator import MemoryActivator from src.chat.memory_system.memory_activator import MemoryActivator
from src.chat.memory_system.instant_memory import InstantMemory from src.chat.memory_system.instant_memory import InstantMemory
from src.mood.mood_manager import mood_manager from src.mood.mood_manager import mood_manager
from src.person_info.relationship_fetcher import relationship_fetcher_manager from src.person_info.relationship_fetcher import relationship_fetcher_manager
from src.person_info.person_info import get_person_info_manager from src.person_info.person_info import get_person_info_manager
from src.plugin_system.base.component_types import ActionInfo from src.plugin_system.base.component_types import ActionInfo
from src.plugin_system.apis import llm_api
logger = get_logger("replyer") logger = get_logger("replyer")
@@ -102,6 +102,22 @@ def init_prompt():
"s4u_style_prompt", "s4u_style_prompt",
) )
Prompt(
"""
你是一个专门获取知识的助手。你的名字是{bot_name}。现在是{time_now}
群里正在进行的聊天内容:
{chat_history}
现在,{sender}发送了内容:{target_message},你想要回复ta。
请仔细分析聊天内容,考虑以下几点:
1. 内容中是否包含需要查询信息的问题
2. 是否有明确的知识获取指令
If you need to use the search tool, please directly call the function "lpmm_search_knowledge". If you do not need to use any tool, simply output "No tool needed".
""",
name="lpmm_get_knowledge_prompt",
)
class DefaultReplyer: class DefaultReplyer:
def __init__( def __init__(
@@ -698,7 +714,7 @@ class DefaultReplyer:
self._time_and_run_task( self._time_and_run_task(
self.build_tool_info(chat_talking_prompt_short, reply_to, enable_tool=enable_tool), "tool_info" self.build_tool_info(chat_talking_prompt_short, reply_to, enable_tool=enable_tool), "tool_info"
), ),
self._time_and_run_task(get_prompt_info(target, threshold=0.38), "prompt_info"), self._time_and_run_task(self.get_prompt_info(chat_talking_prompt_short, reply_to), "prompt_info"),
) )
# 任务名称中英文映射 # 任务名称中英文映射
@@ -1000,6 +1016,63 @@ class DefaultReplyer:
logger.debug(f"replyer生成内容: {content}") logger.debug(f"replyer生成内容: {content}")
return content, reasoning_content, model_name, tool_calls return content, reasoning_content, model_name, tool_calls
async def get_prompt_info(self, message: str, reply_to: str):
related_info = ""
start_time = time.time()
from src.plugins.built_in.knowledge.lpmm_get_knowledge import SearchKnowledgeFromLPMMTool
if not reply_to:
logger.debug("没有回复对象,跳过获取知识库内容")
return ""
sender, content = self._parse_reply_target(reply_to)
if not content:
logger.debug("回复对象内容为空,跳过获取知识库内容")
return ""
logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}")
# 从LPMM知识库获取知识
try:
# 检查LPMM知识库是否启用
if not global_config.lpmm_knowledge.enable:
logger.debug("LPMM知识库未启用跳过获取知识库内容")
return ""
time_now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
bot_name = global_config.bot.nickname
prompt = await global_prompt_manager.format_prompt(
"lpmm_get_knowledge_prompt",
bot_name=bot_name,
time_now=time_now,
chat_history=message,
sender=sender,
target_message=content,
)
_, _, _, _, tool_calls = await llm_api.generate_with_model_with_tools(
prompt,
model_config=model_config.model_task_config.tool_use,
tool_options=[SearchKnowledgeFromLPMMTool.get_tool_definition()],
)
if tool_calls:
result = await self.tool_executor.execute_tool_call(tool_calls[0], SearchKnowledgeFromLPMMTool())
end_time = time.time()
if not result or not result.get("content"):
logger.debug("从LPMM知识库获取知识失败返回空知识...")
return ""
found_knowledge_from_lpmm = result.get("content", "")
logger.debug(
f"从LPMM知识库获取知识相关信息{found_knowledge_from_lpmm[:100]}...,信息长度: {len(found_knowledge_from_lpmm)}"
)
related_info += found_knowledge_from_lpmm
logger.debug(f"获取知识库内容耗时: {(end_time - start_time):.3f}")
logger.debug(f"获取知识库内容,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}")
return f"你有以下这些**知识**\n{related_info}\n请你**记住上面的知识**,之后可能会用到。\n"
else:
logger.debug("从LPMM知识库获取知识失败可能是从未导入过知识返回空知识...")
return ""
except Exception as e:
logger.error(f"获取知识库内容时发生异常: {str(e)}")
return ""
def weighted_sample_no_replacement(items, weights, k) -> list: def weighted_sample_no_replacement(items, weights, k) -> list:
""" """
@@ -1035,36 +1108,4 @@ def weighted_sample_no_replacement(items, weights, k) -> list:
return selected return selected
async def get_prompt_info(message: str, threshold: float):
related_info = ""
start_time = time.time()
logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}")
# 从LPMM知识库获取知识
try:
# 检查LPMM知识库是否启用
if qa_manager is None:
logger.debug("LPMM知识库已禁用跳过知识获取")
return ""
found_knowledge_from_lpmm = await qa_manager.get_knowledge(message)
end_time = time.time()
if found_knowledge_from_lpmm is not None:
logger.debug(
f"从LPMM知识库获取知识相关信息{found_knowledge_from_lpmm[:100]}...,信息长度: {len(found_knowledge_from_lpmm)}"
)
related_info += found_knowledge_from_lpmm
logger.debug(f"获取知识库内容耗时: {(end_time - start_time):.3f}")
logger.debug(f"获取知识库内容,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}")
return f"你有以下这些**知识**\n{related_info}\n请你**记住上面的知识**,之后可能会用到。\n"
else:
logger.debug("从LPMM知识库获取知识失败可能是从未导入过知识返回空知识...")
return ""
except Exception as e:
logger.error(f"获取知识库内容时发生异常: {str(e)}")
return ""
init_prompt() init_prompt()

View File

@@ -281,20 +281,6 @@ class Memory(BaseModel):
table_name = "memory" table_name = "memory"
class Knowledges(BaseModel):
"""
用于存储知识库条目的模型。
"""
content = TextField() # 知识内容的文本
embedding = TextField() # 知识内容的嵌入向量,存储为 JSON 字符串的浮点数列表
# 可以添加其他元数据字段,如 source, create_time 等
class Meta:
# database = db # 继承自 BaseModel
table_name = "knowledges"
class Expression(BaseModel): class Expression(BaseModel):
""" """
用于存储表达风格的模型。 用于存储表达风格的模型。
@@ -382,7 +368,6 @@ def create_tables():
ImageDescriptions, ImageDescriptions,
OnlineTime, OnlineTime,
PersonInfo, PersonInfo,
Knowledges,
Expression, Expression,
ThinkingLog, ThinkingLog,
GraphNodes, # 添加图节点表 GraphNodes, # 添加图节点表
@@ -408,7 +393,6 @@ def initialize_database():
ImageDescriptions, ImageDescriptions,
OnlineTime, OnlineTime,
PersonInfo, PersonInfo,
Knowledges,
Expression, Expression,
Memory, Memory,
ThinkingLog, ThinkingLog,

View File

@@ -181,7 +181,8 @@ class LLMRequest:
endpoint="/chat/completions", endpoint="/chat/completions",
) )
if not content: if not content:
raise RuntimeError("获取LLM生成内容失败") logger.warning("生成的响应为空")
content = "生成的响应为空,请检查模型配置或输入内容是否正确"
return content, (reasoning_content, model_info.name, tool_calls) return content, (reasoning_content, model_info.name, tool_calls)

View File

@@ -7,8 +7,9 @@
success, response, reasoning, model_name = await llm_api.generate_with_model(prompt, model_config) success, response, reasoning, model_name = await llm_api.generate_with_model(prompt, model_config)
""" """
from typing import Tuple, Dict from typing import Tuple, Dict, List, Any, Optional
from src.common.logger import get_logger from src.common.logger import get_logger
from src.llm_models.payload_content.tool_option import ToolCall
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config, model_config from src.config.config import global_config, model_config
from src.config.api_ada_configs import TaskConfig from src.config.api_ada_configs import TaskConfig
@@ -52,7 +53,11 @@ def get_available_models() -> Dict[str, TaskConfig]:
async def generate_with_model( async def generate_with_model(
prompt: str, model_config: TaskConfig, request_type: str = "plugin.generate", **kwargs prompt: str,
model_config: TaskConfig,
request_type: str = "plugin.generate",
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
) -> Tuple[bool, str, str, str]: ) -> Tuple[bool, str, str, str]:
"""使用指定模型生成内容 """使用指定模型生成内容
@@ -60,7 +65,6 @@ async def generate_with_model(
prompt: 提示词 prompt: 提示词
model_config: 模型配置(从 get_available_models 获取的模型配置) model_config: 模型配置(从 get_available_models 获取的模型配置)
request_type: 请求类型标识 request_type: 请求类型标识
**kwargs: 其他模型特定参数如temperature、max_tokens等
Returns: Returns:
Tuple[bool, str, str, str]: (是否成功, 生成的内容, 推理过程, 模型名称) Tuple[bool, str, str, str]: (是否成功, 生成的内容, 推理过程, 模型名称)
@@ -70,12 +74,53 @@ async def generate_with_model(
logger.info(f"[LLMAPI] 使用模型集合 {model_name_list} 生成内容") logger.info(f"[LLMAPI] 使用模型集合 {model_name_list} 生成内容")
logger.debug(f"[LLMAPI] 完整提示词: {prompt}") logger.debug(f"[LLMAPI] 完整提示词: {prompt}")
llm_request = LLMRequest(model_set=model_config, request_type=request_type, **kwargs) llm_request = LLMRequest(model_set=model_config, request_type=request_type)
response, (reasoning_content, model_name, _) = await llm_request.generate_response_async(prompt) response, (reasoning_content, model_name, _) = await llm_request.generate_response_async(prompt, temperature=temperature, max_tokens=max_tokens)
return True, response, reasoning_content, model_name return True, response, reasoning_content, model_name
except Exception as e: except Exception as e:
error_msg = f"生成内容时出错: {str(e)}" error_msg = f"生成内容时出错: {str(e)}"
logger.error(f"[LLMAPI] {error_msg}") logger.error(f"[LLMAPI] {error_msg}")
return False, error_msg, "", "" return False, error_msg, "", ""
async def generate_with_model_with_tools(
prompt: str,
model_config: TaskConfig,
tool_options: List[Dict[str, Any]] | None = None,
request_type: str = "plugin.generate",
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
) -> Tuple[bool, str, str, str, List[ToolCall] | None]:
"""使用指定模型和工具生成内容
Args:
prompt: 提示词
model_config: 模型配置(从 get_available_models 获取的模型配置)
tool_options: 工具选项列表
request_type: 请求类型标识
temperature: 温度参数
max_tokens: 最大token数
Returns:
Tuple[bool, str, str, str]: (是否成功, 生成的内容, 推理过程, 模型名称)
"""
try:
model_name_list = model_config.model_list
logger.info(f"[LLMAPI] 使用模型集合 {model_name_list} 生成内容")
logger.debug(f"[LLMAPI] 完整提示词: {prompt}")
llm_request = LLMRequest(model_set=model_config, request_type=request_type)
response, (reasoning_content, model_name, tool_call) = await llm_request.generate_response_async(
prompt,
tools=tool_options,
temperature=temperature,
max_tokens=max_tokens
)
return True, response, reasoning_content, model_name, tool_call
except Exception as e:
error_msg = f"生成内容时出错: {str(e)}"
logger.error(f"[LLMAPI] {error_msg}")
return False, error_msg, "", "", None

View File

@@ -3,10 +3,11 @@ from typing import List, Type, Tuple, Union
from .plugin_base import PluginBase from .plugin_base import PluginBase
from src.common.logger import get_logger from src.common.logger import get_logger
from src.plugin_system.base.component_types import ActionInfo, CommandInfo, EventHandlerInfo from src.plugin_system.base.component_types import ActionInfo, CommandInfo, EventHandlerInfo, ToolInfo
from .base_action import BaseAction from .base_action import BaseAction
from .base_command import BaseCommand from .base_command import BaseCommand
from .base_events_handler import BaseEventHandler from .base_events_handler import BaseEventHandler
from .base_tool import BaseTool
logger = get_logger("base_plugin") logger = get_logger("base_plugin")
@@ -31,6 +32,7 @@ class BasePlugin(PluginBase):
Tuple[ActionInfo, Type[BaseAction]], Tuple[ActionInfo, Type[BaseAction]],
Tuple[CommandInfo, Type[BaseCommand]], Tuple[CommandInfo, Type[BaseCommand]],
Tuple[EventHandlerInfo, Type[BaseEventHandler]], Tuple[EventHandlerInfo, Type[BaseEventHandler]],
Tuple[ToolInfo, Type[BaseTool]],
] ]
]: ]:
"""获取插件包含的组件列表 """获取插件包含的组件列表

View File

@@ -1,6 +1,7 @@
import time import time
from typing import List, Dict, Tuple, Optional, Any from typing import List, Dict, Tuple, Optional, Any
from src.plugin_system.apis.tool_api import get_llm_available_tool_definitions, get_tool_instance from src.plugin_system.apis.tool_api import get_llm_available_tool_definitions, get_tool_instance
from src.plugin_system.base.base_tool import BaseTool
from src.plugin_system.core.global_announcement_manager import global_announcement_manager from src.plugin_system.core.global_announcement_manager import global_announcement_manager
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.llm_models.payload_content import ToolCall from src.llm_models.payload_content import ToolCall
@@ -114,7 +115,7 @@ class ToolExecutor:
) )
# 执行工具调用 # 执行工具调用
tool_results, used_tools = await self._execute_tool_calls(tool_calls) tool_results, used_tools = await self.execute_tool_calls(tool_calls)
# 缓存结果 # 缓存结果
if tool_results: if tool_results:
@@ -133,7 +134,7 @@ class ToolExecutor:
user_disabled_tools = global_announcement_manager.get_disabled_chat_tools(self.chat_id) user_disabled_tools = global_announcement_manager.get_disabled_chat_tools(self.chat_id)
return [definition for name, definition in all_tools if name not in user_disabled_tools] return [definition for name, definition in all_tools if name not in user_disabled_tools]
async def _execute_tool_calls(self, tool_calls: Optional[List[ToolCall]]) -> Tuple[List[Dict[str, Any]], List[str]]: async def execute_tool_calls(self, tool_calls: Optional[List[ToolCall]]) -> Tuple[List[Dict[str, Any]], List[str]]:
"""执行工具调用 """执行工具调用
Args: Args:
@@ -158,7 +159,7 @@ class ToolExecutor:
logger.debug(f"{self.log_prefix}执行工具: {tool_name}") logger.debug(f"{self.log_prefix}执行工具: {tool_name}")
# 执行工具 # 执行工具
result = await self._execute_tool_call(tool_call) result = await self.execute_tool_call(tool_call)
if result: if result:
tool_info = { tool_info = {
@@ -191,7 +192,7 @@ class ToolExecutor:
return tool_results, used_tools return tool_results, used_tools
async def _execute_tool_call(self, tool_call: ToolCall) -> Optional[Dict[str, Any]]: async def execute_tool_call(self, tool_call: ToolCall, tool_instance: Optional[BaseTool] = None) -> Optional[Dict[str, Any]]:
# sourcery skip: use-assigned-variable # sourcery skip: use-assigned-variable
"""执行单个工具调用 """执行单个工具调用
@@ -207,7 +208,7 @@ class ToolExecutor:
function_args["llm_called"] = True # 标记为LLM调用 function_args["llm_called"] = True # 标记为LLM调用
# 获取对应工具实例 # 获取对应工具实例
tool_instance = get_tool_instance(function_name) tool_instance = tool_instance or get_tool_instance(function_name)
if not tool_instance: if not tool_instance:
logger.warning(f"未知工具名称: {function_name}") logger.warning(f"未知工具名称: {function_name}")
return None return None
@@ -294,7 +295,7 @@ class ToolExecutor:
if expired_keys: if expired_keys:
logger.debug(f"{self.log_prefix}清理了{len(expired_keys)}个过期缓存") logger.debug(f"{self.log_prefix}清理了{len(expired_keys)}个过期缓存")
async def execute_specific_tool(self, tool_name: str, tool_args: Dict) -> Optional[Dict]: async def execute_specific_tool_simple(self, tool_name: str, tool_args: Dict) -> Optional[Dict]:
"""直接执行指定工具 """直接执行指定工具
Args: Args:
@@ -314,7 +315,7 @@ class ToolExecutor:
logger.info(f"{self.log_prefix}直接执行工具: {tool_name}") logger.info(f"{self.log_prefix}直接执行工具: {tool_name}")
result = await self._execute_tool_call(tool_call) result = await self.execute_tool_call(tool_call)
if result: if result:
tool_info = { tool_info = {
@@ -405,7 +406,7 @@ results, used_tools, prompt = await executor.execute_from_chat_message(
) )
# 5. 直接执行特定工具 # 5. 直接执行特定工具
result = await executor.execute_specific_tool( result = await executor.execute_specific_tool_simple(
tool_name="get_knowledge", tool_name="get_knowledge",
tool_args={"query": "机器学习"} tool_args={"query": "机器学习"}
) )

View File

@@ -1,131 +0,0 @@
import json # Added for parsing embedding
import math # Added for cosine similarity
from typing import Any, Union, List # Added List
from src.chat.utils.utils import get_embedding
from src.common.database.database_model import Knowledges # Updated import
from src.common.logger import get_logger
from src.plugin_system import BaseTool, ToolParamType
logger = get_logger("get_knowledge_tool")
class SearchKnowledgeTool(BaseTool):
"""从知识库中搜索相关信息的工具"""
name = "search_knowledge"
description = "使用工具从知识库中搜索相关信息"
parameters = [
("query", ToolParamType.STRING, "搜索查询关键词", True, None),
("threshold", ToolParamType.FLOAT, "相似度阈值0.0到1.0之间", False, None),
]
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
"""执行知识库搜索
Args:
function_args: 工具参数
Returns:
dict: 工具执行结果
"""
query = "" # Initialize query to ensure it's defined in except block
try:
query = function_args.get("query")
threshold = function_args.get("threshold", 0.4)
# 调用知识库搜索
embedding = await get_embedding(query, request_type="info_retrieval")
if embedding:
knowledge_info = self.get_info_from_db(embedding, limit=3, threshold=threshold)
if knowledge_info:
content = f"你知道这些知识: {knowledge_info}"
else:
content = f"你不太了解有关{query}的知识"
return {"type": "knowledge", "id": query, "content": content}
return {"type": "info", "id": query, "content": f"无法获取关于'{query}'的嵌入向量,你知识库炸了"}
except Exception as e:
logger.error(f"知识库搜索工具执行失败: {str(e)}")
return {"type": "info", "id": query, "content": f"知识库搜索失败,炸了: {str(e)}"}
@staticmethod
def _cosine_similarity(vec1: List[float], vec2: List[float]) -> float:
"""计算两个向量之间的余弦相似度"""
dot_product = sum(p * q for p, q in zip(vec1, vec2, strict=False))
magnitude1 = math.sqrt(sum(p * p for p in vec1))
magnitude2 = math.sqrt(sum(q * q for q in vec2))
if magnitude1 == 0 or magnitude2 == 0:
return 0.0
return dot_product / (magnitude1 * magnitude2)
@staticmethod
def get_info_from_db(
query_embedding: list[float], limit: int = 1, threshold: float = 0.5, return_raw: bool = False
) -> Union[str, list]:
"""从数据库中获取相关信息
Args:
query_embedding: 查询的嵌入向量
limit: 最大返回结果数
threshold: 相似度阈值
return_raw: 是否返回原始结果
Returns:
Union[str, list]: 格式化的信息字符串或原始结果列表
"""
if not query_embedding:
return [] if return_raw else ""
similar_items = []
try:
all_knowledges = Knowledges.select()
for item in all_knowledges:
try:
item_embedding_str = item.embedding
if not item_embedding_str:
logger.warning(f"Knowledge item ID {item.id} has empty embedding string.")
continue
item_embedding = json.loads(item_embedding_str)
if not isinstance(item_embedding, list) or not all(
isinstance(x, (int, float)) for x in item_embedding
):
logger.warning(f"Knowledge item ID {item.id} has invalid embedding format after JSON parsing.")
continue
except json.JSONDecodeError:
logger.warning(f"Failed to parse embedding for knowledge item ID {item.id}")
continue
except AttributeError:
logger.warning(f"Knowledge item ID {item.id} missing 'embedding' attribute or it's not a string.")
continue
similarity = SearchKnowledgeTool._cosine_similarity(query_embedding, item_embedding)
if similarity >= threshold:
similar_items.append({"content": item.content, "similarity": similarity, "raw_item": item})
# 按相似度降序排序
similar_items.sort(key=lambda x: x["similarity"], reverse=True)
# 应用限制
results = similar_items[:limit]
logger.debug(f"知识库查询后,符合条件的结果数量: {len(results)}")
except Exception as e:
logger.error(f"从 Peewee 数据库获取知识信息失败: {str(e)}")
return [] if return_raw else ""
if not results:
return [] if return_raw else ""
if return_raw:
# Peewee 模型实例不能直接序列化为 JSON如果需要原始模型调用者需要处理
# 这里返回包含内容和相似度的字典列表
return [{"content": r["content"], "similarity": r["similarity"]} for r in results]
else:
# 返回所有找到的内容,用换行分隔
return "\n".join(str(result["content"]) for result in results)
# 注册工具
# register_tool(SearchKnowledgeTool)

View File

@@ -1,6 +1,7 @@
from typing import Dict, Any from typing import Dict, Any
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.config import global_config
from src.chat.knowledge.knowledge_lib import qa_manager from src.chat.knowledge.knowledge_lib import qa_manager
from src.plugin_system import BaseTool, ToolParamType from src.plugin_system import BaseTool, ToolParamType
@@ -16,6 +17,7 @@ class SearchKnowledgeFromLPMMTool(BaseTool):
("query", ToolParamType.STRING, "搜索查询关键词", True, None), ("query", ToolParamType.STRING, "搜索查询关键词", True, None),
("threshold", ToolParamType.FLOAT, "相似度阈值0.0到1.0之间", False, None), ("threshold", ToolParamType.FLOAT, "相似度阈值0.0到1.0之间", False, None),
] ]
available_for_llm = global_config.lpmm_knowledge.enable
async def execute(self, function_args: Dict[str, Any]) -> Dict[str, Any]: async def execute(self, function_args: Dict[str, Any]) -> Dict[str, Any]:
"""执行知识库搜索 """执行知识库搜索