refactor of focus_chat
This commit is contained in:
@@ -1,8 +1,8 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
from collections import deque
|
import random
|
||||||
from typing import List, Optional, Dict, Any, Deque, Callable, Awaitable
|
from typing import List, Optional, Dict, Any
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
|
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
@@ -10,19 +10,18 @@ from src.common.logger import get_logger
|
|||||||
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
||||||
from src.chat.utils.prompt_builder import global_prompt_manager
|
from src.chat.utils.prompt_builder import global_prompt_manager
|
||||||
from src.chat.utils.timer_calculator import Timer
|
from src.chat.utils.timer_calculator import Timer
|
||||||
|
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat
|
||||||
from src.chat.planner_actions.planner import ActionPlanner
|
from src.chat.planner_actions.planner import ActionPlanner
|
||||||
from src.chat.planner_actions.action_modifier import ActionModifier
|
from src.chat.planner_actions.action_modifier import ActionModifier
|
||||||
from src.chat.planner_actions.action_manager import ActionManager
|
from src.chat.planner_actions.action_manager import ActionManager
|
||||||
from src.chat.focus_chat.hfc_utils import CycleDetail
|
from src.chat.focus_chat.hfc_utils import CycleDetail
|
||||||
from src.person_info.relationship_builder_manager import relationship_builder_manager
|
|
||||||
from src.plugin_system.base.component_types import ChatMode
|
|
||||||
import random
|
|
||||||
from src.chat.focus_chat.hfc_utils import get_recent_message_stats
|
from src.chat.focus_chat.hfc_utils import get_recent_message_stats
|
||||||
|
from src.person_info.relationship_builder_manager import relationship_builder_manager
|
||||||
from src.person_info.person_info import get_person_info_manager
|
from src.person_info.person_info import get_person_info_manager
|
||||||
|
from src.plugin_system.base.component_types import ActionInfo, ChatMode
|
||||||
from src.plugin_system.apis import generator_api, send_api, message_api
|
from src.plugin_system.apis import generator_api, send_api, message_api
|
||||||
from src.chat.willing.willing_manager import get_willing_manager
|
from src.chat.willing.willing_manager import get_willing_manager
|
||||||
from .priority_manager import PriorityManager
|
from .priority_manager import PriorityManager
|
||||||
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat
|
|
||||||
|
|
||||||
|
|
||||||
ERROR_LOOP_INFO = {
|
ERROR_LOOP_INFO = {
|
||||||
@@ -107,7 +106,7 @@ class HeartFChatting:
|
|||||||
# 添加循环信息管理相关的属性
|
# 添加循环信息管理相关的属性
|
||||||
self.history_loop: List[CycleDetail] = []
|
self.history_loop: List[CycleDetail] = []
|
||||||
self._cycle_counter = 0
|
self._cycle_counter = 0
|
||||||
self._current_cycle_detail: Optional[CycleDetail] = None
|
self._current_cycle_detail: CycleDetail = None # type: ignore
|
||||||
|
|
||||||
self.reply_timeout_count = 0
|
self.reply_timeout_count = 0
|
||||||
self.plan_timeout_count = 0
|
self.plan_timeout_count = 0
|
||||||
@@ -169,7 +168,7 @@ class HeartFChatting:
|
|||||||
def start_cycle(self):
|
def start_cycle(self):
|
||||||
self._cycle_counter += 1
|
self._cycle_counter += 1
|
||||||
self._current_cycle_detail = CycleDetail(self._cycle_counter)
|
self._current_cycle_detail = CycleDetail(self._cycle_counter)
|
||||||
self._current_cycle_detail.thinking_id = "tid" + str(round(time.time(), 2))
|
self._current_cycle_detail.thinking_id = f"tid{str(round(time.time(), 2))}"
|
||||||
cycle_timers = {}
|
cycle_timers = {}
|
||||||
return cycle_timers, self._current_cycle_detail.thinking_id
|
return cycle_timers, self._current_cycle_detail.thinking_id
|
||||||
|
|
||||||
@@ -230,13 +229,15 @@ class HeartFChatting:
|
|||||||
async def build_reply_to_str(self, message_data: dict):
|
async def build_reply_to_str(self, message_data: dict):
|
||||||
person_info_manager = get_person_info_manager()
|
person_info_manager = get_person_info_manager()
|
||||||
person_id = person_info_manager.get_person_id(
|
person_id = person_info_manager.get_person_id(
|
||||||
message_data.get("chat_info_platform"), message_data.get("user_id")
|
message_data.get("chat_info_platform"), # type: ignore
|
||||||
|
message_data.get("user_id"), # type: ignore
|
||||||
)
|
)
|
||||||
person_name = await person_info_manager.get_value(person_id, "person_name")
|
person_name = await person_info_manager.get_value(person_id, "person_name")
|
||||||
reply_to_str = f"{person_name}:{message_data.get('processed_plain_text')}"
|
return f"{person_name}:{message_data.get('processed_plain_text')}"
|
||||||
return reply_to_str
|
|
||||||
|
|
||||||
async def _observe(self, message_data: dict = None):
|
async def _observe(self, message_data: Optional[Dict[str, Any]] = None):
|
||||||
|
if not message_data:
|
||||||
|
message_data = {}
|
||||||
# 创建新的循环信息
|
# 创建新的循环信息
|
||||||
cycle_timers, thinking_id = self.start_cycle()
|
cycle_timers, thinking_id = self.start_cycle()
|
||||||
|
|
||||||
@@ -339,7 +340,7 @@ class HeartFChatting:
|
|||||||
self.print_cycle_info(cycle_timers)
|
self.print_cycle_info(cycle_timers)
|
||||||
|
|
||||||
if self.loop_mode == "normal":
|
if self.loop_mode == "normal":
|
||||||
await self.willing_manager.after_generate_reply_handle(message_data.get("message_id"))
|
await self.willing_manager.after_generate_reply_handle(message_data.get("message_id", ""))
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@@ -425,39 +426,39 @@ class HeartFChatting:
|
|||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
return False, "", ""
|
return False, "", ""
|
||||||
|
|
||||||
async def shutdown(self):
|
# async def shutdown(self):
|
||||||
"""优雅关闭HeartFChatting实例,取消活动循环任务"""
|
# """优雅关闭HeartFChatting实例,取消活动循环任务"""
|
||||||
logger.info(f"{self.log_prefix} 正在关闭HeartFChatting...")
|
# logger.info(f"{self.log_prefix} 正在关闭HeartFChatting...")
|
||||||
self.running = False # <-- 在开始关闭时设置标志位
|
# self.running = False # <-- 在开始关闭时设置标志位
|
||||||
|
|
||||||
# 记录最终的消息统计
|
# # 记录最终的消息统计
|
||||||
if self._message_count > 0:
|
# if self._message_count > 0:
|
||||||
logger.info(f"{self.log_prefix} 本次focus会话共发送了 {self._message_count} 条消息")
|
# logger.info(f"{self.log_prefix} 本次focus会话共发送了 {self._message_count} 条消息")
|
||||||
if self._fatigue_triggered:
|
# if self._fatigue_triggered:
|
||||||
logger.info(f"{self.log_prefix} 因疲惫而退出focus模式")
|
# logger.info(f"{self.log_prefix} 因疲惫而退出focus模式")
|
||||||
|
|
||||||
# 取消循环任务
|
# # 取消循环任务
|
||||||
if self._loop_task and not self._loop_task.done():
|
# if self._loop_task and not self._loop_task.done():
|
||||||
logger.info(f"{self.log_prefix} 正在取消HeartFChatting循环任务")
|
# logger.info(f"{self.log_prefix} 正在取消HeartFChatting循环任务")
|
||||||
self._loop_task.cancel()
|
# self._loop_task.cancel()
|
||||||
try:
|
# try:
|
||||||
await asyncio.wait_for(self._loop_task, timeout=1.0)
|
# await asyncio.wait_for(self._loop_task, timeout=1.0)
|
||||||
logger.info(f"{self.log_prefix} HeartFChatting循环任务已取消")
|
# logger.info(f"{self.log_prefix} HeartFChatting循环任务已取消")
|
||||||
except (asyncio.CancelledError, asyncio.TimeoutError):
|
# except (asyncio.CancelledError, asyncio.TimeoutError):
|
||||||
pass
|
# pass
|
||||||
except Exception as e:
|
# except Exception as e:
|
||||||
logger.error(f"{self.log_prefix} 取消循环任务出错: {e}")
|
# logger.error(f"{self.log_prefix} 取消循环任务出错: {e}")
|
||||||
else:
|
# else:
|
||||||
logger.info(f"{self.log_prefix} 没有活动的HeartFChatting循环任务")
|
# logger.info(f"{self.log_prefix} 没有活动的HeartFChatting循环任务")
|
||||||
|
|
||||||
# 清理状态
|
# # 清理状态
|
||||||
self.running = False
|
# self.running = False
|
||||||
self._loop_task = None
|
# self._loop_task = None
|
||||||
|
|
||||||
# 重置消息计数器,为下次启动做准备
|
# # 重置消息计数器,为下次启动做准备
|
||||||
self.reset_message_count()
|
# self.reset_message_count()
|
||||||
|
|
||||||
logger.info(f"{self.log_prefix} HeartFChatting关闭完成")
|
# logger.info(f"{self.log_prefix} HeartFChatting关闭完成")
|
||||||
|
|
||||||
def adjust_reply_frequency(self):
|
def adjust_reply_frequency(self):
|
||||||
"""
|
"""
|
||||||
@@ -549,7 +550,7 @@ class HeartFChatting:
|
|||||||
# 仅在未被提及或基础概率不为1时查询意愿概率
|
# 仅在未被提及或基础概率不为1时查询意愿概率
|
||||||
if reply_probability < 1: # 简化逻辑,如果未提及 (reply_probability 为 0),则获取意愿概率
|
if reply_probability < 1: # 简化逻辑,如果未提及 (reply_probability 为 0),则获取意愿概率
|
||||||
# is_willing = True
|
# is_willing = True
|
||||||
reply_probability = await self.willing_manager.get_reply_probability(message_data.get("message_id"))
|
reply_probability = await self.willing_manager.get_reply_probability(message_data.get("message_id", ""))
|
||||||
|
|
||||||
additional_config = message_data.get("additional_config", {})
|
additional_config = message_data.get("additional_config", {})
|
||||||
if additional_config and "maimcore_reply_probability_gain" in additional_config:
|
if additional_config and "maimcore_reply_probability_gain" in additional_config:
|
||||||
@@ -570,20 +571,18 @@ class HeartFChatting:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if random.random() < reply_probability:
|
if random.random() < reply_probability:
|
||||||
await self.willing_manager.before_generate_reply_handle(message_data.get("message_id"))
|
await self.willing_manager.before_generate_reply_handle(message_data.get("message_id", ""))
|
||||||
await self._observe(message_data=message_data)
|
await self._observe(message_data=message_data)
|
||||||
|
|
||||||
# 意愿管理器:注销当前message信息 (无论是否回复,只要处理过就删除)
|
# 意愿管理器:注销当前message信息 (无论是否回复,只要处理过就删除)
|
||||||
self.willing_manager.delete(message_data.get("message_id"))
|
self.willing_manager.delete(message_data.get("message_id", ""))
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
async def _generate_response(
|
async def _generate_response(
|
||||||
self, message_data: dict, available_actions: Optional[list], reply_to: str
|
self, message_data: dict, available_actions: Optional[Dict[str, ActionInfo]], reply_to: str
|
||||||
) -> Optional[list]:
|
) -> Optional[list]:
|
||||||
"""生成普通回复"""
|
"""生成普通回复"""
|
||||||
try:
|
try:
|
||||||
success, reply_set = await generator_api.generate_reply(
|
success, reply_set, _ = await generator_api.generate_reply(
|
||||||
chat_stream=self.chat_stream,
|
chat_stream=self.chat_stream,
|
||||||
reply_to=reply_to,
|
reply_to=reply_to,
|
||||||
available_actions=available_actions,
|
available_actions=available_actions,
|
||||||
@@ -622,7 +621,6 @@ class HeartFChatting:
|
|||||||
await send_api.text_to_stream(
|
await send_api.text_to_stream(
|
||||||
text=data, stream_id=self.chat_stream.stream_id, reply_to=reply_to, typing=False
|
text=data, stream_id=self.chat_stream.stream_id, reply_to=reply_to, typing=False
|
||||||
)
|
)
|
||||||
first_replyed = True
|
|
||||||
else:
|
else:
|
||||||
await send_api.text_to_stream(text=data, stream_id=self.chat_stream.stream_id, typing=False)
|
await send_api.text_to_stream(text=data, stream_id=self.chat_stream.stream_id, typing=False)
|
||||||
first_replyed = True
|
first_replyed = True
|
||||||
|
|||||||
@@ -1,14 +1,10 @@
|
|||||||
import time
|
import time
|
||||||
import json
|
|
||||||
|
|
||||||
from typing import Optional, Dict, Any
|
from typing import Optional, Dict, Any
|
||||||
|
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
from src.common.message_repository import count_messages
|
from src.common.message_repository import count_messages
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.chat.message_receive.message import MessageRecv, BaseMessageInfo
|
|
||||||
from src.chat.message_receive.chat_stream import ChatStream
|
|
||||||
from src.chat.message_receive.message import UserInfo
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
@@ -85,10 +81,10 @@ class CycleDetail:
|
|||||||
self.loop_action_info = loop_info["loop_action_info"]
|
self.loop_action_info = loop_info["loop_action_info"]
|
||||||
|
|
||||||
|
|
||||||
def get_recent_message_stats(minutes: int = 30, chat_id: str = None) -> dict:
|
def get_recent_message_stats(minutes: float = 30, chat_id: Optional[str] = None) -> dict:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
minutes (int): 检索的分钟数,默认30分钟
|
minutes (float): 检索的分钟数,默认30分钟
|
||||||
chat_id (str, optional): 指定的chat_id,仅统计该chat下的消息。为None时统计全部。
|
chat_id (str, optional): 指定的chat_id,仅统计该chat下的消息。为None时统计全部。
|
||||||
Returns:
|
Returns:
|
||||||
dict: {"bot_reply_count": int, "total_message_count": int}
|
dict: {"bot_reply_count": int, "total_message_count": int}
|
||||||
@@ -98,7 +94,7 @@ def get_recent_message_stats(minutes: int = 30, chat_id: str = None) -> dict:
|
|||||||
start_time = now - minutes * 60
|
start_time = now - minutes * 60
|
||||||
bot_id = global_config.bot.qq_account
|
bot_id = global_config.bot.qq_account
|
||||||
|
|
||||||
filter_base = {"time": {"$gte": start_time}}
|
filter_base: Dict[str, Any] = {"time": {"$gte": start_time}}
|
||||||
if chat_id is not None:
|
if chat_id is not None:
|
||||||
filter_base["chat_id"] = chat_id
|
filter_base["chat_id"] = chat_id
|
||||||
|
|
||||||
|
|||||||
@@ -25,8 +25,7 @@ class PrioritizedMessage:
|
|||||||
"""
|
"""
|
||||||
age = time.time() - self.arrival_time
|
age = time.time() - self.arrival_time
|
||||||
decay_factor = math.exp(-decay_rate * age)
|
decay_factor = math.exp(-decay_rate * age)
|
||||||
priority = sum(self.interest_scores) + decay_factor
|
return sum(self.interest_scores) + decay_factor
|
||||||
return priority
|
|
||||||
|
|
||||||
def __lt__(self, other: "PrioritizedMessage") -> bool:
|
def __lt__(self, other: "PrioritizedMessage") -> bool:
|
||||||
"""用于堆排序的比较函数,我们想要一个最大堆,所以用 >"""
|
"""用于堆排序的比较函数,我们想要一个最大堆,所以用 >"""
|
||||||
@@ -43,7 +42,7 @@ class PriorityManager:
|
|||||||
self.normal_queue: List[PrioritizedMessage] = [] # 普通消息队列 (最大堆)
|
self.normal_queue: List[PrioritizedMessage] = [] # 普通消息队列 (最大堆)
|
||||||
self.normal_queue_max_size = normal_queue_max_size
|
self.normal_queue_max_size = normal_queue_max_size
|
||||||
|
|
||||||
def add_message(self, message_data: dict, interest_score: Optional[float] = None):
|
def add_message(self, message_data: dict, interest_score: float = 0):
|
||||||
"""
|
"""
|
||||||
添加新消息到合适的队列中。
|
添加新消息到合适的队列中。
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -79,7 +79,7 @@ class ActionPlanner:
|
|||||||
|
|
||||||
self.last_obs_time_mark = 0.0
|
self.last_obs_time_mark = 0.0
|
||||||
|
|
||||||
async def plan(self, mode: str = "focus") -> Dict[str, Any]: # sourcery skip: dict-comprehension
|
async def plan(self, mode: str = "focus") -> Dict[str, Dict[str, Any]]: # sourcery skip: dict-comprehension
|
||||||
"""
|
"""
|
||||||
规划器 (Planner): 使用LLM根据上下文决定做出什么动作。
|
规划器 (Planner): 使用LLM根据上下文决定做出什么动作。
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user