QA: Refactor similarity calculation and improve state management logic

This commit is contained in:
晴猫
2025-05-01 06:07:59 +09:00
parent 2f669c7055
commit 3d001da30e
7 changed files with 91 additions and 86 deletions

View File

@@ -15,6 +15,26 @@ if TYPE_CHECKING:
logger = get_module_logger("pfc")
def _calculate_similarity(goal1: str, goal2: str) -> float:
"""简单计算两个目标之间的相似度
这里使用一个简单的实现,实际可以使用更复杂的文本相似度算法
Args:
goal1: 第一个目标
goal2: 第二个目标
Returns:
float: 相似度得分 (0-1)
"""
# 简单实现:检查重叠字数比例
words1 = set(goal1)
words2 = set(goal2)
overlap = len(words1.intersection(words2))
total = len(words1.union(words2))
return overlap / total if total > 0 else 0
class GoalAnalyzer:
"""对话目标分析器"""
@@ -166,7 +186,7 @@ class GoalAnalyzer:
"""
# 检查新目标是否与现有目标相似
for i, (existing_goal, _, _) in enumerate(self.goals):
if self._calculate_similarity(new_goal, existing_goal) > 0.7: # 相似度阈值
if _calculate_similarity(new_goal, existing_goal) > 0.7: # 相似度阈值
# 更新现有目标
self.goals[i] = (new_goal, method, reasoning)
# 将此目标移到列表前面(最主要的位置)
@@ -180,25 +200,6 @@ class GoalAnalyzer:
if len(self.goals) > self.max_goals:
self.goals.pop() # 移除最老的目标
def _calculate_similarity(self, goal1: str, goal2: str) -> float:
"""简单计算两个目标之间的相似度
这里使用一个简单的实现,实际可以使用更复杂的文本相似度算法
Args:
goal1: 第一个目标
goal2: 第二个目标
Returns:
float: 相似度得分 (0-1)
"""
# 简单实现:检查重叠字数比例
words1 = set(goal1)
words2 = set(goal2)
overlap = len(words1.intersection(words2))
total = len(words1.union(words2))
return overlap / total if total > 0 else 0
async def get_all_goals(self) -> List[Tuple[str, str, str]]:
"""获取所有当前目标

View File

@@ -84,7 +84,7 @@ class ChatBot:
return
# 群聊黑名单拦截
if groupinfo != None and groupinfo.group_id not in global_config.talk_allowed_groups:
if groupinfo is not None and groupinfo.group_id not in global_config.talk_allowed_groups:
logger.trace(f"{groupinfo.group_id}被禁止回复")
return

View File

@@ -1,5 +1,3 @@
from typing import List
from .llm_client import LLMMessage
entity_extract_system_prompt = """你是一个性能优异的实体提取系统。请从段落中提取出所有实体并以JSON列表的形式输出。
@@ -13,7 +11,7 @@ entity_extract_system_prompt = """你是一个性能优异的实体提取系统
"""
def build_entity_extract_context(paragraph: str) -> List[LLMMessage]:
def build_entity_extract_context(paragraph: str) -> list[LLMMessage]:
messages = [
LLMMessage("system", entity_extract_system_prompt).to_dict(),
LLMMessage("user", f"""段落:\n```\n{paragraph}```""").to_dict(),
@@ -38,7 +36,7 @@ rdf_triple_extract_system_prompt = """你是一个性能优异的RDF资源描
"""
def build_rdf_triple_extract_context(paragraph: str, entities: str) -> List[LLMMessage]:
def build_rdf_triple_extract_context(paragraph: str, entities: str) -> list[LLMMessage]:
messages = [
LLMMessage("system", rdf_triple_extract_system_prompt).to_dict(),
LLMMessage("user", f"""段落:\n```\n{paragraph}```\n\n实体列表:\n```\n{entities}```""").to_dict(),
@@ -56,7 +54,7 @@ qa_system_prompt = """
"""
def build_qa_context(question: str, knowledge: list[(str, str, str)]) -> List[LLMMessage]:
def build_qa_context(question: str, knowledge: list[tuple[str, str, str]]) -> list[LLMMessage]:
knowledge = "\n".join([f"{i + 1}. 相关性:{k[0]}\n{k[1]}" for i, k in enumerate(knowledge)])
messages = [
LLMMessage("system", qa_system_prompt).to_dict(),

View File

@@ -65,6 +65,28 @@ error_code_mapping = {
}
async def _safely_record(request_content: Dict[str, Any], payload: Dict[str, Any]):
image_base64: str = request_content.get("image_base64")
image_format: str = request_content.get("image_format")
if (
image_base64
and payload
and isinstance(payload, dict)
and "messages" in payload
and len(payload["messages"]) > 0
):
if isinstance(payload["messages"][0], dict) and "content" in payload["messages"][0]:
content = payload["messages"][0]["content"]
if isinstance(content, list) and len(content) > 1 and "image_url" in content[1]:
payload["messages"][0]["content"][1]["image_url"]["url"] = (
f"data:image/{image_format.lower() if image_format else 'jpeg'};base64,"
f"{image_base64[:10]}...{image_base64[-10:]}"
)
# if isinstance(content, str) and len(content) > 100:
# payload["messages"][0]["content"] = content[:100]
return payload
class LLMRequest:
# 定义需要转换的模型列表,作为类变量避免重复
MODELS_NEEDING_TRANSFORMATION = [
@@ -551,7 +573,7 @@ class LLMRequest:
f"模型 {self.model_name} HTTP响应错误达到最大重试次数: 状态码: {exception.status}, 错误: {exception.message}"
)
# 安全地检查和记录请求详情
handled_payload = await self._safely_record(request_content, payload)
handled_payload = await _safely_record(request_content, payload)
logger.critical(f"请求头: {await self._build_headers(no_key=True)} 请求体: {handled_payload}")
raise RuntimeError(
f"模型 {self.model_name} API请求失败: 状态码 {exception.status}, {exception.message}"
@@ -565,31 +587,10 @@ class LLMRequest:
else:
logger.critical(f"模型 {self.model_name} 请求失败: {str(exception)}")
# 安全地检查和记录请求详情
handled_payload = await self._safely_record(request_content, payload)
handled_payload = await _safely_record(request_content, payload)
logger.critical(f"请求头: {await self._build_headers(no_key=True)} 请求体: {handled_payload}")
raise RuntimeError(f"模型 {self.model_name} API请求失败: {str(exception)}")
async def _safely_record(self, request_content: Dict[str, Any], payload: Dict[str, Any]):
image_base64: str = request_content.get("image_base64")
image_format: str = request_content.get("image_format")
if (
image_base64
and payload
and isinstance(payload, dict)
and "messages" in payload
and len(payload["messages"]) > 0
):
if isinstance(payload["messages"][0], dict) and "content" in payload["messages"][0]:
content = payload["messages"][0]["content"]
if isinstance(content, list) and len(content) > 1 and "image_url" in content[1]:
payload["messages"][0]["content"][1]["image_url"]["url"] = (
f"data:image/{image_format.lower() if image_format else 'jpeg'};base64,"
f"{image_base64[:10]}...{image_base64[-10:]}"
)
# if isinstance(content, str) and len(content) > 100:
# payload["messages"][0]["content"] = content[:100]
return payload
async def _transform_parameters(self, params: dict) -> dict:
"""
根据模型名称转换参数: