QA: Refactor similarity calculation and improve state management logic
This commit is contained in:
@@ -15,6 +15,26 @@ if TYPE_CHECKING:
|
||||
logger = get_module_logger("pfc")
|
||||
|
||||
|
||||
def _calculate_similarity(goal1: str, goal2: str) -> float:
|
||||
"""简单计算两个目标之间的相似度
|
||||
|
||||
这里使用一个简单的实现,实际可以使用更复杂的文本相似度算法
|
||||
|
||||
Args:
|
||||
goal1: 第一个目标
|
||||
goal2: 第二个目标
|
||||
|
||||
Returns:
|
||||
float: 相似度得分 (0-1)
|
||||
"""
|
||||
# 简单实现:检查重叠字数比例
|
||||
words1 = set(goal1)
|
||||
words2 = set(goal2)
|
||||
overlap = len(words1.intersection(words2))
|
||||
total = len(words1.union(words2))
|
||||
return overlap / total if total > 0 else 0
|
||||
|
||||
|
||||
class GoalAnalyzer:
|
||||
"""对话目标分析器"""
|
||||
|
||||
@@ -166,7 +186,7 @@ class GoalAnalyzer:
|
||||
"""
|
||||
# 检查新目标是否与现有目标相似
|
||||
for i, (existing_goal, _, _) in enumerate(self.goals):
|
||||
if self._calculate_similarity(new_goal, existing_goal) > 0.7: # 相似度阈值
|
||||
if _calculate_similarity(new_goal, existing_goal) > 0.7: # 相似度阈值
|
||||
# 更新现有目标
|
||||
self.goals[i] = (new_goal, method, reasoning)
|
||||
# 将此目标移到列表前面(最主要的位置)
|
||||
@@ -180,25 +200,6 @@ class GoalAnalyzer:
|
||||
if len(self.goals) > self.max_goals:
|
||||
self.goals.pop() # 移除最老的目标
|
||||
|
||||
def _calculate_similarity(self, goal1: str, goal2: str) -> float:
|
||||
"""简单计算两个目标之间的相似度
|
||||
|
||||
这里使用一个简单的实现,实际可以使用更复杂的文本相似度算法
|
||||
|
||||
Args:
|
||||
goal1: 第一个目标
|
||||
goal2: 第二个目标
|
||||
|
||||
Returns:
|
||||
float: 相似度得分 (0-1)
|
||||
"""
|
||||
# 简单实现:检查重叠字数比例
|
||||
words1 = set(goal1)
|
||||
words2 = set(goal2)
|
||||
overlap = len(words1.intersection(words2))
|
||||
total = len(words1.union(words2))
|
||||
return overlap / total if total > 0 else 0
|
||||
|
||||
async def get_all_goals(self) -> List[Tuple[str, str, str]]:
|
||||
"""获取所有当前目标
|
||||
|
||||
|
||||
@@ -84,7 +84,7 @@ class ChatBot:
|
||||
return
|
||||
|
||||
# 群聊黑名单拦截
|
||||
if groupinfo != None and groupinfo.group_id not in global_config.talk_allowed_groups:
|
||||
if groupinfo is not None and groupinfo.group_id not in global_config.talk_allowed_groups:
|
||||
logger.trace(f"群{groupinfo.group_id}被禁止回复")
|
||||
return
|
||||
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from typing import List
|
||||
|
||||
from .llm_client import LLMMessage
|
||||
|
||||
entity_extract_system_prompt = """你是一个性能优异的实体提取系统。请从段落中提取出所有实体,并以JSON列表的形式输出。
|
||||
@@ -13,7 +11,7 @@ entity_extract_system_prompt = """你是一个性能优异的实体提取系统
|
||||
"""
|
||||
|
||||
|
||||
def build_entity_extract_context(paragraph: str) -> List[LLMMessage]:
|
||||
def build_entity_extract_context(paragraph: str) -> list[LLMMessage]:
|
||||
messages = [
|
||||
LLMMessage("system", entity_extract_system_prompt).to_dict(),
|
||||
LLMMessage("user", f"""段落:\n```\n{paragraph}```""").to_dict(),
|
||||
@@ -38,7 +36,7 @@ rdf_triple_extract_system_prompt = """你是一个性能优异的RDF(资源描
|
||||
"""
|
||||
|
||||
|
||||
def build_rdf_triple_extract_context(paragraph: str, entities: str) -> List[LLMMessage]:
|
||||
def build_rdf_triple_extract_context(paragraph: str, entities: str) -> list[LLMMessage]:
|
||||
messages = [
|
||||
LLMMessage("system", rdf_triple_extract_system_prompt).to_dict(),
|
||||
LLMMessage("user", f"""段落:\n```\n{paragraph}```\n\n实体列表:\n```\n{entities}```""").to_dict(),
|
||||
@@ -56,7 +54,7 @@ qa_system_prompt = """
|
||||
"""
|
||||
|
||||
|
||||
def build_qa_context(question: str, knowledge: list[(str, str, str)]) -> List[LLMMessage]:
|
||||
def build_qa_context(question: str, knowledge: list[tuple[str, str, str]]) -> list[LLMMessage]:
|
||||
knowledge = "\n".join([f"{i + 1}. 相关性:{k[0]}\n{k[1]}" for i, k in enumerate(knowledge)])
|
||||
messages = [
|
||||
LLMMessage("system", qa_system_prompt).to_dict(),
|
||||
|
||||
@@ -65,6 +65,28 @@ error_code_mapping = {
|
||||
}
|
||||
|
||||
|
||||
async def _safely_record(request_content: Dict[str, Any], payload: Dict[str, Any]):
|
||||
image_base64: str = request_content.get("image_base64")
|
||||
image_format: str = request_content.get("image_format")
|
||||
if (
|
||||
image_base64
|
||||
and payload
|
||||
and isinstance(payload, dict)
|
||||
and "messages" in payload
|
||||
and len(payload["messages"]) > 0
|
||||
):
|
||||
if isinstance(payload["messages"][0], dict) and "content" in payload["messages"][0]:
|
||||
content = payload["messages"][0]["content"]
|
||||
if isinstance(content, list) and len(content) > 1 and "image_url" in content[1]:
|
||||
payload["messages"][0]["content"][1]["image_url"]["url"] = (
|
||||
f"data:image/{image_format.lower() if image_format else 'jpeg'};base64,"
|
||||
f"{image_base64[:10]}...{image_base64[-10:]}"
|
||||
)
|
||||
# if isinstance(content, str) and len(content) > 100:
|
||||
# payload["messages"][0]["content"] = content[:100]
|
||||
return payload
|
||||
|
||||
|
||||
class LLMRequest:
|
||||
# 定义需要转换的模型列表,作为类变量避免重复
|
||||
MODELS_NEEDING_TRANSFORMATION = [
|
||||
@@ -551,7 +573,7 @@ class LLMRequest:
|
||||
f"模型 {self.model_name} HTTP响应错误达到最大重试次数: 状态码: {exception.status}, 错误: {exception.message}"
|
||||
)
|
||||
# 安全地检查和记录请求详情
|
||||
handled_payload = await self._safely_record(request_content, payload)
|
||||
handled_payload = await _safely_record(request_content, payload)
|
||||
logger.critical(f"请求头: {await self._build_headers(no_key=True)} 请求体: {handled_payload}")
|
||||
raise RuntimeError(
|
||||
f"模型 {self.model_name} API请求失败: 状态码 {exception.status}, {exception.message}"
|
||||
@@ -565,31 +587,10 @@ class LLMRequest:
|
||||
else:
|
||||
logger.critical(f"模型 {self.model_name} 请求失败: {str(exception)}")
|
||||
# 安全地检查和记录请求详情
|
||||
handled_payload = await self._safely_record(request_content, payload)
|
||||
handled_payload = await _safely_record(request_content, payload)
|
||||
logger.critical(f"请求头: {await self._build_headers(no_key=True)} 请求体: {handled_payload}")
|
||||
raise RuntimeError(f"模型 {self.model_name} API请求失败: {str(exception)}")
|
||||
|
||||
async def _safely_record(self, request_content: Dict[str, Any], payload: Dict[str, Any]):
|
||||
image_base64: str = request_content.get("image_base64")
|
||||
image_format: str = request_content.get("image_format")
|
||||
if (
|
||||
image_base64
|
||||
and payload
|
||||
and isinstance(payload, dict)
|
||||
and "messages" in payload
|
||||
and len(payload["messages"]) > 0
|
||||
):
|
||||
if isinstance(payload["messages"][0], dict) and "content" in payload["messages"][0]:
|
||||
content = payload["messages"][0]["content"]
|
||||
if isinstance(content, list) and len(content) > 1 and "image_url" in content[1]:
|
||||
payload["messages"][0]["content"][1]["image_url"]["url"] = (
|
||||
f"data:image/{image_format.lower() if image_format else 'jpeg'};base64,"
|
||||
f"{image_base64[:10]}...{image_base64[-10:]}"
|
||||
)
|
||||
# if isinstance(content, str) and len(content) > 100:
|
||||
# payload["messages"][0]["content"] = content[:100]
|
||||
return payload
|
||||
|
||||
async def _transform_parameters(self, params: dict) -> dict:
|
||||
"""
|
||||
根据模型名称转换参数:
|
||||
|
||||
Reference in New Issue
Block a user