QA: Refactor similarity calculation and improve state management logic

This commit is contained in:
晴猫
2025-05-01 06:07:59 +09:00
parent 2f669c7055
commit 3d001da30e
7 changed files with 91 additions and 86 deletions

View File

@@ -62,6 +62,7 @@ class MaiState(enum.Enum):
return MAX_NORMAL_CHAT_NUM_NORMAL
elif self == MaiState.FOCUSED_CHAT:
return MAX_NORMAL_CHAT_NUM_FOCUSED
return None
def get_focused_chat_max_num(self):
# 调试用
@@ -76,6 +77,7 @@ class MaiState(enum.Enum):
return MAX_FOCUSED_CHAT_NUM_NORMAL
elif self == MaiState.FOCUSED_CHAT:
return MAX_FOCUSED_CHAT_NUM_FOCUSED
return None
class MaiStateInfo:
@@ -135,7 +137,8 @@ class MaiStateManager:
def __init__(self):
pass
def check_and_decide_next_state(self, current_state_info: MaiStateInfo) -> Optional[MaiState]:
@staticmethod
def check_and_decide_next_state(current_state_info: MaiStateInfo) -> Optional[MaiState]:
"""
根据当前状态和规则检查是否需要转换状态,并决定下一个状态。

View File

@@ -78,7 +78,7 @@ def calculate_replacement_probability(similarity: float) -> float:
# p = 3.5 * s - 1.4
probability = 3.5 * similarity - 1.4
return max(0.0, probability)
elif 0.6 < similarity < 0.9:
else: # 0.6 < similarity < 0.9
# p = s + 0.1
probability = similarity + 0.1
return min(1.0, max(0.0, probability))
@@ -169,7 +169,7 @@ class SubMind:
last_cycle = history_cycle[-1] if history_cycle else None
# 上一次决策信息
if last_cycle != None:
if last_cycle is not None:
last_action = last_cycle.action_type
last_reasoning = last_cycle.reasoning
is_replan = last_cycle.replanned

View File

@@ -32,6 +32,40 @@ INACTIVE_THRESHOLD_SECONDS = 3600 # 子心流不活跃超时时间(秒)
NORMAL_CHAT_TIMEOUT_SECONDS = 30 * 60 # 30分钟
async def _try_set_subflow_absent_internal(subflow: "SubHeartflow", log_prefix: str) -> bool:
"""
尝试将给定的子心流对象状态设置为 ABSENT (内部方法,不处理锁)。
Args:
subflow: 子心流对象。
log_prefix: 用于日志记录的前缀 (例如 "[子心流管理]""[停用]")。
Returns:
bool: 如果状态成功变为 ABSENT 或原本就是 ABSENT返回 True否则返回 False。
"""
flow_id = subflow.subheartflow_id
stream_name = chat_manager.get_stream_name(flow_id) or flow_id
if subflow.chat_state.chat_status != ChatState.ABSENT:
logger.debug(f"{log_prefix} 设置 {stream_name} 状态为 ABSENT")
try:
await subflow.change_chat_state(ChatState.ABSENT)
# 再次检查以确认状态已更改 (change_chat_state 内部应确保)
if subflow.chat_state.chat_status == ChatState.ABSENT:
return True
else:
logger.warning(
f"{log_prefix} 调用 change_chat_state 后,{stream_name} 状态仍为 {subflow.chat_state.chat_status.value}"
)
return False
except Exception as e:
logger.error(f"{log_prefix} 设置 {stream_name} 状态为 ABSENT 时失败: {e}", exc_info=True)
return False
else:
logger.debug(f"{log_prefix} {stream_name} 已是 ABSENT 状态")
return True # 已经是目标状态,视为成功
class SubHeartflowManager:
"""管理所有活跃的 SubHeartflow 实例。"""
@@ -109,38 +143,6 @@ class SubHeartflowManager:
return None
# --- 新增:内部方法,用于尝试将单个子心流设置为 ABSENT ---
async def _try_set_subflow_absent_internal(self, subflow: "SubHeartflow", log_prefix: str) -> bool:
"""
尝试将给定的子心流对象状态设置为 ABSENT (内部方法,不处理锁)。
Args:
subflow: 子心流对象。
log_prefix: 用于日志记录的前缀 (例如 "[子心流管理]""[停用]")。
Returns:
bool: 如果状态成功变为 ABSENT 或原本就是 ABSENT返回 True否则返回 False。
"""
flow_id = subflow.subheartflow_id
stream_name = chat_manager.get_stream_name(flow_id) or flow_id
if subflow.chat_state.chat_status != ChatState.ABSENT:
logger.debug(f"{log_prefix} 设置 {stream_name} 状态为 ABSENT")
try:
await subflow.change_chat_state(ChatState.ABSENT)
# 再次检查以确认状态已更改 (change_chat_state 内部应确保)
if subflow.chat_state.chat_status == ChatState.ABSENT:
return True
else:
logger.warning(
f"{log_prefix} 调用 change_chat_state 后,{stream_name} 状态仍为 {subflow.chat_state.chat_status.value}"
)
return False
except Exception as e:
logger.error(f"{log_prefix} 设置 {stream_name} 状态为 ABSENT 时失败: {e}", exc_info=True)
return False
else:
logger.debug(f"{log_prefix} {stream_name} 已是 ABSENT 状态")
return True # 已经是目标状态,视为成功
# --- 结束新增 ---
@@ -154,7 +156,7 @@ class SubHeartflowManager:
logger.info(f"{log_prefix} 正在停止 {stream_name}, 原因: {reason}")
# 调用内部方法处理状态变更
success = await self._try_set_subflow_absent_internal(subheartflow, log_prefix)
success = await _try_set_subflow_absent_internal(subheartflow, log_prefix)
return success
# 锁在此处自动释放
@@ -241,7 +243,7 @@ class SubHeartflowManager:
# 记录原始状态,以便统计实际改变的数量
original_state_was_absent = subflow.chat_state.chat_status == ChatState.ABSENT
success = await self._try_set_subflow_absent_internal(subflow, log_prefix)
success = await _try_set_subflow_absent_internal(subflow, log_prefix)
# 如果成功设置为 ABSENT 且原始状态不是 ABSENT则计数
if success and not original_state_was_absent:

View File

@@ -15,6 +15,26 @@ if TYPE_CHECKING:
logger = get_module_logger("pfc")
def _calculate_similarity(goal1: str, goal2: str) -> float:
"""简单计算两个目标之间的相似度
这里使用一个简单的实现,实际可以使用更复杂的文本相似度算法
Args:
goal1: 第一个目标
goal2: 第二个目标
Returns:
float: 相似度得分 (0-1)
"""
# 简单实现:检查重叠字数比例
words1 = set(goal1)
words2 = set(goal2)
overlap = len(words1.intersection(words2))
total = len(words1.union(words2))
return overlap / total if total > 0 else 0
class GoalAnalyzer:
"""对话目标分析器"""
@@ -166,7 +186,7 @@ class GoalAnalyzer:
"""
# 检查新目标是否与现有目标相似
for i, (existing_goal, _, _) in enumerate(self.goals):
if self._calculate_similarity(new_goal, existing_goal) > 0.7: # 相似度阈值
if _calculate_similarity(new_goal, existing_goal) > 0.7: # 相似度阈值
# 更新现有目标
self.goals[i] = (new_goal, method, reasoning)
# 将此目标移到列表前面(最主要的位置)
@@ -180,25 +200,6 @@ class GoalAnalyzer:
if len(self.goals) > self.max_goals:
self.goals.pop() # 移除最老的目标
def _calculate_similarity(self, goal1: str, goal2: str) -> float:
"""简单计算两个目标之间的相似度
这里使用一个简单的实现,实际可以使用更复杂的文本相似度算法
Args:
goal1: 第一个目标
goal2: 第二个目标
Returns:
float: 相似度得分 (0-1)
"""
# 简单实现:检查重叠字数比例
words1 = set(goal1)
words2 = set(goal2)
overlap = len(words1.intersection(words2))
total = len(words1.union(words2))
return overlap / total if total > 0 else 0
async def get_all_goals(self) -> List[Tuple[str, str, str]]:
"""获取所有当前目标

View File

@@ -84,7 +84,7 @@ class ChatBot:
return
# 群聊黑名单拦截
if groupinfo != None and groupinfo.group_id not in global_config.talk_allowed_groups:
if groupinfo is not None and groupinfo.group_id not in global_config.talk_allowed_groups:
logger.trace(f"{groupinfo.group_id}被禁止回复")
return

View File

@@ -1,5 +1,3 @@
from typing import List
from .llm_client import LLMMessage
entity_extract_system_prompt = """你是一个性能优异的实体提取系统。请从段落中提取出所有实体并以JSON列表的形式输出。
@@ -13,7 +11,7 @@ entity_extract_system_prompt = """你是一个性能优异的实体提取系统
"""
def build_entity_extract_context(paragraph: str) -> List[LLMMessage]:
def build_entity_extract_context(paragraph: str) -> list[LLMMessage]:
messages = [
LLMMessage("system", entity_extract_system_prompt).to_dict(),
LLMMessage("user", f"""段落:\n```\n{paragraph}```""").to_dict(),
@@ -38,7 +36,7 @@ rdf_triple_extract_system_prompt = """你是一个性能优异的RDF资源描
"""
def build_rdf_triple_extract_context(paragraph: str, entities: str) -> List[LLMMessage]:
def build_rdf_triple_extract_context(paragraph: str, entities: str) -> list[LLMMessage]:
messages = [
LLMMessage("system", rdf_triple_extract_system_prompt).to_dict(),
LLMMessage("user", f"""段落:\n```\n{paragraph}```\n\n实体列表:\n```\n{entities}```""").to_dict(),
@@ -56,7 +54,7 @@ qa_system_prompt = """
"""
def build_qa_context(question: str, knowledge: list[(str, str, str)]) -> List[LLMMessage]:
def build_qa_context(question: str, knowledge: list[tuple[str, str, str]]) -> list[LLMMessage]:
knowledge = "\n".join([f"{i + 1}. 相关性:{k[0]}\n{k[1]}" for i, k in enumerate(knowledge)])
messages = [
LLMMessage("system", qa_system_prompt).to_dict(),

View File

@@ -65,6 +65,28 @@ error_code_mapping = {
}
async def _safely_record(request_content: Dict[str, Any], payload: Dict[str, Any]):
image_base64: str = request_content.get("image_base64")
image_format: str = request_content.get("image_format")
if (
image_base64
and payload
and isinstance(payload, dict)
and "messages" in payload
and len(payload["messages"]) > 0
):
if isinstance(payload["messages"][0], dict) and "content" in payload["messages"][0]:
content = payload["messages"][0]["content"]
if isinstance(content, list) and len(content) > 1 and "image_url" in content[1]:
payload["messages"][0]["content"][1]["image_url"]["url"] = (
f"data:image/{image_format.lower() if image_format else 'jpeg'};base64,"
f"{image_base64[:10]}...{image_base64[-10:]}"
)
# if isinstance(content, str) and len(content) > 100:
# payload["messages"][0]["content"] = content[:100]
return payload
class LLMRequest:
# 定义需要转换的模型列表,作为类变量避免重复
MODELS_NEEDING_TRANSFORMATION = [
@@ -551,7 +573,7 @@ class LLMRequest:
f"模型 {self.model_name} HTTP响应错误达到最大重试次数: 状态码: {exception.status}, 错误: {exception.message}"
)
# 安全地检查和记录请求详情
handled_payload = await self._safely_record(request_content, payload)
handled_payload = await _safely_record(request_content, payload)
logger.critical(f"请求头: {await self._build_headers(no_key=True)} 请求体: {handled_payload}")
raise RuntimeError(
f"模型 {self.model_name} API请求失败: 状态码 {exception.status}, {exception.message}"
@@ -565,31 +587,10 @@ class LLMRequest:
else:
logger.critical(f"模型 {self.model_name} 请求失败: {str(exception)}")
# 安全地检查和记录请求详情
handled_payload = await self._safely_record(request_content, payload)
handled_payload = await _safely_record(request_content, payload)
logger.critical(f"请求头: {await self._build_headers(no_key=True)} 请求体: {handled_payload}")
raise RuntimeError(f"模型 {self.model_name} API请求失败: {str(exception)}")
async def _safely_record(self, request_content: Dict[str, Any], payload: Dict[str, Any]):
image_base64: str = request_content.get("image_base64")
image_format: str = request_content.get("image_format")
if (
image_base64
and payload
and isinstance(payload, dict)
and "messages" in payload
and len(payload["messages"]) > 0
):
if isinstance(payload["messages"][0], dict) and "content" in payload["messages"][0]:
content = payload["messages"][0]["content"]
if isinstance(content, list) and len(content) > 1 and "image_url" in content[1]:
payload["messages"][0]["content"][1]["image_url"]["url"] = (
f"data:image/{image_format.lower() if image_format else 'jpeg'};base64,"
f"{image_base64[:10]}...{image_base64[-10:]}"
)
# if isinstance(content, str) and len(content) > 100:
# payload["messages"][0]["content"] = content[:100]
return payload
async def _transform_parameters(self, params: dict) -> dict:
"""
根据模型名称转换参数: