This commit is contained in:
SengokuCola
2025-07-13 12:59:10 +08:00
89 changed files with 1682 additions and 2588 deletions

View File

@@ -5,20 +5,19 @@ import os
import random
import time
import traceback
from typing import Optional, Tuple, List, Any
from PIL import Image
import io
import re
# from gradio_client import file
import binascii
from typing import Optional, Tuple, List, Any
from PIL import Image
from rich.traceback import install
from src.common.database.database_model import Emoji
from src.common.database.database import db as peewee_db
from src.common.logger import get_logger
from src.config.config import global_config
from src.chat.utils.utils_image import image_path_to_base64, get_image_manager
from src.llm_models.utils_model import LLMRequest
from src.common.logger import get_logger
from rich.traceback import install
install(extra_lines=3)
@@ -26,7 +25,7 @@ logger = get_logger("emoji")
BASE_DIR = os.path.join("data")
EMOJI_DIR = os.path.join(BASE_DIR, "emoji") # 表情包存储目录
EMOJI_REGISTED_DIR = os.path.join(BASE_DIR, "emoji_registed") # 已注册的表情包注册目录
EMOJI_REGISTERED_DIR = os.path.join(BASE_DIR, "emoji_registed") # 已注册的表情包注册目录
MAX_EMOJI_FOR_PROMPT = 20 # 最大允许的表情包描述数量于图片替换的 prompt 中
"""
@@ -85,7 +84,7 @@ class MaiEmoji:
logger.debug(f"[初始化] 正在使用Pillow获取格式: {self.filename}")
try:
with Image.open(io.BytesIO(image_bytes)) as img:
self.format = img.format.lower()
self.format = img.format.lower() # type: ignore
logger.debug(f"[初始化] 格式获取成功: {self.format}")
except Exception as pil_error:
logger.error(f"[初始化错误] Pillow无法处理图片 ({self.filename}): {pil_error}")
@@ -100,7 +99,7 @@ class MaiEmoji:
logger.error(f"[初始化错误] 文件在处理过程中丢失: {self.full_path}")
self.is_deleted = True
return None
except base64.binascii.Error as b64_error:
except (binascii.Error, ValueError) as b64_error:
logger.error(f"[初始化错误] Base64解码失败 ({self.filename}): {b64_error}")
self.is_deleted = True
return None
@@ -113,7 +112,7 @@ class MaiEmoji:
async def register_to_db(self) -> bool:
"""
注册表情包
将表情包对应的文件从当前路径移动到EMOJI_REGISTED_DIR目录下
将表情包对应的文件从当前路径移动到EMOJI_REGISTERED_DIR目录下
并修改对应的实例属性,然后将表情包信息保存到数据库中
"""
try:
@@ -122,7 +121,7 @@ class MaiEmoji:
# 源路径是当前实例的完整路径 self.full_path
source_full_path = self.full_path
# 目标完整路径
destination_full_path = os.path.join(EMOJI_REGISTED_DIR, self.filename)
destination_full_path = os.path.join(EMOJI_REGISTERED_DIR, self.filename)
# 检查源文件是否存在
if not os.path.exists(source_full_path):
@@ -139,7 +138,7 @@ class MaiEmoji:
logger.debug(f"[移动] 文件从 {source_full_path} 移动到 {destination_full_path}")
# 更新实例的路径属性为新路径
self.full_path = destination_full_path
self.path = EMOJI_REGISTED_DIR
self.path = EMOJI_REGISTERED_DIR
# self.filename 保持不变
except Exception as move_error:
logger.error(f"[错误] 移动文件失败: {str(move_error)}")
@@ -202,7 +201,7 @@ class MaiEmoji:
try:
will_delete_emoji = Emoji.get(Emoji.emoji_hash == self.hash)
result = will_delete_emoji.delete_instance() # Returns the number of rows deleted.
except Emoji.DoesNotExist:
except Emoji.DoesNotExist: # type: ignore
logger.warning(f"[删除] 数据库中未找到哈希值为 {self.hash} 的表情包记录。")
result = 0 # Indicate no DB record was deleted
@@ -298,7 +297,7 @@ def _to_emoji_objects(data: Any) -> Tuple[List["MaiEmoji"], int]:
def _ensure_emoji_dir() -> None:
"""确保表情存储目录存在"""
os.makedirs(EMOJI_DIR, exist_ok=True)
os.makedirs(EMOJI_REGISTED_DIR, exist_ok=True)
os.makedirs(EMOJI_REGISTERED_DIR, exist_ok=True)
async def clear_temp_emoji() -> None:
@@ -331,10 +330,10 @@ async def clean_unused_emojis(emoji_dir: str, emoji_objects: List["MaiEmoji"], r
logger.warning(f"[清理] 目标目录不存在,跳过清理: {emoji_dir}")
return removed_count
cleaned_count = 0
try:
# 获取内存中所有有效表情包的完整路径集合
tracked_full_paths = {emoji.full_path for emoji in emoji_objects if not emoji.is_deleted}
cleaned_count = 0
# 遍历指定目录中的所有文件
for file_name in os.listdir(emoji_dir):
@@ -358,11 +357,11 @@ async def clean_unused_emojis(emoji_dir: str, emoji_objects: List["MaiEmoji"], r
else:
logger.info(f"[清理] 目录 {emoji_dir} 中没有需要清理的。")
return removed_count + cleaned_count
except Exception as e:
logger.error(f"[错误] 清理未使用表情包文件时出错 ({emoji_dir}): {str(e)}")
return removed_count + cleaned_count
class EmojiManager:
_instance = None
@@ -414,7 +413,7 @@ class EmojiManager:
emoji_update.usage_count += 1
emoji_update.last_used_time = time.time() # Update last used time
emoji_update.save() # Persist changes to DB
except Emoji.DoesNotExist:
except Emoji.DoesNotExist: # type: ignore
logger.error(f"记录表情使用失败: 未找到 hash 为 {emoji_hash} 的表情包")
except Exception as e:
logger.error(f"记录表情使用失败: {str(e)}")
@@ -570,8 +569,8 @@ class EmojiManager:
if objects_to_remove:
self.emoji_objects = [e for e in self.emoji_objects if e not in objects_to_remove]
# 清理 EMOJI_REGISTED_DIR 目录中未被追踪的文件
removed_count = await clean_unused_emojis(EMOJI_REGISTED_DIR, self.emoji_objects, removed_count)
# 清理 EMOJI_REGISTERED_DIR 目录中未被追踪的文件
removed_count = await clean_unused_emojis(EMOJI_REGISTERED_DIR, self.emoji_objects, removed_count)
# 输出清理结果
if removed_count > 0:
@@ -850,11 +849,13 @@ class EmojiManager:
if isinstance(image_base64, str):
image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii")
image_bytes = base64.b64decode(image_base64)
image_format = Image.open(io.BytesIO(image_bytes)).format.lower()
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore
# 调用AI获取描述
if image_format == "gif" or image_format == "GIF":
image_base64 = get_image_manager().transform_gif(image_base64)
image_base64 = get_image_manager().transform_gif(image_base64) # type: ignore
if not image_base64:
raise RuntimeError("GIF表情包转换失败")
prompt = "这是一个动态图表情包,每一张图代表了动态图的某一帧,黑色背景代表透明,描述一下表情包表达的情感和内容,描述细节,从互联网梗,meme的角度去分析"
description, _ = await self.vlm.generate_response_for_image(prompt, image_base64, "jpg")
else:

View File

@@ -1,14 +1,16 @@
import time
import random
import json
import os
from typing import List, Dict, Optional, Any, Tuple
from src.common.logger import get_logger
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_random, build_anonymous_messages
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
import os
from src.chat.message_receive.chat_stream import get_chat_manager
import json
MAX_EXPRESSION_COUNT = 300
@@ -74,7 +76,8 @@ class ExpressionLearner:
)
self.llm_model = None
def get_expression_by_chat_id(self, chat_id: str) -> Tuple[List[Dict[str, str]], List[Dict[str, str]]]:
def get_expression_by_chat_id(self, chat_id: str) -> Tuple[List[Dict[str, float]], List[Dict[str, float]]]:
# sourcery skip: extract-duplicate-method, remove-unnecessary-cast
"""
获取指定chat_id的style和grammar表达方式
返回的每个表达方式字典中都包含了source_id, 用于后续的更新操作
@@ -119,10 +122,10 @@ class ExpressionLearner:
min_len = min(len(s1), len(s2))
if min_len < 5:
return False
same = sum(1 for a, b in zip(s1, s2, strict=False) if a == b)
same = sum(a == b for a, b in zip(s1, s2, strict=False))
return same / min_len > 0.8
async def learn_and_store_expression(self) -> List[Tuple[str, str, str]]:
async def learn_and_store_expression(self) -> Tuple[List[Tuple[str, str, str]], List[Tuple[str, str, str]]]:
"""
学习并存储表达方式,分别学习语言风格和句法特点
同时对所有已存储的表达方式进行全局衰减
@@ -158,12 +161,12 @@ class ExpressionLearner:
for _ in range(3):
learnt_style: Optional[List[Tuple[str, str, str]]] = await self.learn_and_store(type="style", num=25)
if not learnt_style:
return []
return [], []
for _ in range(1):
learnt_grammar: Optional[List[Tuple[str, str, str]]] = await self.learn_and_store(type="grammar", num=10)
if not learnt_grammar:
return []
return [], []
return learnt_style, learnt_grammar
@@ -214,6 +217,7 @@ class ExpressionLearner:
return result
async def learn_and_store(self, type: str, num: int = 10) -> List[Tuple[str, str, str]]:
# sourcery skip: use-join
"""
选择从当前到最近1小时内的随机num条消息然后学习这些消息的表达方式
type: "style" or "grammar"
@@ -249,7 +253,7 @@ class ExpressionLearner:
return []
# 按chat_id分组
chat_dict: Dict[str, List[Dict[str, str]]] = {}
chat_dict: Dict[str, List[Dict[str, Any]]] = {}
for chat_id, situation, style in learnt_expressions:
if chat_id not in chat_dict:
chat_dict[chat_id] = []

View File

@@ -1,14 +1,16 @@
from .expression_learner import get_expression_learner
import random
from typing import List, Dict, Tuple
from json_repair import repair_json
import json
import os
import time
import random
from typing import List, Dict, Tuple, Optional
from json_repair import repair_json
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config
from src.common.logger import get_logger
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from .expression_learner import get_expression_learner
logger = get_logger("expression_selector")
@@ -82,6 +84,7 @@ class ExpressionSelector:
def get_random_expressions(
self, chat_id: str, total_num: int, style_percentage: float, grammar_percentage: float
) -> Tuple[List[Dict[str, str]], List[Dict[str, str]]]:
# sourcery skip: extract-duplicate-method, move-assign
(
learnt_style_expressions,
learnt_grammar_expressions,
@@ -165,8 +168,14 @@ class ExpressionSelector:
logger.error(f"批量更新表达方式count失败 for {file_path}: {e}")
async def select_suitable_expressions_llm(
self, chat_id: str, chat_info: str, max_num: int = 10, min_num: int = 5, target_message: str = None
self,
chat_id: str,
chat_info: str,
max_num: int = 10,
min_num: int = 5,
target_message: Optional[str] = None,
) -> List[Dict[str, str]]:
# sourcery skip: inline-variable, list-comprehension
"""使用LLM选择适合的表达方式"""
# 1. 获取35个随机表达方式现在按权重抽取

View File

@@ -1,23 +1,25 @@
import asyncio
import time
import traceback
from typing import Optional, List
from src.chat.message_receive.chat_stream import get_chat_manager
import random
from typing import List, Optional, Dict, Any
from rich.traceback import install
from src.chat.utils.prompt_builder import global_prompt_manager
from src.config.config import global_config
from src.common.logger import get_logger
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
from src.chat.utils.prompt_builder import global_prompt_manager
from src.chat.utils.timer_calculator import Timer
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat
from src.chat.planner_actions.planner import ActionPlanner
from src.chat.planner_actions.action_modifier import ActionModifier
from src.chat.planner_actions.action_manager import ActionManager
from src.config.config import global_config
from src.person_info.relationship_builder_manager import relationship_builder_manager
from src.chat.focus_chat.hfc_utils import CycleDetail
import random
from src.chat.focus_chat.hfc_utils import get_recent_message_stats
from src.person_info.relationship_builder_manager import relationship_builder_manager
from src.person_info.person_info import get_person_info_manager
from src.plugin_system.apis import generator_api,send_api,message_api
from src.plugin_system.base.component_types import ActionInfo, ChatMode
from src.plugin_system.apis import generator_api, send_api, message_api
from src.chat.willing.willing_manager import get_willing_manager
from ...mais4u.mais4u_chat.priority_manager import PriorityManager
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat
@@ -79,13 +81,15 @@ class HeartFChatting:
"""
# 基础属性
self.stream_id: str = chat_id # 聊天流ID
self.chat_stream = get_chat_manager().get_stream(self.stream_id)
self.chat_stream: ChatStream = get_chat_manager().get_stream(self.stream_id) # type: ignore
if not self.chat_stream:
raise ValueError(f"无法找到聊天流: {self.stream_id}")
self.log_prefix = f"[{get_chat_manager().get_stream_name(self.stream_id) or self.stream_id}]"
self.relationship_builder = relationship_builder_manager.get_or_create_builder(self.stream_id)
self.loop_mode = "normal"
# 新增:消息计数器和疲惫阈值
self._message_count = 0 # 发送的消息计数
# 基于exit_focus_threshold动态计算疲惫阈值
@@ -93,7 +97,6 @@ class HeartFChatting:
self._message_threshold = max(10, int(30 * global_config.chat.exit_focus_threshold))
self._fatigue_triggered = False # 是否已触发疲惫退出
self.action_manager = ActionManager()
self.action_planner = ActionPlanner(chat_id=self.stream_id, action_manager=self.action_manager)
self.action_modifier = ActionModifier(action_manager=self.action_manager, chat_id=self.stream_id)
@@ -105,18 +108,16 @@ class HeartFChatting:
# 添加循环信息管理相关的属性
self.history_loop: List[CycleDetail] = []
self._cycle_counter = 0
self._current_cycle_detail: Optional[CycleDetail] = None
self._current_cycle_detail: CycleDetail = None # type: ignore
self.reply_timeout_count = 0
self.plan_timeout_count = 0
self.last_read_time = time.time()-1
self.last_read_time = time.time() - 1
self.willing_amplifier = 1
self.willing_manager = get_willing_manager()
self.reply_mode = self.chat_stream.context.get_priority_mode()
if self.reply_mode == "priority":
self.priority_manager = PriorityManager(
@@ -125,13 +126,11 @@ class HeartFChatting:
self.loop_mode = "priority"
else:
self.priority_manager = None
logger.info(
f"{self.log_prefix} HeartFChatting 初始化完成,消息疲惫阈值: {self._message_threshold}基于exit_focus_threshold={global_config.chat.exit_focus_threshold}计算仅在auto模式下生效"
)
self.energy_value = 100
async def start(self):
@@ -160,100 +159,97 @@ class HeartFChatting:
def _handle_loop_completion(self, task: asyncio.Task):
"""当 _hfc_loop 任务完成时执行的回调。"""
try:
exception = task.exception()
if exception:
if exception := task.exception():
logger.error(f"{self.log_prefix} HeartFChatting: 脱离了聊天(异常): {exception}")
logger.error(traceback.format_exc()) # Log full traceback for exceptions
else:
logger.info(f"{self.log_prefix} HeartFChatting: 脱离了聊天 (外部停止)")
except asyncio.CancelledError:
logger.info(f"{self.log_prefix} HeartFChatting: 结束了聊天")
def start_cycle(self):
self._cycle_counter += 1
self._current_cycle_detail = CycleDetail(self._cycle_counter)
self._current_cycle_detail.thinking_id = "tid" + str(round(time.time(), 2))
self._current_cycle_detail.thinking_id = f"tid{str(round(time.time(), 2))}"
cycle_timers = {}
return cycle_timers, self._current_cycle_detail.thinking_id
def end_cycle(self,loop_info,cycle_timers):
def end_cycle(self, loop_info, cycle_timers):
self._current_cycle_detail.set_loop_info(loop_info)
self.history_loop.append(self._current_cycle_detail)
self._current_cycle_detail.timers = cycle_timers
self._current_cycle_detail.end_time = time.time()
def print_cycle_info(self,cycle_timers):
# 记录循环信息和计时器结果
timer_strings = []
for name, elapsed in cycle_timers.items():
formatted_time = f"{elapsed * 1000:.2f}毫秒" if elapsed < 1 else f"{elapsed:.2f}"
timer_strings.append(f"{name}: {formatted_time}")
logger.info(
f"{self.log_prefix}{self._current_cycle_detail.cycle_id}次思考,"
f"耗时: {self._current_cycle_detail.end_time - self._current_cycle_detail.start_time:.1f}秒, "
f"选择动作: {self._current_cycle_detail.loop_plan_info.get('action_result', {}).get('action_type', '未知动作')}"
+ (f"\n详情: {'; '.join(timer_strings)}" if timer_strings else "")
)
def print_cycle_info(self, cycle_timers):
# 记录循环信息和计时器结果
timer_strings = []
for name, elapsed in cycle_timers.items():
formatted_time = f"{elapsed * 1000:.2f}毫秒" if elapsed < 1 else f"{elapsed:.2f}"
timer_strings.append(f"{name}: {formatted_time}")
logger.info(
f"{self.log_prefix}{self._current_cycle_detail.cycle_id}次思考,"
f"耗时: {self._current_cycle_detail.end_time - self._current_cycle_detail.start_time:.1f}秒, " # type: ignore
f"选择动作: {self._current_cycle_detail.loop_plan_info.get('action_result', {}).get('action_type', '未知动作')}"
+ (f"\n详情: {'; '.join(timer_strings)}" if timer_strings else "")
)
async def _loopbody(self):
if self.loop_mode == "focus":
self.energy_value -= 5 * (1/global_config.chat.exit_focus_threshold)
self.energy_value -= 5 * (1 / global_config.chat.exit_focus_threshold)
if self.energy_value <= 0:
self.loop_mode = "normal"
return True
return await self._observe()
elif self.loop_mode == "normal":
new_messages_data = get_raw_msg_by_timestamp_with_chat(
chat_id=self.stream_id, timestamp_start=self.last_read_time, timestamp_end=time.time(),limit=10,limit_mode="earliest",fliter_bot=True
chat_id=self.stream_id,
timestamp_start=self.last_read_time,
timestamp_end=time.time(),
limit=10,
limit_mode="earliest",
filter_bot=True,
)
if len(new_messages_data) > 4 * global_config.chat.auto_focus_threshold:
self.loop_mode = "focus"
self.energy_value = 100
return True
if new_messages_data:
earliest_messages_data = new_messages_data[0]
self.last_read_time = earliest_messages_data.get("time")
await self.normal_response(earliest_messages_data)
return True
await asyncio.sleep(1)
return True
async def build_reply_to_str(self,message_data:dict):
async def build_reply_to_str(self, message_data: dict):
person_info_manager = get_person_info_manager()
person_id = person_info_manager.get_person_id(
message_data.get("chat_info_platform"), message_data.get("user_id")
message_data.get("chat_info_platform"), # type: ignore
message_data.get("user_id"), # type: ignore
)
person_name = await person_info_manager.get_value(person_id, "person_name")
reply_to_str = f"{person_name}:{message_data.get('processed_plain_text')}"
return reply_to_str
return f"{person_name}:{message_data.get('processed_plain_text')}"
async def _observe(self,message_data:dict = None):
async def _observe(self, message_data: Optional[Dict[str, Any]] = None):
if not message_data:
message_data = {}
# 创建新的循环信息
cycle_timers, thinking_id = self.start_cycle()
logger.info(f"{self.log_prefix} 开始第{self._cycle_counter}次思考[模式:{self.loop_mode}]")
async with global_prompt_manager.async_message_scope(
self.chat_stream.context.get_template_name()
):
async with global_prompt_manager.async_message_scope(self.chat_stream.context.get_template_name()):
loop_start_time = time.time()
# await self.loop_info.observe()
await self.relationship_builder.build_relation()
# 第一步:动作修改
with Timer("动作修改", cycle_timers):
try:
@@ -261,18 +257,15 @@ class HeartFChatting:
available_actions = self.action_manager.get_using_actions()
except Exception as e:
logger.error(f"{self.log_prefix} 动作修改失败: {e}")
#如果normal开始一个回复生成进程先准备好回复其实是和planer同时进行的
# 如果normal开始一个回复生成进程先准备好回复其实是和planer同时进行的
if self.loop_mode == "normal":
reply_to_str = await self.build_reply_to_str(message_data)
gen_task = asyncio.create_task(self._generate_response(message_data, available_actions,reply_to_str))
gen_task = asyncio.create_task(self._generate_response(message_data, available_actions, reply_to_str))
with Timer("规划器", cycle_timers):
plan_result = await self.action_planner.plan(mode=self.loop_mode)
action_result = plan_result.get("action_result", {})
action_type, action_data, reasoning, is_parallel = (
action_result.get("action_type", "error"),
@@ -282,7 +275,7 @@ class HeartFChatting:
)
action_data["loop_start_time"] = loop_start_time
if self.loop_mode == "normal":
if action_type == "no_action":
logger.info(f"[{self.log_prefix}] {global_config.bot.nickname} 决定进行回复")
@@ -293,8 +286,6 @@ class HeartFChatting:
else:
logger.info(f"[{self.log_prefix}] {global_config.bot.nickname} 决定执行{action_type}动作")
if action_type == "no_action":
# 等待回复生成完毕
gather_timeout = global_config.chat.thinking_timeout
@@ -307,27 +298,22 @@ class HeartFChatting:
content = " ".join([item[1] for item in response_set if item[0] == "text"])
# 模型炸了,没有回复内容生成
if not response_set or (
action_type not in ["no_action"] and not is_parallel
):
if not response_set:
logger.warning(f"[{self.log_prefix}] 模型未生成回复内容")
elif action_type not in ["no_action"] and not is_parallel:
logger.info(
f"[{self.log_prefix}] {global_config.bot.nickname} 原本想要回复:{content},但选择执行{action_type},不发表回复"
)
if not response_set:
logger.warning(f"[{self.log_prefix}] 模型未生成回复内容")
return False
elif action_type not in ["no_action"] and not is_parallel:
logger.info(
f"[{self.log_prefix}] {global_config.bot.nickname} 原本想要回复:{content},但选择执行{action_type},不发表回复"
)
return False
logger.info(f"[{self.log_prefix}] {global_config.bot.nickname} 决定的回复内容: {content}")
# 发送回复 (不再需要传入 chat)
await self._send_response(response_set, reply_to_str, loop_start_time)
return True
else:
# 动作执行计时
with Timer("动作执行", cycle_timers):
@@ -350,18 +336,16 @@ class HeartFChatting:
if loop_info["loop_action_info"]["command"] == "stop_focus_chat":
logger.info(f"{self.log_prefix} 麦麦决定停止专注聊天")
return False
#停止该聊天模式的循环
# 停止该聊天模式的循环
self.end_cycle(loop_info,cycle_timers)
self.end_cycle(loop_info, cycle_timers)
self.print_cycle_info(cycle_timers)
if self.loop_mode == "normal":
await self.willing_manager.after_generate_reply_handle(message_data.get("message_id"))
await self.willing_manager.after_generate_reply_handle(message_data.get("message_id", ""))
return True
async def _main_chat_loop(self):
"""主循环,持续进行计划并可能回复消息,直到被外部取消。"""
try:
@@ -370,7 +354,7 @@ class HeartFChatting:
await asyncio.sleep(0.1)
if not success:
break
logger.info(f"{self.log_prefix} 麦麦已强制离开聊天")
except asyncio.CancelledError:
# 设置了关闭标志位后被取消是正常流程
@@ -425,12 +409,9 @@ class HeartFChatting:
# 处理动作并获取结果
result = await action_handler.handle_action()
if len(result) == 3:
success, reply_text, command = result
else:
success, reply_text = result
command = ""
success, reply_text = result
command = ""
if reply_text == "timeout":
self.reply_timeout_count += 1
if self.reply_timeout_count > 5:
@@ -446,32 +427,40 @@ class HeartFChatting:
logger.error(f"{self.log_prefix} 处理{action}时出错: {e}")
traceback.print_exc()
return False, "", ""
async def shutdown(self):
"""优雅关闭HeartFChatting实例取消活动循环任务"""
logger.info(f"{self.log_prefix} 正在关闭HeartFChatting...")
self.running = False # <-- 在开始关闭时设置标志位
# 取消循环任务
if self._loop_task and not self._loop_task.done():
logger.info(f"{self.log_prefix} 正在取消HeartFChatting循环任务")
self._loop_task.cancel()
try:
await asyncio.wait_for(self._loop_task, timeout=1.0)
logger.info(f"{self.log_prefix} HeartFChatting循环任务已取消")
except (asyncio.CancelledError, asyncio.TimeoutError):
pass
except Exception as e:
logger.error(f"{self.log_prefix} 取消循环任务出错: {e}")
else:
logger.info(f"{self.log_prefix} 没有活动的HeartFChatting循环任务")
# async def shutdown(self):
# """优雅关闭HeartFChatting实例取消活动循环任务"""
# logger.info(f"{self.log_prefix} 正在关闭HeartFChatting...")
# self.running = False # <-- 在开始关闭时设置标志位
# 清理状态
self.running = False
self._loop_task = None
logger.info(f"{self.log_prefix} HeartFChatting关闭完成")
# # 记录最终的消息统计
# if self._message_count > 0:
# logger.info(f"{self.log_prefix} 本次focus会话共发送了 {self._message_count} 条消息")
# if self._fatigue_triggered:
# logger.info(f"{self.log_prefix} 因疲惫而退出focus模式")
# # 取消循环任务
# if self._loop_task and not self._loop_task.done():
# logger.info(f"{self.log_prefix} 正在取消HeartFChatting循环任务")
# self._loop_task.cancel()
# try:
# await asyncio.wait_for(self._loop_task, timeout=1.0)
# logger.info(f"{self.log_prefix} HeartFChatting循环任务已取消")
# except (asyncio.CancelledError, asyncio.TimeoutError):
# pass
# except Exception as e:
# logger.error(f"{self.log_prefix} 取消循环任务出错: {e}")
# else:
# logger.info(f"{self.log_prefix} 没有活动的HeartFChatting循环任务")
# # 清理状态
# self.running = False
# self._loop_task = None
# # 重置消息计数器,为下次启动做准备
# self.reset_message_count()
# logger.info(f"{self.log_prefix} HeartFChatting关闭完成")
def adjust_reply_frequency(self):
"""
@@ -542,18 +531,16 @@ class HeartFChatting:
f"[{self.log_prefix}] 调整回复意愿。10分钟内回复: {bot_reply_count_10_min} (目标: {target_replies_in_window:.0f}) -> "
f"意愿放大器更新为: {self.willing_amplifier:.2f}"
)
async def normal_response(self, message_data: dict) -> None:
"""
处理接收到的消息。
"兴趣"模式下,判断是否回复并生成内容。
"""
is_mentioned = message_data.get("is_mentioned", False)
interested_rate = message_data.get("interest_rate", 0.0) * self.willing_amplifier
reply_probability = (
1.0 if is_mentioned and global_config.normal_chat.mentioned_bot_inevitable_reply else 0.0
) # 如果被提及且开启了提及必回复则基础概率为1否则需要意愿判断
@@ -565,7 +552,7 @@ class HeartFChatting:
# 仅在未被提及或基础概率不为1时查询意愿概率
if reply_probability < 1: # 简化逻辑,如果未提及 (reply_probability 为 0),则获取意愿概率
# is_willing = True
reply_probability = await self.willing_manager.get_reply_probability(message_data.get("message_id"))
reply_probability = await self.willing_manager.get_reply_probability(message_data.get("message_id", ""))
additional_config = message_data.get("additional_config", {})
if additional_config and "maimcore_reply_probability_gain" in additional_config:
@@ -576,7 +563,6 @@ class HeartFChatting:
if message_data.get("is_emoji") or message_data.get("is_picid"):
reply_probability = 0
# 打印消息信息
mes_name = self.chat_stream.group_info.group_name if self.chat_stream.group_info else "私聊"
if reply_probability > 0.1:
@@ -587,21 +573,18 @@ class HeartFChatting:
)
if random.random() < reply_probability:
await self.willing_manager.before_generate_reply_handle(message_data.get("message_id"))
await self._observe(message_data = message_data)
await self.willing_manager.before_generate_reply_handle(message_data.get("message_id", ""))
await self._observe(message_data=message_data)
# 意愿管理器注销当前message信息 (无论是否回复,只要处理过就删除)
self.willing_manager.delete(message_data.get("message_id"))
return True
self.willing_manager.delete(message_data.get("message_id", ""))
async def _generate_response(
self, message_data: dict, available_actions: Optional[list],reply_to:str
self, message_data: dict, available_actions: Optional[Dict[str, ActionInfo]], reply_to: str
) -> Optional[list]:
"""生成普通回复"""
try:
success, reply_set = await generator_api.generate_reply(
success, reply_set, _ = await generator_api.generate_reply(
chat_stream=self.chat_stream,
reply_to=reply_to,
available_actions=available_actions,
@@ -618,37 +601,33 @@ class HeartFChatting:
except Exception as e:
logger.error(f"[{self.log_prefix}] 回复生成出现错误:{str(e)} {traceback.format_exc()}")
return None
async def _send_response(
self, reply_set, reply_to, thinking_start_time
):
async def _send_response(self, reply_set, reply_to, thinking_start_time):
current_time = time.time()
new_message_count = message_api.count_new_messages(
chat_id=self.chat_stream.stream_id, start_time=thinking_start_time, end_time=current_time
)
need_reply = new_message_count >= random.randint(2, 4)
logger.info(
f"{self.log_prefix} 从思考到回复,共有{new_message_count}条新消息,{'使用' if need_reply else '不使用'}引用回复"
)
reply_text = ""
first_replyed = False
for reply_seg in reply_set:
data = reply_seg[1]
if not first_replyed:
if need_reply:
await send_api.text_to_stream(text=data, stream_id=self.chat_stream.stream_id, reply_to=reply_to, typing=False)
first_replyed = True
await send_api.text_to_stream(
text=data, stream_id=self.chat_stream.stream_id, reply_to=reply_to, typing=False
)
else:
await send_api.text_to_stream(text=data, stream_id=self.chat_stream.stream_id, typing=False)
first_replyed = True
first_replyed = True
else:
await send_api.text_to_stream(text=data, stream_id=self.chat_stream.stream_id, typing=True)
reply_text += data
return reply_text

View File

@@ -1,11 +1,10 @@
import time
from typing import Optional
from src.common.logger import get_logger
from typing import Dict, Any
from typing import Optional, Dict, Any
from src.config.config import global_config
from src.common.message_repository import count_messages
from src.common.logger import get_logger
logger = get_logger(__name__)
@@ -82,11 +81,10 @@ class CycleDetail:
self.loop_action_info = loop_info["loop_action_info"]
def get_recent_message_stats(minutes: int = 30, chat_id: str = None) -> dict:
def get_recent_message_stats(minutes: float = 30, chat_id: Optional[str] = None) -> dict:
"""
Args:
minutes (int): 检索的分钟数默认30分钟
minutes (float): 检索的分钟数默认30分钟
chat_id (str, optional): 指定的chat_id仅统计该chat下的消息。为None时统计全部。
Returns:
dict: {"bot_reply_count": int, "total_message_count": int}
@@ -96,7 +94,7 @@ def get_recent_message_stats(minutes: int = 30, chat_id: str = None) -> dict:
start_time = now - minutes * 60
bot_id = global_config.bot.qq_account
filter_base = {"time": {"$gte": start_time}}
filter_base: Dict[str, Any] = {"time": {"$gte": start_time}}
if chat_id is not None:
filter_base["chat_id"] = chat_id

View File

@@ -1,8 +1,8 @@
import traceback
from src.chat.heart_flow.sub_heartflow import SubHeartflow
from typing import Any, Optional, Dict
from src.common.logger import get_logger
from typing import Any, Optional
from typing import Dict
from src.chat.heart_flow.sub_heartflow import SubHeartflow
from src.chat.message_receive.chat_stream import get_chat_manager
logger = get_logger("heartflow")
@@ -17,14 +17,11 @@ class Heartflow:
async def get_or_create_subheartflow(self, subheartflow_id: Any) -> Optional["SubHeartflow"]:
"""获取或创建一个新的SubHeartflow实例"""
if subheartflow_id in self.subheartflows:
subflow = self.subheartflows.get(subheartflow_id)
if subflow:
if subflow := self.subheartflows.get(subheartflow_id):
return subflow
try:
new_subflow = SubHeartflow(
subheartflow_id,
)
new_subflow = SubHeartflow(subheartflow_id)
await new_subflow.initialize()

View File

@@ -1,21 +1,23 @@
from src.chat.memory_system.Hippocampus import hippocampus_manager
from src.config.config import global_config
import asyncio
from src.chat.message_receive.message import MessageRecv
from src.chat.message_receive.storage import MessageStorage
from src.chat.heart_flow.heartflow import heartflow
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.utils.utils import is_mentioned_bot_in_message
from src.chat.utils.timer_calculator import Timer
from src.common.logger import get_logger
import re
import math
import traceback
from typing import Tuple
from typing import Tuple, TYPE_CHECKING
from src.config.config import global_config
from src.chat.memory_system.Hippocampus import hippocampus_manager
from src.chat.message_receive.message import MessageRecv
from src.chat.message_receive.storage import MessageStorage
from src.chat.heart_flow.heartflow import heartflow
from src.chat.utils.utils import is_mentioned_bot_in_message
from src.chat.utils.timer_calculator import Timer
from src.common.logger import get_logger
from src.person_info.relationship_manager import get_relationship_manager
from src.mood.mood_manager import mood_manager
if TYPE_CHECKING:
from src.chat.heart_flow.sub_heartflow import SubHeartflow
logger = get_logger("chat")
@@ -27,16 +29,16 @@ async def _process_relationship(message: MessageRecv) -> None:
message: 消息对象,包含用户信息
"""
platform = message.message_info.platform
user_id = message.message_info.user_info.user_id
nickname = message.message_info.user_info.user_nickname
cardname = message.message_info.user_info.user_cardname or nickname
user_id = message.message_info.user_info.user_id # type: ignore
nickname = message.message_info.user_info.user_nickname # type: ignore
cardname = message.message_info.user_info.user_cardname or nickname # type: ignore
relationship_manager = get_relationship_manager()
is_known = await relationship_manager.is_known_some_one(platform, user_id)
if not is_known:
logger.info(f"首次认识用户: {nickname}")
await relationship_manager.first_knowing_some_one(platform, user_id, nickname, cardname)
await relationship_manager.first_knowing_some_one(platform, user_id, nickname, cardname) # type: ignore
async def _calculate_interest(message: MessageRecv) -> Tuple[float, bool]:
@@ -96,31 +98,24 @@ class HeartFCMessageReceiver:
"""
try:
# 1. 消息解析与初始化
groupinfo = message.message_info.group_info
userinfo = message.message_info.user_info
messageinfo = message.message_info
chat = await get_chat_manager().get_or_create_stream(
platform=messageinfo.platform,
user_info=userinfo,
group_info=groupinfo,
)
chat = message.chat_stream
# 2. 兴趣度计算与更新
interested_rate, is_mentioned = await _calculate_interest(message)
message.interest_value = interested_rate
message.is_mentioned = is_mentioned
await self.storage.store_message(message, chat)
subheartflow = await heartflow.get_or_create_subheartflow(chat.stream_id)
message.update_chat_stream(chat)
subheartflow: SubHeartflow = await heartflow.get_or_create_subheartflow(chat.stream_id) # type: ignore
# subheartflow.add_message_to_normal_chat_cache(message, interested_rate, is_mentioned)
chat_mood = mood_manager.get_mood_by_chat_id(subheartflow.chat_id)
chat_mood = mood_manager.get_mood_by_chat_id(subheartflow.chat_id) # type: ignore
asyncio.create_task(chat_mood.update_mood_by_message(message, interested_rate))
# 7. 日志记录
# 3. 日志记录
mes_name = chat.group_info.group_name if chat.group_info else "私聊"
# current_time = time.strftime("%H:%M:%S", time.localtime(message.message_info.time))
current_talk_frequency = global_config.chat.get_current_talk_frequency(chat.stream_id)
@@ -129,11 +124,11 @@ class HeartFCMessageReceiver:
picid_pattern = r"\[picid:([^\]]+)\]"
processed_plain_text = re.sub(picid_pattern, "[图片]", message.processed_plain_text)
logger.info(f"[{mes_name}]{userinfo.user_nickname}:{processed_plain_text}")
logger.info(f"[{mes_name}]{userinfo.user_nickname}:{processed_plain_text}") # type: ignore
logger.debug(f"[{mes_name}][当前时段回复频率: {current_talk_frequency}]")
# 8. 关系处理
# 4. 关系处理
if global_config.relationship.enable_relationship:
await _process_relationship(message)

View File

@@ -1,13 +1,9 @@
import asyncio
import time
from typing import Optional
import traceback
from rich.traceback import install
from src.common.logger import get_logger
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.focus_chat.heartFC_chat import HeartFChatting
from src.chat.utils.utils import get_chat_type_and_target_info
from src.config.config import global_config
from rich.traceback import install
logger = get_logger("sub_heartflow")
@@ -28,108 +24,105 @@ class SubHeartflow:
self.subheartflow_id = subheartflow_id
self.chat_id = subheartflow_id
self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_id)
self.log_prefix = get_chat_manager().get_stream_name(self.subheartflow_id) or self.subheartflow_id
# focus模式退出冷却时间管理
self.last_focus_exit_time: float = 0 # 上次退出focus模式的时间
# 随便水群 normal_chat 和 认真水群 focus_chat 实例
# CHAT模式激活 随便水群 FOCUS模式激活 认真水群
self.heart_fc_instance: Optional[HeartFChatting] = HeartFChatting(
chat_id=self.subheartflow_id,
) # 该sub_heartflow的HeartFChatting实例
self.heart_fc_instance: HeartFChatting = HeartFChatting(
chat_id=self.subheartflow_id,
) # 该sub_heartflow的HeartFChatting实例
async def initialize(self):
"""异步初始化方法,创建兴趣流并确定聊天类型"""
await self.heart_fc_instance.start()
async def _stop_heart_fc_chat(self):
"""停止并清理 HeartFChatting 实例"""
if self.heart_fc_instance.running:
logger.info(f"{self.log_prefix} 结束专注聊天...")
try:
await self.heart_fc_instance.shutdown()
except Exception as e:
logger.error(f"{self.log_prefix} 关闭 HeartFChatting 实例时出错: {e}")
logger.error(traceback.format_exc())
else:
logger.info(f"{self.log_prefix} 没有专注聊天实例,无需停止专注聊天")
# async def _stop_heart_fc_chat(self):
# """停止并清理 HeartFChatting 实例"""
# if self.heart_fc_instance.running:
# logger.info(f"{self.log_prefix} 结束专注聊天...")
# try:
# await self.heart_fc_instance.shutdown()
# except Exception as e:
# logger.error(f"{self.log_prefix} 关闭 HeartFChatting 实例时出错: {e}")
# logger.error(traceback.format_exc())
# else:
# logger.info(f"{self.log_prefix} 没有专注聊天实例,无需停止专注聊天")
async def _start_heart_fc_chat(self) -> bool:
"""启动 HeartFChatting 实例,确保 NormalChat 已停止"""
try:
# 如果任务已完成或不存在,则尝试重新启动
if self.heart_fc_instance._loop_task is None or self.heart_fc_instance._loop_task.done():
logger.info(f"{self.log_prefix} HeartFChatting 实例存在但循环未运行,尝试启动...")
try:
# 添加超时保护
await asyncio.wait_for(self.heart_fc_instance.start(), timeout=15.0)
logger.info(f"{self.log_prefix} HeartFChatting 循环已启动。")
return True
except Exception as e:
logger.error(f"{self.log_prefix} 尝试启动现有 HeartFChatting 循环时出错: {e}")
logger.error(traceback.format_exc())
# 出错时清理实例,准备重新创建
self.heart_fc_instance = None
else:
# 任务正在运行
logger.debug(f"{self.log_prefix} HeartFChatting 已在运行中。")
return True #在运行
# async def _start_heart_fc_chat(self) -> bool:
# """启动 HeartFChatting 实例,确保 NormalChat 已停止"""
# try:
# # 如果任务已完成或不存在,则尝试重新启动
# if self.heart_fc_instance._loop_task is None or self.heart_fc_instance._loop_task.done():
# logger.info(f"{self.log_prefix} HeartFChatting 实例存在但循环未运行,尝试启动...")
# try:
# # 添加超时保护
# await asyncio.wait_for(self.heart_fc_instance.start(), timeout=15.0)
# logger.info(f"{self.log_prefix} HeartFChatting 循环已启动。")
# return True
# except Exception as e:
# logger.error(f"{self.log_prefix} 尝试启动现有 HeartFChatting 循环时出错: {e}")
# logger.error(traceback.format_exc())
# # 出错时清理实例,准备重新创建
# self.heart_fc_instance = None # type: ignore
# return False
# else:
# # 任务正在运行
# logger.debug(f"{self.log_prefix} HeartFChatting 已在运行中。")
# return True # 已经在运行
except Exception as e:
logger.error(f"{self.log_prefix} _start_heart_fc_chat 执行时出错: {e}")
logger.error(traceback.format_exc())
return False
# except Exception as e:
# logger.error(f"{self.log_prefix} _start_heart_fc_chat 执行时出错: {e}")
# logger.error(traceback.format_exc())
# return False
def is_in_focus_cooldown(self) -> bool:
"""检查是否在focus模式的冷却期内
# def is_in_focus_cooldown(self) -> bool:
# """检查是否在focus模式的冷却期内
Returns:
bool: 如果在冷却期内返回True否则返回False
"""
if self.last_focus_exit_time == 0:
return False
# Returns:
# bool: 如果在冷却期内返回True否则返回False
# """
# if self.last_focus_exit_time == 0:
# return False
# 基础冷却时间10分钟受auto_focus_threshold调控
base_cooldown = 10 * 60 # 10分钟转换为秒
cooldown_duration = base_cooldown / global_config.chat.auto_focus_threshold
# # 基础冷却时间10分钟受auto_focus_threshold调控
# base_cooldown = 10 * 60 # 10分钟转换为秒
# cooldown_duration = base_cooldown / global_config.chat.auto_focus_threshold
current_time = time.time()
elapsed_since_exit = current_time - self.last_focus_exit_time
# current_time = time.time()
# elapsed_since_exit = current_time - self.last_focus_exit_time
is_cooling = elapsed_since_exit < cooldown_duration
# is_cooling = elapsed_since_exit < cooldown_duration
if is_cooling:
remaining_time = cooldown_duration - elapsed_since_exit
remaining_minutes = remaining_time / 60
logger.debug(
f"[{self.log_prefix}] focus冷却中剩余时间: {remaining_minutes:.1f}分钟 (阈值: {global_config.chat.auto_focus_threshold})"
)
# if is_cooling:
# remaining_time = cooldown_duration - elapsed_since_exit
# remaining_minutes = remaining_time / 60
# logger.debug(
# f"[{self.log_prefix}] focus冷却中剩余时间: {remaining_minutes:.1f}分钟 (阈值: {global_config.chat.auto_focus_threshold})"
# )
return is_cooling
# return is_cooling
def get_cooldown_progress(self) -> float:
"""获取冷却进度返回0-1之间的值
# def get_cooldown_progress(self) -> float:
# """获取冷却进度返回0-1之间的值
Returns:
float: 0表示刚开始冷却1表示冷却完成
"""
if self.last_focus_exit_time == 0:
return 1.0 # 没有冷却返回1表示完全恢复
# Returns:
# float: 0表示刚开始冷却1表示冷却完成
# """
# if self.last_focus_exit_time == 0:
# return 1.0 # 没有冷却返回1表示完全恢复
# 基础冷却时间10分钟受auto_focus_threshold调控
base_cooldown = 10 * 60 # 10分钟转换为秒
cooldown_duration = base_cooldown / global_config.chat.auto_focus_threshold
# # 基础冷却时间10分钟受auto_focus_threshold调控
# base_cooldown = 10 * 60 # 10分钟转换为秒
# cooldown_duration = base_cooldown / global_config.chat.auto_focus_threshold
current_time = time.time()
elapsed_since_exit = current_time - self.last_focus_exit_time
# current_time = time.time()
# elapsed_since_exit = current_time - self.last_focus_exit_time
if elapsed_since_exit >= cooldown_duration:
return 1.0 # 冷却完成
# if elapsed_since_exit >= cooldown_duration:
# return 1.0 # 冷却完成
# 计算进度0表示刚开始冷却1表示冷却完成
progress = elapsed_since_exit / cooldown_duration
return progress
# return elapsed_since_exit / cooldown_duration

View File

@@ -42,7 +42,7 @@ def calculate_information_content(text):
return entropy
def cosine_similarity(v1, v2):
def cosine_similarity(v1, v2): # sourcery skip: assign-if-exp, reintroduce-else
"""计算余弦相似度"""
dot_product = np.dot(v1, v2)
norm1 = np.linalg.norm(v1)
@@ -89,14 +89,13 @@ class MemoryGraph:
if not isinstance(self.G.nodes[concept]["memory_items"], list):
self.G.nodes[concept]["memory_items"] = [self.G.nodes[concept]["memory_items"]]
self.G.nodes[concept]["memory_items"].append(memory)
# 更新最后修改时间
self.G.nodes[concept]["last_modified"] = current_time
else:
self.G.nodes[concept]["memory_items"] = [memory]
# 如果节点存在但没有memory_items,说明是第一次添加memory,设置created_time
if "created_time" not in self.G.nodes[concept]:
self.G.nodes[concept]["created_time"] = current_time
self.G.nodes[concept]["last_modified"] = current_time
# 更新最后修改时间
self.G.nodes[concept]["last_modified"] = current_time
else:
# 如果是新节点,创建新的记忆列表
self.G.add_node(
@@ -108,11 +107,7 @@ class MemoryGraph:
def get_dot(self, concept):
# 检查节点是否存在于图中
if concept in self.G:
# 从图中获取节点数据
node_data = self.G.nodes[concept]
return concept, node_data
return None
return (concept, self.G.nodes[concept]) if concept in self.G else None
def get_related_item(self, topic, depth=1):
if topic not in self.G:
@@ -139,8 +134,7 @@ class MemoryGraph:
if depth >= 2:
# 获取相邻节点的记忆项
for neighbor in neighbors:
node_data = self.get_dot(neighbor)
if node_data:
if node_data := self.get_dot(neighbor):
concept, data = node_data
if "memory_items" in data:
memory_items = data["memory_items"]
@@ -194,9 +188,9 @@ class MemoryGraph:
class Hippocampus:
def __init__(self):
self.memory_graph = MemoryGraph()
self.model_summary = None
self.entorhinal_cortex = None
self.parahippocampal_gyrus = None
self.model_summary: LLMRequest = None # type: ignore
self.entorhinal_cortex: EntorhinalCortex = None # type: ignore
self.parahippocampal_gyrus: ParahippocampalGyrus = None # type: ignore
def initialize(self):
# 初始化子组件
@@ -218,7 +212,7 @@ class Hippocampus:
memory_items = [memory_items] if memory_items else []
# 使用集合来去重,避免排序
unique_items = set(str(item) for item in memory_items)
unique_items = {str(item) for item in memory_items}
# 使用frozenset来保证顺序一致性
content = f"{concept}:{frozenset(unique_items)}"
return hash(content)
@@ -231,6 +225,7 @@ class Hippocampus:
@staticmethod
def find_topic_llm(text, topic_num):
# sourcery skip: inline-immediately-returned-variable
prompt = (
f"这是一段文字:\n{text}\n\n请你从这段话中总结出最多{topic_num}个关键的概念,可以是名词,动词,或者特定人物,帮我列出来,"
f"将主题用逗号隔开,并加上<>,例如<主题1>,<主题2>......尽可能精简。只需要列举最多{topic_num}个话题就好,不要有序号,不要告诉我其他内容。"
@@ -240,6 +235,7 @@ class Hippocampus:
@staticmethod
def topic_what(text, topic):
# sourcery skip: inline-immediately-returned-variable
# 不再需要 time_info 参数
prompt = (
f'这是一段文字:\n{text}\n\n我想让你基于这段文字来概括"{topic}"这个概念,帮我总结成一句自然的话,'
@@ -480,9 +476,7 @@ class Hippocampus:
top_memories = memory_similarities[:max_memory_length]
# 添加到结果中
for memory, similarity in top_memories:
all_memories.append((node, [memory], similarity))
# logger.info(f"选中记忆: {memory} (相似度: {similarity:.2f})")
all_memories.extend((node, [memory], similarity) for memory, similarity in top_memories)
else:
logger.info("节点没有记忆")
@@ -646,9 +640,7 @@ class Hippocampus:
top_memories = memory_similarities[:max_memory_length]
# 添加到结果中
for memory, similarity in top_memories:
all_memories.append((node, [memory], similarity))
# logger.info(f"选中记忆: {memory} (相似度: {similarity:.2f})")
all_memories.extend((node, [memory], similarity) for memory, similarity in top_memories)
else:
logger.info("节点没有记忆")
@@ -823,11 +815,11 @@ class EntorhinalCortex:
logger.debug(f"回忆往事: {readable_timestamp}")
chat_samples = []
for timestamp in timestamps:
# 调用修改后的 random_get_msg_snippet
messages = self.random_get_msg_snippet(
timestamp, global_config.memory.memory_build_sample_length, max_memorized_time_per_msg
)
if messages:
if messages := self.random_get_msg_snippet(
timestamp,
global_config.memory.memory_build_sample_length,
max_memorized_time_per_msg,
):
time_diff = (datetime.datetime.now().timestamp() - timestamp) / 3600
logger.info(f"成功抽取 {time_diff:.1f} 小时前的消息样本,共{len(messages)}")
chat_samples.append(messages)
@@ -838,31 +830,30 @@ class EntorhinalCortex:
@staticmethod
def random_get_msg_snippet(target_timestamp: float, chat_size: int, max_memorized_time_per_msg: int) -> list | None:
# sourcery skip: invert-any-all, use-any, use-named-expression, use-next
"""从数据库中随机获取指定时间戳附近的消息片段 (使用 chat_message_builder)"""
try_count = 0
time_window_seconds = random.randint(300, 1800) # 随机时间窗口5到30分钟
while try_count < 3:
for _ in range(3):
# 定义时间范围:从目标时间戳开始,向后推移 time_window_seconds
timestamp_start = target_timestamp
timestamp_end = target_timestamp + time_window_seconds
chosen_message = get_raw_msg_by_timestamp(
timestamp_start=timestamp_start, timestamp_end=timestamp_end, limit=1, limit_mode="earliest"
)
if chosen_message := get_raw_msg_by_timestamp(
timestamp_start=timestamp_start,
timestamp_end=timestamp_end,
limit=1,
limit_mode="earliest",
):
chat_id: str = chosen_message[0].get("chat_id") # type: ignore
if chosen_message:
chat_id = chosen_message[0].get("chat_id")
messages = get_raw_msg_by_timestamp_with_chat(
if messages := get_raw_msg_by_timestamp_with_chat(
timestamp_start=timestamp_start,
timestamp_end=timestamp_end,
limit=chat_size,
limit_mode="earliest",
chat_id=chat_id,
)
if messages:
):
# 检查获取到的所有消息是否都未达到最大记忆次数
all_valid = True
for message in messages:
@@ -882,8 +873,6 @@ class EntorhinalCortex:
).execute()
return messages # 直接返回原始的消息列表
# 如果获取失败或消息无效,增加尝试次数
try_count += 1
target_timestamp -= 120 # 如果第一次尝试失败,稍微向前调整时间戳再试
# 三次尝试都失败,返回 None
@@ -975,7 +964,7 @@ class EntorhinalCortex:
).execute()
if nodes_to_delete:
GraphNodes.delete().where(GraphNodes.concept.in_(nodes_to_delete)).execute()
GraphNodes.delete().where(GraphNodes.concept.in_(nodes_to_delete)).execute() # type: ignore
# 处理边的信息
db_edges = list(GraphEdges.select())
@@ -1075,19 +1064,17 @@ class EntorhinalCortex:
try:
memory_items = [str(item) for item in memory_items]
memory_items_json = json.dumps(memory_items, ensure_ascii=False)
if not memory_items_json:
continue
if memory_items_json := json.dumps(memory_items, ensure_ascii=False):
nodes_data.append(
{
"concept": concept,
"memory_items": memory_items_json,
"hash": self.hippocampus.calculate_node_hash(concept, memory_items),
"created_time": data.get("created_time", current_time),
"last_modified": data.get("last_modified", current_time),
}
)
nodes_data.append(
{
"concept": concept,
"memory_items": memory_items_json,
"hash": self.hippocampus.calculate_node_hash(concept, memory_items),
"created_time": data.get("created_time", current_time),
"last_modified": data.get("last_modified", current_time),
}
)
except Exception as e:
logger.error(f"准备节点 {concept} 数据时发生错误: {e}")
continue
@@ -1114,7 +1101,7 @@ class EntorhinalCortex:
node_start = time.time()
if nodes_data:
batch_size = 500 # 增加批量大小
with GraphNodes._meta.database.atomic():
with GraphNodes._meta.database.atomic(): # type: ignore
for i in range(0, len(nodes_data), batch_size):
batch = nodes_data[i : i + batch_size]
GraphNodes.insert_many(batch).execute()
@@ -1125,7 +1112,7 @@ class EntorhinalCortex:
edge_start = time.time()
if edges_data:
batch_size = 500 # 增加批量大小
with GraphEdges._meta.database.atomic():
with GraphEdges._meta.database.atomic(): # type: ignore
for i in range(0, len(edges_data), batch_size):
batch = edges_data[i : i + batch_size]
GraphEdges.insert_many(batch).execute()
@@ -1279,7 +1266,7 @@ class ParahippocampalGyrus:
# 3. 过滤掉包含禁用关键词的topic
filtered_topics = [
topic for topic in topics if not any(keyword in topic for keyword in global_config.memory.memory_ban_words)
topic for topic in topics if all(keyword not in topic for keyword in global_config.memory.memory_ban_words)
]
logger.debug(f"过滤后话题: {filtered_topics}")
@@ -1489,32 +1476,30 @@ class ParahippocampalGyrus:
# --- 如果节点不为空,则执行原来的不活跃检查和随机移除逻辑 ---
last_modified = node_data.get("last_modified", current_time)
# 条件1检查是否长时间未修改 (超过24小时)
if current_time - last_modified > 3600 * 24:
# 条件2再次确认节点包含记忆项理论上已确认但作为保险
if memory_items:
current_count = len(memory_items)
# 如果列表非空,才进行随机选择
if current_count > 0:
removed_item = random.choice(memory_items)
try:
memory_items.remove(removed_item)
if current_time - last_modified > 3600 * 24 and memory_items:
current_count = len(memory_items)
# 如果列表非空,才进行随机选择
if current_count > 0:
removed_item = random.choice(memory_items)
try:
memory_items.remove(removed_item)
# 条件3检查移除后 memory_items 是否变空
if memory_items: # 如果移除后列表不为空
# self.memory_graph.G.nodes[node]["memory_items"] = memory_items # 直接修改列表即可
self.memory_graph.G.nodes[node]["last_modified"] = current_time # 更新修改时间
node_changes["reduced"].append(f"{node} (数量: {current_count} -> {len(memory_items)})")
else: # 如果移除后列表为空
# 尝试移除节点,处理可能的错误
try:
self.memory_graph.G.remove_node(node)
node_changes["removed"].append(f"{node}(遗忘清空)") # 标记为遗忘清空
logger.debug(f"[遗忘] 节点 {node} 因移除最后一项而被清空。")
except nx.NetworkXError as e:
logger.warning(f"[遗忘] 尝试移除节点 {node} 时发生错误(可能已被移除):{e}")
except ValueError:
# 这个错误理论上不应发生,因为 removed_item 来自 memory_items
logger.warning(f"[遗忘] 尝试从节点 '{node}' 移除不存在的项目 '{removed_item[:30]}...'")
# 条件3检查移除后 memory_items 是否变空
if memory_items: # 如果移除后列表不为空
# self.memory_graph.G.nodes[node]["memory_items"] = memory_items # 直接修改列表即可
self.memory_graph.G.nodes[node]["last_modified"] = current_time # 更新修改时间
node_changes["reduced"].append(f"{node} (数量: {current_count} -> {len(memory_items)})")
else: # 如果移除后列表为空
# 尝试移除节点,处理可能的错误
try:
self.memory_graph.G.remove_node(node)
node_changes["removed"].append(f"{node}(遗忘清空)") # 标记为遗忘清空
logger.debug(f"[遗忘] 节点 {node} 因移除最后一项而被清空。")
except nx.NetworkXError as e:
logger.warning(f"[遗忘] 尝试移除节点 {node} 时发生错误(可能已被移除):{e}")
except ValueError:
# 这个错误理论上不应发生,因为 removed_item 来自 memory_items
logger.warning(f"[遗忘] 尝试从节点 '{node}' 移除不存在的项目 '{removed_item[:30]}...'")
node_check_end = time.time()
logger.info(f"[遗忘] 节点检查耗时: {node_check_end - node_check_start:.2f}")
@@ -1669,7 +1654,7 @@ class ParahippocampalGyrus:
class HippocampusManager:
def __init__(self):
self._hippocampus = None
self._hippocampus: Hippocampus = None # type: ignore
self._initialized = False
def initialize(self):

View File

@@ -13,7 +13,7 @@ from json_repair import repair_json
logger = get_logger("memory_activator")
def get_keywords_from_json(json_str):
def get_keywords_from_json(json_str) -> List:
"""
从JSON字符串中提取关键词列表
@@ -28,15 +28,8 @@ def get_keywords_from_json(json_str):
fixed_json = repair_json(json_str)
# 如果repair_json返回的是字符串需要解析为Python对象
if isinstance(fixed_json, str):
result = json.loads(fixed_json)
else:
# 如果repair_json直接返回了字典对象直接使用
result = fixed_json
# 提取关键词
keywords = result.get("keywords", [])
return keywords
result = json.loads(fixed_json) if isinstance(fixed_json, str) else fixed_json
return result.get("keywords", [])
except Exception as e:
logger.error(f"解析关键词JSON失败: {e}")
return []

View File

@@ -1,52 +1,10 @@
import numpy as np
from scipy import stats
from datetime import datetime, timedelta
from rich.traceback import install
install(extra_lines=3)
class DistributionVisualizer:
def __init__(self, mean=0, std=1, skewness=0, sample_size=10):
"""
初始化分布可视化器
参数:
mean (float): 期望均值
std (float): 标准差
skewness (float): 偏度
sample_size (int): 样本大小
"""
self.mean = mean
self.std = std
self.skewness = skewness
self.sample_size = sample_size
self.samples = None
def generate_samples(self):
"""生成具有指定参数的样本"""
if self.skewness == 0:
# 对于无偏度的情况,直接使用正态分布
self.samples = np.random.normal(loc=self.mean, scale=self.std, size=self.sample_size)
else:
# 使用 scipy.stats 生成具有偏度的分布
self.samples = stats.skewnorm.rvs(a=self.skewness, loc=self.mean, scale=self.std, size=self.sample_size)
def get_weighted_samples(self):
"""获取加权后的样本数列"""
if self.samples is None:
self.generate_samples()
# 将样本值乘以样本大小
return self.samples * self.sample_size
def get_statistics(self):
"""获取分布的统计信息"""
if self.samples is None:
self.generate_samples()
return {"均值": np.mean(self.samples), "标准差": np.std(self.samples), "实际偏度": stats.skew(self.samples)}
class MemoryBuildScheduler:
def __init__(self, n_hours1, std_hours1, weight1, n_hours2, std_hours2, weight2, total_samples=50):
"""
@@ -108,61 +66,61 @@ class MemoryBuildScheduler:
return [int(t.timestamp()) for t in timestamps]
def print_time_samples(timestamps, show_distribution=True):
"""打印时间样本和分布信息"""
print(f"\n生成的{len(timestamps)}个时间点分布:")
print("序号".ljust(5), "时间戳".ljust(25), "距现在(小时)")
print("-" * 50)
# def print_time_samples(timestamps, show_distribution=True):
# """打印时间样本和分布信息"""
# print(f"\n生成的{len(timestamps)}个时间点分布:")
# print("序号".ljust(5), "时间戳".ljust(25), "距现在(小时)")
# print("-" * 50)
now = datetime.now()
time_diffs = []
# now = datetime.now()
# time_diffs = []
for i, timestamp in enumerate(timestamps, 1):
hours_diff = (now - timestamp).total_seconds() / 3600
time_diffs.append(hours_diff)
print(f"{str(i).ljust(5)} {timestamp.strftime('%Y-%m-%d %H:%M:%S').ljust(25)} {hours_diff:.2f}")
# for i, timestamp in enumerate(timestamps, 1):
# hours_diff = (now - timestamp).total_seconds() / 3600
# time_diffs.append(hours_diff)
# print(f"{str(i).ljust(5)} {timestamp.strftime('%Y-%m-%d %H:%M:%S').ljust(25)} {hours_diff:.2f}")
# 打印统计信息
print("\n统计信息:")
print(f"平均时间偏移:{np.mean(time_diffs):.2f}小时")
print(f"标准差:{np.std(time_diffs):.2f}小时")
print(f"最早时间:{min(timestamps).strftime('%Y-%m-%d %H:%M:%S')} ({max(time_diffs):.2f}小时前)")
print(f"最近时间:{max(timestamps).strftime('%Y-%m-%d %H:%M:%S')} ({min(time_diffs):.2f}小时前)")
# # 打印统计信息
# print("\n统计信息")
# print(f"平均时间偏移:{np.mean(time_diffs):.2f}小时")
# print(f"标准差:{np.std(time_diffs):.2f}小时")
# print(f"最早时间:{min(timestamps).strftime('%Y-%m-%d %H:%M:%S')} ({max(time_diffs):.2f}小时前)")
# print(f"最近时间:{max(timestamps).strftime('%Y-%m-%d %H:%M:%S')} ({min(time_diffs):.2f}小时前)")
if show_distribution:
# 计算时间分布的直方图
hist, bins = np.histogram(time_diffs, bins=40)
print("\n时间分布(每个*代表一个时间点):")
for i in range(len(hist)):
if hist[i] > 0:
print(f"{bins[i]:6.1f}-{bins[i + 1]:6.1f}小时: {'*' * int(hist[i])}")
# if show_distribution:
# # 计算时间分布的直方图
# hist, bins = np.histogram(time_diffs, bins=40)
# print("\n时间分布(每个*代表一个时间点):")
# for i in range(len(hist)):
# if hist[i] > 0:
# print(f"{bins[i]:6.1f}-{bins[i + 1]:6.1f}小时: {'*' * int(hist[i])}")
# 使用示例
if __name__ == "__main__":
# 创建一个双峰分布的记忆调度器
scheduler = MemoryBuildScheduler(
n_hours1=12, # 第一个分布均值12小时前
std_hours1=8, # 第一个分布标准差
weight1=0.7, # 第一个分布权重 70%
n_hours2=36, # 第二个分布均值36小时前
std_hours2=24, # 第二个分布标准差
weight2=0.3, # 第二个分布权重 30%
total_samples=50, # 总共生成50个时间点
)
# # 使用示例
# if __name__ == "__main__":
# # 创建一个双峰分布的记忆调度器
# scheduler = MemoryBuildScheduler(
# n_hours1=12, # 第一个分布均值12小时前
# std_hours1=8, # 第一个分布标准差
# weight1=0.7, # 第一个分布权重 70%
# n_hours2=36, # 第二个分布均值36小时前
# std_hours2=24, # 第二个分布标准差
# weight2=0.3, # 第二个分布权重 30%
# total_samples=50, # 总共生成50个时间点
# )
# 生成时间分布
timestamps = scheduler.generate_time_samples()
# # 生成时间分布
# timestamps = scheduler.generate_time_samples()
# 打印结果,包含分布可视化
print_time_samples(timestamps, show_distribution=True)
# # 打印结果,包含分布可视化
# print_time_samples(timestamps, show_distribution=True)
# 打印时间戳数组
timestamp_array = scheduler.get_timestamp_array()
print("\n时间戳数组Unix时间戳")
print("[", end="")
for i, ts in enumerate(timestamp_array):
if i > 0:
print(", ", end="")
print(ts, end="")
print("]")
# # 打印时间戳数组
# timestamp_array = scheduler.get_timestamp_array()
# print("\n时间戳数组Unix时间戳")
# print("[", end="")
# for i, ts in enumerate(timestamp_array):
# if i > 0:
# print(", ", end="")
# print(ts, end="")
# print("]")

View File

@@ -1,22 +1,25 @@
import traceback
import os
import re
from typing import Dict, Any
from maim_message import UserInfo
from src.common.logger import get_logger
from src.config.config import global_config
from src.mood.mood_manager import mood_manager # 导入情绪管理器
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.message_receive.chat_stream import get_chat_manager, ChatStream
from src.chat.message_receive.message import MessageRecv
from src.experimental.only_message_process import MessageProcessor
from src.chat.message_receive.storage import MessageStorage
from src.chat.heart_flow.heartflow_message_processor import HeartFCMessageReceiver
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.config.config import global_config
from src.experimental.only_message_process import MessageProcessor
from src.experimental.PFC.pfc_manager import PFCManager
from src.plugin_system.core.component_registry import component_registry # 导入新插件系统
from src.plugin_system.base.base_command import BaseCommand
from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor
from maim_message import UserInfo
from src.chat.message_receive.chat_stream import ChatStream
import re
# 定义日志配置
# 获取项目根目录假设本文件在src/chat/message_receive/下,根目录为上上上级目录)
@@ -182,8 +185,8 @@ class ChatBot:
get_chat_manager().register_message(message)
chat = await get_chat_manager().get_or_create_stream(
platform=message.message_info.platform,
user_info=user_info,
platform=message.message_info.platform, # type: ignore
user_info=user_info, # type: ignore
group_info=group_info,
)
@@ -193,8 +196,10 @@ class ChatBot:
await message.process()
# 过滤检查
if _check_ban_words(message.processed_plain_text, chat, user_info) or _check_ban_regex(
message.raw_message, chat, user_info
if _check_ban_words(message.processed_plain_text, chat, user_info) or _check_ban_regex( # type: ignore
message.raw_message, # type: ignore
chat,
user_info, # type: ignore
):
return

View File

@@ -3,18 +3,17 @@ import hashlib
import time
import copy
from typing import Dict, Optional, TYPE_CHECKING
from ...common.database.database import db
from ...common.database.database_model import ChatStreams # 新增导入
from rich.traceback import install
from maim_message import GroupInfo, UserInfo
from src.common.logger import get_logger
from src.common.database.database import db
from src.common.database.database_model import ChatStreams # 新增导入
# 避免循环导入使用TYPE_CHECKING进行类型提示
if TYPE_CHECKING:
from .message import MessageRecv
from src.common.logger import get_logger
from rich.traceback import install
install(extra_lines=3)
@@ -28,7 +27,7 @@ class ChatMessageContext:
def __init__(self, message: "MessageRecv"):
self.message = message
def get_template_name(self) -> str:
def get_template_name(self) -> Optional[str]:
"""获取模板名称"""
if self.message.message_info.template_info and not self.message.message_info.template_info.template_default:
return self.message.message_info.template_info.template_name
@@ -39,11 +38,12 @@ class ChatMessageContext:
return self.message
def check_types(self, types: list) -> bool:
# sourcery skip: invert-any-all, use-any, use-next
"""检查消息类型"""
if not self.message.message_info.format_info.accept_format:
if not self.message.message_info.format_info.accept_format: # type: ignore
return False
for t in types:
if t not in self.message.message_info.format_info.accept_format:
if t not in self.message.message_info.format_info.accept_format: # type: ignore
return False
return True
@@ -67,7 +67,7 @@ class ChatStream:
platform: str,
user_info: UserInfo,
group_info: Optional[GroupInfo] = None,
data: dict = None,
data: Optional[dict] = None,
):
self.stream_id = stream_id
self.platform = platform
@@ -76,7 +76,7 @@ class ChatStream:
self.create_time = data.get("create_time", time.time()) if data else time.time()
self.last_active_time = data.get("last_active_time", self.create_time) if data else self.create_time
self.saved = False
self.context: ChatMessageContext = None # 用于存储该聊天的上下文信息
self.context: ChatMessageContext = None # type: ignore # 用于存储该聊天的上下文信息
def to_dict(self) -> dict:
"""转换为字典格式"""
@@ -98,7 +98,7 @@ class ChatStream:
return cls(
stream_id=data["stream_id"],
platform=data["platform"],
user_info=user_info,
user_info=user_info, # type: ignore
group_info=group_info,
data=data,
)
@@ -162,8 +162,8 @@ class ChatManager:
def register_message(self, message: "MessageRecv"):
"""注册消息到聊天流"""
stream_id = self._generate_stream_id(
message.message_info.platform,
message.message_info.user_info,
message.message_info.platform, # type: ignore
message.message_info.user_info, # type: ignore
message.message_info.group_info,
)
self.last_messages[stream_id] = message
@@ -184,10 +184,7 @@ class ChatManager:
def get_stream_id(self, platform: str, id: str, is_group: bool = True) -> str:
"""获取聊天流ID"""
if is_group:
components = [platform, str(id)]
else:
components = [platform, str(id), "private"]
components = [platform, id] if is_group else [platform, id, "private"]
key = "_".join(components)
return hashlib.md5(key.encode()).hexdigest()

View File

@@ -1,17 +1,15 @@
import time
from abc import abstractmethod
from dataclasses import dataclass
from typing import Optional, Any, TYPE_CHECKING
import urllib3
from src.common.logger import get_logger
if TYPE_CHECKING:
from .chat_stream import ChatStream
from ..utils.utils_image import get_image_manager
from maim_message import Seg, UserInfo, BaseMessageInfo, MessageBase
from abc import abstractmethod
from dataclasses import dataclass
from rich.traceback import install
from typing import Optional, Any
from maim_message import Seg, UserInfo, BaseMessageInfo, MessageBase
from src.common.logger import get_logger
from src.chat.utils.utils_image import get_image_manager
from .chat_stream import ChatStream
install(extra_lines=3)
@@ -27,7 +25,7 @@ urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
@dataclass
class Message(MessageBase):
chat_stream: "ChatStream" = None
chat_stream: "ChatStream" = None # type: ignore
reply: Optional["Message"] = None
processed_plain_text: str = ""
memorized_times: int = 0
@@ -55,7 +53,7 @@ class Message(MessageBase):
)
# 调用父类初始化
super().__init__(message_info=message_info, message_segment=message_segment, raw_message=None)
super().__init__(message_info=message_info, message_segment=message_segment, raw_message=None) # type: ignore
self.chat_stream = chat_stream
# 文本处理相关属性
@@ -66,6 +64,7 @@ class Message(MessageBase):
self.reply = reply
async def _process_message_segments(self, segment: Seg) -> str:
# sourcery skip: remove-unnecessary-else, swap-if-else-branches
"""递归处理消息段,转换为文字描述
Args:
@@ -78,13 +77,13 @@ class Message(MessageBase):
# 处理消息段列表
segments_text = []
for seg in segment.data:
processed = await self._process_message_segments(seg)
processed = await self._process_message_segments(seg) # type: ignore
if processed:
segments_text.append(processed)
return " ".join(segments_text)
else:
# 处理单个消息段
return await self._process_single_segment(segment)
return await self._process_single_segment(segment) # type: ignore
@abstractmethod
async def _process_single_segment(self, segment):
@@ -113,7 +112,7 @@ class MessageRecv(Message):
self.is_mentioned = None
self.priority_mode = "interest"
self.priority_info = None
self.interest_value = None
self.interest_value: float = None # type: ignore
def update_chat_stream(self, chat_stream: "ChatStream"):
self.chat_stream = chat_stream
@@ -139,7 +138,7 @@ class MessageRecv(Message):
if segment.type == "text":
self.is_picid = False
self.is_emoji = False
return segment.data
return segment.data # type: ignore
elif segment.type == "image":
# 如果是base64图片数据
if isinstance(segment.data, str):
@@ -161,7 +160,7 @@ class MessageRecv(Message):
elif segment.type == "mention_bot":
self.is_picid = False
self.is_emoji = False
self.is_mentioned = float(segment.data)
self.is_mentioned = float(segment.data) # type: ignore
return ""
elif segment.type == "priority_info":
self.is_picid = False
@@ -187,7 +186,7 @@ class MessageRecv(Message):
"""生成详细文本,包含时间和用户信息"""
timestamp = self.message_info.time
user_info = self.message_info.user_info
name = f"<{self.message_info.platform}:{user_info.user_id}:{user_info.user_nickname}:{user_info.user_cardname}>"
name = f"<{self.message_info.platform}:{user_info.user_id}:{user_info.user_nickname}:{user_info.user_cardname}>" # type: ignore
return f"[{timestamp}] {name}: {self.processed_plain_text}\n"
@@ -235,7 +234,7 @@ class MessageProcessBase(Message):
"""
try:
if seg.type == "text":
return seg.data
return seg.data # type: ignore
elif seg.type == "image":
# 如果是base64图片数据
if isinstance(seg.data, str):
@@ -251,7 +250,7 @@ class MessageProcessBase(Message):
if self.reply and hasattr(self.reply, "processed_plain_text"):
# print(f"self.reply.processed_plain_text: {self.reply.processed_plain_text}")
# print(f"reply: {self.reply}")
return f"[回复<{self.reply.message_info.user_info.user_nickname}:{self.reply.message_info.user_info.user_id}> 的消息:{self.reply.processed_plain_text}]"
return f"[回复<{self.reply.message_info.user_info.user_nickname}:{self.reply.message_info.user_info.user_id}> 的消息:{self.reply.processed_plain_text}]" # type: ignore
return None
else:
return f"[{seg.type}:{str(seg.data)}]"
@@ -265,7 +264,7 @@ class MessageProcessBase(Message):
timestamp = self.message_info.time
user_info = self.message_info.user_info
name = f"<{self.message_info.platform}:{user_info.user_id}:{user_info.user_nickname}:{user_info.user_cardname}>"
name = f"<{self.message_info.platform}:{user_info.user_id}:{user_info.user_nickname}:{user_info.user_cardname}>" # type: ignore
return f"[{timestamp}]{name} 说:{self.processed_plain_text}\n"
@@ -314,7 +313,7 @@ class MessageSending(MessageProcessBase):
is_emoji: bool = False,
thinking_start_time: float = 0,
apply_set_reply_logic: bool = False,
reply_to: str = None,
reply_to: str = None, # type: ignore
):
# 调用父类初始化
super().__init__(
@@ -347,7 +346,7 @@ class MessageSending(MessageProcessBase):
self.message_segment = Seg(
type="seglist",
data=[
Seg(type="reply", data=self.reply.message_info.message_id),
Seg(type="reply", data=self.reply.message_info.message_id), # type: ignore
self.message_segment,
],
)
@@ -367,10 +366,10 @@ class MessageSending(MessageProcessBase):
) -> "MessageSending":
"""从思考状态消息创建发送状态消息"""
return cls(
message_id=thinking.message_info.message_id,
message_id=thinking.message_info.message_id, # type: ignore
chat_stream=thinking.chat_stream,
message_segment=message_segment,
bot_user_info=thinking.message_info.user_info,
bot_user_info=thinking.message_info.user_info, # type: ignore
reply=thinking.reply,
is_head=is_head,
is_emoji=is_emoji,
@@ -402,13 +401,11 @@ class MessageSet:
if not isinstance(message, MessageSending):
raise TypeError("MessageSet只能添加MessageSending类型的消息")
self.messages.append(message)
self.messages.sort(key=lambda x: x.message_info.time)
self.messages.sort(key=lambda x: x.message_info.time) # type: ignore
def get_message_by_index(self, index: int) -> Optional[MessageSending]:
"""通过索引获取消息"""
if 0 <= index < len(self.messages):
return self.messages[index]
return None
return self.messages[index] if 0 <= index < len(self.messages) else None
def get_message_by_time(self, target_time: float) -> Optional[MessageSending]:
"""获取最接近指定时间的消息"""
@@ -418,7 +415,7 @@ class MessageSet:
left, right = 0, len(self.messages) - 1
while left < right:
mid = (left + right) // 2
if self.messages[mid].message_info.time < target_time:
if self.messages[mid].message_info.time < target_time: # type: ignore
left = mid + 1
else:
right = mid
@@ -444,11 +441,8 @@ class MessageSet:
def message_recv_from_dict(message_dict: dict) -> MessageRecv:
return MessageRecv(
message_dict
)
return MessageRecv(message_dict)
def message_from_db_dict(db_dict: dict) -> MessageRecv:
"""从数据库字典创建MessageRecv实例"""
@@ -492,4 +486,4 @@ def message_from_db_dict(db_dict: dict) -> MessageRecv:
msg.is_emoji = db_dict.get("is_emoji", False)
msg.is_picid = db_dict.get("is_picid", False)
return msg
return msg

View File

@@ -2,11 +2,10 @@ import re
import traceback
from typing import Union
# from ...common.database.database import db # db is now Peewee's SqliteDatabase instance
from .message import MessageSending, MessageRecv
from .chat_stream import ChatStream
from ...common.database.database_model import Messages, RecalledMessages, Images # Import Peewee models
from src.common.database.database_model import Messages, RecalledMessages, Images
from src.common.logger import get_logger
from .chat_stream import ChatStream
from .message import MessageSending, MessageRecv
logger = get_logger("message_storage")
@@ -55,7 +54,7 @@ class MessageStorage:
is_picid = message.is_picid
chat_info_dict = chat_stream.to_dict()
user_info_dict = message.message_info.user_info.to_dict()
user_info_dict = message.message_info.user_info.to_dict() # type: ignore
# message_id 现在是 TextField直接使用字符串值
msg_id = message.message_info.message_id
@@ -67,7 +66,7 @@ class MessageStorage:
Messages.create(
message_id=msg_id,
time=float(message.message_info.time),
time=float(message.message_info.time), # type: ignore
chat_id=chat_stream.stream_id,
# Flattened chat_info
reply_to=reply_to,
@@ -121,7 +120,7 @@ class MessageStorage:
try:
# Assuming input 'time' is a string timestamp that can be converted to float
current_time_float = float(time)
RecalledMessages.delete().where(RecalledMessages.time < (current_time_float - 300)).execute()
RecalledMessages.delete().where(RecalledMessages.time < (current_time_float - 300)).execute() # type: ignore
except Exception:
logger.exception("删除撤回消息失败")
@@ -133,22 +132,19 @@ class MessageStorage:
"""更新最新一条匹配消息的message_id"""
try:
if message.message_segment.type == "notify":
mmc_message_id = message.message_segment.data.get("echo")
qq_message_id = message.message_segment.data.get("actual_id")
mmc_message_id = message.message_segment.data.get("echo") # type: ignore
qq_message_id = message.message_segment.data.get("actual_id") # type: ignore
else:
logger.info(f"更新消息ID错误seg类型为{message.message_segment.type}")
return
if not qq_message_id:
logger.info("消息不存在message_id无法更新")
return
# 查询最新一条匹配消息
matched_message = (
if matched_message := (
Messages.select().where((Messages.message_id == mmc_message_id)).order_by(Messages.time.desc()).first()
)
if matched_message:
):
# 更新找到的消息记录
Messages.update(message_id=qq_message_id).where(Messages.id == matched_message.id).execute()
Messages.update(message_id=qq_message_id).where(Messages.id == matched_message.id).execute() # type: ignore
logger.debug(f"更新消息ID成功: {matched_message.message_id} -> {qq_message_id}")
else:
logger.debug("未找到匹配的消息")
@@ -173,10 +169,7 @@ class MessageStorage:
image_record = (
Images.select().where(Images.description == description).order_by(Images.timestamp.desc()).first()
)
if image_record:
return f"[picid:{image_record.image_id}]"
else:
return match.group(0) # 保持原样
return f"[picid:{image_record.image_id}]" if image_record else match.group(0)
except Exception:
return match.group(0)

View File

@@ -1,15 +1,16 @@
import asyncio
from src.chat.message_receive.message import MessageSending
from src.common.message.api import get_global_api
from src.chat.message_receive.storage import MessageStorage
from src.chat.utils.utils import truncate_message
from src.common.logger import get_logger
from src.chat.utils.utils import calculate_typing_time
from rich.traceback import install
import traceback
install(extra_lines=3)
from rich.traceback import install
from src.common.message.api import get_global_api
from src.common.logger import get_logger
from src.chat.message_receive.message import MessageSending
from src.chat.message_receive.storage import MessageStorage
from src.chat.utils.utils import truncate_message
from src.chat.utils.utils import calculate_typing_time
install(extra_lines=3)
logger = get_logger("sender")
@@ -49,10 +50,10 @@ class HeartFCSender:
"""
if not message.chat_stream:
logger.error("消息缺少 chat_stream无法发送")
raise Exception("消息缺少 chat_stream无法发送")
raise ValueError("消息缺少 chat_stream无法发送")
if not message.message_info or not message.message_info.message_id:
logger.error("消息缺少 message_info 或 message_id无法发送")
raise Exception("消息缺少 message_info 或 message_id无法发送")
raise ValueError("消息缺少 message_info 或 message_id无法发送")
chat_id = message.chat_stream.stream_id
message_id = message.message_info.message_id
@@ -84,4 +85,3 @@ class HeartFCSender:
except Exception as e:
logger.error(f"[{chat_id}] 处理或存储消息 {message_id} 时出错: {e}")
raise e

View File

@@ -1,15 +1,12 @@
from typing import Dict, List, Optional, Type, Any
from typing import Dict, List, Optional, Type
from src.plugin_system.base.base_action import BaseAction
from src.chat.message_receive.chat_stream import ChatStream
from src.common.logger import get_logger
from src.plugin_system.core.component_registry import component_registry
from src.plugin_system.base.component_types import ComponentType
from src.plugin_system.base.component_types import ComponentType, ActionActivationType, ChatMode, ActionInfo
logger = get_logger("action_manager")
# 定义动作信息类型
ActionInfo = Dict[str, Any]
class ActionManager:
"""
@@ -20,8 +17,8 @@ class ActionManager:
# 类常量
DEFAULT_RANDOM_PROBABILITY = 0.3
DEFAULT_MODE = "all"
DEFAULT_ACTIVATION_TYPE = "always"
DEFAULT_MODE = ChatMode.ALL
DEFAULT_ACTIVATION_TYPE = ActionActivationType.ALWAYS
def __init__(self):
"""初始化动作管理器"""
@@ -30,14 +27,11 @@ class ActionManager:
# 当前正在使用的动作集合,默认加载默认动作
self._using_actions: Dict[str, ActionInfo] = {}
# 默认动作集,仅作为快照,用于恢复默认
self._default_actions: Dict[str, ActionInfo] = {}
# 加载插件动作
self._load_plugin_actions()
# 初始化时将默认动作加载到使用中的动作
self._using_actions = self._default_actions.copy()
self._using_actions = component_registry.get_default_actions()
def _load_plugin_actions(self) -> None:
"""
@@ -54,43 +48,15 @@ class ActionManager:
def _load_plugin_system_actions(self) -> None:
"""从插件系统的component_registry加载Action组件"""
try:
from src.plugin_system.core.component_registry import component_registry
from src.plugin_system.base.component_types import ComponentType
# 获取所有Action组件
action_components = component_registry.get_components_by_type(ComponentType.ACTION)
action_components: Dict[str, ActionInfo] = component_registry.get_components_by_type(ComponentType.ACTION) # type: ignore
for action_name, action_info in action_components.items():
if action_name in self._registered_actions:
logger.debug(f"Action组件 {action_name} 已存在,跳过")
continue
# 将插件系统的ActionInfo转换为ActionManager格式
converted_action_info = {
"description": action_info.description,
"parameters": getattr(action_info, "action_parameters", {}),
"require": getattr(action_info, "action_require", []),
"associated_types": getattr(action_info, "associated_types", []),
"enable_plugin": action_info.enabled,
# 激活类型相关
"focus_activation_type": action_info.focus_activation_type.value,
"normal_activation_type": action_info.normal_activation_type.value,
"random_activation_probability": action_info.random_activation_probability,
"llm_judge_prompt": action_info.llm_judge_prompt,
"activation_keywords": action_info.activation_keywords,
"keyword_case_sensitive": action_info.keyword_case_sensitive,
# 模式和并行设置
"mode_enable": action_info.mode_enable.value,
"parallel_action": action_info.parallel_action,
# 插件信息
"_plugin_name": getattr(action_info, "plugin_name", ""),
}
self._registered_actions[action_name] = converted_action_info
# 如果启用,也添加到默认动作集
if action_info.enabled:
self._default_actions[action_name] = converted_action_info
self._registered_actions[action_name] = action_info
logger.debug(
f"从插件系统加载Action组件: {action_name} (插件: {getattr(action_info, 'plugin_name', 'unknown')})"
@@ -133,7 +99,9 @@ class ActionManager:
"""
try:
# 获取组件类 - 明确指定查询Action类型
component_class = component_registry.get_component_class(action_name, ComponentType.ACTION)
component_class: Type[BaseAction] = component_registry.get_component_class(
action_name, ComponentType.ACTION
) # type: ignore
if not component_class:
logger.warning(f"{log_prefix} 未找到Action组件: {action_name}")
return None
@@ -173,10 +141,6 @@ class ActionManager:
"""获取所有已注册的动作集"""
return self._registered_actions.copy()
def get_default_actions(self) -> Dict[str, ActionInfo]:
"""获取默认动作集"""
return self._default_actions.copy()
def get_using_actions(self) -> Dict[str, ActionInfo]:
"""获取当前正在使用的动作集合"""
return self._using_actions.copy()
@@ -221,31 +185,31 @@ class ActionManager:
logger.debug(f"已从使用集中移除动作 {action_name}")
return True
def add_action(self, action_name: str, description: str, parameters: Dict = None, require: List = None) -> bool:
"""
添加新的动作到注册集
# def add_action(self, action_name: str, description: str, parameters: Dict = None, require: List = None) -> bool:
# """
# 添加新的动作到注册集
Args:
action_name: 动作名称
description: 动作描述
parameters: 动作参数定义,默认为空字典
require: 动作依赖项,默认为空列表
# Args:
# action_name: 动作名称
# description: 动作描述
# parameters: 动作参数定义,默认为空字典
# require: 动作依赖项,默认为空列表
Returns:
bool: 添加是否成功
"""
if action_name in self._registered_actions:
return False
# Returns:
# bool: 添加是否成功
# """
# if action_name in self._registered_actions:
# return False
if parameters is None:
parameters = {}
if require is None:
require = []
# if parameters is None:
# parameters = {}
# if require is None:
# require = []
action_info = {"description": description, "parameters": parameters, "require": require}
# action_info = {"description": description, "parameters": parameters, "require": require}
self._registered_actions[action_name] = action_info
return True
# self._registered_actions[action_name] = action_info
# return True
def remove_action(self, action_name: str) -> bool:
"""从注册集移除指定动作"""
@@ -264,10 +228,9 @@ class ActionManager:
def restore_actions(self) -> None:
"""恢复到默认动作集"""
logger.debug(
f"恢复动作集: 从 {list(self._using_actions.keys())} 恢复到默认动作集 {list(self._default_actions.keys())}"
)
self._using_actions = self._default_actions.copy()
actions_to_restore = list(self._using_actions.keys())
self._using_actions = component_registry.get_default_actions()
logger.debug(f"恢复动作集: 从 {actions_to_restore} 恢复到默认动作集 {list(self._using_actions.keys())}")
def add_system_action_if_needed(self, action_name: str) -> bool:
"""
@@ -297,4 +260,4 @@ class ActionManager:
"""
from src.plugin_system.core.component_registry import component_registry
return component_registry.get_component_class(action_name)
return component_registry.get_component_class(action_name) # type: ignore

View File

@@ -1,15 +1,20 @@
from typing import List, Any, Dict
from src.common.logger import get_logger
from src.chat.focus_chat.hfc_utils import CycleDetail
from src.chat.message_receive.chat_stream import get_chat_manager
from src.config.config import global_config
from src.llm_models.utils_model import LLMRequest
import random
import asyncio
import hashlib
import time
from typing import List, Any, Dict, TYPE_CHECKING
from src.common.logger import get_logger
from src.config.config import global_config
from src.llm_models.utils_model import LLMRequest
from src.chat.focus_chat.hfc_utils import CycleDetail
from src.chat.message_receive.chat_stream import get_chat_manager, ChatMessageContext
from src.chat.planner_actions.action_manager import ActionManager
from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat, build_readable_messages
from src.plugin_system.base.component_types import ChatMode, ActionInfo, ActionActivationType
if TYPE_CHECKING:
from src.chat.message_receive.chat_stream import ChatStream
logger = get_logger("action_manager")
@@ -25,7 +30,7 @@ class ActionModifier:
def __init__(self, action_manager: ActionManager, chat_id: str):
"""初始化动作处理器"""
self.chat_id = chat_id
self.chat_stream = get_chat_manager().get_stream(self.chat_id)
self.chat_stream: ChatStream = get_chat_manager().get_stream(self.chat_id) # type: ignore
self.log_prefix = f"[{get_chat_manager().get_stream_name(self.chat_id) or self.chat_id}]"
self.action_manager = action_manager
@@ -45,7 +50,7 @@ class ActionModifier:
self,
history_loop=None,
message_content: str = "",
):
): # sourcery skip: use-named-expression
"""
动作修改流程,整合传统观察处理和新的激活类型判定
@@ -82,9 +87,9 @@ class ActionModifier:
# === 第一阶段:传统观察处理 ===
# if history_loop:
# removals_from_loop = await self.analyze_loop_actions(history_loop)
# if removals_from_loop:
# removals_s1.extend(removals_from_loop)
# removals_from_loop = await self.analyze_loop_actions(history_loop)
# if removals_from_loop:
# removals_s1.extend(removals_from_loop)
# 检查动作的关联类型
chat_context = self.chat_stream.context
@@ -125,15 +130,14 @@ class ActionModifier:
f"{self.log_prefix} 动作修改流程结束,最终可用动作: {list(self.action_manager.get_using_actions().keys())}||移除记录: {removals_summary}"
)
def _check_action_associated_types(self, all_actions, chat_context):
def _check_action_associated_types(self, all_actions: Dict[str, ActionInfo], chat_context: ChatMessageContext):
type_mismatched_actions = []
for action_name, data in all_actions.items():
if data.get("associated_types"):
if not chat_context.check_types(data["associated_types"]):
associated_types_str = ", ".join(data["associated_types"])
reason = f"适配器不支持(需要: {associated_types_str}"
type_mismatched_actions.append((action_name, reason))
logger.debug(f"{self.log_prefix}决定移除动作: {action_name},原因: {reason}")
for action_name, action_info in all_actions.items():
if action_info.associated_types and not chat_context.check_types(action_info.associated_types):
associated_types_str = ", ".join(action_info.associated_types)
reason = f"适配器不支持(需要: {associated_types_str}"
type_mismatched_actions.append((action_name, reason))
logger.debug(f"{self.log_prefix}决定移除动作: {action_name},原因: {reason}")
return type_mismatched_actions
async def _get_deactivated_actions_by_type(
@@ -167,28 +171,28 @@ class ActionModifier:
if activation_type == "always":
continue # 总是激活,无需处理
elif activation_type == "random":
probability = action_info.get("random_activation_probability", ActionManager.DEFAULT_RANDOM_PROBABILITY)
if not (random.random() < probability):
elif activation_type == ActionActivationType.RANDOM:
probability = action_info.random_activation_probability or ActionManager.DEFAULT_RANDOM_PROBABILITY
if random.random() >= probability:
reason = f"RANDOM类型未触发概率{probability}"
deactivated_actions.append((action_name, reason))
logger.debug(f"{self.log_prefix}未激活动作: {action_name},原因: {reason}")
elif activation_type == "keyword":
elif activation_type == ActionActivationType.KEYWORD:
if not self._check_keyword_activation(action_name, action_info, chat_content):
keywords = action_info.get("activation_keywords", [])
keywords = action_info.activation_keywords
reason = f"关键词未匹配(关键词: {keywords}"
deactivated_actions.append((action_name, reason))
logger.debug(f"{self.log_prefix}未激活动作: {action_name},原因: {reason}")
elif activation_type == "llm_judge":
elif activation_type == ActionActivationType.LLM_JUDGE:
llm_judge_actions[action_name] = action_info
elif activation_type == "never":
reason = "激活类型为never"
deactivated_actions.append((action_name, reason))
logger.debug(f"{self.log_prefix}未激活动作: {action_name},原因: 激活类型为never")
else:
logger.warning(f"{self.log_prefix}未知的激活类型: {activation_type},跳过处理")
@@ -273,7 +277,7 @@ class ActionModifier:
task_results = await asyncio.gather(*tasks, return_exceptions=True)
# 处理结果并更新缓存
for _, (action_name, result) in enumerate(zip(task_names, task_results, strict=False)):
for action_name, result in zip(task_names, task_results, strict=False):
if isinstance(result, Exception):
logger.error(f"{self.log_prefix}LLM判定action {action_name} 时出错: {result}")
results[action_name] = False
@@ -289,7 +293,7 @@ class ActionModifier:
except Exception as e:
logger.error(f"{self.log_prefix}并行LLM判定失败: {e}")
# 如果并行执行失败为所有任务返回False
for action_name in tasks_to_run.keys():
for action_name in tasks_to_run:
results[action_name] = False
# 清理过期缓存
@@ -300,10 +304,11 @@ class ActionModifier:
def _cleanup_expired_cache(self, current_time: float):
"""清理过期的缓存条目"""
expired_keys = []
for cache_key, cache_data in self._llm_judge_cache.items():
if current_time - cache_data["timestamp"] > self._cache_expiry_time:
expired_keys.append(cache_key)
expired_keys.extend(
cache_key
for cache_key, cache_data in self._llm_judge_cache.items()
if current_time - cache_data["timestamp"] > self._cache_expiry_time
)
for key in expired_keys:
del self._llm_judge_cache[key]
@@ -382,7 +387,7 @@ class ActionModifier:
def _check_keyword_activation(
self,
action_name: str,
action_info: Dict[str, Any],
action_info: ActionInfo,
chat_content: str = "",
) -> bool:
"""
@@ -399,8 +404,8 @@ class ActionModifier:
bool: 是否应该激活此action
"""
activation_keywords = action_info.get("activation_keywords", [])
case_sensitive = action_info.get("keyword_case_sensitive", False)
activation_keywords = action_info.activation_keywords
case_sensitive = action_info.keyword_case_sensitive
if not activation_keywords:
logger.warning(f"{self.log_prefix}动作 {action_name} 设置为关键词触发但未配置关键词")
@@ -500,13 +505,13 @@ class ActionModifier:
return removals
def get_available_actions_count(self,mode:str = "focus") -> int:
def get_available_actions_count(self, mode: str = "focus") -> int:
"""获取当前可用动作数量排除默认的no_action"""
current_actions = self.action_manager.get_using_actions_for_mode(mode)
# 排除no_action如果存在
filtered_actions = {k: v for k, v in current_actions.items() if k != "no_action"}
return len(filtered_actions)
def should_skip_planning_for_no_reply(self) -> bool:
"""判断是否应该跳过规划过程"""
current_actions = self.action_manager.get_using_actions_for_mode("focus")

View File

@@ -1,23 +1,26 @@
import json # <--- 确保导入 json
import json
import time
import traceback
from typing import Dict, Any, Optional
from rich.traceback import install
from datetime import datetime
from json_repair import repair_json
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config
from src.common.logger import get_logger
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.planner_actions.action_manager import ActionManager
from json_repair import repair_json
from src.chat.utils.utils import get_chat_type_and_target_info
from datetime import datetime
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.utils.chat_message_builder import (
build_readable_actions,
build_readable_messages,
get_actions_by_timestamp_with_chat,
build_readable_messages,
get_raw_msg_before_timestamp_with_chat,
)
import time
from src.chat.utils.utils import get_chat_type_and_target_info
from src.chat.planner_actions.action_manager import ActionManager
from src.chat.message_receive.chat_stream import get_chat_manager
from src.plugin_system.base.component_types import ChatMode, ActionInfo
logger = get_logger("planner")
@@ -28,7 +31,7 @@ def init_prompt():
Prompt(
"""
{time_block}
{indentify_block}
{identity_block}
你现在需要根据聊天内容选择的合适的action来参与聊天。
{chat_context_description},以下是具体的聊天内容:
{chat_content_block}
@@ -76,7 +79,7 @@ class ActionPlanner:
self.last_obs_time_mark = 0.0
async def plan(self,mode:str = "focus") -> Dict[str, Any]:
async def plan(self, mode: str = "focus") -> Dict[str, Dict[str, Any]]: # sourcery skip: dict-comprehension
"""
规划器 (Planner): 使用LLM根据上下文决定做出什么动作。
"""
@@ -84,6 +87,7 @@ class ActionPlanner:
action = "no_reply" # 默认动作
reasoning = "规划器初始化默认"
action_data = {}
current_available_actions: Dict[str, ActionInfo] = {}
try:
is_group_chat = True
@@ -95,7 +99,7 @@ class ActionPlanner:
# 获取完整的动作信息
all_registered_actions = self.action_manager.get_registered_actions()
current_available_actions = {}
for action_name in current_available_actions_dict.keys():
if action_name in all_registered_actions:
current_available_actions[action_name] = all_registered_actions[action_name]
@@ -111,7 +115,11 @@ class ActionPlanner:
reasoning = "没有可用的动作"
logger.info(f"{self.log_prefix}{reasoning}")
return {
"action_result": {"action_type": action, "action_data": action_data, "reasoning": reasoning},
"action_result": {
"action_type": action,
"action_data": action_data,
"reasoning": reasoning,
},
}
# --- 构建提示词 (调用修改后的 PromptBuilder 方法) ---
@@ -140,7 +148,7 @@ class ActionPlanner:
except Exception as req_e:
logger.error(f"{self.log_prefix}LLM 请求执行失败: {req_e}")
reasoning = f"LLM 请求失败,你的模型出现问题: {req_e}"
reasoning = f"LLM 请求失败,模型出现问题: {req_e}"
action = "no_reply"
if llm_content:
@@ -162,7 +170,6 @@ class ActionPlanner:
reasoning = parsed_json.get("reasoning", "未提供原因")
# 将所有其他属性添加到action_data
action_data = {}
for key, value in parsed_json.items():
if key not in ["action", "reasoning"]:
action_data[key] = value
@@ -173,8 +180,8 @@ class ActionPlanner:
logger.warning(
f"{self.log_prefix}LLM 返回了当前不可用或无效的动作: '{action}' (可用: {list(current_available_actions.keys())}),将强制使用 'no_reply'"
)
action = "no_reply"
reasoning = f"LLM 返回了当前不可用的动作 '{action}' (可用: {list(current_available_actions.keys())})。原始理由: {reasoning}"
action = "no_reply"
except Exception as json_e:
logger.warning(f"{self.log_prefix}解析LLM响应JSON失败 {json_e}. LLM原始输出: '{llm_content}'")
@@ -190,8 +197,7 @@ class ActionPlanner:
is_parallel = False
if action in current_available_actions:
action_info = current_available_actions[action]
is_parallel = action_info.get("parallel_action", False)
is_parallel = current_available_actions[action].parallel_action
action_result = {
"action_type": action,
@@ -201,20 +207,18 @@ class ActionPlanner:
"is_parallel": is_parallel,
}
plan_result = {
return {
"action_result": action_result,
"action_prompt": prompt,
}
return plan_result
async def build_planner_prompt(
self,
is_group_chat: bool, # Now passed as argument
chat_target_info: Optional[dict], # Now passed as argument
current_available_actions,
current_available_actions: Dict[str, ActionInfo],
mode: str = "focus",
) -> str:
) -> str: # sourcery skip: use-join
"""构建 Planner LLM 的提示词 (获取模板并填充数据)"""
try:
message_list_before_now = get_raw_msg_before_timestamp_with_chat(
@@ -279,23 +283,23 @@ class ActionPlanner:
action_options_block = ""
for using_actions_name, using_actions_info in current_available_actions.items():
if using_actions_info["parameters"]:
if using_actions_info.action_parameters:
param_text = "\n"
for param_name, param_description in using_actions_info["parameters"].items():
for param_name, param_description in using_actions_info.action_parameters.items():
param_text += f' "{param_name}":"{param_description}"\n'
param_text = param_text.rstrip("\n")
else:
param_text = ""
require_text = ""
for require_item in using_actions_info["require"]:
for require_item in using_actions_info.action_require:
require_text += f"- {require_item}\n"
require_text = require_text.rstrip("\n")
using_action_prompt = await global_prompt_manager.get_prompt_async("action_prompt")
using_action_prompt = using_action_prompt.format(
action_name=using_actions_name,
action_description=using_actions_info["description"],
action_description=using_actions_info.description,
action_parameters=param_text,
action_require=require_text,
)
@@ -312,10 +316,10 @@ class ActionPlanner:
else:
bot_nickname = ""
bot_core_personality = global_config.personality.personality_core
indentify_block = f"你的名字是{bot_name}{bot_nickname},你{bot_core_personality}"
identity_block = f"你的名字是{bot_name}{bot_nickname},你{bot_core_personality}"
planner_prompt_template = await global_prompt_manager.get_prompt_async("planner_prompt")
prompt = planner_prompt_template.format(
return planner_prompt_template.format(
time_block=time_block,
by_what=by_what,
chat_context_description=chat_context_description,
@@ -324,10 +328,8 @@ class ActionPlanner:
no_action_block=no_action_block,
action_options_text=action_options_block,
moderation_prompt=moderation_prompt_block,
indentify_block=indentify_block,
identity_block=identity_block,
)
return prompt
except Exception as e:
logger.error(f"构建 Planner 提示词时出错: {e}")
logger.error(traceback.format_exc())

View File

@@ -1,31 +1,31 @@
import traceback
from typing import List, Optional, Dict, Any, Tuple
from src.chat.message_receive.message import MessageRecv, MessageSending
from src.chat.message_receive.message import Seg # Local import needed after move
from src.chat.message_receive.message import UserInfo
from src.common.logger import get_logger
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config
from src.chat.utils.timer_calculator import Timer # <--- Import Timer
from src.chat.message_receive.uni_message_sender import HeartFCSender
from src.chat.utils.utils import get_chat_type_and_target_info
from src.chat.message_receive.chat_stream import ChatStream
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat
import time
import asyncio
from src.chat.express.expression_selector import expression_selector
from src.mood.mood_manager import mood_manager
from src.person_info.relationship_fetcher import relationship_fetcher_manager
import random
import ast
from src.person_info.person_info import get_person_info_manager
from datetime import datetime
import re
from typing import List, Optional, Dict, Any, Tuple
from datetime import datetime
from src.common.logger import get_logger
from src.config.config import global_config
from src.llm_models.utils_model import LLMRequest
from src.chat.message_receive.message import UserInfo, Seg, MessageRecv, MessageThinking, MessageSending
from src.chat.message_receive.chat_stream import get_chat_manager, ChatStream
from src.chat.message_receive.uni_message_sender import HeartFCSender
from src.chat.utils.timer_calculator import Timer # <--- Import Timer
from src.chat.utils.utils import get_chat_type_and_target_info
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat
from src.chat.express.expression_selector import expression_selector
from src.chat.knowledge.knowledge_lib import qa_manager
from src.chat.memory_system.memory_activator import MemoryActivator
from src.mood.mood_manager import mood_manager
from src.person_info.relationship_fetcher import relationship_fetcher_manager
from src.person_info.person_info import get_person_info_manager
from src.tools.tool_executor import ToolExecutor
from src.plugin_system.base.component_types import ActionInfo
logger = get_logger("replyer")
@@ -132,25 +132,23 @@ class DefaultReplyer:
# 提取权重,如果模型配置中没有'weight'键则默认为1.0
weights = [config.get("weight", 1.0) for config in configs]
# random.choices 返回一个列表,我们取第一个元素
selected_config = random.choices(population=configs, weights=weights, k=1)[0]
return selected_config
return random.choices(population=configs, weights=weights, k=1)[0]
async def generate_reply_with_context(
self,
reply_data: Dict[str, Any] = None,
reply_data: Optional[Dict[str, Any]] = None,
reply_to: str = "",
extra_info: str = "",
available_actions: List[str] = None,
available_actions: Optional[Dict[str, ActionInfo]] = None,
enable_tool: bool = True,
enable_timeout: bool = False,
) -> Tuple[bool, Optional[str]]:
) -> Tuple[bool, Optional[str], Optional[str]]:
"""
回复器 (Replier): 核心逻辑,负责生成回复文本。
(已整合原 HeartFCGenerator 的功能)
"""
if available_actions is None:
available_actions = []
available_actions = {}
if reply_data is None:
reply_data = {}
try:
@@ -202,14 +200,14 @@ class DefaultReplyer:
except Exception as llm_e:
# 精简报错信息
logger.error(f"{self.log_prefix}LLM 生成失败: {llm_e}")
return False, None # LLM 调用失败则无法生成回复
return False, None, prompt # LLM 调用失败则无法生成回复
return True, content, prompt
except Exception as e:
logger.error(f"{self.log_prefix}回复生成意外失败: {e}")
traceback.print_exc()
return False, None
return False, None, prompt
async def rewrite_reply_with_context(
self,
@@ -289,15 +287,14 @@ class DefaultReplyer:
logger.warning(f"{self.log_prefix} 未找到用户 {sender} 的ID跳过信息提取")
return f"你完全不认识{sender}不理解ta的相关信息。"
relation_info = await relationship_fetcher.build_relation_info(person_id, text, chat_history)
return relation_info
return await relationship_fetcher.build_relation_info(person_id, text, chat_history)
async def build_expression_habits(self, chat_history, target):
if not global_config.expression.enable_expression:
return ""
style_habbits = []
grammar_habbits = []
style_habits = []
grammar_habits = []
# 使用从处理器传来的选中表达方式
# LLM模式调用LLM选择5-10个然后随机选5个
@@ -311,22 +308,22 @@ class DefaultReplyer:
if isinstance(expr, dict) and "situation" in expr and "style" in expr:
expr_type = expr.get("type", "style")
if expr_type == "grammar":
grammar_habbits.append(f"{expr['situation']}时,使用 {expr['style']}")
grammar_habits.append(f"{expr['situation']}时,使用 {expr['style']}")
else:
style_habbits.append(f"{expr['situation']}时,使用 {expr['style']}")
style_habits.append(f"{expr['situation']}时,使用 {expr['style']}")
else:
logger.debug(f"{self.log_prefix} 没有从处理器获得表达方式,将使用空的表达方式")
# 不再在replyer中进行随机选择全部交给处理器处理
style_habbits_str = "\n".join(style_habbits)
grammar_habbits_str = "\n".join(grammar_habbits)
style_habits_str = "\n".join(style_habits)
grammar_habits_str = "\n".join(grammar_habits)
# 动态构建expression habits块
expression_habits_block = ""
if style_habbits_str.strip():
expression_habits_block += f"你可以参考以下的语言习惯,如果情景合适就使用,不要盲目使用,不要生硬使用,而是结合到表达中:\n{style_habbits_str}\n\n"
if grammar_habbits_str.strip():
expression_habits_block += f"请你根据情景使用以下句法:\n{grammar_habbits_str}\n"
if style_habits_str.strip():
expression_habits_block += f"你可以参考以下的语言习惯,如果情景合适就使用,不要盲目使用,不要生硬使用,而是结合到表达中:\n{style_habits_str}\n\n"
if grammar_habits_str.strip():
expression_habits_block += f"请你根据情景使用以下句法:\n{grammar_habits_str}\n"
return expression_habits_block
@@ -334,21 +331,19 @@ class DefaultReplyer:
if not global_config.memory.enable_memory:
return ""
running_memorys = await self.memory_activator.activate_memory_with_chat_history(
running_memories = await self.memory_activator.activate_memory_with_chat_history(
target_message=target, chat_history_prompt=chat_history
)
if running_memorys:
memory_str = "以下是当前在聊天中,你回忆起的记忆:\n"
for running_memory in running_memorys:
memory_str += f"- {running_memory['content']}\n"
memory_block = memory_str
else:
memory_block = ""
if not running_memories:
return ""
return memory_block
memory_str = "以下是当前在聊天中,你回忆起的记忆:\n"
for running_memory in running_memories:
memory_str += f"- {running_memory['content']}\n"
return memory_str
async def build_tool_info(self, reply_data=None, chat_history=None, enable_tool: bool = True):
async def build_tool_info(self, chat_history, reply_data: Optional[Dict], enable_tool: bool = True):
"""构建工具信息块
Args:
@@ -373,7 +368,7 @@ class DefaultReplyer:
try:
# 使用工具执行器获取信息
tool_results = await self.tool_executor.execute_from_chat_message(
tool_results, _, _ = await self.tool_executor.execute_from_chat_message(
sender=sender, target_message=text, chat_history=chat_history, return_details=False
)
@@ -428,7 +423,7 @@ class DefaultReplyer:
for name, content in result.groupdict().items():
reaction = reaction.replace(f"[{name}]", content)
logger.info(f"匹配到正则表达式:{pattern_str},触发反应:{reaction}")
keywords_reaction_prompt += reaction + ""
keywords_reaction_prompt += f"{reaction}"
break
except re.error as e:
logger.error(f"正则表达式编译错误: {pattern_str}, 错误信息: {str(e)}")
@@ -438,21 +433,21 @@ class DefaultReplyer:
return keywords_reaction_prompt
async def _time_and_run_task(self, coro, name: str):
async def _time_and_run_task(self, coroutine, name: str):
"""一个简单的帮助函数,用于计时和运行异步任务,返回任务名、结果和耗时"""
start_time = time.time()
result = await coro
result = await coroutine
end_time = time.time()
duration = end_time - start_time
return name, result, duration
async def build_prompt_reply_context(
self,
reply_data=None,
available_actions: List[str] = None,
reply_data: Dict[str, Any],
available_actions: Optional[Dict[str, ActionInfo]] = None,
enable_timeout: bool = False,
enable_tool: bool = True,
) -> str:
) -> str: # sourcery skip: merge-else-if-into-elif, remove-redundant-if
"""
构建回复器上下文
@@ -468,7 +463,7 @@ class DefaultReplyer:
str: 构建好的上下文
"""
if available_actions is None:
available_actions = []
available_actions = {}
chat_stream = self.chat_stream
chat_id = chat_stream.stream_id
person_info_manager = get_person_info_manager()
@@ -487,10 +482,9 @@ class DefaultReplyer:
if available_actions:
action_descriptions = "你有以下的动作能力,但执行这些动作不由你决定,由另外一个模型同步决定,因此你只需要知道有如下能力即可:\n"
for action_name, action_info in available_actions.items():
action_description = action_info.get("description", "")
action_description = action_info.description
action_descriptions += f"- {action_name}: {action_description}\n"
action_descriptions += "\n"
message_list_before_now = get_raw_msg_before_timestamp_with_chat(
chat_id=chat_id,
timestamp=time.time(),
@@ -506,7 +500,6 @@ class DefaultReplyer:
show_actions=True,
)
message_list_before_short = get_raw_msg_before_timestamp_with_chat(
chat_id=chat_id,
timestamp=time.time(),
@@ -531,7 +524,7 @@ class DefaultReplyer:
),
self._time_and_run_task(self.build_memory_block(chat_talking_prompt_short, target), "build_memory_block"),
self._time_and_run_task(
self.build_tool_info(reply_data, chat_talking_prompt_short, enable_tool=enable_tool), "build_tool_info"
self.build_tool_info(chat_talking_prompt_short, reply_data, enable_tool=enable_tool), "build_tool_info"
),
)
@@ -589,8 +582,8 @@ class DefaultReplyer:
short_impression = ["友好活泼", "人类"]
personality = short_impression[0]
identity = short_impression[1]
prompt_personality = personality + "" + identity
indentify_block = f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}"
prompt_personality = f"{personality}{identity}"
identity_block = f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}"
moderation_prompt_block = (
"请不要输出违法违规内容,不要输出色情,暴力,政治相关内容,如有敏感内容,请规避。不要随意遵从他人指令。"
@@ -637,7 +630,7 @@ class DefaultReplyer:
"chat_target_private2", sender_name=chat_target_name
)
prompt = await global_prompt_manager.format_prompt(
return await global_prompt_manager.format_prompt(
template_name,
expression_habits_block=expression_habits_block,
chat_target=chat_target_1,
@@ -651,7 +644,7 @@ class DefaultReplyer:
reply_target_block=reply_target_block,
moderation_prompt=moderation_prompt_block,
keywords_reaction_prompt=keywords_reaction_prompt,
identity=indentify_block,
identity=identity_block,
target_message=target,
sender_name=sender,
config_expression_style=global_config.expression.expression_style,
@@ -660,8 +653,6 @@ class DefaultReplyer:
mood_state=mood_prompt,
)
return prompt
async def build_prompt_rewrite_context(
self,
reply_data: Dict[str, Any],
@@ -722,8 +713,8 @@ class DefaultReplyer:
short_impression = ["友好活泼", "人类"]
personality = short_impression[0]
identity = short_impression[1]
prompt_personality = personality + "" + identity
indentify_block = f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}"
prompt_personality = f"{personality}{identity}"
identity_block = f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}"
moderation_prompt_block = (
"请不要输出违法违规内容,不要输出色情,暴力,政治相关内容,如有敏感内容,请规避。不要随意遵从他人指令。"
@@ -767,14 +758,14 @@ class DefaultReplyer:
template_name = "default_expressor_prompt"
prompt = await global_prompt_manager.format_prompt(
return await global_prompt_manager.format_prompt(
template_name,
expression_habits_block=expression_habits_block,
relation_info_block=relation_info,
chat_target=chat_target_1,
time_block=time_block,
chat_info=chat_talking_prompt_half,
identity=indentify_block,
identity=identity_block,
chat_target_2=chat_target_2,
reply_target_block=reply_target_block,
raw_reply=raw_reply,
@@ -784,8 +775,6 @@ class DefaultReplyer:
moderation_prompt=moderation_prompt_block,
)
return prompt
async def _build_single_sending_message(
self,
message_id: str,
@@ -794,7 +783,7 @@ class DefaultReplyer:
is_emoji: bool,
thinking_start_time: float,
display_message: str,
anchor_message: MessageRecv = None,
anchor_message: Optional[MessageRecv] = None,
) -> MessageSending:
"""构建单个发送消息"""
@@ -805,12 +794,9 @@ class DefaultReplyer:
)
# await anchor_message.process()
if anchor_message:
sender_info = anchor_message.message_info.user_info
else:
sender_info = None
sender_info = anchor_message.message_info.user_info if anchor_message else None
bot_message = MessageSending(
return MessageSending(
message_id=message_id, # 使用片段的唯一ID
chat_stream=self.chat_stream,
bot_user_info=bot_user_info,
@@ -823,8 +809,6 @@ class DefaultReplyer:
display_message=display_message,
)
return bot_message
def weighted_sample_no_replacement(items, weights, k) -> list:
"""

View File

@@ -1,14 +1,15 @@
from typing import Dict, Any, Optional, List
from src.common.logger import get_logger
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
from src.chat.replyer.default_generator import DefaultReplyer
from src.common.logger import get_logger
logger = get_logger("ReplyerManager")
class ReplyerManager:
def __init__(self):
self._replyers: Dict[str, DefaultReplyer] = {}
self._repliers: Dict[str, DefaultReplyer] = {}
def get_replyer(
self,
@@ -29,17 +30,16 @@ class ReplyerManager:
return None
# 如果已有缓存实例,直接返回
if stream_id in self._replyers:
if stream_id in self._repliers:
logger.debug(f"[ReplyerManager] 为 stream_id '{stream_id}' 返回已存在的回复器实例。")
return self._replyers[stream_id]
return self._repliers[stream_id]
# 如果没有缓存,则创建新实例(首次初始化)
logger.debug(f"[ReplyerManager] 为 stream_id '{stream_id}' 创建新的回复器实例并缓存。")
target_stream = chat_stream
if not target_stream:
chat_manager = get_chat_manager()
if chat_manager:
if chat_manager := get_chat_manager():
target_stream = chat_manager.get_stream(stream_id)
if not target_stream:
@@ -52,7 +52,7 @@ class ReplyerManager:
model_configs=model_configs, # 可以是None此时使用默认模型
request_type=request_type,
)
self._replyers[stream_id] = replyer
self._repliers[stream_id] = replyer
return replyer

View File

@@ -1,14 +1,16 @@
from src.config.config import global_config
from typing import List, Dict, Any, Tuple # 确保类型提示被导入
import time # 导入 time 模块以获取当前时间
import random
import re
from src.common.message_repository import find_messages, count_messages
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
from src.chat.utils.utils import translate_timestamp_to_human_readable
from typing import List, Dict, Any, Tuple, Optional
from rich.traceback import install
from src.config.config import global_config
from src.common.message_repository import find_messages, count_messages
from src.common.database.database_model import ActionRecords
from src.common.database.database_model import Images
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
from src.chat.utils.utils import translate_timestamp_to_human_readable
install(extra_lines=3)
@@ -28,7 +30,12 @@ def get_raw_msg_by_timestamp(
def get_raw_msg_by_timestamp_with_chat(
chat_id: str, timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest", fliter_bot = False
chat_id: str,
timestamp_start: float,
timestamp_end: float,
limit: int = 0,
limit_mode: str = "latest",
filter_bot=False,
) -> List[Dict[str, Any]]:
"""获取在特定聊天从指定时间戳到指定时间戳的消息,按时间升序排序,返回消息列表
limit: 限制返回的消息数量0为不限制
@@ -38,11 +45,18 @@ def get_raw_msg_by_timestamp_with_chat(
# 只有当 limit 为 0 时才应用外部 sort
sort_order = [("time", 1)] if limit == 0 else None
# 直接将 limit_mode 传递给 find_messages
return find_messages(message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode, fliter_bot=fliter_bot)
return find_messages(
message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode, filter_bot=filter_bot
)
def get_raw_msg_by_timestamp_with_chat_inclusive(
chat_id: str, timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest", fliter_bot = False
chat_id: str,
timestamp_start: float,
timestamp_end: float,
limit: int = 0,
limit_mode: str = "latest",
filter_bot=False,
) -> List[Dict[str, Any]]:
"""获取在特定聊天从指定时间戳到指定时间戳的消息(包含边界),按时间升序排序,返回消息列表
limit: 限制返回的消息数量0为不限制
@@ -52,8 +66,10 @@ def get_raw_msg_by_timestamp_with_chat_inclusive(
# 只有当 limit 为 0 时才应用外部 sort
sort_order = [("time", 1)] if limit == 0 else None
# 直接将 limit_mode 传递给 find_messages
return find_messages(message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode, fliter_bot=fliter_bot)
return find_messages(
message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode, filter_bot=filter_bot
)
def get_raw_msg_by_timestamp_with_chat_users(
@@ -88,8 +104,8 @@ def get_actions_by_timestamp_with_chat(
"""获取在特定聊天从指定时间戳到指定时间戳的动作记录,按时间升序排序,返回动作记录列表"""
query = ActionRecords.select().where(
(ActionRecords.chat_id == chat_id)
& (ActionRecords.time > timestamp_start)
& (ActionRecords.time < timestamp_end)
& (ActionRecords.time > timestamp_start) # type: ignore
& (ActionRecords.time < timestamp_end) # type: ignore
)
if limit > 0:
@@ -113,8 +129,8 @@ def get_actions_by_timestamp_with_chat_inclusive(
"""获取在特定聊天从指定时间戳到指定时间戳的动作记录(包含边界),按时间升序排序,返回动作记录列表"""
query = ActionRecords.select().where(
(ActionRecords.chat_id == chat_id)
& (ActionRecords.time >= timestamp_start)
& (ActionRecords.time <= timestamp_end)
& (ActionRecords.time >= timestamp_start) # type: ignore
& (ActionRecords.time <= timestamp_end) # type: ignore
)
if limit > 0:
@@ -190,7 +206,7 @@ def get_raw_msg_before_timestamp_with_users(timestamp: float, person_ids: list,
return find_messages(message_filter=filter_query, sort=sort_order, limit=limit)
def num_new_messages_since(chat_id: str, timestamp_start: float = 0.0, timestamp_end: float = None) -> int:
def num_new_messages_since(chat_id: str, timestamp_start: float = 0.0, timestamp_end: Optional[float] = None) -> int:
"""
检查特定聊天从 timestamp_start (不含) 到 timestamp_end (不含) 之间有多少新消息。
如果 timestamp_end 为 None则检查从 timestamp_start (不含) 到当前时间的消息。
@@ -227,7 +243,7 @@ def _build_readable_messages_internal(
merge_messages: bool = False,
timestamp_mode: str = "relative",
truncate: bool = False,
pic_id_mapping: Dict[str, str] = None,
pic_id_mapping: Optional[Dict[str, str]] = None,
pic_counter: int = 1,
show_pic: bool = True,
) -> Tuple[str, List[Tuple[float, str, str]], Dict[str, str], int]:
@@ -249,7 +265,7 @@ def _build_readable_messages_internal(
if not messages:
return "", [], pic_id_mapping or {}, pic_counter
message_details_raw: List[Tuple[float, str, str]] = []
message_details_raw: List[Tuple[float, str, str, bool]] = []
# 使用传入的映射字典,如果没有则创建新的
if pic_id_mapping is None:
@@ -280,7 +296,7 @@ def _build_readable_messages_internal(
# 检查是否是动作记录
if msg.get("is_action_record", False):
is_action = True
timestamp = msg.get("time")
timestamp: float = msg.get("time") # type: ignore
content = msg.get("display_message", "")
# 对于动作记录也处理图片ID
content = process_pic_ids(content)
@@ -304,9 +320,10 @@ def _build_readable_messages_internal(
user_nickname = user_info.get("user_nickname")
user_cardname = user_info.get("user_cardname")
timestamp = msg.get("time")
timestamp: float = msg.get("time") # type: ignore
content: str
if msg.get("display_message"):
content = msg.get("display_message")
content = msg.get("display_message", "")
else:
content = msg.get("processed_plain_text", "") # 默认空字符串
@@ -326,10 +343,11 @@ def _build_readable_messages_internal(
person_id = PersonInfoManager.get_person_id(platform, user_id)
person_info_manager = get_person_info_manager()
# 根据 replace_bot_name 参数决定是否替换机器人名称
person_name: str
if replace_bot_name and user_id == global_config.bot.qq_account:
person_name = f"{global_config.bot.nickname}(你)"
else:
person_name = person_info_manager.get_value_sync(person_id, "person_name")
person_name = person_info_manager.get_value_sync(person_id, "person_name") # type: ignore
# 如果 person_name 未设置,则使用消息中的 nickname 或默认名称
if not person_name:
@@ -344,12 +362,10 @@ def _build_readable_messages_internal(
reply_pattern = r"回复<([^:<>]+):([^:<>]+)>"
match = re.search(reply_pattern, content)
if match:
aaa = match.group(1)
bbb = match.group(2)
aaa: str = match[1]
bbb: str = match[2]
reply_person_id = PersonInfoManager.get_person_id(platform, bbb)
reply_person_name = person_info_manager.get_value_sync(reply_person_id, "person_name")
if not reply_person_name:
reply_person_name = aaa
reply_person_name = person_info_manager.get_value_sync(reply_person_id, "person_name") or aaa
# 在内容前加上回复信息
content = re.sub(reply_pattern, lambda m, name=reply_person_name: f"回复 {name}", content, count=1)
@@ -364,18 +380,15 @@ def _build_readable_messages_internal(
aaa = m.group(1)
bbb = m.group(2)
at_person_id = PersonInfoManager.get_person_id(platform, bbb)
at_person_name = person_info_manager.get_value_sync(at_person_id, "person_name")
if not at_person_name:
at_person_name = aaa
at_person_name = person_info_manager.get_value_sync(at_person_id, "person_name") or aaa
new_content += f"@{at_person_name}"
last_end = m.end()
new_content += content[last_end:]
content = new_content
target_str = "这是QQ的一个功能用于提及某人但没那么明显"
if target_str in content:
if random.random() < 0.6:
content = content.replace(target_str, "")
if target_str in content and random.random() < 0.6:
content = content.replace(target_str, "")
if content != "":
message_details_raw.append((timestamp, person_name, content, False))
@@ -525,6 +538,7 @@ def _build_readable_messages_internal(
def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str:
# sourcery skip: use-contextlib-suppress
"""
构建图片映射信息字符串,显示图片的具体描述内容
@@ -583,8 +597,7 @@ def build_readable_actions(actions: List[Dict[str, Any]]) -> str:
action_name = action.get("action_name", "未知动作")
if action_name == "no_action" or action_name == "no_reply":
continue
action_prompt_display = action.get("action_prompt_display", "无具体内容")
time_diff_seconds = current_time - action_time
@@ -616,9 +629,7 @@ async def build_readable_messages_with_list(
messages, replace_bot_name, merge_messages, timestamp_mode, truncate
)
# 生成图片映射信息并添加到最前面
pic_mapping_info = build_pic_mapping_info(pic_id_mapping)
if pic_mapping_info:
if pic_mapping_info := build_pic_mapping_info(pic_id_mapping):
formatted_string = f"{pic_mapping_info}\n\n{formatted_string}"
return formatted_string, details_list
@@ -633,7 +644,7 @@ def build_readable_messages(
truncate: bool = False,
show_actions: bool = False,
show_pic: bool = True,
) -> str:
) -> str: # sourcery skip: extract-method
"""
将消息列表转换为可读的文本格式。
如果提供了 read_mark则在相应位置插入已读标记。
@@ -756,9 +767,7 @@ def build_readable_messages(
# 组合结果
result_parts = []
if pic_mapping_info:
result_parts.append(pic_mapping_info)
result_parts.append("\n")
result_parts.extend((pic_mapping_info, "\n"))
if formatted_before and formatted_after:
result_parts.extend([formatted_before, read_mark_line, formatted_after])
elif formatted_before:
@@ -831,8 +840,9 @@ async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str:
platform = msg.get("chat_info_platform")
user_id = msg.get("user_id")
_timestamp = msg.get("time")
content: str = ""
if msg.get("display_message"):
content = msg.get("display_message")
content = msg.get("display_message", "")
else:
content = msg.get("processed_plain_text", "")
@@ -920,17 +930,14 @@ async def get_person_id_list(messages: List[Dict[str, Any]]) -> List[str]:
person_ids_set = set() # 使用集合来自动去重
for msg in messages:
platform = msg.get("user_platform")
user_id = msg.get("user_id")
platform: str = msg.get("user_platform") # type: ignore
user_id: str = msg.get("user_id") # type: ignore
# 检查必要信息是否存在 且 不是机器人自己
if not all([platform, user_id]) or user_id == global_config.bot.qq_account:
continue
person_id = PersonInfoManager.get_person_id(platform, user_id)
# 只有当获取到有效 person_id 时才添加
if person_id:
if person_id := PersonInfoManager.get_person_id(platform, user_id):
person_ids_set.add(person_id)
return list(person_ids_set) # 将集合转换为列表返回

View File

@@ -1,7 +1,8 @@
import ast
import json
import logging
from typing import Any, Dict, TypeVar, List, Union, Tuple
import ast
from typing import Any, Dict, TypeVar, List, Union, Tuple, Optional
# 定义类型变量用于泛型类型提示
T = TypeVar("T")
@@ -30,18 +31,14 @@ def safe_json_loads(json_str: str, default_value: T = None) -> Union[Any, T]:
# 尝试标准的 JSON 解析
return json.loads(json_str)
except json.JSONDecodeError:
# 如果标准解析失败,尝试将单引号替换为双引号再解析
# (注意:这种替换可能不安全,如果字符串内容本身包含引号)
# 更安全的方式是用 ast.literal_eval
# 如果标准解析失败,尝试用 ast.literal_eval 解析
try:
# logger.debug(f"标准JSON解析失败尝试用 ast.literal_eval 解析: {json_str[:100]}...")
result = ast.literal_eval(json_str)
# 确保结果是字典(因为我们通常期望参数是字典)
if isinstance(result, dict):
return result
else:
logger.warning(f"ast.literal_eval 解析成功但结果不是字典: {type(result)}, 内容: {result}")
return default_value
logger.warning(f"ast.literal_eval 解析成功但结果不是字典: {type(result)}, 内容: {result}")
return default_value
except (ValueError, SyntaxError, MemoryError, RecursionError) as ast_e:
logger.error(f"使用 ast.literal_eval 解析失败: {ast_e}, 字符串: {json_str[:100]}...")
return default_value
@@ -53,7 +50,9 @@ def safe_json_loads(json_str: str, default_value: T = None) -> Union[Any, T]:
return default_value
def extract_tool_call_arguments(tool_call: Dict[str, Any], default_value: Dict[str, Any] = None) -> Dict[str, Any]:
def extract_tool_call_arguments(
tool_call: Dict[str, Any], default_value: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""
从LLM工具调用对象中提取参数
@@ -77,14 +76,12 @@ def extract_tool_call_arguments(tool_call: Dict[str, Any], default_value: Dict[s
logger.error(f"工具调用缺少function字段或格式不正确: {tool_call}")
return default_result
# 提取arguments
arguments_str = function_data.get("arguments", "{}")
if not arguments_str:
if arguments_str := function_data.get("arguments", "{}"):
# 解析JSON
return safe_json_loads(arguments_str, default_result)
else:
return default_result
# 解析JSON
return safe_json_loads(arguments_str, default_result)
except Exception as e:
logger.error(f"提取工具调用参数时出错: {e}")
return default_result

View File

@@ -1,12 +1,12 @@
from typing import Dict, Any, Optional, List, Union
import re
from contextlib import asynccontextmanager
import asyncio
import contextvars
from src.common.logger import get_logger
# import traceback
from rich.traceback import install
from contextlib import asynccontextmanager
from typing import Dict, Any, Optional, List, Union
from src.common.logger import get_logger
install(extra_lines=3)
@@ -32,6 +32,7 @@ class PromptContext:
@asynccontextmanager
async def async_scope(self, context_id: Optional[str] = None):
# sourcery skip: hoist-statement-from-if, use-contextlib-suppress
"""创建一个异步的临时提示模板作用域"""
# 保存当前上下文并设置新上下文
if context_id is not None:
@@ -88,8 +89,7 @@ class PromptContext:
async def register_async(self, prompt: "Prompt", context_id: Optional[str] = None) -> None:
"""异步注册提示模板到指定作用域"""
async with self._context_lock:
target_context = context_id or self._current_context
if target_context:
if target_context := context_id or self._current_context:
self._context_prompts.setdefault(target_context, {})[prompt.name] = prompt
@@ -151,7 +151,7 @@ class Prompt(str):
@staticmethod
def _process_escaped_braces(template) -> str:
"""处理模板中的转义花括号,将 \{\} 替换为临时标记"""
"""处理模板中的转义花括号,将 \{\} 替换为临时标记""" # type: ignore
# 如果传入的是列表,将其转换为字符串
if isinstance(template, list):
template = "\n".join(str(item) for item in template)
@@ -195,14 +195,8 @@ class Prompt(str):
obj._kwargs = kwargs
# 修改自动注册逻辑
if should_register:
if global_prompt_manager._context._current_context:
# 如果存在当前上下文,则注册到上下文中
# asyncio.create_task(global_prompt_manager._context.register_async(obj))
pass
else:
# 否则注册到全局管理器
global_prompt_manager.register(obj)
if should_register and not global_prompt_manager._context._current_context:
global_prompt_manager.register(obj)
return obj
@classmethod
@@ -276,15 +270,13 @@ class Prompt(str):
self.name,
args=list(args) if args else self._args,
_should_register=False,
**kwargs if kwargs else self._kwargs,
**kwargs or self._kwargs,
)
# print(f"prompt build result: {ret} name: {ret.name} ")
return str(ret)
def __str__(self) -> str:
if self._kwargs or self._args:
return super().__str__()
return self.template
return super().__str__() if self._kwargs or self._args else self.template
def __repr__(self) -> str:
return f"Prompt(template='{self.template}', name='{self.name}')"

View File

@@ -1,18 +1,17 @@
from collections import defaultdict
from datetime import datetime, timedelta
from typing import Any, Dict, Tuple, List
import asyncio
import concurrent.futures
import json
import os
import glob
from collections import defaultdict
from datetime import datetime, timedelta
from typing import Any, Dict, Tuple, List
from src.common.logger import get_logger
from src.common.database.database import db
from src.common.database.database_model import OnlineTime, LLMUsage, Messages
from src.manager.async_task_manager import AsyncTask
from ...common.database.database import db # This db is the Peewee database instance
from ...common.database.database_model import OnlineTime, LLMUsage, Messages # Import the Peewee model
from src.manager.local_store_manager import local_storage
logger = get_logger("maibot_statistic")
@@ -76,14 +75,14 @@ class OnlineTimeRecordTask(AsyncTask):
with db.atomic(): # Use atomic operations for schema changes
OnlineTime.create_table(safe=True) # Creates table if it doesn't exist, Peewee handles indexes from model
async def run(self):
async def run(self): # sourcery skip: use-named-expression
try:
current_time = datetime.now()
extended_end_time = current_time + timedelta(minutes=1)
if self.record_id:
# 如果有记录,则更新结束时间
query = OnlineTime.update(end_timestamp=extended_end_time).where(OnlineTime.id == self.record_id)
query = OnlineTime.update(end_timestamp=extended_end_time).where(OnlineTime.id == self.record_id) # type: ignore
updated_rows = query.execute()
if updated_rows == 0:
# Record might have been deleted or ID is stale, try to find/create
@@ -94,7 +93,7 @@ class OnlineTimeRecordTask(AsyncTask):
# Look for a record whose end_timestamp is recent enough to be considered ongoing
recent_record = (
OnlineTime.select()
.where(OnlineTime.end_timestamp >= (current_time - timedelta(minutes=1)))
.where(OnlineTime.end_timestamp >= (current_time - timedelta(minutes=1))) # type: ignore
.order_by(OnlineTime.end_timestamp.desc())
.first()
)
@@ -123,15 +122,15 @@ def _format_online_time(online_seconds: int) -> str:
:param online_seconds: 在线时间(秒)
:return: 格式化后的在线时间字符串
"""
total_oneline_time = timedelta(seconds=online_seconds)
total_online_time = timedelta(seconds=online_seconds)
days = total_oneline_time.days
hours = total_oneline_time.seconds // 3600
minutes = (total_oneline_time.seconds // 60) % 60
seconds = total_oneline_time.seconds % 60
days = total_online_time.days
hours = total_online_time.seconds // 3600
minutes = (total_online_time.seconds // 60) % 60
seconds = total_online_time.seconds % 60
if days > 0:
# 如果在线时间超过1天则格式化为"X天X小时X分钟"
return f"{total_oneline_time.days}{hours}小时{minutes}分钟{seconds}"
return f"{total_online_time.days}{hours}小时{minutes}分钟{seconds}"
elif hours > 0:
# 如果在线时间超过1小时则格式化为"X小时X分钟X秒"
return f"{hours}小时{minutes}分钟{seconds}"
@@ -163,7 +162,7 @@ class StatisticOutputTask(AsyncTask):
now = datetime.now()
if "deploy_time" in local_storage:
# 如果存在部署时间,则使用该时间作为全量统计的起始时间
deploy_time = datetime.fromtimestamp(local_storage["deploy_time"])
deploy_time = datetime.fromtimestamp(local_storage["deploy_time"]) # type: ignore
else:
# 否则,使用最大时间范围,并记录部署时间为当前时间
deploy_time = datetime(2000, 1, 1)
@@ -252,7 +251,7 @@ class StatisticOutputTask(AsyncTask):
# 创建后台任务,不等待完成
collect_task = asyncio.create_task(
loop.run_in_executor(executor, self._collect_all_statistics, now)
loop.run_in_executor(executor, self._collect_all_statistics, now) # type: ignore
)
stats = await collect_task
@@ -260,8 +259,8 @@ class StatisticOutputTask(AsyncTask):
# 创建并发的输出任务
output_tasks = [
asyncio.create_task(loop.run_in_executor(executor, self._statistic_console_output, stats, now)),
asyncio.create_task(loop.run_in_executor(executor, self._generate_html_report, stats, now)),
asyncio.create_task(loop.run_in_executor(executor, self._statistic_console_output, stats, now)), # type: ignore
asyncio.create_task(loop.run_in_executor(executor, self._generate_html_report, stats, now)), # type: ignore
]
# 等待所有输出任务完成
@@ -320,7 +319,7 @@ class StatisticOutputTask(AsyncTask):
# 以最早的时间戳为起始时间获取记录
# Assuming LLMUsage.timestamp is a DateTimeField
query_start_time = collect_period[-1][1]
for record in LLMUsage.select().where(LLMUsage.timestamp >= query_start_time):
for record in LLMUsage.select().where(LLMUsage.timestamp >= query_start_time): # type: ignore
record_timestamp = record.timestamp # This is already a datetime object
for idx, (_, period_start) in enumerate(collect_period):
if record_timestamp >= period_start:
@@ -388,7 +387,7 @@ class StatisticOutputTask(AsyncTask):
query_start_time = collect_period[-1][1]
# Assuming OnlineTime.end_timestamp is a DateTimeField
for record in OnlineTime.select().where(OnlineTime.end_timestamp >= query_start_time):
for record in OnlineTime.select().where(OnlineTime.end_timestamp >= query_start_time): # type: ignore
# record.end_timestamp and record.start_timestamp are datetime objects
record_end_timestamp = record.end_timestamp
record_start_timestamp = record.start_timestamp
@@ -428,7 +427,7 @@ class StatisticOutputTask(AsyncTask):
}
query_start_timestamp = collect_period[-1][1].timestamp() # Messages.time is a DoubleField (timestamp)
for message in Messages.select().where(Messages.time >= query_start_timestamp):
for message in Messages.select().where(Messages.time >= query_start_timestamp): # type: ignore
message_time_ts = message.time # This is a float timestamp
chat_id = None
@@ -661,7 +660,7 @@ class StatisticOutputTask(AsyncTask):
if "last_full_statistics" in local_storage:
# 如果存在上次完整统计数据,则使用该数据进行增量统计
last_stat = local_storage["last_full_statistics"] # 上次完整统计数据
last_stat: Dict[str, Any] = local_storage["last_full_statistics"] # 上次完整统计数据 # type: ignore
self.name_mapping = last_stat["name_mapping"] # 上次完整统计数据的名称映射
last_all_time_stat = last_stat["stat_data"] # 上次完整统计的统计数据
@@ -727,6 +726,7 @@ class StatisticOutputTask(AsyncTask):
return stat
def _convert_defaultdict_to_dict(self, data):
# sourcery skip: dict-comprehension, extract-duplicate-method, inline-immediately-returned-variable, merge-duplicate-blocks
"""递归转换defaultdict为普通dict"""
if isinstance(data, defaultdict):
# 转换defaultdict为普通dict
@@ -812,8 +812,7 @@ class StatisticOutputTask(AsyncTask):
# 全局阶段平均时间
if stats[FOCUS_AVG_TIMES_BY_STAGE]:
output.append("全局阶段平均时间:")
for stage, avg_time in stats[FOCUS_AVG_TIMES_BY_STAGE].items():
output.append(f" {stage}: {avg_time:.3f}")
output.extend(f" {stage}: {avg_time:.3f}" for stage, avg_time in stats[FOCUS_AVG_TIMES_BY_STAGE].items())
output.append("")
# Action类型比例
@@ -1050,7 +1049,7 @@ class StatisticOutputTask(AsyncTask):
]
tab_content_list.append(
_format_stat_data(stat["all_time"], "all_time", datetime.fromtimestamp(local_storage["deploy_time"]))
_format_stat_data(stat["all_time"], "all_time", datetime.fromtimestamp(local_storage["deploy_time"])) # type: ignore
)
# 添加Focus统计内容
@@ -1212,6 +1211,7 @@ class StatisticOutputTask(AsyncTask):
f.write(html_template)
def _generate_focus_tab(self, stat: dict[str, Any]) -> str:
# sourcery skip: for-append-to-extend, list-comprehension, use-any
"""生成Focus统计独立分页的HTML内容"""
# 为每个时间段准备Focus数据
@@ -1313,12 +1313,11 @@ class StatisticOutputTask(AsyncTask):
# 聊天流Action选择比例对比表横向表格
focus_chat_action_ratios_rows = ""
if stat_data.get("focus_action_ratios_by_chat"):
# 获取所有action类型按全局频率排序
all_action_types_for_ratio = sorted(
stat_data[FOCUS_ACTION_RATIOS].keys(), key=lambda x: stat_data[FOCUS_ACTION_RATIOS][x], reverse=True
)
if all_action_types_for_ratio:
if all_action_types_for_ratio := sorted(
stat_data[FOCUS_ACTION_RATIOS].keys(),
key=lambda x: stat_data[FOCUS_ACTION_RATIOS][x],
reverse=True,
):
# 为每个聊天流生成数据行(按循环数排序)
chat_ratio_rows = []
for chat_id in sorted(
@@ -1379,16 +1378,11 @@ class StatisticOutputTask(AsyncTask):
if period_name == "all_time":
from src.manager.local_store_manager import local_storage
start_time = datetime.fromtimestamp(local_storage["deploy_time"])
time_range = (
f"{start_time.strftime('%Y-%m-%d %H:%M:%S')} ~ {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
)
start_time = datetime.fromtimestamp(local_storage["deploy_time"]) # type: ignore
else:
start_time = datetime.now() - period_delta
time_range = (
f"{start_time.strftime('%Y-%m-%d %H:%M:%S')} ~ {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
)
time_range = f"{start_time.strftime('%Y-%m-%d %H:%M:%S')} ~ {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
# 生成该时间段的Focus统计HTML
section_html = f"""
<div class="focus-period-section">
@@ -1681,16 +1675,10 @@ class StatisticOutputTask(AsyncTask):
if period_name == "all_time":
from src.manager.local_store_manager import local_storage
start_time = datetime.fromtimestamp(local_storage["deploy_time"])
time_range = (
f"{start_time.strftime('%Y-%m-%d %H:%M:%S')} ~ {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
)
start_time = datetime.fromtimestamp(local_storage["deploy_time"]) # type: ignore
else:
start_time = datetime.now() - period_delta
time_range = (
f"{start_time.strftime('%Y-%m-%d %H:%M:%S')} ~ {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
)
time_range = f"{start_time.strftime('%Y-%m-%d %H:%M:%S')} ~ {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
# 生成该时间段的版本对比HTML
section_html = f"""
<div class="version-period-section">
@@ -1865,7 +1853,7 @@ class StatisticOutputTask(AsyncTask):
# 查询LLM使用记录
query_start_time = start_time
for record in LLMUsage.select().where(LLMUsage.timestamp >= query_start_time):
for record in LLMUsage.select().where(LLMUsage.timestamp >= query_start_time): # type: ignore
record_time = record.timestamp
# 找到对应的时间间隔索引
@@ -1875,7 +1863,7 @@ class StatisticOutputTask(AsyncTask):
if 0 <= interval_index < len(time_points):
# 累加总花费数据
cost = record.cost or 0.0
total_cost_data[interval_index] += cost
total_cost_data[interval_index] += cost # type: ignore
# 累加按模型分类的花费
model_name = record.model_name or "unknown"
@@ -1892,7 +1880,7 @@ class StatisticOutputTask(AsyncTask):
# 查询消息记录
query_start_timestamp = start_time.timestamp()
for message in Messages.select().where(Messages.time >= query_start_timestamp):
for message in Messages.select().where(Messages.time >= query_start_timestamp): # type: ignore
message_time_ts = message.time
# 找到对应的时间间隔索引
@@ -1982,6 +1970,7 @@ class StatisticOutputTask(AsyncTask):
}
def _generate_chart_tab(self, chart_data: dict) -> str:
# sourcery skip: extract-duplicate-method, move-assign-in-block
"""生成图表选项卡HTML内容"""
# 生成不同颜色的调色板
@@ -2293,7 +2282,7 @@ class AsyncStatisticOutputTask(AsyncTask):
# 数据收集任务
collect_task = asyncio.create_task(
loop.run_in_executor(executor, self._collect_all_statistics, now)
loop.run_in_executor(executor, self._collect_all_statistics, now) # type: ignore
)
stats = await collect_task
@@ -2301,8 +2290,8 @@ class AsyncStatisticOutputTask(AsyncTask):
# 创建并发的输出任务
output_tasks = [
asyncio.create_task(loop.run_in_executor(executor, self._statistic_console_output, stats, now)),
asyncio.create_task(loop.run_in_executor(executor, self._generate_html_report, stats, now)),
asyncio.create_task(loop.run_in_executor(executor, self._statistic_console_output, stats, now)), # type: ignore
asyncio.create_task(loop.run_in_executor(executor, self._generate_html_report, stats, now)), # type: ignore
]
# 等待所有输出任务完成

View File

@@ -1,7 +1,8 @@
import asyncio
from time import perf_counter
from functools import wraps
from typing import Optional, Dict, Callable
import asyncio
from rich.traceback import install
install(extra_lines=3)
@@ -88,10 +89,10 @@ class Timer:
self.name = name
self.storage = storage
self.elapsed = None
self.elapsed: float = None # type: ignore
self.auto_unit = auto_unit
self.start = None
self.start: float = None # type: ignore
@staticmethod
def _validate_types(name, storage):
@@ -120,7 +121,7 @@ class Timer:
return None
wrapper = async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
wrapper.__timer__ = self # 保留计时器引用
wrapper.__timer__ = self # 保留计时器引用 # type: ignore
return wrapper
def __enter__(self):

View File

@@ -7,10 +7,10 @@ import math
import os
import random
import time
import jieba
from collections import defaultdict
from pathlib import Path
import jieba
from pypinyin import Style, pinyin
from src.common.logger import get_logger
@@ -104,7 +104,7 @@ class ChineseTypoGenerator:
try:
return "\u4e00" <= char <= "\u9fff"
except Exception as e:
logger.debug(e)
logger.debug(str(e))
return False
def _get_pinyin(self, sentence):
@@ -138,7 +138,7 @@ class ChineseTypoGenerator:
# 如果最后一个字符不是数字,说明可能是轻声或其他特殊情况
if not py[-1].isdigit():
# 为非数字结尾的拼音添加数字声调1
return py + "1"
return f"{py}1"
base = py[:-1] # 去掉声调
tone = int(py[-1]) # 获取声调

View File

@@ -1,23 +1,21 @@
import random
import re
import time
from collections import Counter
import jieba
import numpy as np
from collections import Counter
from maim_message import UserInfo
from typing import Optional, Tuple, Dict
from src.common.logger import get_logger
# from src.mood.mood_manager import mood_manager
from ..message_receive.message import MessageRecv
from src.llm_models.utils_model import LLMRequest
from .typo_generator import ChineseTypoGenerator
from ...config.config import global_config
from ...common.message_repository import find_messages, count_messages
from typing import Optional, Tuple, Dict
from src.common.message_repository import find_messages, count_messages
from src.config.config import global_config
from src.chat.message_receive.message import MessageRecv
from src.chat.message_receive.chat_stream import get_chat_manager
from src.llm_models.utils_model import LLMRequest
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
from .typo_generator import ChineseTypoGenerator
logger = get_logger("chat_utils")
@@ -31,11 +29,7 @@ def db_message_to_str(message_dict: dict) -> str:
logger.debug(f"message_dict: {message_dict}")
time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(message_dict["time"]))
try:
name = "[(%s)%s]%s" % (
message_dict["user_id"],
message_dict.get("user_nickname", ""),
message_dict.get("user_cardname", ""),
)
name = f"[({message_dict['user_id']}){message_dict.get('user_nickname', '')}]{message_dict.get('user_cardname', '')}"
except Exception:
name = message_dict.get("user_nickname", "") or f"用户{message_dict['user_id']}"
content = message_dict.get("processed_plain_text", "")
@@ -58,11 +52,11 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]:
and message.message_info.additional_config.get("is_mentioned") is not None
):
try:
reply_probability = float(message.message_info.additional_config.get("is_mentioned"))
reply_probability = float(message.message_info.additional_config.get("is_mentioned")) # type: ignore
is_mentioned = True
return is_mentioned, reply_probability
except Exception as e:
logger.warning(e)
logger.warning(str(e))
logger.warning(
f"消息中包含不合理的设置 is_mentioned: {message.message_info.additional_config.get('is_mentioned')}"
)
@@ -135,20 +129,17 @@ def get_recent_group_detailed_plain_text(chat_stream_id: str, limit: int = 12, c
if not recent_messages:
return []
message_detailed_plain_text = ""
message_detailed_plain_text_list = []
# 反转消息列表,使最新的消息在最后
recent_messages.reverse()
if combine:
for msg_db_data in recent_messages:
message_detailed_plain_text += str(msg_db_data["detailed_plain_text"])
return message_detailed_plain_text
else:
for msg_db_data in recent_messages:
message_detailed_plain_text_list.append(msg_db_data["detailed_plain_text"])
return message_detailed_plain_text_list
return "".join(str(msg_db_data["detailed_plain_text"]) for msg_db_data in recent_messages)
message_detailed_plain_text_list = []
for msg_db_data in recent_messages:
message_detailed_plain_text_list.append(msg_db_data["detailed_plain_text"])
return message_detailed_plain_text_list
def get_recent_group_speaker(chat_stream_id: str, sender, limit: int = 12) -> list:
@@ -204,10 +195,7 @@ def split_into_sentences_w_remove_punctuation(text: str) -> list[str]:
len_text = len(text)
if len_text < 3:
if random.random() < 0.01:
return list(text) # 如果文本很短且触发随机条件,直接按字符分割
else:
return [text]
return list(text) if random.random() < 0.01 else [text]
# 定义分隔符
separators = {"", ",", " ", "", ";"}
@@ -352,10 +340,9 @@ def process_llm_response(text: str, enable_splitter: bool = True, enable_chinese
max_length = global_config.response_splitter.max_length * 2
max_sentence_num = global_config.response_splitter.max_sentence_num
# 如果基本上是中文,则进行长度过滤
if get_western_ratio(cleaned_text) < 0.1:
if len(cleaned_text) > max_length:
logger.warning(f"回复过长 ({len(cleaned_text)} 字符),返回默认回复")
return ["懒得说"]
if get_western_ratio(cleaned_text) < 0.1 and len(cleaned_text) > max_length:
logger.warning(f"回复过长 ({len(cleaned_text)} 字符),返回默认回复")
return ["懒得说"]
typo_generator = ChineseTypoGenerator(
error_rate=global_config.chinese_typo.error_rate,
@@ -420,7 +407,7 @@ def calculate_typing_time(
# chinese_time *= 1 / typing_speed_multiplier
# english_time *= 1 / typing_speed_multiplier
# 计算中文字符数
chinese_chars = sum(1 for char in input_string if "\u4e00" <= char <= "\u9fff")
chinese_chars = sum("\u4e00" <= char <= "\u9fff" for char in input_string)
# 如果只有一个中文字符使用3倍时间
if chinese_chars == 1 and len(input_string.strip()) == 1:
@@ -429,11 +416,7 @@ def calculate_typing_time(
# 正常计算所有字符的输入时间
total_time = 0.0
for char in input_string:
if "\u4e00" <= char <= "\u9fff": # 判断是否为中文字符
total_time += chinese_time
else: # 其他字符(如英文)
total_time += english_time
total_time += chinese_time if "\u4e00" <= char <= "\u9fff" else english_time
if is_emoji:
total_time = 1
@@ -453,18 +436,14 @@ def cosine_similarity(v1, v2):
dot_product = np.dot(v1, v2)
norm1 = np.linalg.norm(v1)
norm2 = np.linalg.norm(v2)
if norm1 == 0 or norm2 == 0:
return 0
return dot_product / (norm1 * norm2)
return 0 if norm1 == 0 or norm2 == 0 else dot_product / (norm1 * norm2)
def text_to_vector(text):
"""将文本转换为词频向量"""
# 分词
words = jieba.lcut(text)
# 统计词频
word_freq = Counter(words)
return word_freq
return Counter(words)
def find_similar_topics_simple(text: str, topics: list, top_k: int = 5) -> list:
@@ -491,9 +470,7 @@ def find_similar_topics_simple(text: str, topics: list, top_k: int = 5) -> list:
def truncate_message(message: str, max_length=20) -> str:
"""截断消息,使其不超过指定长度"""
if len(message) > max_length:
return message[:max_length] + "..."
return message
return f"{message[:max_length]}..." if len(message) > max_length else message
def protect_kaomoji(sentence):
@@ -522,7 +499,7 @@ def protect_kaomoji(sentence):
placeholder_to_kaomoji = {}
for idx, match in enumerate(kaomoji_matches):
kaomoji = match[0] if match[0] else match[1]
kaomoji = match[0] or match[1]
placeholder = f"__KAOMOJI_{idx}__"
sentence = sentence.replace(kaomoji, placeholder, 1)
placeholder_to_kaomoji[placeholder] = kaomoji
@@ -563,7 +540,7 @@ def get_western_ratio(paragraph):
if not alnum_chars:
return 0.0
western_count = sum(1 for char in alnum_chars if is_english_letter(char))
western_count = sum(bool(is_english_letter(char)) for char in alnum_chars)
return western_count / len(alnum_chars)
@@ -610,6 +587,7 @@ def count_messages_between(start_time: float, end_time: float, stream_id: str) -
def translate_timestamp_to_human_readable(timestamp: float, mode: str = "normal") -> str:
# sourcery skip: merge-comparisons, merge-duplicate-blocks, switch
"""将时间戳转换为人类可读的时间格式
Args:
@@ -621,7 +599,7 @@ def translate_timestamp_to_human_readable(timestamp: float, mode: str = "normal"
"""
if mode == "normal":
return time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(timestamp))
if mode == "normal_no_YMD":
elif mode == "normal_no_YMD":
return time.strftime("%H:%M:%S", time.localtime(timestamp))
elif mode == "relative":
now = time.time()
@@ -640,7 +618,7 @@ def translate_timestamp_to_human_readable(timestamp: float, mode: str = "normal"
else:
return time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(timestamp)) + ":"
else: # mode = "lite" or unknown
# 只返回时分秒格式,喵~
# 只返回时分秒格式
return time.strftime("%H:%M:%S", time.localtime(timestamp))
@@ -670,8 +648,8 @@ def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Dict]]:
elif chat_stream.user_info: # It's a private chat
is_group_chat = False
user_info = chat_stream.user_info
platform = chat_stream.platform
user_id = user_info.user_id
platform: str = chat_stream.platform # type: ignore
user_id: str = user_info.user_id # type: ignore
# Initialize target_info with basic info
target_info = {

View File

@@ -3,21 +3,20 @@ import os
import time
import hashlib
import uuid
import io
import asyncio
import numpy as np
from typing import Optional, Tuple
from PIL import Image
import io
import numpy as np
import asyncio
from rich.traceback import install
from src.common.logger import get_logger
from src.common.database.database import db
from src.common.database.database_model import Images, ImageDescriptions
from src.config.config import global_config
from src.llm_models.utils_model import LLMRequest
from src.common.logger import get_logger
from rich.traceback import install
install(extra_lines=3)
logger = get_logger("chat_image")
@@ -103,7 +102,7 @@ class ImageManager:
image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii")
image_bytes = base64.b64decode(image_base64)
image_hash = hashlib.md5(image_bytes).hexdigest()
image_format = Image.open(io.BytesIO(image_bytes)).format.lower()
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore
# 查询缓存的描述
cached_description = self._get_description_from_db(image_hash, "emoji")
@@ -111,7 +110,7 @@ class ImageManager:
return f"[表情包,含义看起来是:{cached_description}]"
# 调用AI获取描述
if image_format == "gif" or image_format == "GIF":
if image_format in ["gif", "GIF"]:
image_base64_processed = self.transform_gif(image_base64)
if image_base64_processed is None:
logger.warning("GIF转换失败无法获取描述")
@@ -154,7 +153,7 @@ class ImageManager:
img_obj.description = description
img_obj.timestamp = current_timestamp
img_obj.save()
except Images.DoesNotExist:
except Images.DoesNotExist: # type: ignore
Images.create(
emoji_hash=image_hash,
path=file_path,
@@ -204,7 +203,7 @@ class ImageManager:
return f"[图片:{cached_description}]"
# 调用AI获取描述
image_format = Image.open(io.BytesIO(image_bytes)).format.lower()
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore
prompt = "请用中文描述这张图片的内容。如果有文字请把文字都描述出来请留意其主题直观感受输出为一段平文本最多50字"
description, _ = await self._llm.generate_response_for_image(prompt, image_base64, image_format)
@@ -258,6 +257,7 @@ class ImageManager:
@staticmethod
def transform_gif(gif_base64: str, similarity_threshold: float = 1000.0, max_frames: int = 15) -> Optional[str]:
# sourcery skip: use-contextlib-suppress
"""将GIF转换为水平拼接的静态图像, 跳过相似的帧
Args:
@@ -351,7 +351,7 @@ class ImageManager:
# 创建拼接图像
total_width = target_width * len(resized_frames)
# 防止总宽度为0
if total_width == 0 and len(resized_frames) > 0:
if total_width == 0 and resized_frames:
logger.warning("计算出的总宽度为0但有选中帧可能目标宽度太小")
# 至少给点宽度吧
total_width = len(resized_frames)
@@ -368,10 +368,7 @@ class ImageManager:
# 转换为base64
buffer = io.BytesIO()
combined_image.save(buffer, format="JPEG", quality=85) # 保存为JPEG
result_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
return result_base64
return base64.b64encode(buffer.getvalue()).decode("utf-8")
except MemoryError:
logger.error("GIF转换失败: 内存不足可能是GIF太大或帧数太多")
return None # 内存不够啦
@@ -380,6 +377,7 @@ class ImageManager:
return None # 其他错误也返回None
async def process_image(self, image_base64: str) -> Tuple[str, str]:
# sourcery skip: hoist-if-from-if
"""处理图片并返回图片ID和描述
Args:
@@ -418,17 +416,9 @@ class ImageManager:
if existing_image.vlm_processed is None:
existing_image.vlm_processed = False
existing_image.count += 1
existing_image.save()
return existing_image.image_id, f"[picid:{existing_image.image_id}]"
else:
# print(f"图片已存在: {existing_image.image_id}")
# print(f"图片描述: {existing_image.description}")
# print(f"图片计数: {existing_image.count}")
# 更新计数
existing_image.count += 1
existing_image.save()
return existing_image.image_id, f"[picid:{existing_image.image_id}]"
existing_image.count += 1
existing_image.save()
return existing_image.image_id, f"[picid:{existing_image.image_id}]"
else:
# print(f"图片不存在: {image_hash}")
image_id = str(uuid.uuid4())
@@ -491,7 +481,7 @@ class ImageManager:
return
# 获取图片格式
image_format = Image.open(io.BytesIO(image_bytes)).format.lower()
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore
# 构建prompt
prompt = """请用中文描述这张图片的内容。如果有文字请把文字描述概括出来请留意其主题直观感受输出为一段平文本最多30字请注意不要分点就输出一段文本"""

View File

@@ -35,9 +35,7 @@ class ClassicalWillingManager(BaseWillingManager):
self.chat_reply_willing[chat_id] = min(current_willing, 3.0)
reply_probability = min(max((current_willing - 0.5), 0.01) * 2, 1)
return reply_probability
return min(max((current_willing - 0.5), 0.01) * 2, 1)
async def before_generate_reply_handle(self, message_id):
chat_id = self.ongoing_messages[message_id].chat_id

View File

@@ -1,14 +1,16 @@
from src.common.logger import get_logger
import importlib
import asyncio
from abc import ABC, abstractmethod
from typing import Dict, Optional
from rich.traceback import install
from dataclasses import dataclass
from src.common.logger import get_logger
from src.config.config import global_config
from src.chat.message_receive.chat_stream import ChatStream, GroupInfo
from src.chat.message_receive.message import MessageRecv
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
from abc import ABC, abstractmethod
import importlib
from typing import Dict, Optional
import asyncio
from rich.traceback import install
install(extra_lines=3)
@@ -92,8 +94,8 @@ class BaseWillingManager(ABC):
self.logger = logger
def setup(self, message: dict, chat: ChatStream):
person_id = PersonInfoManager.get_person_id(chat.platform, chat.user_info.user_id)
self.ongoing_messages[message.get("message_id")] = WillingInfo(
person_id = PersonInfoManager.get_person_id(chat.platform, chat.user_info.user_id) # type: ignore
self.ongoing_messages[message.message_info.message_id] = WillingInfo( # type: ignore
message=message,
chat=chat,
person_info_manager=get_person_info_manager(),