chore: format code and remove redundant blank lines
This commit applies automated code formatting across the project. The changes primarily involve removing unnecessary blank lines and ensuring consistent code style, improving readability and maintainability without altering functionality.
This commit is contained in:
@@ -37,6 +37,7 @@ def get_classes_in_module(module):
|
|||||||
classes.append(member)
|
classes.append(member)
|
||||||
return classes
|
return classes
|
||||||
|
|
||||||
|
|
||||||
async def message_recv(server_connection: Server.ServerConnection):
|
async def message_recv(server_connection: Server.ServerConnection):
|
||||||
await message_handler.set_server_connection(server_connection)
|
await message_handler.set_server_connection(server_connection)
|
||||||
asyncio.create_task(notice_handler.set_server_connection(server_connection))
|
asyncio.create_task(notice_handler.set_server_connection(server_connection))
|
||||||
@@ -47,7 +48,7 @@ async def message_recv(server_connection: Server.ServerConnection):
|
|||||||
try:
|
try:
|
||||||
# 首先尝试解析原始消息
|
# 首先尝试解析原始消息
|
||||||
decoded_raw_message: dict = json.loads(raw_message)
|
decoded_raw_message: dict = json.loads(raw_message)
|
||||||
|
|
||||||
# 检查是否是切片消息 (来自 MMC)
|
# 检查是否是切片消息 (来自 MMC)
|
||||||
if chunker.is_chunk_message(decoded_raw_message):
|
if chunker.is_chunk_message(decoded_raw_message):
|
||||||
logger.debug("接收到切片消息,尝试重组")
|
logger.debug("接收到切片消息,尝试重组")
|
||||||
@@ -61,14 +62,14 @@ async def message_recv(server_connection: Server.ServerConnection):
|
|||||||
# 切片尚未完整,继续等待更多切片
|
# 切片尚未完整,继续等待更多切片
|
||||||
logger.debug("等待更多切片...")
|
logger.debug("等待更多切片...")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 处理完整消息(可能是重组后的,也可能是原本就完整的)
|
# 处理完整消息(可能是重组后的,也可能是原本就完整的)
|
||||||
post_type = decoded_raw_message.get("post_type")
|
post_type = decoded_raw_message.get("post_type")
|
||||||
if post_type in ["meta_event", "message", "notice"]:
|
if post_type in ["meta_event", "message", "notice"]:
|
||||||
await message_queue.put(decoded_raw_message)
|
await message_queue.put(decoded_raw_message)
|
||||||
elif post_type is None:
|
elif post_type is None:
|
||||||
await put_response(decoded_raw_message)
|
await put_response(decoded_raw_message)
|
||||||
|
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
logger.error(f"消息解析失败: {e}")
|
logger.error(f"消息解析失败: {e}")
|
||||||
logger.debug(f"原始消息: {raw_message[:500]}...")
|
logger.debug(f"原始消息: {raw_message[:500]}...")
|
||||||
@@ -76,6 +77,7 @@ async def message_recv(server_connection: Server.ServerConnection):
|
|||||||
logger.error(f"处理消息时出错: {e}")
|
logger.error(f"处理消息时出错: {e}")
|
||||||
logger.debug(f"原始消息: {raw_message[:500]}...")
|
logger.debug(f"原始消息: {raw_message[:500]}...")
|
||||||
|
|
||||||
|
|
||||||
async def message_process():
|
async def message_process():
|
||||||
"""消息处理主循环"""
|
"""消息处理主循环"""
|
||||||
logger.info("消息处理器已启动")
|
logger.info("消息处理器已启动")
|
||||||
@@ -84,7 +86,7 @@ async def message_process():
|
|||||||
try:
|
try:
|
||||||
# 使用超时等待,以便能够响应取消请求
|
# 使用超时等待,以便能够响应取消请求
|
||||||
message = await asyncio.wait_for(message_queue.get(), timeout=1.0)
|
message = await asyncio.wait_for(message_queue.get(), timeout=1.0)
|
||||||
|
|
||||||
post_type = message.get("post_type")
|
post_type = message.get("post_type")
|
||||||
if post_type == "message":
|
if post_type == "message":
|
||||||
await message_handler.handle_raw_message(message)
|
await message_handler.handle_raw_message(message)
|
||||||
@@ -94,10 +96,10 @@ async def message_process():
|
|||||||
await notice_handler.handle_notice(message)
|
await notice_handler.handle_notice(message)
|
||||||
else:
|
else:
|
||||||
logger.warning(f"未知的post_type: {post_type}")
|
logger.warning(f"未知的post_type: {post_type}")
|
||||||
|
|
||||||
message_queue.task_done()
|
message_queue.task_done()
|
||||||
await asyncio.sleep(0.05)
|
await asyncio.sleep(0.05)
|
||||||
|
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
# 超时是正常的,继续循环
|
# 超时是正常的,继续循环
|
||||||
continue
|
continue
|
||||||
@@ -112,7 +114,7 @@ async def message_process():
|
|||||||
except ValueError:
|
except ValueError:
|
||||||
pass
|
pass
|
||||||
await asyncio.sleep(0.1)
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
logger.info("消息处理器已停止")
|
logger.info("消息处理器已停止")
|
||||||
raise
|
raise
|
||||||
@@ -132,6 +134,7 @@ async def message_process():
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"清理消息队列时出错: {e}")
|
logger.debug(f"清理消息队列时出错: {e}")
|
||||||
|
|
||||||
|
|
||||||
async def napcat_server():
|
async def napcat_server():
|
||||||
"""启动 Napcat WebSocket 连接(支持正向和反向连接)"""
|
"""启动 Napcat WebSocket 连接(支持正向和反向连接)"""
|
||||||
mode = global_config.napcat_server.mode
|
mode = global_config.napcat_server.mode
|
||||||
@@ -143,63 +146,61 @@ async def napcat_server():
|
|||||||
logger.error(f"启动 WebSocket 连接失败: {e}")
|
logger.error(f"启动 WebSocket 连接失败: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
async def graceful_shutdown():
|
async def graceful_shutdown():
|
||||||
"""优雅关闭所有组件"""
|
"""优雅关闭所有组件"""
|
||||||
try:
|
try:
|
||||||
logger.info("正在关闭adapter...")
|
logger.info("正在关闭adapter...")
|
||||||
|
|
||||||
# 停止消息重组器的清理任务
|
# 停止消息重组器的清理任务
|
||||||
try:
|
try:
|
||||||
await reassembler.stop_cleanup_task()
|
await reassembler.stop_cleanup_task()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"停止消息重组器清理任务时出错: {e}")
|
logger.warning(f"停止消息重组器清理任务时出错: {e}")
|
||||||
|
|
||||||
# 停止功能管理器文件监控
|
# 停止功能管理器文件监控
|
||||||
try:
|
try:
|
||||||
await features_manager.stop_file_watcher()
|
await features_manager.stop_file_watcher()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"停止功能管理器文件监控时出错: {e}")
|
logger.warning(f"停止功能管理器文件监控时出错: {e}")
|
||||||
|
|
||||||
# 关闭消息处理器(包括消息缓冲器)
|
# 关闭消息处理器(包括消息缓冲器)
|
||||||
try:
|
try:
|
||||||
await message_handler.shutdown()
|
await message_handler.shutdown()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"关闭消息处理器时出错: {e}")
|
logger.warning(f"关闭消息处理器时出错: {e}")
|
||||||
|
|
||||||
# 关闭 WebSocket 连接
|
# 关闭 WebSocket 连接
|
||||||
try:
|
try:
|
||||||
await websocket_manager.stop_connection()
|
await websocket_manager.stop_connection()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"关闭WebSocket连接时出错: {e}")
|
logger.warning(f"关闭WebSocket连接时出错: {e}")
|
||||||
|
|
||||||
# 关闭 MaiBot 连接
|
# 关闭 MaiBot 连接
|
||||||
try:
|
try:
|
||||||
await mmc_stop_com()
|
await mmc_stop_com()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"关闭MaiBot连接时出错: {e}")
|
logger.warning(f"关闭MaiBot连接时出错: {e}")
|
||||||
|
|
||||||
# 取消所有剩余任务
|
# 取消所有剩余任务
|
||||||
current_task = asyncio.current_task()
|
current_task = asyncio.current_task()
|
||||||
tasks = [t for t in asyncio.all_tasks() if t is not current_task and not t.done()]
|
tasks = [t for t in asyncio.all_tasks() if t is not current_task and not t.done()]
|
||||||
|
|
||||||
if tasks:
|
if tasks:
|
||||||
logger.info(f"正在取消 {len(tasks)} 个剩余任务...")
|
logger.info(f"正在取消 {len(tasks)} 个剩余任务...")
|
||||||
for task in tasks:
|
for task in tasks:
|
||||||
task.cancel()
|
task.cancel()
|
||||||
|
|
||||||
# 等待任务取消完成,忽略 CancelledError
|
# 等待任务取消完成,忽略 CancelledError
|
||||||
try:
|
try:
|
||||||
await asyncio.wait_for(
|
await asyncio.wait_for(asyncio.gather(*tasks, return_exceptions=True), timeout=10)
|
||||||
asyncio.gather(*tasks, return_exceptions=True),
|
|
||||||
timeout=10
|
|
||||||
)
|
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
logger.warning("部分任务取消超时")
|
logger.warning("部分任务取消超时")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"任务取消过程中的异常(可忽略): {e}")
|
logger.debug(f"任务取消过程中的异常(可忽略): {e}")
|
||||||
|
|
||||||
logger.info("Adapter已成功关闭")
|
logger.info("Adapter已成功关闭")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Adapter关闭中出现错误: {e}")
|
logger.error(f"Adapter关闭中出现错误: {e}")
|
||||||
finally:
|
finally:
|
||||||
@@ -214,6 +215,7 @@ async def graceful_shutdown():
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class LauchNapcatAdapterHandler(BaseEventHandler):
|
class LauchNapcatAdapterHandler(BaseEventHandler):
|
||||||
"""自动启动Adapter"""
|
"""自动启动Adapter"""
|
||||||
|
|
||||||
@@ -245,6 +247,7 @@ class LauchNapcatAdapterHandler(BaseEventHandler):
|
|||||||
asyncio.create_task(message_process())
|
asyncio.create_task(message_process())
|
||||||
asyncio.create_task(check_timeout_response())
|
asyncio.create_task(check_timeout_response())
|
||||||
|
|
||||||
|
|
||||||
class StopNapcatAdapterHandler(BaseEventHandler):
|
class StopNapcatAdapterHandler(BaseEventHandler):
|
||||||
"""关闭Adapter"""
|
"""关闭Adapter"""
|
||||||
|
|
||||||
@@ -257,7 +260,7 @@ class StopNapcatAdapterHandler(BaseEventHandler):
|
|||||||
async def execute(self, kwargs):
|
async def execute(self, kwargs):
|
||||||
await graceful_shutdown()
|
await graceful_shutdown()
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
@register_plugin
|
@register_plugin
|
||||||
class NapcatAdapterPlugin(BasePlugin):
|
class NapcatAdapterPlugin(BasePlugin):
|
||||||
@@ -295,7 +298,7 @@ class NapcatAdapterPlugin(BasePlugin):
|
|||||||
|
|
||||||
def get_plugin_components(self):
|
def get_plugin_components(self):
|
||||||
self.register_events()
|
self.register_events()
|
||||||
|
|
||||||
components = []
|
components = []
|
||||||
components.append((LauchNapcatAdapterHandler.get_handler_info(), LauchNapcatAdapterHandler))
|
components.append((LauchNapcatAdapterHandler.get_handler_info(), LauchNapcatAdapterHandler))
|
||||||
components.append((StopNapcatAdapterHandler.get_handler_info(), StopNapcatAdapterHandler))
|
components.append((StopNapcatAdapterHandler.get_handler_info(), StopNapcatAdapterHandler))
|
||||||
|
|||||||
@@ -58,14 +58,16 @@ class VoiceConfig(ConfigBase):
|
|||||||
use_tts: bool = False
|
use_tts: bool = False
|
||||||
"""是否启用TTS功能"""
|
"""是否启用TTS功能"""
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SlicingConfig(ConfigBase):
|
class SlicingConfig(ConfigBase):
|
||||||
max_frame_size: int = 64
|
max_frame_size: int = 64
|
||||||
"""WebSocket帧的最大大小,单位为字节,默认64KB"""
|
"""WebSocket帧的最大大小,单位为字节,默认64KB"""
|
||||||
|
|
||||||
delay_ms: int = 10
|
delay_ms: int = 10
|
||||||
"""切片发送间隔时间,单位为毫秒"""
|
"""切片发送间隔时间,单位为毫秒"""
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DebugConfig(ConfigBase):
|
class DebugConfig(ConfigBase):
|
||||||
level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = "INFO"
|
level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = "INFO"
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
用于在 Ada 发送给 MMC 时进行消息切片,利用 WebSocket 协议的自动重组特性
|
用于在 Ada 发送给 MMC 时进行消息切片,利用 WebSocket 协议的自动重组特性
|
||||||
仅在 Ada -> MMC 方向进行切片,其他方向(MMC -> Ada,Ada <-> Napcat)不切片
|
仅在 Ada -> MMC 方向进行切片,其他方向(MMC -> Ada,Ada <-> Napcat)不切片
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import uuid
|
import uuid
|
||||||
import asyncio
|
import asyncio
|
||||||
@@ -15,10 +16,9 @@ from src.common.logger import get_logger
|
|||||||
logger = get_logger("napcat_adapter")
|
logger = get_logger("napcat_adapter")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class MessageChunker:
|
class MessageChunker:
|
||||||
"""消息切片器,用于处理大消息的分片发送"""
|
"""消息切片器,用于处理大消息的分片发送"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.max_chunk_size = global_config.slicing.max_frame_size * 1024
|
self.max_chunk_size = global_config.slicing.max_frame_size * 1024
|
||||||
|
|
||||||
@@ -29,19 +29,21 @@ class MessageChunker:
|
|||||||
message_str = json.dumps(message, ensure_ascii=False)
|
message_str = json.dumps(message, ensure_ascii=False)
|
||||||
else:
|
else:
|
||||||
message_str = message
|
message_str = message
|
||||||
return len(message_str.encode('utf-8')) > self.max_chunk_size
|
return len(message_str.encode("utf-8")) > self.max_chunk_size
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"检查消息大小时出错: {e}")
|
logger.error(f"检查消息大小时出错: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def chunk_message(self, message: Union[str, Dict[str, Any]], chunk_id: Optional[str] = None) -> List[Dict[str, Any]]:
|
def chunk_message(
|
||||||
|
self, message: Union[str, Dict[str, Any]], chunk_id: Optional[str] = None
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
将消息切片
|
将消息切片
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
message: 要切片的消息(字符串或字典)
|
message: 要切片的消息(字符串或字典)
|
||||||
chunk_id: 切片组ID,如果不提供则自动生成
|
chunk_id: 切片组ID,如果不提供则自动生成
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
切片后的消息字典列表
|
切片后的消息字典列表
|
||||||
"""
|
"""
|
||||||
@@ -51,30 +53,30 @@ class MessageChunker:
|
|||||||
message_str = json.dumps(message, ensure_ascii=False)
|
message_str = json.dumps(message, ensure_ascii=False)
|
||||||
else:
|
else:
|
||||||
message_str = message
|
message_str = message
|
||||||
|
|
||||||
if not self.should_chunk_message(message_str):
|
if not self.should_chunk_message(message_str):
|
||||||
# 不需要切片的情况,如果输入是字典则返回字典,如果是字符串则包装成非切片标记的字典
|
# 不需要切片的情况,如果输入是字典则返回字典,如果是字符串则包装成非切片标记的字典
|
||||||
if isinstance(message, dict):
|
if isinstance(message, dict):
|
||||||
return [message]
|
return [message]
|
||||||
else:
|
else:
|
||||||
return [{"_original_message": message_str}]
|
return [{"_original_message": message_str}]
|
||||||
|
|
||||||
if chunk_id is None:
|
if chunk_id is None:
|
||||||
chunk_id = str(uuid.uuid4())
|
chunk_id = str(uuid.uuid4())
|
||||||
|
|
||||||
message_bytes = message_str.encode('utf-8')
|
message_bytes = message_str.encode("utf-8")
|
||||||
total_size = len(message_bytes)
|
total_size = len(message_bytes)
|
||||||
|
|
||||||
# 计算需要多少个切片
|
# 计算需要多少个切片
|
||||||
num_chunks = (total_size + self.max_chunk_size - 1) // self.max_chunk_size
|
num_chunks = (total_size + self.max_chunk_size - 1) // self.max_chunk_size
|
||||||
|
|
||||||
chunks = []
|
chunks = []
|
||||||
for i in range(num_chunks):
|
for i in range(num_chunks):
|
||||||
start_pos = i * self.max_chunk_size
|
start_pos = i * self.max_chunk_size
|
||||||
end_pos = min(start_pos + self.max_chunk_size, total_size)
|
end_pos = min(start_pos + self.max_chunk_size, total_size)
|
||||||
|
|
||||||
chunk_data = message_bytes[start_pos:end_pos]
|
chunk_data = message_bytes[start_pos:end_pos]
|
||||||
|
|
||||||
# 构建切片消息
|
# 构建切片消息
|
||||||
chunk_message = {
|
chunk_message = {
|
||||||
"__mmc_chunk_info__": {
|
"__mmc_chunk_info__": {
|
||||||
@@ -83,17 +85,17 @@ class MessageChunker:
|
|||||||
"total_chunks": num_chunks,
|
"total_chunks": num_chunks,
|
||||||
"chunk_size": len(chunk_data),
|
"chunk_size": len(chunk_data),
|
||||||
"total_size": total_size,
|
"total_size": total_size,
|
||||||
"timestamp": time.time()
|
"timestamp": time.time(),
|
||||||
},
|
},
|
||||||
"__mmc_chunk_data__": chunk_data.decode('utf-8', errors='ignore'),
|
"__mmc_chunk_data__": chunk_data.decode("utf-8", errors="ignore"),
|
||||||
"__mmc_is_chunked__": True
|
"__mmc_is_chunked__": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
chunks.append(chunk_message)
|
chunks.append(chunk_message)
|
||||||
|
|
||||||
logger.debug(f"消息切片完成: {total_size} bytes -> {num_chunks} chunks (ID: {chunk_id})")
|
logger.debug(f"消息切片完成: {total_size} bytes -> {num_chunks} chunks (ID: {chunk_id})")
|
||||||
return chunks
|
return chunks
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"消息切片时出错: {e}")
|
logger.error(f"消息切片时出错: {e}")
|
||||||
# 出错时返回原消息
|
# 出错时返回原消息
|
||||||
@@ -101,7 +103,7 @@ class MessageChunker:
|
|||||||
return [message]
|
return [message]
|
||||||
else:
|
else:
|
||||||
return [{"_original_message": message}]
|
return [{"_original_message": message}]
|
||||||
|
|
||||||
def is_chunk_message(self, message: Union[str, Dict[str, Any]]) -> bool:
|
def is_chunk_message(self, message: Union[str, Dict[str, Any]]) -> bool:
|
||||||
"""判断是否是切片消息"""
|
"""判断是否是切片消息"""
|
||||||
try:
|
try:
|
||||||
@@ -109,12 +111,12 @@ class MessageChunker:
|
|||||||
data = json.loads(message)
|
data = json.loads(message)
|
||||||
else:
|
else:
|
||||||
data = message
|
data = message
|
||||||
|
|
||||||
return (
|
return (
|
||||||
isinstance(data, dict) and
|
isinstance(data, dict)
|
||||||
"__mmc_chunk_info__" in data and
|
and "__mmc_chunk_info__" in data
|
||||||
"__mmc_chunk_data__" in data and
|
and "__mmc_chunk_data__" in data
|
||||||
"__mmc_is_chunked__" in data
|
and "__mmc_is_chunked__" in data
|
||||||
)
|
)
|
||||||
except (json.JSONDecodeError, TypeError):
|
except (json.JSONDecodeError, TypeError):
|
||||||
return False
|
return False
|
||||||
@@ -122,17 +124,17 @@ class MessageChunker:
|
|||||||
|
|
||||||
class MessageReassembler:
|
class MessageReassembler:
|
||||||
"""消息重组器,用于重组接收到的切片消息"""
|
"""消息重组器,用于重组接收到的切片消息"""
|
||||||
|
|
||||||
def __init__(self, timeout: int = 30):
|
def __init__(self, timeout: int = 30):
|
||||||
self.timeout = timeout
|
self.timeout = timeout
|
||||||
self.chunk_buffers: Dict[str, Dict[str, Any]] = {}
|
self.chunk_buffers: Dict[str, Dict[str, Any]] = {}
|
||||||
self._cleanup_task = None
|
self._cleanup_task = None
|
||||||
|
|
||||||
async def start_cleanup_task(self):
|
async def start_cleanup_task(self):
|
||||||
"""启动清理任务"""
|
"""启动清理任务"""
|
||||||
if self._cleanup_task is None:
|
if self._cleanup_task is None:
|
||||||
self._cleanup_task = asyncio.create_task(self._cleanup_expired_chunks())
|
self._cleanup_task = asyncio.create_task(self._cleanup_expired_chunks())
|
||||||
|
|
||||||
async def stop_cleanup_task(self):
|
async def stop_cleanup_task(self):
|
||||||
"""停止清理任务"""
|
"""停止清理任务"""
|
||||||
if self._cleanup_task:
|
if self._cleanup_task:
|
||||||
@@ -142,35 +144,35 @@ class MessageReassembler:
|
|||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
pass
|
pass
|
||||||
self._cleanup_task = None
|
self._cleanup_task = None
|
||||||
|
|
||||||
async def _cleanup_expired_chunks(self):
|
async def _cleanup_expired_chunks(self):
|
||||||
"""清理过期的切片缓冲区"""
|
"""清理过期的切片缓冲区"""
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
await asyncio.sleep(10) # 每10秒检查一次
|
await asyncio.sleep(10) # 每10秒检查一次
|
||||||
current_time = time.time()
|
current_time = time.time()
|
||||||
|
|
||||||
expired_chunks = []
|
expired_chunks = []
|
||||||
for chunk_id, buffer_info in self.chunk_buffers.items():
|
for chunk_id, buffer_info in self.chunk_buffers.items():
|
||||||
if current_time - buffer_info['timestamp'] > self.timeout:
|
if current_time - buffer_info["timestamp"] > self.timeout:
|
||||||
expired_chunks.append(chunk_id)
|
expired_chunks.append(chunk_id)
|
||||||
|
|
||||||
for chunk_id in expired_chunks:
|
for chunk_id in expired_chunks:
|
||||||
logger.warning(f"清理过期的切片缓冲区: {chunk_id}")
|
logger.warning(f"清理过期的切片缓冲区: {chunk_id}")
|
||||||
del self.chunk_buffers[chunk_id]
|
del self.chunk_buffers[chunk_id]
|
||||||
|
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
break
|
break
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"清理过期切片时出错: {e}")
|
logger.error(f"清理过期切片时出错: {e}")
|
||||||
|
|
||||||
async def add_chunk(self, message: Union[str, Dict[str, Any]]) -> Optional[Dict[str, Any]]:
|
async def add_chunk(self, message: Union[str, Dict[str, Any]]) -> Optional[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
添加切片,如果切片完整则返回重组后的消息
|
添加切片,如果切片完整则返回重组后的消息
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
message: 切片消息(字符串或字典)
|
message: 切片消息(字符串或字典)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
如果切片完整则返回重组后的原始消息字典,否则返回None
|
如果切片完整则返回重组后的原始消息字典,否则返回None
|
||||||
"""
|
"""
|
||||||
@@ -180,7 +182,7 @@ class MessageReassembler:
|
|||||||
chunk_data = json.loads(message)
|
chunk_data = json.loads(message)
|
||||||
else:
|
else:
|
||||||
chunk_data = message
|
chunk_data = message
|
||||||
|
|
||||||
# 检查是否是切片消息
|
# 检查是否是切片消息
|
||||||
if not chunker.is_chunk_message(chunk_data):
|
if not chunker.is_chunk_message(chunk_data):
|
||||||
# 不是切片消息,直接返回
|
# 不是切片消息,直接返回
|
||||||
@@ -192,38 +194,38 @@ class MessageReassembler:
|
|||||||
return {"text_message": chunk_data["_original_message"]}
|
return {"text_message": chunk_data["_original_message"]}
|
||||||
else:
|
else:
|
||||||
return chunk_data
|
return chunk_data
|
||||||
|
|
||||||
chunk_info = chunk_data["__mmc_chunk_info__"]
|
chunk_info = chunk_data["__mmc_chunk_info__"]
|
||||||
chunk_content = chunk_data["__mmc_chunk_data__"]
|
chunk_content = chunk_data["__mmc_chunk_data__"]
|
||||||
|
|
||||||
chunk_id = chunk_info["chunk_id"]
|
chunk_id = chunk_info["chunk_id"]
|
||||||
chunk_index = chunk_info["chunk_index"]
|
chunk_index = chunk_info["chunk_index"]
|
||||||
total_chunks = chunk_info["total_chunks"]
|
total_chunks = chunk_info["total_chunks"]
|
||||||
chunk_timestamp = chunk_info.get("timestamp", time.time())
|
chunk_timestamp = chunk_info.get("timestamp", time.time())
|
||||||
|
|
||||||
# 初始化缓冲区
|
# 初始化缓冲区
|
||||||
if chunk_id not in self.chunk_buffers:
|
if chunk_id not in self.chunk_buffers:
|
||||||
self.chunk_buffers[chunk_id] = {
|
self.chunk_buffers[chunk_id] = {
|
||||||
"chunks": {},
|
"chunks": {},
|
||||||
"total_chunks": total_chunks,
|
"total_chunks": total_chunks,
|
||||||
"received_chunks": 0,
|
"received_chunks": 0,
|
||||||
"timestamp": chunk_timestamp
|
"timestamp": chunk_timestamp,
|
||||||
}
|
}
|
||||||
|
|
||||||
buffer = self.chunk_buffers[chunk_id]
|
buffer = self.chunk_buffers[chunk_id]
|
||||||
|
|
||||||
# 检查切片是否已经接收过
|
# 检查切片是否已经接收过
|
||||||
if chunk_index in buffer["chunks"]:
|
if chunk_index in buffer["chunks"]:
|
||||||
logger.warning(f"重复接收切片: {chunk_id}#{chunk_index}")
|
logger.warning(f"重复接收切片: {chunk_id}#{chunk_index}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# 添加切片
|
# 添加切片
|
||||||
buffer["chunks"][chunk_index] = chunk_content
|
buffer["chunks"][chunk_index] = chunk_content
|
||||||
buffer["received_chunks"] += 1
|
buffer["received_chunks"] += 1
|
||||||
buffer["timestamp"] = time.time() # 更新时间戳
|
buffer["timestamp"] = time.time() # 更新时间戳
|
||||||
|
|
||||||
logger.debug(f"接收切片: {chunk_id}#{chunk_index} ({buffer['received_chunks']}/{total_chunks})")
|
logger.debug(f"接收切片: {chunk_id}#{chunk_index} ({buffer['received_chunks']}/{total_chunks})")
|
||||||
|
|
||||||
# 检查是否接收完整
|
# 检查是否接收完整
|
||||||
if buffer["received_chunks"] == total_chunks:
|
if buffer["received_chunks"] == total_chunks:
|
||||||
# 重组消息
|
# 重组消息
|
||||||
@@ -233,25 +235,25 @@ class MessageReassembler:
|
|||||||
logger.error(f"切片 {chunk_id}#{i} 缺失,无法重组")
|
logger.error(f"切片 {chunk_id}#{i} 缺失,无法重组")
|
||||||
return None
|
return None
|
||||||
reassembled_message += buffer["chunks"][i]
|
reassembled_message += buffer["chunks"][i]
|
||||||
|
|
||||||
# 清理缓冲区
|
# 清理缓冲区
|
||||||
del self.chunk_buffers[chunk_id]
|
del self.chunk_buffers[chunk_id]
|
||||||
|
|
||||||
logger.debug(f"消息重组完成: {chunk_id} ({len(reassembled_message)} chars)")
|
logger.debug(f"消息重组完成: {chunk_id} ({len(reassembled_message)} chars)")
|
||||||
|
|
||||||
# 尝试反序列化重组后的消息
|
# 尝试反序列化重组后的消息
|
||||||
try:
|
try:
|
||||||
return json.loads(reassembled_message)
|
return json.loads(reassembled_message)
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
# 如果不能反序列化为JSON,则作为文本消息返回
|
# 如果不能反序列化为JSON,则作为文本消息返回
|
||||||
return {"text_message": reassembled_message}
|
return {"text_message": reassembled_message}
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
except (json.JSONDecodeError, KeyError, TypeError) as e:
|
except (json.JSONDecodeError, KeyError, TypeError) as e:
|
||||||
logger.error(f"处理切片消息时出错: {e}")
|
logger.error(f"处理切片消息时出错: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_pending_chunks_info(self) -> Dict[str, Any]:
|
def get_pending_chunks_info(self) -> Dict[str, Any]:
|
||||||
"""获取待处理切片信息"""
|
"""获取待处理切片信息"""
|
||||||
info = {}
|
info = {}
|
||||||
@@ -260,11 +262,11 @@ class MessageReassembler:
|
|||||||
"received": buffer["received_chunks"],
|
"received": buffer["received_chunks"],
|
||||||
"total": buffer["total_chunks"],
|
"total": buffer["total_chunks"],
|
||||||
"progress": f"{buffer['received_chunks']}/{buffer['total_chunks']}",
|
"progress": f"{buffer['received_chunks']}/{buffer['total_chunks']}",
|
||||||
"age_seconds": time.time() - buffer["timestamp"]
|
"age_seconds": time.time() - buffer["timestamp"],
|
||||||
}
|
}
|
||||||
return info
|
return info
|
||||||
|
|
||||||
|
|
||||||
# 全局实例
|
# 全局实例
|
||||||
chunker = MessageChunker()
|
chunker = MessageChunker()
|
||||||
reassembler = MessageReassembler()
|
reassembler = MessageReassembler()
|
||||||
|
|||||||
@@ -743,31 +743,31 @@ class MessageHandler:
|
|||||||
"""
|
"""
|
||||||
message_data: dict = raw_message.get("data", {})
|
message_data: dict = raw_message.get("data", {})
|
||||||
json_data = message_data.get("data", "")
|
json_data = message_data.get("data", "")
|
||||||
|
|
||||||
# 检查JSON消息格式
|
# 检查JSON消息格式
|
||||||
if not message_data or "data" not in message_data:
|
if not message_data or "data" not in message_data:
|
||||||
logger.warning("JSON消息格式不正确")
|
logger.warning("JSON消息格式不正确")
|
||||||
return Seg(type="json", data=json.dumps(message_data))
|
return Seg(type="json", data=json.dumps(message_data))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
nested_data = json.loads(json_data)
|
nested_data = json.loads(json_data)
|
||||||
|
|
||||||
# 检查是否是QQ小程序分享消息
|
# 检查是否是QQ小程序分享消息
|
||||||
if "app" in nested_data and "com.tencent.miniapp" in str(nested_data.get("app", "")):
|
if "app" in nested_data and "com.tencent.miniapp" in str(nested_data.get("app", "")):
|
||||||
logger.debug("检测到QQ小程序分享消息,开始提取信息")
|
logger.debug("检测到QQ小程序分享消息,开始提取信息")
|
||||||
|
|
||||||
# 提取目标字段
|
# 提取目标字段
|
||||||
extracted_info = {}
|
extracted_info = {}
|
||||||
|
|
||||||
# 提取 meta.detail_1 中的信息
|
# 提取 meta.detail_1 中的信息
|
||||||
meta = nested_data.get("meta", {})
|
meta = nested_data.get("meta", {})
|
||||||
detail_1 = meta.get("detail_1", {})
|
detail_1 = meta.get("detail_1", {})
|
||||||
|
|
||||||
if detail_1:
|
if detail_1:
|
||||||
extracted_info["title"] = detail_1.get("title", "")
|
extracted_info["title"] = detail_1.get("title", "")
|
||||||
extracted_info["desc"] = detail_1.get("desc", "")
|
extracted_info["desc"] = detail_1.get("desc", "")
|
||||||
qqdocurl = detail_1.get("qqdocurl", "")
|
qqdocurl = detail_1.get("qqdocurl", "")
|
||||||
|
|
||||||
# 从qqdocurl中提取b23.tv短链接
|
# 从qqdocurl中提取b23.tv短链接
|
||||||
if qqdocurl and "b23.tv" in qqdocurl:
|
if qqdocurl and "b23.tv" in qqdocurl:
|
||||||
# 查找b23.tv链接的起始位置
|
# 查找b23.tv链接的起始位置
|
||||||
@@ -785,26 +785,29 @@ class MessageHandler:
|
|||||||
extracted_info["short_url"] = qqdocurl
|
extracted_info["short_url"] = qqdocurl
|
||||||
else:
|
else:
|
||||||
extracted_info["short_url"] = qqdocurl
|
extracted_info["short_url"] = qqdocurl
|
||||||
|
|
||||||
# 如果成功提取到关键信息,返回格式化的文本
|
# 如果成功提取到关键信息,返回格式化的文本
|
||||||
if extracted_info.get("title") or extracted_info.get("desc") or extracted_info.get("short_url"):
|
if extracted_info.get("title") or extracted_info.get("desc") or extracted_info.get("short_url"):
|
||||||
content_parts = []
|
content_parts = []
|
||||||
|
|
||||||
if extracted_info.get("title"):
|
if extracted_info.get("title"):
|
||||||
content_parts.append(f"来源: {extracted_info['title']}")
|
content_parts.append(f"来源: {extracted_info['title']}")
|
||||||
|
|
||||||
if extracted_info.get("desc"):
|
if extracted_info.get("desc"):
|
||||||
content_parts.append(f"标题: {extracted_info['desc']}")
|
content_parts.append(f"标题: {extracted_info['desc']}")
|
||||||
|
|
||||||
if extracted_info.get("short_url"):
|
if extracted_info.get("short_url"):
|
||||||
content_parts.append(f"链接: {extracted_info['short_url']}")
|
content_parts.append(f"链接: {extracted_info['short_url']}")
|
||||||
|
|
||||||
formatted_content = "\n".join(content_parts)
|
formatted_content = "\n".join(content_parts)
|
||||||
return Seg(type="text", data=f"这是一条小程序分享消息,可以根据来源,考虑使用对应解析工具\n{formatted_content}")
|
return Seg(
|
||||||
|
type="text",
|
||||||
|
data=f"这是一条小程序分享消息,可以根据来源,考虑使用对应解析工具\n{formatted_content}",
|
||||||
|
)
|
||||||
|
|
||||||
# 如果没有提取到关键信息,返回None
|
# 如果没有提取到关键信息,返回None
|
||||||
return None
|
return None
|
||||||
|
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
logger.error(f"解析JSON消息失败: {e}")
|
logger.error(f"解析JSON消息失败: {e}")
|
||||||
return None
|
return None
|
||||||
|
|||||||
@@ -28,36 +28,36 @@ class MessageSending:
|
|||||||
try:
|
try:
|
||||||
# 检查是否需要切片发送
|
# 检查是否需要切片发送
|
||||||
message_dict = message_base.to_dict()
|
message_dict = message_base.to_dict()
|
||||||
|
|
||||||
if chunker.should_chunk_message(message_dict):
|
if chunker.should_chunk_message(message_dict):
|
||||||
logger.info(f"消息过大,进行切片发送到 MaiBot")
|
logger.info(f"消息过大,进行切片发送到 MaiBot")
|
||||||
|
|
||||||
# 切片消息
|
# 切片消息
|
||||||
chunks = chunker.chunk_message(message_dict)
|
chunks = chunker.chunk_message(message_dict)
|
||||||
|
|
||||||
# 逐个发送切片
|
# 逐个发送切片
|
||||||
for i, chunk in enumerate(chunks):
|
for i, chunk in enumerate(chunks):
|
||||||
logger.debug(f"发送切片 {i+1}/{len(chunks)} 到 MaiBot")
|
logger.debug(f"发送切片 {i + 1}/{len(chunks)} 到 MaiBot")
|
||||||
|
|
||||||
# 获取对应的客户端并发送切片
|
# 获取对应的客户端并发送切片
|
||||||
platform = message_base.message_info.platform
|
platform = message_base.message_info.platform
|
||||||
if platform not in self.maibot_router.clients:
|
if platform not in self.maibot_router.clients:
|
||||||
logger.error(f"平台 {platform} 未连接")
|
logger.error(f"平台 {platform} 未连接")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
client = self.maibot_router.clients[platform]
|
client = self.maibot_router.clients[platform]
|
||||||
send_status = await client.send_message(chunk)
|
send_status = await client.send_message(chunk)
|
||||||
|
|
||||||
if not send_status:
|
if not send_status:
|
||||||
logger.error(f"发送切片 {i+1}/{len(chunks)} 失败")
|
logger.error(f"发送切片 {i + 1}/{len(chunks)} 失败")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# 使用配置中的延迟时间
|
# 使用配置中的延迟时间
|
||||||
if i < len(chunks) - 1:
|
if i < len(chunks) - 1:
|
||||||
delay_seconds = global_config.slicing.delay_ms / 1000.0
|
delay_seconds = global_config.slicing.delay_ms / 1000.0
|
||||||
logger.debug(f"切片发送延迟: {global_config.slicing.delay_ms}毫秒")
|
logger.debug(f"切片发送延迟: {global_config.slicing.delay_ms}毫秒")
|
||||||
await asyncio.sleep(delay_seconds)
|
await asyncio.sleep(delay_seconds)
|
||||||
|
|
||||||
logger.debug("所有切片发送完成")
|
logger.debug("所有切片发送完成")
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
@@ -66,7 +66,7 @@ class MessageSending:
|
|||||||
if not send_status:
|
if not send_status:
|
||||||
raise RuntimeError("可能是路由未正确配置或连接异常")
|
raise RuntimeError("可能是路由未正确配置或连接异常")
|
||||||
return send_status
|
return send_status
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"发送消息失败: {str(e)}")
|
logger.error(f"发送消息失败: {str(e)}")
|
||||||
logger.error("请检查与MaiBot之间的连接")
|
logger.error("请检查与MaiBot之间的连接")
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ class CycleProcessor:
|
|||||||
context: HFC聊天上下文对象,包含聊天流、能量值等信息
|
context: HFC聊天上下文对象,包含聊天流、能量值等信息
|
||||||
response_handler: 响应处理器,负责生成和发送回复
|
response_handler: 响应处理器,负责生成和发送回复
|
||||||
cycle_tracker: 循环跟踪器,负责记录和管理每次思考循环的信息
|
cycle_tracker: 循环跟踪器,负责记录和管理每次思考循环的信息
|
||||||
"""
|
"""
|
||||||
self.context = context
|
self.context = context
|
||||||
self.response_handler = response_handler
|
self.response_handler = response_handler
|
||||||
self.cycle_tracker = cycle_tracker
|
self.cycle_tracker = cycle_tracker
|
||||||
@@ -57,12 +57,12 @@ class CycleProcessor:
|
|||||||
|
|
||||||
# 存储reply action信息
|
# 存储reply action信息
|
||||||
person_info_manager = get_person_info_manager()
|
person_info_manager = get_person_info_manager()
|
||||||
|
|
||||||
# 获取 platform,如果不存在则从 chat_stream 获取,如果还是 None 则使用默认值
|
# 获取 platform,如果不存在则从 chat_stream 获取,如果还是 None 则使用默认值
|
||||||
platform = action_message.get("chat_info_platform")
|
platform = action_message.get("chat_info_platform")
|
||||||
if platform is None:
|
if platform is None:
|
||||||
platform = getattr(self.context.chat_stream, "platform", "unknown")
|
platform = getattr(self.context.chat_stream, "platform", "unknown")
|
||||||
|
|
||||||
person_id = person_info_manager.get_person_id(
|
person_id = person_info_manager.get_person_id(
|
||||||
platform,
|
platform,
|
||||||
action_message.get("user_id", ""),
|
action_message.get("user_id", ""),
|
||||||
@@ -94,8 +94,8 @@ class CycleProcessor:
|
|||||||
}
|
}
|
||||||
|
|
||||||
return loop_info, reply_text, cycle_timers
|
return loop_info, reply_text, cycle_timers
|
||||||
|
|
||||||
async def observe(self,interest_value:float = 0.0) -> bool:
|
async def observe(self, interest_value: float = 0.0) -> bool:
|
||||||
"""
|
"""
|
||||||
观察和处理单次思考循环的核心方法
|
观察和处理单次思考循环的核心方法
|
||||||
|
|
||||||
@@ -114,7 +114,7 @@ class CycleProcessor:
|
|||||||
"""
|
"""
|
||||||
action_type = "no_action"
|
action_type = "no_action"
|
||||||
reply_text = "" # 初始化reply_text变量,避免UnboundLocalError
|
reply_text = "" # 初始化reply_text变量,避免UnboundLocalError
|
||||||
|
|
||||||
# 使用sigmoid函数将interest_value转换为概率
|
# 使用sigmoid函数将interest_value转换为概率
|
||||||
# 当interest_value为0时,概率接近0(使用Focus模式)
|
# 当interest_value为0时,概率接近0(使用Focus模式)
|
||||||
# 当interest_value很高时,概率接近1(使用Normal模式)
|
# 当interest_value很高时,概率接近1(使用Normal模式)
|
||||||
@@ -127,16 +127,24 @@ class CycleProcessor:
|
|||||||
k = 2.0 # 控制曲线陡峭程度
|
k = 2.0 # 控制曲线陡峭程度
|
||||||
x0 = 1.0 # 控制曲线中心点
|
x0 = 1.0 # 控制曲线中心点
|
||||||
return 1.0 / (1.0 + math.exp(-k * (interest_val - x0)))
|
return 1.0 / (1.0 + math.exp(-k * (interest_val - x0)))
|
||||||
|
|
||||||
normal_mode_probability = calculate_normal_mode_probability(interest_value) * 0.5 / global_config.chat.get_current_talk_frequency(self.context.stream_id)
|
normal_mode_probability = (
|
||||||
|
calculate_normal_mode_probability(interest_value)
|
||||||
|
* 0.5
|
||||||
|
/ global_config.chat.get_current_talk_frequency(self.context.stream_id)
|
||||||
|
)
|
||||||
|
|
||||||
# 根据概率决定使用哪种模式
|
# 根据概率决定使用哪种模式
|
||||||
if random.random() < normal_mode_probability:
|
if random.random() < normal_mode_probability:
|
||||||
mode = ChatMode.NORMAL
|
mode = ChatMode.NORMAL
|
||||||
logger.info(f"{self.log_prefix} 基于兴趣值 {interest_value:.2f},概率 {normal_mode_probability:.2f},选择Normal planner模式")
|
logger.info(
|
||||||
|
f"{self.log_prefix} 基于兴趣值 {interest_value:.2f},概率 {normal_mode_probability:.2f},选择Normal planner模式"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
mode = ChatMode.FOCUS
|
mode = ChatMode.FOCUS
|
||||||
logger.info(f"{self.log_prefix} 基于兴趣值 {interest_value:.2f},概率 {normal_mode_probability:.2f},选择Focus planner模式")
|
logger.info(
|
||||||
|
f"{self.log_prefix} 基于兴趣值 {interest_value:.2f},概率 {normal_mode_probability:.2f},选择Focus planner模式"
|
||||||
|
)
|
||||||
|
|
||||||
cycle_timers, thinking_id = self.cycle_tracker.start_cycle()
|
cycle_timers, thinking_id = self.cycle_tracker.start_cycle()
|
||||||
logger.info(f"{self.log_prefix} 开始第{self.context.cycle_counter}次思考")
|
logger.info(f"{self.log_prefix} 开始第{self.context.cycle_counter}次思考")
|
||||||
@@ -165,12 +173,14 @@ class CycleProcessor:
|
|||||||
from src.plugin_system.core.event_manager import event_manager
|
from src.plugin_system.core.event_manager import event_manager
|
||||||
from src.plugin_system import EventType
|
from src.plugin_system import EventType
|
||||||
|
|
||||||
result = await event_manager.trigger_event(EventType.ON_PLAN,plugin_name="SYSTEM", stream_id=self.context.chat_stream)
|
result = await event_manager.trigger_event(
|
||||||
|
EventType.ON_PLAN, plugin_name="SYSTEM", stream_id=self.context.chat_stream
|
||||||
|
)
|
||||||
if not result.all_continue_process():
|
if not result.all_continue_process():
|
||||||
raise UserWarning(f"插件{result.get_summary().get('stopped_handlers', '')}于规划前中断了内容生成")
|
raise UserWarning(f"插件{result.get_summary().get('stopped_handlers', '')}于规划前中断了内容生成")
|
||||||
|
|
||||||
with Timer("规划器", cycle_timers):
|
with Timer("规划器", cycle_timers):
|
||||||
actions, _= await self.action_planner.plan(
|
actions, _ = await self.action_planner.plan(
|
||||||
mode=mode,
|
mode=mode,
|
||||||
loop_start_time=loop_start_time,
|
loop_start_time=loop_start_time,
|
||||||
available_actions=available_actions,
|
available_actions=available_actions,
|
||||||
@@ -183,7 +193,7 @@ class CycleProcessor:
|
|||||||
# 直接处理no_reply逻辑,不再通过动作系统
|
# 直接处理no_reply逻辑,不再通过动作系统
|
||||||
reason = action_info.get("reasoning", "选择不回复")
|
reason = action_info.get("reasoning", "选择不回复")
|
||||||
logger.info(f"{self.log_prefix} 选择不回复,原因: {reason}")
|
logger.info(f"{self.log_prefix} 选择不回复,原因: {reason}")
|
||||||
|
|
||||||
# 存储no_reply信息到数据库
|
# 存储no_reply信息到数据库
|
||||||
await database_api.store_action_info(
|
await database_api.store_action_info(
|
||||||
chat_stream=self.context.chat_stream,
|
chat_stream=self.context.chat_stream,
|
||||||
@@ -194,13 +204,8 @@ class CycleProcessor:
|
|||||||
action_data={"reason": reason},
|
action_data={"reason": reason},
|
||||||
action_name="no_reply",
|
action_name="no_reply",
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {"action_type": "no_reply", "success": True, "reply_text": "", "command": ""}
|
||||||
"action_type": "no_reply",
|
|
||||||
"success": True,
|
|
||||||
"reply_text": "",
|
|
||||||
"command": ""
|
|
||||||
}
|
|
||||||
elif action_info["action_type"] != "reply":
|
elif action_info["action_type"] != "reply":
|
||||||
# 执行普通动作
|
# 执行普通动作
|
||||||
with Timer("动作执行", cycle_timers):
|
with Timer("动作执行", cycle_timers):
|
||||||
@@ -210,40 +215,32 @@ class CycleProcessor:
|
|||||||
action_info["action_data"],
|
action_info["action_data"],
|
||||||
cycle_timers,
|
cycle_timers,
|
||||||
thinking_id,
|
thinking_id,
|
||||||
action_info["action_message"]
|
action_info["action_message"],
|
||||||
)
|
)
|
||||||
return {
|
return {
|
||||||
"action_type": action_info["action_type"],
|
"action_type": action_info["action_type"],
|
||||||
"success": success,
|
"success": success,
|
||||||
"reply_text": reply_text,
|
"reply_text": reply_text,
|
||||||
"command": command
|
"command": command,
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
success, response_set, _ = await generator_api.generate_reply(
|
success, response_set, _ = await generator_api.generate_reply(
|
||||||
chat_stream=self.context.chat_stream,
|
chat_stream=self.context.chat_stream,
|
||||||
reply_message = action_info["action_message"],
|
reply_message=action_info["action_message"],
|
||||||
available_actions=available_actions,
|
available_actions=available_actions,
|
||||||
enable_tool=global_config.tool.enable_tool,
|
enable_tool=global_config.tool.enable_tool,
|
||||||
request_type="chat.replyer",
|
request_type="chat.replyer",
|
||||||
from_plugin=False,
|
from_plugin=False,
|
||||||
)
|
)
|
||||||
if not success or not response_set:
|
if not success or not response_set:
|
||||||
logger.info(f"对 {action_info['action_message'].get('processed_plain_text')} 的回复生成失败")
|
logger.info(
|
||||||
return {
|
f"对 {action_info['action_message'].get('processed_plain_text')} 的回复生成失败"
|
||||||
"action_type": "reply",
|
)
|
||||||
"success": False,
|
return {"action_type": "reply", "success": False, "reply_text": "", "loop_info": None}
|
||||||
"reply_text": "",
|
|
||||||
"loop_info": None
|
|
||||||
}
|
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
logger.debug(f"{self.log_prefix} 并行执行:回复生成任务已被取消")
|
logger.debug(f"{self.log_prefix} 并行执行:回复生成任务已被取消")
|
||||||
return {
|
return {"action_type": "reply", "success": False, "reply_text": "", "loop_info": None}
|
||||||
"action_type": "reply",
|
|
||||||
"success": False,
|
|
||||||
"reply_text": "",
|
|
||||||
"loop_info": None
|
|
||||||
}
|
|
||||||
|
|
||||||
loop_info, reply_text, cycle_timers_reply = await self._send_and_store_reply(
|
loop_info, reply_text, cycle_timers_reply = await self._send_and_store_reply(
|
||||||
response_set,
|
response_set,
|
||||||
@@ -253,12 +250,7 @@ class CycleProcessor:
|
|||||||
thinking_id,
|
thinking_id,
|
||||||
actions,
|
actions,
|
||||||
)
|
)
|
||||||
return {
|
return {"action_type": "reply", "success": True, "reply_text": reply_text, "loop_info": loop_info}
|
||||||
"action_type": "reply",
|
|
||||||
"success": True,
|
|
||||||
"reply_text": reply_text,
|
|
||||||
"loop_info": loop_info
|
|
||||||
}
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"{self.log_prefix} 执行动作时出错: {e}")
|
logger.error(f"{self.log_prefix} 执行动作时出错: {e}")
|
||||||
logger.error(f"{self.log_prefix} 错误信息: {traceback.format_exc()}")
|
logger.error(f"{self.log_prefix} 错误信息: {traceback.format_exc()}")
|
||||||
@@ -267,9 +259,9 @@ class CycleProcessor:
|
|||||||
"success": False,
|
"success": False,
|
||||||
"reply_text": "",
|
"reply_text": "",
|
||||||
"loop_info": None,
|
"loop_info": None,
|
||||||
"error": str(e)
|
"error": str(e),
|
||||||
}
|
}
|
||||||
|
|
||||||
# 创建所有动作的后台任务
|
# 创建所有动作的后台任务
|
||||||
action_tasks = [asyncio.create_task(execute_action(action)) for action in actions]
|
action_tasks = [asyncio.create_task(execute_action(action)) for action in actions]
|
||||||
|
|
||||||
@@ -282,12 +274,12 @@ class CycleProcessor:
|
|||||||
action_success = False
|
action_success = False
|
||||||
action_reply_text = ""
|
action_reply_text = ""
|
||||||
action_command = ""
|
action_command = ""
|
||||||
|
|
||||||
for i, result in enumerate(results):
|
for i, result in enumerate(results):
|
||||||
if isinstance(result, BaseException):
|
if isinstance(result, BaseException):
|
||||||
logger.error(f"{self.log_prefix} 动作执行异常: {result}")
|
logger.error(f"{self.log_prefix} 动作执行异常: {result}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
action_info = actions[i]
|
action_info = actions[i]
|
||||||
if result["action_type"] != "reply":
|
if result["action_type"] != "reply":
|
||||||
action_success = result["success"]
|
action_success = result["success"]
|
||||||
@@ -327,7 +319,7 @@ class CycleProcessor:
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
reply_text = action_reply_text
|
reply_text = action_reply_text
|
||||||
|
|
||||||
if ENABLE_S4U:
|
if ENABLE_S4U:
|
||||||
await stop_typing()
|
await stop_typing()
|
||||||
|
|
||||||
@@ -342,7 +334,7 @@ class CycleProcessor:
|
|||||||
self.context.no_reply_consecutive = 0
|
self.context.no_reply_consecutive = 0
|
||||||
logger.debug(f"{self.log_prefix} 执行了{action_type}动作,重置no_reply计数器")
|
logger.debug(f"{self.log_prefix} 执行了{action_type}动作,重置no_reply计数器")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
if action_type == "no_reply":
|
if action_type == "no_reply":
|
||||||
self.context.no_reply_consecutive += 1
|
self.context.no_reply_consecutive += 1
|
||||||
self.context.chat_instance._determine_form_type()
|
self.context.chat_instance._determine_form_type()
|
||||||
|
|||||||
@@ -91,25 +91,24 @@ class CycleTracker:
|
|||||||
|
|
||||||
# 获取动作类型,兼容新旧格式
|
# 获取动作类型,兼容新旧格式
|
||||||
action_type = "未知动作"
|
action_type = "未知动作"
|
||||||
if hasattr(self, '_current_cycle_detail') and self._current_cycle_detail:
|
if hasattr(self, "_current_cycle_detail") and self._current_cycle_detail:
|
||||||
loop_plan_info = self._current_cycle_detail.loop_plan_info
|
loop_plan_info = self._current_cycle_detail.loop_plan_info
|
||||||
if isinstance(loop_plan_info, dict):
|
if isinstance(loop_plan_info, dict):
|
||||||
action_result = loop_plan_info.get('action_result', {})
|
action_result = loop_plan_info.get("action_result", {})
|
||||||
if isinstance(action_result, dict):
|
if isinstance(action_result, dict):
|
||||||
# 旧格式:action_result是字典
|
# 旧格式:action_result是字典
|
||||||
action_type = action_result.get('action_type', '未知动作')
|
action_type = action_result.get("action_type", "未知动作")
|
||||||
elif isinstance(action_result, list) and action_result:
|
elif isinstance(action_result, list) and action_result:
|
||||||
# 新格式:action_result是actions列表
|
# 新格式:action_result是actions列表
|
||||||
action_type = action_result[0].get('action_type', '未知动作')
|
action_type = action_result[0].get("action_type", "未知动作")
|
||||||
elif isinstance(loop_plan_info, list) and loop_plan_info:
|
elif isinstance(loop_plan_info, list) and loop_plan_info:
|
||||||
# 直接是actions列表的情况
|
# 直接是actions列表的情况
|
||||||
action_type = loop_plan_info[0].get('action_type', '未知动作')
|
action_type = loop_plan_info[0].get("action_type", "未知动作")
|
||||||
|
|
||||||
if self.context.current_cycle_detail.end_time and self.context.current_cycle_detail.start_time:
|
if self.context.current_cycle_detail.end_time and self.context.current_cycle_detail.start_time:
|
||||||
duration = self.context.current_cycle_detail.end_time - self.context.current_cycle_detail.start_time
|
duration = self.context.current_cycle_detail.end_time - self.context.current_cycle_detail.start_time
|
||||||
logger.info(
|
logger.info(
|
||||||
f"{self.context.log_prefix} 第{self.context.current_cycle_detail.cycle_id}次思考,"
|
f"{self.context.log_prefix} 第{self.context.current_cycle_detail.cycle_id}次思考,"
|
||||||
f"耗时: {duration:.1f}秒, "
|
f"耗时: {duration:.1f}秒, "
|
||||||
f"选择动作: {action_type}"
|
f"选择动作: {action_type}" + (f"\n详情: {'; '.join(timer_strings)}" if timer_strings else "")
|
||||||
+ (f"\n详情: {'; '.join(timer_strings)}" if timer_strings else "")
|
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -56,7 +56,7 @@ class HeartFChatting:
|
|||||||
|
|
||||||
self._loop_task: Optional[asyncio.Task] = None
|
self._loop_task: Optional[asyncio.Task] = None
|
||||||
self._proactive_monitor_task: Optional[asyncio.Task] = None
|
self._proactive_monitor_task: Optional[asyncio.Task] = None
|
||||||
|
|
||||||
# 记录最近3次的兴趣度
|
# 记录最近3次的兴趣度
|
||||||
self.recent_interest_records: deque = deque(maxlen=3)
|
self.recent_interest_records: deque = deque(maxlen=3)
|
||||||
self._initialize_chat_mode()
|
self._initialize_chat_mode()
|
||||||
@@ -183,7 +183,7 @@ class HeartFChatting:
|
|||||||
event = ProactiveTriggerEvent(
|
event = ProactiveTriggerEvent(
|
||||||
source="silence_monitor",
|
source="silence_monitor",
|
||||||
reason=f"聊天已沉默 {formatted_time}",
|
reason=f"聊天已沉默 {formatted_time}",
|
||||||
metadata={"silence_duration": silence_duration}
|
metadata={"silence_duration": silence_duration},
|
||||||
)
|
)
|
||||||
await self.proactive_thinker.think(event)
|
await self.proactive_thinker.think(event)
|
||||||
self.context.last_message_time = current_time
|
self.context.last_message_time = current_time
|
||||||
@@ -205,21 +205,30 @@ class HeartFChatting:
|
|||||||
stream_parts = self.context.stream_id.split(":")
|
stream_parts = self.context.stream_id.split(":")
|
||||||
current_chat_identifier = f"{stream_parts}:{stream_parts}" if len(stream_parts) >= 2 else self.context.stream_id
|
current_chat_identifier = f"{stream_parts}:{stream_parts}" if len(stream_parts) >= 2 else self.context.stream_id
|
||||||
|
|
||||||
enable_list = getattr(global_config.chat, "proactive_thinking_enable_in_groups" if is_group_chat else "proactive_thinking_enable_in_private", [])
|
enable_list = getattr(
|
||||||
|
global_config.chat,
|
||||||
|
"proactive_thinking_enable_in_groups" if is_group_chat else "proactive_thinking_enable_in_private",
|
||||||
|
[],
|
||||||
|
)
|
||||||
return not enable_list or current_chat_identifier in enable_list
|
return not enable_list or current_chat_identifier in enable_list
|
||||||
|
|
||||||
def _get_dynamic_thinking_interval(self) -> float:
|
def _get_dynamic_thinking_interval(self) -> float:
|
||||||
try:
|
try:
|
||||||
from src.utils.timing_utils import get_normal_distributed_interval
|
from src.utils.timing_utils import get_normal_distributed_interval
|
||||||
|
|
||||||
base_interval = global_config.chat.proactive_thinking_interval
|
base_interval = global_config.chat.proactive_thinking_interval
|
||||||
delta_sigma = getattr(global_config.chat, "delta_sigma", 120)
|
delta_sigma = getattr(global_config.chat, "delta_sigma", 120)
|
||||||
|
|
||||||
if base_interval <= 0: base_interval = abs(base_interval)
|
if base_interval <= 0:
|
||||||
if delta_sigma < 0: delta_sigma = abs(delta_sigma)
|
base_interval = abs(base_interval)
|
||||||
|
if delta_sigma < 0:
|
||||||
|
delta_sigma = abs(delta_sigma)
|
||||||
|
|
||||||
|
if base_interval == 0 and delta_sigma == 0:
|
||||||
|
return 300
|
||||||
|
if delta_sigma == 0:
|
||||||
|
return base_interval
|
||||||
|
|
||||||
if base_interval == 0 and delta_sigma == 0: return 300
|
|
||||||
if delta_sigma == 0: return base_interval
|
|
||||||
|
|
||||||
sigma_percentage = delta_sigma / base_interval if base_interval > 0 else delta_sigma / 1000
|
sigma_percentage = delta_sigma / base_interval if base_interval > 0 else delta_sigma / 1000
|
||||||
return get_normal_distributed_interval(base_interval, sigma_percentage, 1, 86400, use_3sigma_rule=True)
|
return get_normal_distributed_interval(base_interval, sigma_percentage, 1, 86400, use_3sigma_rule=True)
|
||||||
|
|
||||||
@@ -335,29 +344,30 @@ class HeartFChatting:
|
|||||||
|
|
||||||
# 根据聊天模式处理新消息
|
# 根据聊天模式处理新消息
|
||||||
# 统一使用 _should_process_messages 判断是否应该处理
|
# 统一使用 _should_process_messages 判断是否应该处理
|
||||||
should_process,interest_value = await self._should_process_messages(recent_messages)
|
should_process, interest_value = await self._should_process_messages(recent_messages)
|
||||||
if should_process:
|
if should_process:
|
||||||
self.context.last_read_time = time.time()
|
self.context.last_read_time = time.time()
|
||||||
await self.cycle_processor.observe(interest_value = interest_value)
|
await self.cycle_processor.observe(interest_value=interest_value)
|
||||||
else:
|
else:
|
||||||
# Normal模式:消息数量不足,等待
|
# Normal模式:消息数量不足,等待
|
||||||
await asyncio.sleep(0.5)
|
await asyncio.sleep(0.5)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
if not await self._should_process_messages(recent_messages):
|
if not await self._should_process_messages(recent_messages):
|
||||||
return has_new_messages
|
return has_new_messages
|
||||||
|
|
||||||
# 处理新消息
|
# 处理新消息
|
||||||
for message in recent_messages:
|
for message in recent_messages:
|
||||||
await self.cycle_processor.observe(interest_value = interest_value)
|
await self.cycle_processor.observe(interest_value=interest_value)
|
||||||
|
|
||||||
# 如果成功观察,增加能量值并重置累积兴趣值
|
# 如果成功观察,增加能量值并重置累积兴趣值
|
||||||
if has_new_messages:
|
if has_new_messages:
|
||||||
self.context.energy_value += 1 / global_config.chat.focus_value
|
self.context.energy_value += 1 / global_config.chat.focus_value
|
||||||
# 重置累积兴趣值,因为消息已经被成功处理
|
# 重置累积兴趣值,因为消息已经被成功处理
|
||||||
self.context.breaking_accumulated_interest = 0.0
|
self.context.breaking_accumulated_interest = 0.0
|
||||||
logger.info(f"{self.context.log_prefix} 能量值增加,当前能量值:{self.context.energy_value:.1f},重置累积兴趣值")
|
logger.info(
|
||||||
|
f"{self.context.log_prefix} 能量值增加,当前能量值:{self.context.energy_value:.1f},重置累积兴趣值"
|
||||||
|
)
|
||||||
|
|
||||||
# 更新上一帧的睡眠状态
|
# 更新上一帧的睡眠状态
|
||||||
self.context.was_sleeping = is_sleeping
|
self.context.was_sleeping = is_sleeping
|
||||||
@@ -378,7 +388,6 @@ class HeartFChatting:
|
|||||||
|
|
||||||
return has_new_messages
|
return has_new_messages
|
||||||
|
|
||||||
|
|
||||||
def _handle_wakeup_messages(self, messages):
|
def _handle_wakeup_messages(self, messages):
|
||||||
"""
|
"""
|
||||||
处理休眠状态下的消息,累积唤醒度
|
处理休眠状态下的消息,累积唤醒度
|
||||||
@@ -421,7 +430,7 @@ class HeartFChatting:
|
|||||||
logger.info(f"{self.context.log_prefix} breaking模式已禁用,使用waiting形式")
|
logger.info(f"{self.context.log_prefix} breaking模式已禁用,使用waiting形式")
|
||||||
self.context.focus_energy = 1
|
self.context.focus_energy = 1
|
||||||
return "waiting"
|
return "waiting"
|
||||||
|
|
||||||
# 如果连续no_reply次数少于3次,使用waiting形式
|
# 如果连续no_reply次数少于3次,使用waiting形式
|
||||||
if self.context.no_reply_consecutive <= 3:
|
if self.context.no_reply_consecutive <= 3:
|
||||||
self.context.focus_energy = 1
|
self.context.focus_energy = 1
|
||||||
@@ -429,12 +438,14 @@ class HeartFChatting:
|
|||||||
else:
|
else:
|
||||||
# 使用累积兴趣值而不是最近3次的记录
|
# 使用累积兴趣值而不是最近3次的记录
|
||||||
total_interest = self.context.breaking_accumulated_interest
|
total_interest = self.context.breaking_accumulated_interest
|
||||||
|
|
||||||
# 计算调整后的阈值
|
# 计算调整后的阈值
|
||||||
adjusted_threshold = 1 / global_config.chat.get_current_talk_frequency(self.context.stream_id)
|
adjusted_threshold = 1 / global_config.chat.get_current_talk_frequency(self.context.stream_id)
|
||||||
|
|
||||||
logger.info(f"{self.context.log_prefix} 累积兴趣值: {total_interest:.2f}, 调整后阈值: {adjusted_threshold:.2f}")
|
logger.info(
|
||||||
|
f"{self.context.log_prefix} 累积兴趣值: {total_interest:.2f}, 调整后阈值: {adjusted_threshold:.2f}"
|
||||||
|
)
|
||||||
|
|
||||||
# 如果累积兴趣值小于阈值,进入breaking形式
|
# 如果累积兴趣值小于阈值,进入breaking形式
|
||||||
if total_interest < adjusted_threshold:
|
if total_interest < adjusted_threshold:
|
||||||
logger.info(f"{self.context.log_prefix} 累积兴趣度不足,进入breaking形式")
|
logger.info(f"{self.context.log_prefix} 累积兴趣度不足,进入breaking形式")
|
||||||
@@ -445,7 +456,7 @@ class HeartFChatting:
|
|||||||
self.context.focus_energy = 1
|
self.context.focus_energy = 1
|
||||||
return "waiting"
|
return "waiting"
|
||||||
|
|
||||||
async def _should_process_messages(self, new_message: List[Dict[str, Any]]) -> tuple[bool,float]:
|
async def _should_process_messages(self, new_message: List[Dict[str, Any]]) -> tuple[bool, float]:
|
||||||
"""
|
"""
|
||||||
统一判断是否应该处理消息的函数
|
统一判断是否应该处理消息的函数
|
||||||
根据当前循环模式和消息内容决定是否继续处理
|
根据当前循环模式和消息内容决定是否继续处理
|
||||||
@@ -459,37 +470,39 @@ class HeartFChatting:
|
|||||||
|
|
||||||
modified_exit_count_threshold = self.context.focus_energy * 0.5 / talk_frequency
|
modified_exit_count_threshold = self.context.focus_energy * 0.5 / talk_frequency
|
||||||
modified_exit_interest_threshold = 1.5 / talk_frequency
|
modified_exit_interest_threshold = 1.5 / talk_frequency
|
||||||
|
|
||||||
# 计算当前批次消息的兴趣值
|
# 计算当前批次消息的兴趣值
|
||||||
batch_interest = 0.0
|
batch_interest = 0.0
|
||||||
for msg_dict in new_message:
|
for msg_dict in new_message:
|
||||||
interest_value = msg_dict.get("interest_value", 0.0)
|
interest_value = msg_dict.get("interest_value", 0.0)
|
||||||
if msg_dict.get("processed_plain_text", ""):
|
if msg_dict.get("processed_plain_text", ""):
|
||||||
batch_interest += interest_value
|
batch_interest += interest_value
|
||||||
|
|
||||||
# 在breaking形式下累积所有消息的兴趣值
|
# 在breaking形式下累积所有消息的兴趣值
|
||||||
if new_message_count > 0:
|
if new_message_count > 0:
|
||||||
self.context.breaking_accumulated_interest += batch_interest
|
self.context.breaking_accumulated_interest += batch_interest
|
||||||
total_interest = self.context.breaking_accumulated_interest
|
total_interest = self.context.breaking_accumulated_interest
|
||||||
else:
|
else:
|
||||||
total_interest = self.context.breaking_accumulated_interest
|
total_interest = self.context.breaking_accumulated_interest
|
||||||
|
|
||||||
if new_message_count >= modified_exit_count_threshold:
|
if new_message_count >= modified_exit_count_threshold:
|
||||||
# 记录兴趣度到列表
|
# 记录兴趣度到列表
|
||||||
self.recent_interest_records.append(total_interest)
|
self.recent_interest_records.append(total_interest)
|
||||||
# 重置累积兴趣值,因为已经达到了消息数量阈值
|
# 重置累积兴趣值,因为已经达到了消息数量阈值
|
||||||
self.context.breaking_accumulated_interest = 0.0
|
self.context.breaking_accumulated_interest = 0.0
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"{self.context.log_prefix} 累计消息数量达到{new_message_count}条(>{modified_exit_count_threshold:.1f}),结束等待,累积兴趣值: {total_interest:.2f}"
|
f"{self.context.log_prefix} 累计消息数量达到{new_message_count}条(>{modified_exit_count_threshold:.1f}),结束等待,累积兴趣值: {total_interest:.2f}"
|
||||||
)
|
)
|
||||||
return True,total_interest/new_message_count
|
return True, total_interest / new_message_count
|
||||||
|
|
||||||
# 检查累计兴趣值
|
# 检查累计兴趣值
|
||||||
if new_message_count > 0:
|
if new_message_count > 0:
|
||||||
# 只在兴趣值变化时输出log
|
# 只在兴趣值变化时输出log
|
||||||
if not hasattr(self, "_last_accumulated_interest") or total_interest != self._last_accumulated_interest:
|
if not hasattr(self, "_last_accumulated_interest") or total_interest != self._last_accumulated_interest:
|
||||||
logger.info(f"{self.context.log_prefix} breaking形式当前累积兴趣值: {total_interest:.2f}, 专注度: {global_config.chat.focus_value:.1f}")
|
logger.info(
|
||||||
|
f"{self.context.log_prefix} breaking形式当前累积兴趣值: {total_interest:.2f}, 专注度: {global_config.chat.focus_value:.1f}"
|
||||||
|
)
|
||||||
self._last_accumulated_interest = total_interest
|
self._last_accumulated_interest = total_interest
|
||||||
if total_interest >= modified_exit_interest_threshold:
|
if total_interest >= modified_exit_interest_threshold:
|
||||||
# 记录兴趣度到列表
|
# 记录兴趣度到列表
|
||||||
@@ -499,13 +512,16 @@ class HeartFChatting:
|
|||||||
logger.info(
|
logger.info(
|
||||||
f"{self.context.log_prefix} 累计兴趣值达到{total_interest:.2f}(>{modified_exit_interest_threshold:.1f}),结束等待"
|
f"{self.context.log_prefix} 累计兴趣值达到{total_interest:.2f}(>{modified_exit_interest_threshold:.1f}),结束等待"
|
||||||
)
|
)
|
||||||
return True,total_interest/new_message_count
|
return True, total_interest / new_message_count
|
||||||
|
|
||||||
# 每10秒输出一次等待状态
|
# 每10秒输出一次等待状态
|
||||||
if int(time.time() - self.context.last_read_time) > 0 and int(time.time() - self.context.last_read_time) % 10 == 0:
|
if (
|
||||||
|
int(time.time() - self.context.last_read_time) > 0
|
||||||
|
and int(time.time() - self.context.last_read_time) % 10 == 0
|
||||||
|
):
|
||||||
logger.info(
|
logger.info(
|
||||||
f"{self.context.log_prefix} 已等待{time.time() - self.context.last_read_time:.0f}秒,累计{new_message_count}条消息,累积兴趣{total_interest:.1f},继续等待..."
|
f"{self.context.log_prefix} 已等待{time.time() - self.context.last_read_time:.0f}秒,累计{new_message_count}条消息,累积兴趣{total_interest:.1f},继续等待..."
|
||||||
)
|
)
|
||||||
await asyncio.sleep(0.5)
|
await asyncio.sleep(0.5)
|
||||||
|
|
||||||
return False,0.0
|
return False, 0.0
|
||||||
|
|||||||
@@ -44,13 +44,13 @@ class HfcContext:
|
|||||||
|
|
||||||
self.energy_value = self.chat_stream.energy_value
|
self.energy_value = self.chat_stream.energy_value
|
||||||
self.sleep_pressure = self.chat_stream.sleep_pressure
|
self.sleep_pressure = self.chat_stream.sleep_pressure
|
||||||
self.was_sleeping = False # 用于检测睡眠状态的切换
|
self.was_sleeping = False # 用于检测睡眠状态的切换
|
||||||
|
|
||||||
self.last_message_time = time.time()
|
self.last_message_time = time.time()
|
||||||
self.last_read_time = time.time() - 10
|
self.last_read_time = time.time() - 10
|
||||||
|
|
||||||
# 从聊天流恢复breaking累积兴趣值
|
# 从聊天流恢复breaking累积兴趣值
|
||||||
self.breaking_accumulated_interest = getattr(self.chat_stream, 'breaking_accumulated_interest', 0.0)
|
self.breaking_accumulated_interest = getattr(self.chat_stream, "breaking_accumulated_interest", 0.0)
|
||||||
|
|
||||||
self.action_manager = ActionManager()
|
self.action_manager = ActionManager()
|
||||||
|
|
||||||
@@ -79,4 +79,4 @@ class HfcContext:
|
|||||||
self.chat_stream.sleep_pressure = self.sleep_pressure
|
self.chat_stream.sleep_pressure = self.sleep_pressure
|
||||||
self.chat_stream.focus_energy = self.focus_energy
|
self.chat_stream.focus_energy = self.focus_energy
|
||||||
self.chat_stream.no_reply_consecutive = self.no_reply_consecutive
|
self.chat_stream.no_reply_consecutive = self.no_reply_consecutive
|
||||||
self.chat_stream.breaking_accumulated_interest = self.breaking_accumulated_interest
|
self.chat_stream.breaking_accumulated_interest = self.breaking_accumulated_interest
|
||||||
|
|||||||
@@ -122,6 +122,7 @@ class CycleDetail:
|
|||||||
self.loop_plan_info = loop_info["loop_plan_info"]
|
self.loop_plan_info = loop_info["loop_plan_info"]
|
||||||
self.loop_action_info = loop_info["loop_action_info"]
|
self.loop_action_info = loop_info["loop_action_info"]
|
||||||
|
|
||||||
|
|
||||||
async def send_typing():
|
async def send_typing():
|
||||||
"""
|
"""
|
||||||
发送打字状态指示
|
发送打字状态指示
|
||||||
|
|||||||
@@ -1,11 +1,13 @@
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Optional, Dict, Any
|
from typing import Optional, Dict, Any
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ProactiveTriggerEvent:
|
class ProactiveTriggerEvent:
|
||||||
"""
|
"""
|
||||||
主动思考触发事件的数据类
|
主动思考触发事件的数据类
|
||||||
"""
|
"""
|
||||||
|
|
||||||
source: str # 触发源的标识,例如 "silence_monitor", "insomnia_manager"
|
source: str # 触发源的标识,例如 "silence_monitor", "insomnia_manager"
|
||||||
reason: str # 触发的具体原因,例如 "聊天已沉默10分钟", "深夜emo"
|
reason: str # 触发的具体原因,例如 "聊天已沉默10分钟", "深夜emo"
|
||||||
metadata: Optional[Dict[str, Any]] = field(default_factory=dict) # 可选的元数据,用于传递额外信息
|
metadata: Optional[Dict[str, Any]] = field(default_factory=dict) # 可选的元数据,用于传递额外信息
|
||||||
|
|||||||
@@ -37,8 +37,10 @@ class ProactiveThinker:
|
|||||||
Args:
|
Args:
|
||||||
trigger_event: 描述触发上下文的事件对象
|
trigger_event: 描述触发上下文的事件对象
|
||||||
"""
|
"""
|
||||||
logger.info(f"{self.context.log_prefix} 接收到主动思考事件: "
|
logger.info(
|
||||||
f"来源='{trigger_event.source}', 原因='{trigger_event.reason}'")
|
f"{self.context.log_prefix} 接收到主动思考事件: "
|
||||||
|
f"来源='{trigger_event.source}', 原因='{trigger_event.reason}'"
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 1. 根据事件类型执行前置操作
|
# 1. 根据事件类型执行前置操作
|
||||||
@@ -63,6 +65,7 @@ class ProactiveThinker:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
from src.mood.mood_manager import mood_manager
|
from src.mood.mood_manager import mood_manager
|
||||||
|
|
||||||
mood_obj = mood_manager.get_mood_by_chat_id(self.context.stream_id)
|
mood_obj = mood_manager.get_mood_by_chat_id(self.context.stream_id)
|
||||||
new_mood = None
|
new_mood = None
|
||||||
|
|
||||||
@@ -76,8 +79,10 @@ class ProactiveThinker:
|
|||||||
if new_mood:
|
if new_mood:
|
||||||
mood_obj.mood_state = new_mood
|
mood_obj.mood_state = new_mood
|
||||||
mood_obj.last_change_time = time.time()
|
mood_obj.last_change_time = time.time()
|
||||||
logger.info(f"{self.context.log_prefix} 因 '{trigger_event.reason}',"
|
logger.info(
|
||||||
f"情绪状态被强制更新为: {mood_obj.mood_state}")
|
f"{self.context.log_prefix} 因 '{trigger_event.reason}',"
|
||||||
|
f"情绪状态被强制更新为: {mood_obj.mood_state}"
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"{self.context.log_prefix} 设置失眠情绪时出错: {e}")
|
logger.error(f"{self.context.log_prefix} 设置失眠情绪时出错: {e}")
|
||||||
@@ -91,19 +96,17 @@ class ProactiveThinker:
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# 直接调用 planner 的 PROACTIVE 模式
|
# 直接调用 planner 的 PROACTIVE 模式
|
||||||
actions, target_message = await self.cycle_processor.action_planner.plan(
|
actions, target_message = await self.cycle_processor.action_planner.plan(mode=ChatMode.PROACTIVE)
|
||||||
mode=ChatMode.PROACTIVE
|
|
||||||
)
|
|
||||||
|
|
||||||
# 获取第一个规划出的动作作为主要决策
|
# 获取第一个规划出的动作作为主要决策
|
||||||
action_result = actions[0] if actions else {}
|
action_result = actions[0] if actions else {}
|
||||||
|
|
||||||
# 如果决策不是 do_nothing,则执行
|
# 如果决策不是 do_nothing,则执行
|
||||||
if action_result and action_result.get("action_type") != "do_nothing":
|
if action_result and action_result.get("action_type") != "do_nothing":
|
||||||
|
|
||||||
# 在主动思考时,如果 target_message 为 None,则默认选取最新 message 作为 target_message
|
# 在主动思考时,如果 target_message 为 None,则默认选取最新 message 作为 target_message
|
||||||
if target_message is None and self.context.chat_stream and self.context.chat_stream.context:
|
if target_message is None and self.context.chat_stream and self.context.chat_stream.context:
|
||||||
from src.chat.message_receive.message import MessageRecv
|
from src.chat.message_receive.message import MessageRecv
|
||||||
|
|
||||||
latest_message = self.context.chat_stream.context.get_last_message()
|
latest_message = self.context.chat_stream.context.get_last_message()
|
||||||
if isinstance(latest_message, MessageRecv):
|
if isinstance(latest_message, MessageRecv):
|
||||||
user_info = latest_message.message_info.user_info
|
user_info = latest_message.message_info.user_info
|
||||||
|
|||||||
@@ -157,7 +157,7 @@ class ResponseHandler:
|
|||||||
await send_api.text_to_stream(
|
await send_api.text_to_stream(
|
||||||
text=data,
|
text=data,
|
||||||
stream_id=self.context.stream_id,
|
stream_id=self.context.stream_id,
|
||||||
reply_to_message = message_data,
|
reply_to_message=message_data,
|
||||||
set_reply=need_reply,
|
set_reply=need_reply,
|
||||||
typing=False,
|
typing=False,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -724,7 +724,7 @@ class EmojiManager:
|
|||||||
if not emoji.is_deleted and emoji.hash == emoji_hash:
|
if not emoji.is_deleted and emoji.hash == emoji_hash:
|
||||||
return emoji
|
return emoji
|
||||||
return None # 如果循环结束还没找到,则返回 None
|
return None # 如果循环结束还没找到,则返回 None
|
||||||
|
|
||||||
async def get_emoji_tag_by_hash(self, emoji_hash: str) -> Optional[str]:
|
async def get_emoji_tag_by_hash(self, emoji_hash: str) -> Optional[str]:
|
||||||
"""根据哈希值获取已注册表情包的描述
|
"""根据哈希值获取已注册表情包的描述
|
||||||
|
|
||||||
@@ -755,7 +755,7 @@ class EmojiManager:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"获取表情包描述失败 (Hash: {emoji_hash}): {str(e)}")
|
logger.error(f"获取表情包描述失败 (Hash: {emoji_hash}): {str(e)}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def get_emoji_description_by_hash(self, emoji_hash: str) -> Optional[str]:
|
async def get_emoji_description_by_hash(self, emoji_hash: str) -> Optional[str]:
|
||||||
"""根据哈希值获取已注册表情包的描述
|
"""根据哈希值获取已注册表情包的描述
|
||||||
|
|
||||||
|
|||||||
@@ -259,7 +259,7 @@ class ChatManager:
|
|||||||
"user_cardname": model_instance.user_cardname or "",
|
"user_cardname": model_instance.user_cardname or "",
|
||||||
}
|
}
|
||||||
group_info_data = None
|
group_info_data = None
|
||||||
if model_instance and getattr(model_instance, 'group_id', None):
|
if model_instance and getattr(model_instance, "group_id", None):
|
||||||
group_info_data = {
|
group_info_data = {
|
||||||
"platform": model_instance.group_platform,
|
"platform": model_instance.group_platform,
|
||||||
"group_id": model_instance.group_id,
|
"group_id": model_instance.group_id,
|
||||||
@@ -405,7 +405,7 @@ class ChatManager:
|
|||||||
"user_cardname": model_instance.user_cardname or "",
|
"user_cardname": model_instance.user_cardname or "",
|
||||||
}
|
}
|
||||||
group_info_data = None
|
group_info_data = None
|
||||||
if model_instance and getattr(model_instance, 'group_id', None):
|
if model_instance and getattr(model_instance, "group_id", None):
|
||||||
group_info_data = {
|
group_info_data = {
|
||||||
"platform": model_instance.group_platform,
|
"platform": model_instance.group_platform,
|
||||||
"group_id": model_instance.group_id,
|
"group_id": model_instance.group_id,
|
||||||
|
|||||||
@@ -120,7 +120,7 @@ class MessageRecv(Message):
|
|||||||
self.priority_mode = "interest"
|
self.priority_mode = "interest"
|
||||||
self.priority_info = None
|
self.priority_info = None
|
||||||
self.interest_value: float = 0.0
|
self.interest_value: float = 0.0
|
||||||
|
|
||||||
self.key_words = []
|
self.key_words = []
|
||||||
self.key_words_lite = []
|
self.key_words_lite = []
|
||||||
|
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ class MessageStorage:
|
|||||||
if isinstance(keywords, list):
|
if isinstance(keywords, list):
|
||||||
return orjson.dumps(keywords).decode("utf-8")
|
return orjson.dumps(keywords).decode("utf-8")
|
||||||
return "[]"
|
return "[]"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _deserialize_keywords(keywords_str: str) -> list:
|
def _deserialize_keywords(keywords_str: str) -> list:
|
||||||
"""将JSON字符串反序列化为关键词列表"""
|
"""将JSON字符串反序列化为关键词列表"""
|
||||||
|
|||||||
@@ -161,10 +161,8 @@ class ActionModifier:
|
|||||||
|
|
||||||
available_actions = list(self.action_manager.get_using_actions().keys())
|
available_actions = list(self.action_manager.get_using_actions().keys())
|
||||||
available_actions_text = "、".join(available_actions) if available_actions else "无"
|
available_actions_text = "、".join(available_actions) if available_actions else "无"
|
||||||
|
|
||||||
logger.info(
|
logger.info(f"{self.log_prefix} 当前可用动作: {available_actions_text}||移除: {removals_summary}")
|
||||||
f"{self.log_prefix} 当前可用动作: {available_actions_text}||移除: {removals_summary}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def _check_action_associated_types(self, all_actions: Dict[str, ActionInfo], chat_context: ChatMessageContext):
|
def _check_action_associated_types(self, all_actions: Dict[str, ActionInfo], chat_context: ChatMessageContext):
|
||||||
type_mismatched_actions: List[Tuple[str, str]] = []
|
type_mismatched_actions: List[Tuple[str, str]] = []
|
||||||
|
|||||||
@@ -188,15 +188,12 @@ class ActionPlanner:
|
|||||||
param_text = ""
|
param_text = ""
|
||||||
if action_info.action_parameters:
|
if action_info.action_parameters:
|
||||||
param_text = "\n" + "\n".join(
|
param_text = "\n" + "\n".join(
|
||||||
f' "{p_name}":"{p_desc}"'
|
f' "{p_name}":"{p_desc}"' for p_name, p_desc in action_info.action_parameters.items()
|
||||||
for p_name, p_desc in action_info.action_parameters.items()
|
|
||||||
)
|
)
|
||||||
|
|
||||||
require_text = "\n".join(f"- {req}" for req in action_info.action_require)
|
require_text = "\n".join(f"- {req}" for req in action_info.action_require)
|
||||||
|
|
||||||
using_action_prompt = await global_prompt_manager.get_prompt_async(
|
using_action_prompt = await global_prompt_manager.get_prompt_async("action_prompt")
|
||||||
"action_prompt"
|
|
||||||
)
|
|
||||||
action_options_block += using_action_prompt.format(
|
action_options_block += using_action_prompt.format(
|
||||||
action_name=action_name,
|
action_name=action_name,
|
||||||
action_description=action_info.description,
|
action_description=action_info.description,
|
||||||
@@ -205,9 +202,7 @@ class ActionPlanner:
|
|||||||
)
|
)
|
||||||
return action_options_block
|
return action_options_block
|
||||||
|
|
||||||
def find_message_by_id(
|
def find_message_by_id(self, message_id: str, message_id_list: list) -> Optional[Dict[str, Any]]:
|
||||||
self, message_id: str, message_id_list: list
|
|
||||||
) -> Optional[Dict[str, Any]]:
|
|
||||||
# sourcery skip: use-next
|
# sourcery skip: use-next
|
||||||
"""
|
"""
|
||||||
根据message_id从message_id_list中查找对应的原始消息
|
根据message_id从message_id_list中查找对应的原始消息
|
||||||
@@ -245,7 +240,7 @@ class ActionPlanner:
|
|||||||
async def plan(
|
async def plan(
|
||||||
self,
|
self,
|
||||||
mode: ChatMode = ChatMode.FOCUS,
|
mode: ChatMode = ChatMode.FOCUS,
|
||||||
loop_start_time:float = 0.0,
|
loop_start_time: float = 0.0,
|
||||||
available_actions: Optional[Dict[str, ActionInfo]] = None,
|
available_actions: Optional[Dict[str, ActionInfo]] = None,
|
||||||
) -> Tuple[List[Dict[str, Any]], Optional[Dict[str, Any]]]:
|
) -> Tuple[List[Dict[str, Any]], Optional[Dict[str, Any]]]:
|
||||||
"""
|
"""
|
||||||
@@ -323,11 +318,15 @@ class ActionPlanner:
|
|||||||
# 如果获取的target_message为None,输出warning并重新plan
|
# 如果获取的target_message为None,输出warning并重新plan
|
||||||
if target_message is None:
|
if target_message is None:
|
||||||
self.plan_retry_count += 1
|
self.plan_retry_count += 1
|
||||||
logger.warning(f"{self.log_prefix}无法找到target_message_id '{target_message_id}' 对应的消息,重试次数: {self.plan_retry_count}/{self.max_plan_retries}")
|
logger.warning(
|
||||||
|
f"{self.log_prefix}无法找到target_message_id '{target_message_id}' 对应的消息,重试次数: {self.plan_retry_count}/{self.max_plan_retries}"
|
||||||
|
)
|
||||||
|
|
||||||
# 如果连续三次plan均为None,输出error并选取最新消息
|
# 如果连续三次plan均为None,输出error并选取最新消息
|
||||||
if self.plan_retry_count >= self.max_plan_retries:
|
if self.plan_retry_count >= self.max_plan_retries:
|
||||||
logger.error(f"{self.log_prefix}连续{self.max_plan_retries}次plan获取target_message失败,选择最新消息作为target_message")
|
logger.error(
|
||||||
|
f"{self.log_prefix}连续{self.max_plan_retries}次plan获取target_message失败,选择最新消息作为target_message"
|
||||||
|
)
|
||||||
target_message = self.get_latest_message(message_id_list)
|
target_message = self.get_latest_message(message_id_list)
|
||||||
self.plan_retry_count = 0 # 重置计数器
|
self.plan_retry_count = 0 # 重置计数器
|
||||||
else:
|
else:
|
||||||
@@ -338,8 +337,7 @@ class ActionPlanner:
|
|||||||
self.plan_retry_count = 0
|
self.plan_retry_count = 0
|
||||||
else:
|
else:
|
||||||
logger.warning(f"{self.log_prefix}动作'{action}'缺少target_message_id")
|
logger.warning(f"{self.log_prefix}动作'{action}'缺少target_message_id")
|
||||||
|
|
||||||
|
|
||||||
if action != "no_reply" and action != "reply" and action not in current_available_actions:
|
if action != "no_reply" and action != "reply" and action not in current_available_actions:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"{self.log_prefix}LLM 返回了当前不可用或无效的动作: '{action}' (可用: {list(current_available_actions.keys())}),将强制使用 'no_reply'"
|
f"{self.log_prefix}LLM 返回了当前不可用或无效的动作: '{action}' (可用: {list(current_available_actions.keys())}),将强制使用 'no_reply'"
|
||||||
@@ -362,36 +360,35 @@ class ActionPlanner:
|
|||||||
is_parallel = False
|
is_parallel = False
|
||||||
if mode == ChatMode.NORMAL and action in current_available_actions:
|
if mode == ChatMode.NORMAL and action in current_available_actions:
|
||||||
is_parallel = current_available_actions[action].parallel_action
|
is_parallel = current_available_actions[action].parallel_action
|
||||||
|
|
||||||
|
|
||||||
action_data["loop_start_time"] = loop_start_time
|
action_data["loop_start_time"] = loop_start_time
|
||||||
|
|
||||||
actions = []
|
actions = []
|
||||||
|
|
||||||
# 1. 添加Planner取得的动作
|
# 1. 添加Planner取得的动作
|
||||||
actions.append({
|
actions.append(
|
||||||
"action_type": action,
|
{
|
||||||
"reasoning": reasoning,
|
"action_type": action,
|
||||||
"action_data": action_data,
|
"reasoning": reasoning,
|
||||||
"action_message": target_message,
|
"action_data": action_data,
|
||||||
"available_actions": available_actions # 添加这个字段
|
|
||||||
})
|
|
||||||
|
|
||||||
if action != "reply" and is_parallel:
|
|
||||||
actions.append({
|
|
||||||
"action_type": "reply",
|
|
||||||
"action_message": target_message,
|
"action_message": target_message,
|
||||||
"available_actions": available_actions
|
"available_actions": available_actions, # 添加这个字段
|
||||||
})
|
}
|
||||||
|
)
|
||||||
return actions,target_message
|
|
||||||
|
if action != "reply" and is_parallel:
|
||||||
|
actions.append(
|
||||||
|
{"action_type": "reply", "action_message": target_message, "available_actions": available_actions}
|
||||||
|
)
|
||||||
|
|
||||||
|
return actions, target_message
|
||||||
|
|
||||||
async def build_planner_prompt(
|
async def build_planner_prompt(
|
||||||
self,
|
self,
|
||||||
is_group_chat: bool, # Now passed as argument
|
is_group_chat: bool, # Now passed as argument
|
||||||
chat_target_info: Optional[dict], # Now passed as argument
|
chat_target_info: Optional[dict], # Now passed as argument
|
||||||
current_available_actions: Dict[str, ActionInfo],
|
current_available_actions: Dict[str, ActionInfo],
|
||||||
refresh_time :bool = False,
|
refresh_time: bool = False,
|
||||||
mode: ChatMode = ChatMode.FOCUS,
|
mode: ChatMode = ChatMode.FOCUS,
|
||||||
) -> tuple[str, list]: # sourcery skip: use-join
|
) -> tuple[str, list]: # sourcery skip: use-join
|
||||||
"""构建 Planner LLM 的提示词 (获取模板并填充数据)"""
|
"""构建 Planner LLM 的提示词 (获取模板并填充数据)"""
|
||||||
@@ -400,21 +397,15 @@ class ActionPlanner:
|
|||||||
time_block = f"当前时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
|
time_block = f"当前时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
|
||||||
bot_name = global_config.bot.nickname
|
bot_name = global_config.bot.nickname
|
||||||
bot_nickname = (
|
bot_nickname = (
|
||||||
f",也有人叫你{','.join(global_config.bot.alias_names)}"
|
f",也有人叫你{','.join(global_config.bot.alias_names)}" if global_config.bot.alias_names else ""
|
||||||
if global_config.bot.alias_names
|
|
||||||
else ""
|
|
||||||
)
|
)
|
||||||
bot_core_personality = global_config.personality.personality_core
|
bot_core_personality = global_config.personality.personality_core
|
||||||
identity_block = (
|
identity_block = f"你的名字是{bot_name}{bot_nickname},你{bot_core_personality}:"
|
||||||
f"你的名字是{bot_name}{bot_nickname},你{bot_core_personality}:"
|
|
||||||
)
|
|
||||||
|
|
||||||
schedule_block = ""
|
schedule_block = ""
|
||||||
if global_config.schedule.enable:
|
if global_config.schedule.enable:
|
||||||
if current_activity := schedule_manager.get_current_activity():
|
if current_activity := schedule_manager.get_current_activity():
|
||||||
schedule_block = (
|
schedule_block = f"你当前正在:{current_activity},但注意它与群聊的聊天无关。"
|
||||||
f"你当前正在:{current_activity},但注意它与群聊的聊天无关。"
|
|
||||||
)
|
|
||||||
|
|
||||||
mood_block = ""
|
mood_block = ""
|
||||||
if global_config.mood.enable_mood:
|
if global_config.mood.enable_mood:
|
||||||
@@ -424,13 +415,9 @@ class ActionPlanner:
|
|||||||
# --- 根据模式构建不同的Prompt ---
|
# --- 根据模式构建不同的Prompt ---
|
||||||
if mode == ChatMode.PROACTIVE:
|
if mode == ChatMode.PROACTIVE:
|
||||||
long_term_memory_block = await self._get_long_term_memory_context()
|
long_term_memory_block = await self._get_long_term_memory_context()
|
||||||
action_options_text = await self._build_action_options(
|
action_options_text = await self._build_action_options(current_available_actions, mode)
|
||||||
current_available_actions, mode
|
|
||||||
)
|
|
||||||
|
|
||||||
prompt_template = await global_prompt_manager.get_prompt_async(
|
prompt_template = await global_prompt_manager.get_prompt_async("proactive_planner_prompt")
|
||||||
"proactive_planner_prompt"
|
|
||||||
)
|
|
||||||
prompt = prompt_template.format(
|
prompt = prompt_template.format(
|
||||||
time_block=time_block,
|
time_block=time_block,
|
||||||
identity_block=identity_block,
|
identity_block=identity_block,
|
||||||
@@ -463,12 +450,8 @@ class ActionPlanner:
|
|||||||
limit=5,
|
limit=5,
|
||||||
)
|
)
|
||||||
|
|
||||||
actions_before_now_block = build_readable_actions(
|
actions_before_now_block = build_readable_actions(actions=actions_before_now)
|
||||||
actions=actions_before_now
|
actions_before_now_block = f"你刚刚选择并执行过的action是:\n{actions_before_now_block}"
|
||||||
)
|
|
||||||
actions_before_now_block = (
|
|
||||||
f"你刚刚选择并执行过的action是:\n{actions_before_now_block}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if refresh_time:
|
if refresh_time:
|
||||||
self.last_obs_time_mark = time.time()
|
self.last_obs_time_mark = time.time()
|
||||||
@@ -504,30 +487,22 @@ class ActionPlanner:
|
|||||||
}}"""
|
}}"""
|
||||||
|
|
||||||
chat_context_description = "你现在正在一个群聊中"
|
chat_context_description = "你现在正在一个群聊中"
|
||||||
chat_target_name = None
|
chat_target_name = None
|
||||||
if not is_group_chat and chat_target_info:
|
if not is_group_chat and chat_target_info:
|
||||||
chat_target_name = (
|
chat_target_name = (
|
||||||
chat_target_info.get("person_name")
|
chat_target_info.get("person_name") or chat_target_info.get("user_nickname") or "对方"
|
||||||
or chat_target_info.get("user_nickname")
|
|
||||||
or "对方"
|
|
||||||
)
|
)
|
||||||
chat_context_description = f"你正在和 {chat_target_name} 私聊"
|
chat_context_description = f"你正在和 {chat_target_name} 私聊"
|
||||||
|
|
||||||
action_options_block = await self._build_action_options(
|
action_options_block = await self._build_action_options(current_available_actions, mode)
|
||||||
current_available_actions, mode
|
|
||||||
)
|
|
||||||
|
|
||||||
moderation_prompt_block = "请不要输出违法违规内容,不要输出色情,暴力,政治相关内容,如有敏感内容,请规避。"
|
moderation_prompt_block = "请不要输出违法违规内容,不要输出色情,暴力,政治相关内容,如有敏感内容,请规避。"
|
||||||
|
|
||||||
custom_prompt_block = ""
|
custom_prompt_block = ""
|
||||||
if global_config.custom_prompt.planner_custom_prompt_content:
|
if global_config.custom_prompt.planner_custom_prompt_content:
|
||||||
custom_prompt_block = (
|
custom_prompt_block = global_config.custom_prompt.planner_custom_prompt_content
|
||||||
global_config.custom_prompt.planner_custom_prompt_content
|
|
||||||
)
|
|
||||||
|
|
||||||
planner_prompt_template = await global_prompt_manager.get_prompt_async(
|
planner_prompt_template = await global_prompt_manager.get_prompt_async("planner_prompt")
|
||||||
"planner_prompt"
|
|
||||||
)
|
|
||||||
prompt = planner_prompt_template.format(
|
prompt = planner_prompt_template.format(
|
||||||
schedule_block=schedule_block,
|
schedule_block=schedule_block,
|
||||||
mood_block=mood_block,
|
mood_block=mood_block,
|
||||||
@@ -555,9 +530,7 @@ class ActionPlanner:
|
|||||||
"""
|
"""
|
||||||
is_group_chat = True
|
is_group_chat = True
|
||||||
is_group_chat, chat_target_info = get_chat_type_and_target_info(self.chat_id)
|
is_group_chat, chat_target_info = get_chat_type_and_target_info(self.chat_id)
|
||||||
logger.debug(
|
logger.debug(f"{self.log_prefix}获取到聊天信息 - 群聊: {is_group_chat}, 目标信息: {chat_target_info}")
|
||||||
f"{self.log_prefix}获取到聊天信息 - 群聊: {is_group_chat}, 目标信息: {chat_target_info}"
|
|
||||||
)
|
|
||||||
|
|
||||||
current_available_actions_dict = self.action_manager.get_using_actions()
|
current_available_actions_dict = self.action_manager.get_using_actions()
|
||||||
|
|
||||||
@@ -568,13 +541,9 @@ class ActionPlanner:
|
|||||||
current_available_actions = {}
|
current_available_actions = {}
|
||||||
for action_name in current_available_actions_dict:
|
for action_name in current_available_actions_dict:
|
||||||
if action_name in all_registered_actions:
|
if action_name in all_registered_actions:
|
||||||
current_available_actions[action_name] = all_registered_actions[
|
current_available_actions[action_name] = all_registered_actions[action_name]
|
||||||
action_name
|
|
||||||
]
|
|
||||||
else:
|
else:
|
||||||
logger.warning(
|
logger.warning(f"{self.log_prefix}使用中的动作 {action_name} 未在已注册动作中找到")
|
||||||
f"{self.log_prefix}使用中的动作 {action_name} 未在已注册动作中找到"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 将no_reply作为系统级特殊动作添加到可用动作中
|
# 将no_reply作为系统级特殊动作添加到可用动作中
|
||||||
# no_reply虽然是系统级决策,但需要让规划器认为它是可用的
|
# no_reply虽然是系统级决策,但需要让规划器认为它是可用的
|
||||||
|
|||||||
@@ -706,16 +706,16 @@ class DefaultReplyer:
|
|||||||
# 检查最新五条消息中是否包含bot自己说的消息
|
# 检查最新五条消息中是否包含bot自己说的消息
|
||||||
latest_5_messages = core_dialogue_list[-5:] if len(core_dialogue_list) >= 5 else core_dialogue_list
|
latest_5_messages = core_dialogue_list[-5:] if len(core_dialogue_list) >= 5 else core_dialogue_list
|
||||||
has_bot_message = any(str(msg.get("user_id")) == bot_id for msg in latest_5_messages)
|
has_bot_message = any(str(msg.get("user_id")) == bot_id for msg in latest_5_messages)
|
||||||
|
|
||||||
# logger.info(f"最新五条消息:{latest_5_messages}")
|
# logger.info(f"最新五条消息:{latest_5_messages}")
|
||||||
# logger.info(f"最新五条消息中是否包含bot自己说的消息:{has_bot_message}")
|
# logger.info(f"最新五条消息中是否包含bot自己说的消息:{has_bot_message}")
|
||||||
|
|
||||||
# 如果最新五条消息中不包含bot的消息,则返回空字符串
|
# 如果最新五条消息中不包含bot的消息,则返回空字符串
|
||||||
if not has_bot_message:
|
if not has_bot_message:
|
||||||
core_dialogue_prompt = ""
|
core_dialogue_prompt = ""
|
||||||
else:
|
else:
|
||||||
core_dialogue_list = core_dialogue_list[-int(global_config.chat.max_context_size * 2) :] # 限制消息数量
|
core_dialogue_list = core_dialogue_list[-int(global_config.chat.max_context_size * 2) :] # 限制消息数量
|
||||||
|
|
||||||
core_dialogue_prompt_str = build_readable_messages(
|
core_dialogue_prompt_str = build_readable_messages(
|
||||||
core_dialogue_list,
|
core_dialogue_list,
|
||||||
replace_bot_name=True,
|
replace_bot_name=True,
|
||||||
@@ -819,7 +819,7 @@ class DefaultReplyer:
|
|||||||
mood_prompt = ""
|
mood_prompt = ""
|
||||||
|
|
||||||
if reply_to:
|
if reply_to:
|
||||||
#兼容旧的reply_to
|
# 兼容旧的reply_to
|
||||||
sender, target = self._parse_reply_target(reply_to)
|
sender, target = self._parse_reply_target(reply_to)
|
||||||
else:
|
else:
|
||||||
# 获取 platform,如果不存在则从 chat_stream 获取,如果还是 None 则使用默认值
|
# 获取 platform,如果不存在则从 chat_stream 获取,如果还是 None 则使用默认值
|
||||||
@@ -830,7 +830,7 @@ class DefaultReplyer:
|
|||||||
)
|
)
|
||||||
person_name = await person_info_manager.get_value(person_id, "person_name")
|
person_name = await person_info_manager.get_value(person_id, "person_name")
|
||||||
sender = person_name
|
sender = person_name
|
||||||
target = reply_message.get('processed_plain_text')
|
target = reply_message.get("processed_plain_text")
|
||||||
|
|
||||||
person_info_manager = get_person_info_manager()
|
person_info_manager = get_person_info_manager()
|
||||||
person_id = person_info_manager.get_person_id_by_person_name(sender)
|
person_id = person_info_manager.get_person_id_by_person_name(sender)
|
||||||
@@ -1024,7 +1024,7 @@ class DefaultReplyer:
|
|||||||
chat_stream = self.chat_stream
|
chat_stream = self.chat_stream
|
||||||
chat_id = chat_stream.stream_id
|
chat_id = chat_stream.stream_id
|
||||||
is_group_chat = bool(chat_stream.group_info)
|
is_group_chat = bool(chat_stream.group_info)
|
||||||
|
|
||||||
if reply_message:
|
if reply_message:
|
||||||
sender = reply_message.get("sender")
|
sender = reply_message.get("sender")
|
||||||
target = reply_message.get("target")
|
target = reply_message.get("target")
|
||||||
@@ -1181,7 +1181,9 @@ class DefaultReplyer:
|
|||||||
else:
|
else:
|
||||||
logger.debug(f"\n{prompt}\n")
|
logger.debug(f"\n{prompt}\n")
|
||||||
|
|
||||||
content, (reasoning_content, model_name, tool_calls) = await self.express_model.generate_response_async(prompt)
|
content, (reasoning_content, model_name, tool_calls) = await self.express_model.generate_response_async(
|
||||||
|
prompt
|
||||||
|
)
|
||||||
|
|
||||||
logger.debug(f"replyer生成内容: {content}")
|
logger.debug(f"replyer生成内容: {content}")
|
||||||
return content, reasoning_content, model_name, tool_calls
|
return content, reasoning_content, model_name, tool_calls
|
||||||
|
|||||||
@@ -1250,7 +1250,7 @@ async def get_person_id_list(messages: List[Dict[str, Any]]) -> List[str]:
|
|||||||
# 检查必要信息是否存在 且 不是机器人自己
|
# 检查必要信息是否存在 且 不是机器人自己
|
||||||
if not all([platform, user_id]) or user_id == global_config.bot.qq_account:
|
if not all([platform, user_id]) or user_id == global_config.bot.qq_account:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 添加空值检查,防止 platform 为 None 时出错
|
# 添加空值检查,防止 platform 为 None 时出错
|
||||||
if platform is None:
|
if platform is None:
|
||||||
platform = "unknown"
|
platform = "unknown"
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ from src.config.config import global_config
|
|||||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||||
from src.person_info.person_info import get_person_info_manager
|
from src.person_info.person_info import get_person_info_manager
|
||||||
from src.plugin_system.apis import cross_context_api
|
from src.plugin_system.apis import cross_context_api
|
||||||
|
|
||||||
logger = get_logger("prompt_utils")
|
logger = get_logger("prompt_utils")
|
||||||
|
|
||||||
|
|
||||||
@@ -80,29 +81,29 @@ class PromptUtils:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def build_cross_context(
|
async def build_cross_context(
|
||||||
chat_id: str, target_user_info: Optional[Dict[str, Any]], current_prompt_mode: str
|
chat_id: str, target_user_info: Optional[Dict[str, Any]], current_prompt_mode: str
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
构建跨群聊上下文 - 统一实现,完全继承DefaultReplyer功能
|
构建跨群聊上下文 - 统一实现,完全继承DefaultReplyer功能
|
||||||
"""
|
"""
|
||||||
if not global_config.cross_context.enable:
|
if not global_config.cross_context.enable:
|
||||||
return ""
|
|
||||||
|
|
||||||
other_chat_raw_ids = cross_context_api.get_context_groups(chat_id)
|
|
||||||
if not other_chat_raw_ids:
|
|
||||||
return ""
|
|
||||||
|
|
||||||
chat_stream = get_chat_manager().get_stream(chat_id)
|
|
||||||
if not chat_stream:
|
|
||||||
return ""
|
|
||||||
|
|
||||||
if current_prompt_mode == "normal":
|
|
||||||
return await cross_context_api.build_cross_context_normal(chat_stream, other_chat_raw_ids)
|
|
||||||
elif current_prompt_mode == "s4u":
|
|
||||||
return await cross_context_api.build_cross_context_s4u(chat_stream, other_chat_raw_ids, target_user_info)
|
|
||||||
|
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
other_chat_raw_ids = cross_context_api.get_context_groups(chat_id)
|
||||||
|
if not other_chat_raw_ids:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
chat_stream = get_chat_manager().get_stream(chat_id)
|
||||||
|
if not chat_stream:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
if current_prompt_mode == "normal":
|
||||||
|
return await cross_context_api.build_cross_context_normal(chat_stream, other_chat_raw_ids)
|
||||||
|
elif current_prompt_mode == "s4u":
|
||||||
|
return await cross_context_api.build_cross_context_s4u(chat_stream, other_chat_raw_ids, target_user_info)
|
||||||
|
|
||||||
|
return ""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def parse_reply_target_id(reply_to: str) -> str:
|
def parse_reply_target_id(reply_to: str) -> str:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -194,7 +194,7 @@ class SmartPromptBuilder:
|
|||||||
core_dialogue, background_dialogue = await self._build_s4u_chat_history_prompts(
|
core_dialogue, background_dialogue = await self._build_s4u_chat_history_prompts(
|
||||||
params.message_list_before_now_long,
|
params.message_list_before_now_long,
|
||||||
params.target_user_info.get("user_id") if params.target_user_info else "",
|
params.target_user_info.get("user_id") if params.target_user_info else "",
|
||||||
params.sender
|
params.sender,
|
||||||
)
|
)
|
||||||
|
|
||||||
context_data["core_dialogue_prompt"] = core_dialogue
|
context_data["core_dialogue_prompt"] = core_dialogue
|
||||||
@@ -245,16 +245,16 @@ class SmartPromptBuilder:
|
|||||||
# 检查最新五条消息中是否包含bot自己说的消息
|
# 检查最新五条消息中是否包含bot自己说的消息
|
||||||
latest_5_messages = core_dialogue_list[-5:] if len(core_dialogue_list) >= 5 else core_dialogue_list
|
latest_5_messages = core_dialogue_list[-5:] if len(core_dialogue_list) >= 5 else core_dialogue_list
|
||||||
has_bot_message = any(str(msg.get("user_id")) == bot_id for msg in latest_5_messages)
|
has_bot_message = any(str(msg.get("user_id")) == bot_id for msg in latest_5_messages)
|
||||||
|
|
||||||
# logger.info(f"最新五条消息:{latest_5_messages}")
|
# logger.info(f"最新五条消息:{latest_5_messages}")
|
||||||
# logger.info(f"最新五条消息中是否包含bot自己说的消息:{has_bot_message}")
|
# logger.info(f"最新五条消息中是否包含bot自己说的消息:{has_bot_message}")
|
||||||
|
|
||||||
# 如果最新五条消息中不包含bot的消息,则返回空字符串
|
# 如果最新五条消息中不包含bot的消息,则返回空字符串
|
||||||
if not has_bot_message:
|
if not has_bot_message:
|
||||||
core_dialogue_prompt = ""
|
core_dialogue_prompt = ""
|
||||||
else:
|
else:
|
||||||
core_dialogue_list = core_dialogue_list[-int(global_config.chat.max_context_size * 2) :] # 限制消息数量
|
core_dialogue_list = core_dialogue_list[-int(global_config.chat.max_context_size * 2) :] # 限制消息数量
|
||||||
|
|
||||||
core_dialogue_prompt_str = build_readable_messages(
|
core_dialogue_prompt_str = build_readable_messages(
|
||||||
core_dialogue_list,
|
core_dialogue_list,
|
||||||
replace_bot_name=True,
|
replace_bot_name=True,
|
||||||
|
|||||||
@@ -27,16 +27,15 @@ logger = get_logger("chat_image")
|
|||||||
def is_image_message(message: Dict[str, Any]) -> bool:
|
def is_image_message(message: Dict[str, Any]) -> bool:
|
||||||
"""
|
"""
|
||||||
判断消息是否为图片消息
|
判断消息是否为图片消息
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
message: 消息字典
|
message: 消息字典
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bool: 是否为图片消息
|
bool: 是否为图片消息
|
||||||
"""
|
"""
|
||||||
return message.get("type") == "image" or (
|
return message.get("type") == "image" or (
|
||||||
isinstance(message.get("content"), dict) and
|
isinstance(message.get("content"), dict) and message["content"].get("type") == "image"
|
||||||
message["content"].get("type") == "image"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -596,7 +595,6 @@ class ImageManager:
|
|||||||
return "", "[图片]"
|
return "", "[图片]"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# 创建全局单例
|
# 创建全局单例
|
||||||
image_manager = None
|
image_manager = None
|
||||||
|
|
||||||
|
|||||||
@@ -62,10 +62,12 @@ def get_active_plans_for_month(month: str) -> List[MonthlyPlan]:
|
|||||||
"""
|
"""
|
||||||
with get_db_session() as session:
|
with get_db_session() as session:
|
||||||
try:
|
try:
|
||||||
plans = session.query(MonthlyPlan).filter(
|
plans = (
|
||||||
MonthlyPlan.target_month == month,
|
session.query(MonthlyPlan)
|
||||||
MonthlyPlan.status == 'active'
|
.filter(MonthlyPlan.target_month == month, MonthlyPlan.status == "active")
|
||||||
).order_by(MonthlyPlan.created_at.desc()).all()
|
.order_by(MonthlyPlan.created_at.desc())
|
||||||
|
.all()
|
||||||
|
)
|
||||||
return plans
|
return plans
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"查询 {month} 的有效月度计划时发生错误: {e}")
|
logger.error(f"查询 {month} 的有效月度计划时发生错误: {e}")
|
||||||
|
|||||||
@@ -81,8 +81,8 @@ def get_key_comment(toml_table, key):
|
|||||||
return item.trivia.comment
|
return item.trivia.comment
|
||||||
if hasattr(toml_table, "keys"):
|
if hasattr(toml_table, "keys"):
|
||||||
for k in toml_table.keys():
|
for k in toml_table.keys():
|
||||||
if isinstance(k, KeyType) and k.key == key: # type: ignore
|
if isinstance(k, KeyType) and k.key == key: # type: ignore
|
||||||
return k.trivia.comment # type: ignore
|
return k.trivia.comment # type: ignore
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -259,7 +259,6 @@ class NormalChatConfig(ValidatedConfigBase):
|
|||||||
"""普通聊天配置类"""
|
"""普通聊天配置类"""
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ExpressionRule(ValidatedConfigBase):
|
class ExpressionRule(ValidatedConfigBase):
|
||||||
"""表达学习规则"""
|
"""表达学习规则"""
|
||||||
|
|
||||||
@@ -653,7 +652,8 @@ class ContextGroup(ValidatedConfigBase):
|
|||||||
|
|
||||||
name: str = Field(..., description="共享组的名称")
|
name: str = Field(..., description="共享组的名称")
|
||||||
chat_ids: List[List[str]] = Field(
|
chat_ids: List[List[str]] = Field(
|
||||||
..., description='属于该组的聊天ID列表,格式为 [["type", "chat_id"], ...],例如 [["group", "123456"], ["private", "789012"]]'
|
...,
|
||||||
|
description='属于该组的聊天ID列表,格式为 [["type", "chat_id"], ...],例如 [["group", "123456"], ["private", "789012"]]',
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
44
src/main.py
44
src/main.py
@@ -28,36 +28,57 @@ from src.plugin_system.core.plugin_hot_reload import hot_reload_manager
|
|||||||
|
|
||||||
# 导入消息API和traceback模块
|
# 导入消息API和traceback模块
|
||||||
from src.common.message import get_global_api
|
from src.common.message import get_global_api
|
||||||
|
|
||||||
from src.chat.memory_system.Hippocampus import hippocampus_manager
|
from src.chat.memory_system.Hippocampus import hippocampus_manager
|
||||||
|
|
||||||
if not global_config.memory.enable_memory:
|
if not global_config.memory.enable_memory:
|
||||||
import src.chat.memory_system.Hippocampus as hippocampus_module
|
import src.chat.memory_system.Hippocampus as hippocampus_module
|
||||||
|
|
||||||
class MockHippocampusManager:
|
class MockHippocampusManager:
|
||||||
def initialize(self):
|
def initialize(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def get_hippocampus(self):
|
def get_hippocampus(self):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def build_memory(self):
|
async def build_memory(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def forget_memory(self, percentage: float = 0.005):
|
async def forget_memory(self, percentage: float = 0.005):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def consolidate_memory(self):
|
async def consolidate_memory(self):
|
||||||
pass
|
pass
|
||||||
async def get_memory_from_text(self, text: str, max_memory_num: int = 3, max_memory_length: int = 2, max_depth: int = 3, fast_retrieval: bool = False) -> list:
|
|
||||||
|
async def get_memory_from_text(
|
||||||
|
self,
|
||||||
|
text: str,
|
||||||
|
max_memory_num: int = 3,
|
||||||
|
max_memory_length: int = 2,
|
||||||
|
max_depth: int = 3,
|
||||||
|
fast_retrieval: bool = False,
|
||||||
|
) -> list:
|
||||||
return []
|
return []
|
||||||
async def get_memory_from_topic(self, valid_keywords: list[str], max_memory_num: int = 3, max_memory_length: int = 2, max_depth: int = 3) -> list:
|
|
||||||
|
async def get_memory_from_topic(
|
||||||
|
self, valid_keywords: list[str], max_memory_num: int = 3, max_memory_length: int = 2, max_depth: int = 3
|
||||||
|
) -> list:
|
||||||
return []
|
return []
|
||||||
async def get_activate_from_text(self, text: str, max_depth: int = 3, fast_retrieval: bool = False) -> tuple[float, list[str]]:
|
|
||||||
|
async def get_activate_from_text(
|
||||||
|
self, text: str, max_depth: int = 3, fast_retrieval: bool = False
|
||||||
|
) -> tuple[float, list[str]]:
|
||||||
return 0.0, []
|
return 0.0, []
|
||||||
|
|
||||||
def get_memory_from_keyword(self, keyword: str, max_depth: int = 2) -> list:
|
def get_memory_from_keyword(self, keyword: str, max_depth: int = 2) -> list:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def get_all_node_names(self) -> list:
|
def get_all_node_names(self) -> list:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
hippocampus_module.hippocampus_manager = MockHippocampusManager()
|
hippocampus_module.hippocampus_manager = MockHippocampusManager()
|
||||||
|
|
||||||
# 插件系统现在使用统一的插件加载器
|
# 插件系统现在使用统一的插件加载器
|
||||||
|
|
||||||
install(extra_lines=3)
|
install(extra_lines=3)
|
||||||
|
|
||||||
@@ -67,7 +88,7 @@ logger = get_logger("main")
|
|||||||
class MainSystem:
|
class MainSystem:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.hippocampus_manager = hippocampus_manager
|
self.hippocampus_manager = hippocampus_manager
|
||||||
|
|
||||||
self.individuality: Individuality = get_individuality()
|
self.individuality: Individuality = get_individuality()
|
||||||
|
|
||||||
# 使用消息API替代直接的FastAPI实例
|
# 使用消息API替代直接的FastAPI实例
|
||||||
@@ -207,7 +228,6 @@ MoFox_Bot(第三方修改版)
|
|||||||
get_emoji_manager().initialize()
|
get_emoji_manager().initialize()
|
||||||
logger.info("表情包管理器初始化成功")
|
logger.info("表情包管理器初始化成功")
|
||||||
|
|
||||||
|
|
||||||
# 启动情绪管理器
|
# 启动情绪管理器
|
||||||
await mood_manager.start()
|
await mood_manager.start()
|
||||||
logger.info("情绪管理器初始化成功")
|
logger.info("情绪管理器初始化成功")
|
||||||
@@ -222,11 +242,11 @@ MoFox_Bot(第三方修改版)
|
|||||||
# 初始化记忆系统
|
# 初始化记忆系统
|
||||||
self.hippocampus_manager.initialize()
|
self.hippocampus_manager.initialize()
|
||||||
logger.info("记忆系统初始化成功")
|
logger.info("记忆系统初始化成功")
|
||||||
|
|
||||||
# 初始化异步记忆管理器
|
# 初始化异步记忆管理器
|
||||||
try:
|
try:
|
||||||
from src.chat.memory_system.async_memory_optimizer import async_memory_manager
|
from src.chat.memory_system.async_memory_optimizer import async_memory_manager
|
||||||
|
|
||||||
await async_memory_manager.initialize()
|
await async_memory_manager.initialize()
|
||||||
logger.info("记忆管理器初始化成功")
|
logger.info("记忆管理器初始化成功")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -36,27 +36,19 @@ def get_context_groups(chat_id: str) -> Optional[List[List[str]]]:
|
|||||||
# 检查当前聊天的ID和类型是否在组的chat_ids中
|
# 检查当前聊天的ID和类型是否在组的chat_ids中
|
||||||
if [current_type, str(current_chat_raw_id)] in group.chat_ids:
|
if [current_type, str(current_chat_raw_id)] in group.chat_ids:
|
||||||
# 返回组内其他聊天的 [type, id] 列表
|
# 返回组内其他聊天的 [type, id] 列表
|
||||||
return [
|
return [chat_info for chat_info in group.chat_ids if chat_info != [current_type, str(current_chat_raw_id)]]
|
||||||
chat_info
|
|
||||||
for chat_info in group.chat_ids
|
|
||||||
if chat_info != [current_type, str(current_chat_raw_id)]
|
|
||||||
]
|
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
async def build_cross_context_normal(
|
async def build_cross_context_normal(chat_stream: ChatStream, other_chat_infos: List[List[str]]) -> str:
|
||||||
chat_stream: ChatStream, other_chat_infos: List[List[str]]
|
|
||||||
) -> str:
|
|
||||||
"""
|
"""
|
||||||
构建跨群聊/私聊上下文 (Normal模式)
|
构建跨群聊/私聊上下文 (Normal模式)
|
||||||
"""
|
"""
|
||||||
cross_context_messages = []
|
cross_context_messages = []
|
||||||
for chat_type, chat_raw_id in other_chat_infos:
|
for chat_type, chat_raw_id in other_chat_infos:
|
||||||
is_group = chat_type == "group"
|
is_group = chat_type == "group"
|
||||||
stream_id = get_chat_manager().get_stream_id(
|
stream_id = get_chat_manager().get_stream_id(chat_stream.platform, chat_raw_id, is_group=is_group)
|
||||||
chat_stream.platform, chat_raw_id, is_group=is_group
|
|
||||||
)
|
|
||||||
if not stream_id:
|
if not stream_id:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -68,9 +60,7 @@ async def build_cross_context_normal(
|
|||||||
)
|
)
|
||||||
if messages:
|
if messages:
|
||||||
chat_name = get_chat_manager().get_stream_name(stream_id) or chat_raw_id
|
chat_name = get_chat_manager().get_stream_name(stream_id) or chat_raw_id
|
||||||
formatted_messages, _ = build_readable_messages_with_id(
|
formatted_messages, _ = build_readable_messages_with_id(messages, timestamp_mode="relative")
|
||||||
messages, timestamp_mode="relative"
|
|
||||||
)
|
|
||||||
cross_context_messages.append(f'[以下是来自"{chat_name}"的近期消息]\n{formatted_messages}')
|
cross_context_messages.append(f'[以下是来自"{chat_name}"的近期消息]\n{formatted_messages}')
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"获取聊天 {chat_raw_id} 的消息失败: {e}")
|
logger.error(f"获取聊天 {chat_raw_id} 的消息失败: {e}")
|
||||||
@@ -97,9 +87,7 @@ async def build_cross_context_s4u(
|
|||||||
if user_id:
|
if user_id:
|
||||||
for chat_type, chat_raw_id in other_chat_infos:
|
for chat_type, chat_raw_id in other_chat_infos:
|
||||||
is_group = chat_type == "group"
|
is_group = chat_type == "group"
|
||||||
stream_id = get_chat_manager().get_stream_id(
|
stream_id = get_chat_manager().get_stream_id(chat_stream.platform, chat_raw_id, is_group=is_group)
|
||||||
chat_stream.platform, chat_raw_id, is_group=is_group
|
|
||||||
)
|
|
||||||
if not stream_id:
|
if not stream_id:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -114,9 +102,7 @@ async def build_cross_context_s4u(
|
|||||||
if user_messages:
|
if user_messages:
|
||||||
chat_name = get_chat_manager().get_stream_name(stream_id) or chat_raw_id
|
chat_name = get_chat_manager().get_stream_name(stream_id) or chat_raw_id
|
||||||
user_name = (
|
user_name = (
|
||||||
target_user_info.get("person_name")
|
target_user_info.get("person_name") or target_user_info.get("user_nickname") or user_id
|
||||||
or target_user_info.get("user_nickname")
|
|
||||||
or user_id
|
|
||||||
)
|
)
|
||||||
formatted_messages, _ = build_readable_messages_with_id(
|
formatted_messages, _ = build_readable_messages_with_id(
|
||||||
user_messages, timestamp_mode="relative"
|
user_messages, timestamp_mode="relative"
|
||||||
@@ -182,9 +168,7 @@ async def get_chat_history_by_group_name(group_name: str) -> str:
|
|||||||
)
|
)
|
||||||
if messages:
|
if messages:
|
||||||
chat_name = get_chat_manager().get_stream_name(stream_id) or chat_raw_id
|
chat_name = get_chat_manager().get_stream_name(stream_id) or chat_raw_id
|
||||||
formatted_messages, _ = build_readable_messages_with_id(
|
formatted_messages, _ = build_readable_messages_with_id(messages, timestamp_mode="relative")
|
||||||
messages, timestamp_mode="relative"
|
|
||||||
)
|
|
||||||
cross_context_messages.append(f'[以下是来自"{chat_name}"的近期消息]\n{formatted_messages}')
|
cross_context_messages.append(f'[以下是来自"{chat_name}"的近期消息]\n{formatted_messages}')
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"获取聊天 {chat_raw_id} 的消息失败: {e}")
|
logger.error(f"获取聊天 {chat_raw_id} 的消息失败: {e}")
|
||||||
@@ -193,4 +177,4 @@ async def get_chat_history_by_group_name(group_name: str) -> str:
|
|||||||
if not cross_context_messages:
|
if not cross_context_messages:
|
||||||
return f"无法从互通组 {group_name} 中获取任何聊天记录。"
|
return f"无法从互通组 {group_name} 中获取任何聊天记录。"
|
||||||
|
|
||||||
return "# 跨上下文参考\n" + "\n\n".join(cross_context_messages) + "\n"
|
return "# 跨上下文参考\n" + "\n\n".join(cross_context_messages) + "\n"
|
||||||
|
|||||||
@@ -107,9 +107,7 @@ async def generate_reply(
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# 获取回复器
|
# 获取回复器
|
||||||
replyer = get_replyer(
|
replyer = get_replyer(chat_stream, chat_id, request_type=request_type)
|
||||||
chat_stream, chat_id, request_type=request_type
|
|
||||||
)
|
|
||||||
if not replyer:
|
if not replyer:
|
||||||
logger.error("[GeneratorAPI] 无法获取回复器")
|
logger.error("[GeneratorAPI] 无法获取回复器")
|
||||||
return False, [], None
|
return False, [], None
|
||||||
|
|||||||
@@ -51,6 +51,7 @@ logger = get_logger("send_api")
|
|||||||
# 适配器命令响应等待池
|
# 适配器命令响应等待池
|
||||||
_adapter_response_pool: Dict[str, asyncio.Future] = {}
|
_adapter_response_pool: Dict[str, asyncio.Future] = {}
|
||||||
|
|
||||||
|
|
||||||
def message_dict_to_message_recv(message_dict: Dict[str, Any]) -> Optional[MessageRecv]:
|
def message_dict_to_message_recv(message_dict: Dict[str, Any]) -> Optional[MessageRecv]:
|
||||||
"""查找要回复的消息
|
"""查找要回复的消息
|
||||||
|
|
||||||
@@ -97,10 +98,11 @@ def message_dict_to_message_recv(message_dict: Dict[str, Any]) -> Optional[Messa
|
|||||||
}
|
}
|
||||||
|
|
||||||
message_recv = MessageRecv(message_dict)
|
message_recv = MessageRecv(message_dict)
|
||||||
|
|
||||||
logger.info(f"[SendAPI] 找到匹配的回复消息,发送者: {message_dict.get('user_nickname', '')}")
|
logger.info(f"[SendAPI] 找到匹配的回复消息,发送者: {message_dict.get('user_nickname', '')}")
|
||||||
return message_recv
|
return message_recv
|
||||||
|
|
||||||
|
|
||||||
def put_adapter_response(request_id: str, response_data: dict) -> None:
|
def put_adapter_response(request_id: str, response_data: dict) -> None:
|
||||||
"""将适配器响应放入响应池"""
|
"""将适配器响应放入响应池"""
|
||||||
if request_id in _adapter_response_pool:
|
if request_id in _adapter_response_pool:
|
||||||
@@ -192,7 +194,7 @@ async def _send_to_target(
|
|||||||
anchor_message.update_chat_stream(target_stream)
|
anchor_message.update_chat_stream(target_stream)
|
||||||
reply_to_platform_id = (
|
reply_to_platform_id = (
|
||||||
f"{anchor_message.message_info.platform}:{anchor_message.message_info.user_info.user_id}"
|
f"{anchor_message.message_info.platform}:{anchor_message.message_info.user_info.user_id}"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
anchor_message = None
|
anchor_message = None
|
||||||
reply_to_platform_id = None
|
reply_to_platform_id = None
|
||||||
@@ -234,7 +236,6 @@ async def _send_to_target(
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# 公共API函数 - 预定义类型的发送函数
|
# 公共API函数 - 预定义类型的发送函数
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
@@ -274,7 +275,9 @@ async def text_to_stream(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def emoji_to_stream(emoji_base64: str, stream_id: str, storage_message: bool = True, set_reply: bool = False) -> bool:
|
async def emoji_to_stream(
|
||||||
|
emoji_base64: str, stream_id: str, storage_message: bool = True, set_reply: bool = False
|
||||||
|
) -> bool:
|
||||||
"""向指定流发送表情包
|
"""向指定流发送表情包
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -285,10 +288,14 @@ async def emoji_to_stream(emoji_base64: str, stream_id: str, storage_message: bo
|
|||||||
Returns:
|
Returns:
|
||||||
bool: 是否发送成功
|
bool: 是否发送成功
|
||||||
"""
|
"""
|
||||||
return await _send_to_target("emoji", emoji_base64, stream_id, "", typing=False, storage_message=storage_message, set_reply=set_reply)
|
return await _send_to_target(
|
||||||
|
"emoji", emoji_base64, stream_id, "", typing=False, storage_message=storage_message, set_reply=set_reply
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def image_to_stream(image_base64: str, stream_id: str, storage_message: bool = True, set_reply: bool = False) -> bool:
|
async def image_to_stream(
|
||||||
|
image_base64: str, stream_id: str, storage_message: bool = True, set_reply: bool = False
|
||||||
|
) -> bool:
|
||||||
"""向指定流发送图片
|
"""向指定流发送图片
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -299,11 +306,17 @@ async def image_to_stream(image_base64: str, stream_id: str, storage_message: bo
|
|||||||
Returns:
|
Returns:
|
||||||
bool: 是否发送成功
|
bool: 是否发送成功
|
||||||
"""
|
"""
|
||||||
return await _send_to_target("image", image_base64, stream_id, "", typing=False, storage_message=storage_message, set_reply=set_reply)
|
return await _send_to_target(
|
||||||
|
"image", image_base64, stream_id, "", typing=False, storage_message=storage_message, set_reply=set_reply
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def command_to_stream(
|
async def command_to_stream(
|
||||||
command: Union[str, dict], stream_id: str, storage_message: bool = True, display_message: str = "", set_reply: bool = False
|
command: Union[str, dict],
|
||||||
|
stream_id: str,
|
||||||
|
storage_message: bool = True,
|
||||||
|
display_message: str = "",
|
||||||
|
set_reply: bool = False,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""向指定流发送命令
|
"""向指定流发送命令
|
||||||
|
|
||||||
|
|||||||
@@ -68,7 +68,7 @@ class EventManager:
|
|||||||
event = BaseEvent(event_name, allowed_subscribers, allowed_triggers)
|
event = BaseEvent(event_name, allowed_subscribers, allowed_triggers)
|
||||||
self._events[event_name] = event
|
self._events[event_name] = event
|
||||||
logger.debug(f"事件 {event_name} 注册成功")
|
logger.debug(f"事件 {event_name} 注册成功")
|
||||||
|
|
||||||
# 检查是否有缓存的订阅需要处理
|
# 检查是否有缓存的订阅需要处理
|
||||||
self._process_pending_subscriptions(event_name)
|
self._process_pending_subscriptions(event_name)
|
||||||
|
|
||||||
|
|||||||
@@ -53,7 +53,9 @@ class MaiZoneRefactoredPlugin(BasePlugin):
|
|||||||
"enable_reply": ConfigField(type=bool, default=True, description="完成后是否回复"),
|
"enable_reply": ConfigField(type=bool, default=True, description="完成后是否回复"),
|
||||||
"ai_image_number": ConfigField(type=int, default=1, description="AI生成图片数量"),
|
"ai_image_number": ConfigField(type=int, default=1, description="AI生成图片数量"),
|
||||||
"image_number": ConfigField(type=int, default=1, description="本地配图数量(1-9张)"),
|
"image_number": ConfigField(type=int, default=1, description="本地配图数量(1-9张)"),
|
||||||
"image_directory": ConfigField(type=str, default=(Path(__file__).parent / "images").as_posix(), description="图片存储目录")
|
"image_directory": ConfigField(
|
||||||
|
type=str, default=(Path(__file__).parent / "images").as_posix(), description="图片存储目录"
|
||||||
|
),
|
||||||
},
|
},
|
||||||
"read": {
|
"read": {
|
||||||
"permission": ConfigField(type=list, default=[], description="阅读权限QQ号列表"),
|
"permission": ConfigField(type=list, default=[], description="阅读权限QQ号列表"),
|
||||||
@@ -75,7 +77,9 @@ class MaiZoneRefactoredPlugin(BasePlugin):
|
|||||||
"forbidden_hours_end": ConfigField(type=int, default=6, description="禁止发送的结束小时(24小时制)"),
|
"forbidden_hours_end": ConfigField(type=int, default=6, description="禁止发送的结束小时(24小时制)"),
|
||||||
},
|
},
|
||||||
"cookie": {
|
"cookie": {
|
||||||
"http_fallback_host": ConfigField(type=str, default="127.0.0.1", description="备用Cookie获取服务的主机地址"),
|
"http_fallback_host": ConfigField(
|
||||||
|
type=str, default="127.0.0.1", description="备用Cookie获取服务的主机地址"
|
||||||
|
),
|
||||||
"http_fallback_port": ConfigField(type=int, default=9999, description="备用Cookie获取服务的端口"),
|
"http_fallback_port": ConfigField(type=int, default=9999, description="备用Cookie获取服务的端口"),
|
||||||
"napcat_token": ConfigField(type=str, default="", description="Napcat服务的认证Token(可选)"),
|
"napcat_token": ConfigField(type=str, default="", description="Napcat服务的认证Token(可选)"),
|
||||||
},
|
},
|
||||||
@@ -95,14 +99,14 @@ class MaiZoneRefactoredPlugin(BasePlugin):
|
|||||||
image_service = ImageService(self.get_config)
|
image_service = ImageService(self.get_config)
|
||||||
cookie_service = CookieService(self.get_config)
|
cookie_service = CookieService(self.get_config)
|
||||||
reply_tracker_service = ReplyTrackerService()
|
reply_tracker_service = ReplyTrackerService()
|
||||||
|
|
||||||
# 使用已创建的 reply_tracker_service 实例
|
# 使用已创建的 reply_tracker_service 实例
|
||||||
qzone_service = QZoneService(
|
qzone_service = QZoneService(
|
||||||
self.get_config,
|
self.get_config,
|
||||||
content_service,
|
content_service,
|
||||||
image_service,
|
image_service,
|
||||||
cookie_service,
|
cookie_service,
|
||||||
reply_tracker_service # 传入已创建的实例
|
reply_tracker_service, # 传入已创建的实例
|
||||||
)
|
)
|
||||||
scheduler_service = SchedulerService(self.get_config, qzone_service)
|
scheduler_service = SchedulerService(self.get_config, qzone_service)
|
||||||
monitor_service = MonitorService(self.get_config, qzone_service)
|
monitor_service = MonitorService(self.get_config, qzone_service)
|
||||||
|
|||||||
@@ -272,8 +272,10 @@ class QZoneService:
|
|||||||
# 检查是否已经在持久化记录中标记为已回复
|
# 检查是否已经在持久化记录中标记为已回复
|
||||||
if not self.reply_tracker.has_replied(fid, comment_tid):
|
if not self.reply_tracker.has_replied(fid, comment_tid):
|
||||||
# 记录日志以便追踪
|
# 记录日志以便追踪
|
||||||
logger.debug(f"发现新评论需要回复 - 说说ID: {fid}, 评论ID: {comment_tid}, "
|
logger.debug(
|
||||||
f"评论人: {comment.get('nickname', '')}, 内容: {comment.get('content', '')}")
|
f"发现新评论需要回复 - 说说ID: {fid}, 评论ID: {comment_tid}, "
|
||||||
|
f"评论人: {comment.get('nickname', '')}, 内容: {comment.get('content', '')}"
|
||||||
|
)
|
||||||
comments_to_reply.append(comment)
|
comments_to_reply.append(comment)
|
||||||
|
|
||||||
if not comments_to_reply:
|
if not comments_to_reply:
|
||||||
|
|||||||
@@ -41,7 +41,7 @@ class ReplyTrackerService:
|
|||||||
if not isinstance(data, dict):
|
if not isinstance(data, dict):
|
||||||
logger.error("加载的数据不是字典格式")
|
logger.error("加载的数据不是字典格式")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
for feed_id, comments in data.items():
|
for feed_id, comments in data.items():
|
||||||
if not isinstance(feed_id, str):
|
if not isinstance(feed_id, str):
|
||||||
logger.error(f"无效的说说ID格式: {feed_id}")
|
logger.error(f"无效的说说ID格式: {feed_id}")
|
||||||
@@ -70,12 +70,14 @@ class ReplyTrackerService:
|
|||||||
logger.warning("回复记录文件为空,将创建新的记录")
|
logger.warning("回复记录文件为空,将创建新的记录")
|
||||||
self.replied_comments = {}
|
self.replied_comments = {}
|
||||||
return
|
return
|
||||||
|
|
||||||
data = json.loads(file_content)
|
data = json.loads(file_content)
|
||||||
if self._validate_data(data):
|
if self._validate_data(data):
|
||||||
self.replied_comments = data
|
self.replied_comments = data
|
||||||
logger.info(f"已加载 {len(self.replied_comments)} 条说说的回复记录,"
|
logger.info(
|
||||||
f"总计 {sum(len(comments) for comments in self.replied_comments.values())} 条评论")
|
f"已加载 {len(self.replied_comments)} 条说说的回复记录,"
|
||||||
|
f"总计 {sum(len(comments) for comments in self.replied_comments.values())} 条评论"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
logger.error("加载的数据格式无效,将创建新的记录")
|
logger.error("加载的数据格式无效,将创建新的记录")
|
||||||
self.replied_comments = {}
|
self.replied_comments = {}
|
||||||
@@ -112,12 +114,12 @@ class ReplyTrackerService:
|
|||||||
self._cleanup_old_records()
|
self._cleanup_old_records()
|
||||||
|
|
||||||
# 创建临时文件
|
# 创建临时文件
|
||||||
temp_file = self.reply_record_file.with_suffix('.tmp')
|
temp_file = self.reply_record_file.with_suffix(".tmp")
|
||||||
|
|
||||||
# 先写入临时文件
|
# 先写入临时文件
|
||||||
with open(temp_file, "w", encoding="utf-8") as f:
|
with open(temp_file, "w", encoding="utf-8") as f:
|
||||||
json.dump(self.replied_comments, f, ensure_ascii=False, indent=2)
|
json.dump(self.replied_comments, f, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
# 如果写入成功,重命名为正式文件
|
# 如果写入成功,重命名为正式文件
|
||||||
if temp_file.stat().st_size > 0: # 确保写入成功
|
if temp_file.stat().st_size > 0: # 确保写入成功
|
||||||
# 在Windows上,如果目标文件已存在,需要先删除它
|
# 在Windows上,如果目标文件已存在,需要先删除它
|
||||||
@@ -128,7 +130,7 @@ class ReplyTrackerService:
|
|||||||
else:
|
else:
|
||||||
logger.error("临时文件写入失败,文件大小为0")
|
logger.error("临时文件写入失败,文件大小为0")
|
||||||
temp_file.unlink() # 删除空的临时文件
|
temp_file.unlink() # 删除空的临时文件
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"保存回复记录失败: {e}", exc_info=True)
|
logger.error(f"保存回复记录失败: {e}", exc_info=True)
|
||||||
# 尝试删除可能存在的临时文件
|
# 尝试删除可能存在的临时文件
|
||||||
@@ -204,7 +206,7 @@ class ReplyTrackerService:
|
|||||||
|
|
||||||
# 确保将comment_id转换为字符串格式
|
# 确保将comment_id转换为字符串格式
|
||||||
comment_id_str = str(comment_id)
|
comment_id_str = str(comment_id)
|
||||||
|
|
||||||
if feed_id not in self.replied_comments:
|
if feed_id not in self.replied_comments:
|
||||||
self.replied_comments[feed_id] = {}
|
self.replied_comments[feed_id] = {}
|
||||||
|
|
||||||
|
|||||||
@@ -76,7 +76,7 @@ class MonthlyPlanManager:
|
|||||||
if len(plans) > max_plans:
|
if len(plans) > max_plans:
|
||||||
logger.warning(f"当前月度计划数量 ({len(plans)}) 超出上限 ({max_plans}),将自动删除多余的计划。")
|
logger.warning(f"当前月度计划数量 ({len(plans)}) 超出上限 ({max_plans}),将自动删除多余的计划。")
|
||||||
# 数据库查询结果已按创建时间降序排序(新的在前),直接截取超出上限的部分进行删除
|
# 数据库查询结果已按创建时间降序排序(新的在前),直接截取超出上限的部分进行删除
|
||||||
plans_to_delete = plans[:len(plans)-max_plans]
|
plans_to_delete = plans[: len(plans) - max_plans]
|
||||||
delete_ids = [p.id for p in plans_to_delete]
|
delete_ids = [p.id for p in plans_to_delete]
|
||||||
delete_plans_by_ids(delete_ids)
|
delete_plans_by_ids(delete_ids)
|
||||||
# 重新获取计划列表
|
# 重新获取计划列表
|
||||||
@@ -101,7 +101,7 @@ class MonthlyPlanManager:
|
|||||||
async def _generate_monthly_plans_logic(self, target_month: Optional[str] = None) -> bool:
|
async def _generate_monthly_plans_logic(self, target_month: Optional[str] = None) -> bool:
|
||||||
"""
|
"""
|
||||||
生成指定月份的月度计划的核心逻辑
|
生成指定月份的月度计划的核心逻辑
|
||||||
|
|
||||||
:param target_month: 目标月份,格式为 "YYYY-MM"。如果为 None,则为当前月份。
|
:param target_month: 目标月份,格式为 "YYYY-MM"。如果为 None,则为当前月份。
|
||||||
:return: 是否生成成功
|
:return: 是否生成成功
|
||||||
"""
|
"""
|
||||||
@@ -291,6 +291,8 @@ class MonthlyPlanManager:
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f" 归档 {target_month} 月度计划时发生错误: {e}")
|
logger.error(f" 归档 {target_month} 月度计划时发生错误: {e}")
|
||||||
|
|
||||||
|
|
||||||
class MonthlyPlanGenerationTask(AsyncTask):
|
class MonthlyPlanGenerationTask(AsyncTask):
|
||||||
"""每月初自动生成新月度计划的任务"""
|
"""每月初自动生成新月度计划的任务"""
|
||||||
|
|
||||||
@@ -327,7 +329,7 @@ class MonthlyPlanGenerationTask(AsyncTask):
|
|||||||
current_month = next_month.strftime("%Y-%m")
|
current_month = next_month.strftime("%Y-%m")
|
||||||
logger.info(f" 到达月初,开始生成 {current_month} 的月度计划...")
|
logger.info(f" 到达月初,开始生成 {current_month} 的月度计划...")
|
||||||
await self.monthly_plan_manager._generate_monthly_plans_logic(current_month)
|
await self.monthly_plan_manager._generate_monthly_plans_logic(current_month)
|
||||||
|
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
logger.info(" 每月月度计划生成任务被取消。")
|
logger.info(" 每月月度计划生成任务被取消。")
|
||||||
break
|
break
|
||||||
|
|||||||
@@ -165,14 +165,16 @@ class ScheduleManager:
|
|||||||
schedule_str = f"已成功加载今天的日程 ({today_str}):\n"
|
schedule_str = f"已成功加载今天的日程 ({today_str}):\n"
|
||||||
if self.today_schedule:
|
if self.today_schedule:
|
||||||
for item in self.today_schedule:
|
for item in self.today_schedule:
|
||||||
schedule_str += f" - {item.get('time_range', '未知时间')}: {item.get('activity', '未知活动')}\n"
|
schedule_str += (
|
||||||
|
f" - {item.get('time_range', '未知时间')}: {item.get('activity', '未知活动')}\n"
|
||||||
|
)
|
||||||
logger.info(schedule_str)
|
logger.info(schedule_str)
|
||||||
return # 成功加载,直接返回
|
return # 成功加载,直接返回
|
||||||
else:
|
else:
|
||||||
logger.warning("数据库中的日程数据格式无效,将重新生成日程")
|
logger.warning("数据库中的日程数据格式无效,将重新生成日程")
|
||||||
else:
|
else:
|
||||||
logger.info(f"数据库中未找到今天的日程 ({today_str}),将调用 LLM 生成。")
|
logger.info(f"数据库中未找到今天的日程 ({today_str}),将调用 LLM 生成。")
|
||||||
|
|
||||||
# 仅在需要时生成
|
# 仅在需要时生成
|
||||||
await self.generate_and_save_schedule()
|
await self.generate_and_save_schedule()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user