心流模式完成缓冲功能
This commit is contained in:
@@ -1,11 +1,13 @@
|
|||||||
from ..person_info import person_info
|
from ..person_info.person_info import person_info_manager
|
||||||
from src.common.logger import get_module_logger
|
from src.common.logger import get_module_logger
|
||||||
import asyncio
|
import asyncio
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from .message import MessageRecv
|
from .message import MessageRecv
|
||||||
import hashlib
|
import hashlib
|
||||||
from typing import List, Dict
|
from typing import Dict
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
from collections import OrderedDict
|
||||||
|
import random
|
||||||
|
|
||||||
logger = get_module_logger("message_buffer")
|
logger = get_module_logger("message_buffer")
|
||||||
|
|
||||||
@@ -18,7 +20,7 @@ class CacheMessages:
|
|||||||
|
|
||||||
class MassageBuffer:
|
class MassageBuffer:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.buffer_pool: Dict[str, List[CacheMessages]] = {}
|
self.buffer_pool: Dict[str, OrderedDict[str, CacheMessages]] = {}
|
||||||
self.lock = asyncio.Lock()
|
self.lock = asyncio.Lock()
|
||||||
|
|
||||||
def get_person_id_(self, platform:str, user_id:str, group_id:str):
|
def get_person_id_(self, platform:str, user_id:str, group_id:str):
|
||||||
@@ -28,62 +30,121 @@ class MassageBuffer:
|
|||||||
return hashlib.md5(key.encode()).hexdigest()
|
return hashlib.md5(key.encode()).hexdigest()
|
||||||
|
|
||||||
async def start_caching_messages(self, message:MessageRecv):
|
async def start_caching_messages(self, message:MessageRecv):
|
||||||
"""添加消息并重置缓冲计时器"""
|
"""添加消息,启动缓冲"""
|
||||||
person_id_ = self.get_person_id_(message.chat_info.platform,
|
person_id_ = self.get_person_id_(message.message_info.platform,
|
||||||
message.chat_info.user_info.user_id,
|
message.message_info.user_info.user_id,
|
||||||
message.chat_info.group_info.group_id)
|
message.message_info.group_info.group_id)
|
||||||
|
|
||||||
async with self.lock:
|
async with self.lock:
|
||||||
# 清空该用户之前的未处理消息
|
if person_id_ not in self.buffer_pool:
|
||||||
if person_id_ in self.buffer_pool:
|
self.buffer_pool[person_id_] = OrderedDict()
|
||||||
for old_msg in self.buffer_pool[person_id_]:
|
|
||||||
if old_msg.result == "U":
|
# 查找最近的处理成功消息(T)
|
||||||
old_msg.cache_determination.set()
|
last_T_msg = None
|
||||||
old_msg.result = "F" # 标记旧消息为失败
|
recent_F_count = 0
|
||||||
logger.debug(f"被新消息覆盖信息id: {message.message_id}")
|
for msg_id in reversed(self.buffer_pool[person_id_]):
|
||||||
|
msg = self.buffer_pool[person_id_][msg_id]
|
||||||
|
if msg.result == "T":
|
||||||
|
last_T_msg = msg
|
||||||
|
break
|
||||||
|
elif msg.result == "F":
|
||||||
|
recent_F_count += 1
|
||||||
|
|
||||||
|
# 判断条件:最近T之后有超过3条F
|
||||||
|
if (recent_F_count >= random.randint(3, 5)):
|
||||||
|
new_msg = CacheMessages(message=message, result="T")
|
||||||
|
new_msg.cache_determination.set()
|
||||||
|
self.buffer_pool[person_id_][message.message_info.message_id] = new_msg
|
||||||
|
logger.debug(f"快速处理消息(已堆积{recent_F_count}条F): {message.message_info.message_id}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 标记该用户之前的未处理消息
|
||||||
|
for msg_id, cache_msg in self.buffer_pool[person_id_].items():
|
||||||
|
if cache_msg.result == "U":
|
||||||
|
cache_msg.result = "F"
|
||||||
|
cache_msg.cache_determination.set()
|
||||||
|
logger.debug(f"被新消息覆盖信息id: {message.message_info.message_id}")
|
||||||
|
|
||||||
# 添加新消息
|
# 添加新消息
|
||||||
cache_msg = CacheMessages(message=message, result="U")
|
self.buffer_pool[person_id_][message.message_info.message_id] = CacheMessages(message=message)
|
||||||
self.buffer_pool[person_id_] = [cache_msg] # 只保留最新消息
|
|
||||||
|
|
||||||
# 启动3秒缓冲计时器
|
# 启动3秒缓冲计时器
|
||||||
asyncio.create_task(self._debounce_processor(person_id_, cache_msg))
|
person_id = person_info_manager.get_person_id(message.message_info.user_info.platform,
|
||||||
|
message.message_info.user_info.user_id)
|
||||||
|
asyncio.create_task(self._debounce_processor(person_id_,
|
||||||
|
message.message_info.message_id,
|
||||||
|
person_id))
|
||||||
|
|
||||||
async def _debounce_processor(self, person_id_:str, cache_msg:CacheMessages):
|
async def _debounce_processor(self, person_id_: str, message_id: str, person_id: str):
|
||||||
"""等待3秒无新消息"""
|
"""等待3秒无新消息"""
|
||||||
await asyncio.sleep(3)
|
interval_time = await person_info_manager.get_value(person_id, "msg_interval")
|
||||||
|
if not isinstance(interval_time, (int, str)) or not str(interval_time).isdigit():
|
||||||
|
logger.debug("debounce_processor无效的时间")
|
||||||
|
return
|
||||||
|
interval_time = max(0.5, int(interval_time) / 1000)
|
||||||
|
await asyncio.sleep(interval_time)
|
||||||
|
|
||||||
async with self.lock:
|
async with self.lock:
|
||||||
# 检查消息是否仍未被覆盖
|
if (person_id_ not in self.buffer_pool or
|
||||||
if (person_id_ in self.buffer_pool and
|
message_id not in self.buffer_pool[person_id_]):
|
||||||
cache_msg in self.buffer_pool[person_id_] and
|
logger.debug(f"消息异常被清理,msgid: {message_id}")
|
||||||
cache_msg.result == "U"):
|
return
|
||||||
|
|
||||||
cache_msg.result = "T" # 标记为成功处理
|
cache_msg = self.buffer_pool[person_id_][message_id]
|
||||||
|
if cache_msg.result == "U":
|
||||||
|
cache_msg.result = "T"
|
||||||
cache_msg.cache_determination.set()
|
cache_msg.cache_determination.set()
|
||||||
|
|
||||||
|
|
||||||
async def query_buffer_result(self, message:MessageRecv) -> bool:
|
async def query_buffer_result(self, message:MessageRecv) -> bool:
|
||||||
"""查询缓冲结果"""
|
"""查询缓冲结果,并清理"""
|
||||||
person_id_ = self.get_person_id_(message.chat_info.platform,
|
person_id_ = self.get_person_id_(message.message_info.platform,
|
||||||
message.chat_info.user_info.user_id,
|
message.message_info.user_info.user_id,
|
||||||
message.chat_info.group_info.group_id)
|
message.message_info.group_info.group_id)
|
||||||
|
|
||||||
|
|
||||||
async with self.lock:
|
async with self.lock:
|
||||||
if person_id_ not in self.buffer_pool or not self.buffer_pool[person_id_]:
|
user_msgs = self.buffer_pool.get(person_id_, {})
|
||||||
return False
|
cache_msg = user_msgs.get(message.message_info.message_id)
|
||||||
|
|
||||||
cache_msg = self.buffer_pool[person_id_][-1] # 获取最新消息
|
if not cache_msg:
|
||||||
if cache_msg.message.message_id != message.message_id:
|
logger.debug(f"查询异常,消息不存在,msgid: {message.message_info.message_id}")
|
||||||
return False
|
return False # 消息不存在或已清理
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await asyncio.wait_for(cache_msg.cache_determination.wait(), timeout=10)
|
await asyncio.wait_for(cache_msg.cache_determination.wait(), timeout=10)
|
||||||
return cache_msg.result == "T"
|
result = cache_msg.result == "T"
|
||||||
|
|
||||||
|
if result:
|
||||||
|
async with self.lock: # 再次加锁
|
||||||
|
# 清理所有早于当前消息的已处理消息, 收集所有早于当前消息的F消息的processed_plain_text
|
||||||
|
keep_msgs = OrderedDict()
|
||||||
|
combined_text = []
|
||||||
|
found = False
|
||||||
|
for msg_id, msg in self.buffer_pool[person_id_].items():
|
||||||
|
if msg_id == message.message_info.message_id:
|
||||||
|
found = True
|
||||||
|
combined_text.append(msg.message.processed_plain_text)
|
||||||
|
continue
|
||||||
|
if found:
|
||||||
|
keep_msgs[msg_id] = msg
|
||||||
|
elif msg.result == "F":
|
||||||
|
# 收集F消息的文本内容
|
||||||
|
if hasattr(msg.message, 'processed_plain_text') and msg.message.processed_plain_text:
|
||||||
|
combined_text.append(msg.message.processed_plain_text)
|
||||||
|
elif msg.result == "U":
|
||||||
|
logger.debug(f"异常未处理信息id: {msg.message.message_info.message_id}")
|
||||||
|
|
||||||
|
# 更新当前消息的processed_plain_text
|
||||||
|
if combined_text and combined_text[0] != message.processed_plain_text:
|
||||||
|
message.processed_plain_text = "".join(combined_text)
|
||||||
|
logger.debug(f"整合了{len(combined_text)-1}条F消息的内容到当前消息")
|
||||||
|
|
||||||
|
self.buffer_pool[person_id_] = keep_msgs
|
||||||
|
return result
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
logger.debug(f"查询超时消息id: {message.message_id}")
|
logger.debug(f"查询超时消息id: {message.message_info.message_id}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
message_buffer = MassageBuffer()
|
message_buffer = MassageBuffer()
|
||||||
@@ -18,6 +18,7 @@ from src.heart_flow.heartflow import heartflow
|
|||||||
from src.common.logger import get_module_logger, CHAT_STYLE_CONFIG, LogConfig
|
from src.common.logger import get_module_logger, CHAT_STYLE_CONFIG, LogConfig
|
||||||
from ...chat.chat_stream import chat_manager
|
from ...chat.chat_stream import chat_manager
|
||||||
from ...person_info.relationship_manager import relationship_manager
|
from ...person_info.relationship_manager import relationship_manager
|
||||||
|
from ...chat.message_buffer import message_buffer
|
||||||
|
|
||||||
# 定义日志配置
|
# 定义日志配置
|
||||||
chat_config = LogConfig(
|
chat_config = LogConfig(
|
||||||
@@ -161,6 +162,8 @@ class ThinkFlowChat:
|
|||||||
userinfo = message.message_info.user_info
|
userinfo = message.message_info.user_info
|
||||||
messageinfo = message.message_info
|
messageinfo = message.message_info
|
||||||
|
|
||||||
|
# 消息加入缓冲池
|
||||||
|
await message_buffer.start_caching_messages(message)
|
||||||
|
|
||||||
# 创建聊天流
|
# 创建聊天流
|
||||||
chat = await chat_manager.get_or_create_stream(
|
chat = await chat_manager.get_or_create_stream(
|
||||||
@@ -192,8 +195,15 @@ class ThinkFlowChat:
|
|||||||
timing_results["记忆激活"] = timer2 - timer1
|
timing_results["记忆激活"] = timer2 - timer1
|
||||||
logger.debug(f"记忆激活: {interested_rate}")
|
logger.debug(f"记忆激活: {interested_rate}")
|
||||||
|
|
||||||
|
# 查询缓冲器结果,会整合前面跳过的消息,改变processed_plain_text
|
||||||
|
buffer_result = await message_buffer.query_buffer_result(message)
|
||||||
|
if not buffer_result:
|
||||||
|
logger.info(f"触发缓冲,已炸飞消息:{message.processed_plain_text}")
|
||||||
|
return
|
||||||
|
|
||||||
is_mentioned = is_mentioned_bot_in_message(message)
|
is_mentioned = is_mentioned_bot_in_message(message)
|
||||||
|
|
||||||
|
|
||||||
# 计算回复意愿
|
# 计算回复意愿
|
||||||
current_willing_old = willing_manager.get_willing(chat_stream=chat)
|
current_willing_old = willing_manager.get_willing(chat_stream=chat)
|
||||||
# current_willing_new = (heartflow.get_subheartflow(chat.stream_id).current_state.willing - 5) / 4
|
# current_willing_new = (heartflow.get_subheartflow(chat.stream_id).current_state.willing - 5) / 4
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ person_info_default = {
|
|||||||
# "impression" : None,
|
# "impression" : None,
|
||||||
# "gender" : Unkown,
|
# "gender" : Unkown,
|
||||||
"konw_time" : 0,
|
"konw_time" : 0,
|
||||||
|
"msg_interval": 3000
|
||||||
} # 个人信息的各项与默认值在此定义,以下处理会自动创建/补全每一项
|
} # 个人信息的各项与默认值在此定义,以下处理会自动创建/补全每一项
|
||||||
|
|
||||||
class PersonInfoManager:
|
class PersonInfoManager:
|
||||||
|
|||||||
Reference in New Issue
Block a user