ruff
This commit is contained in:
committed by
Windpicker-owo
parent
e65ab14f94
commit
950b086063
39
bot.py
39
bot.py
@@ -33,6 +33,7 @@ script_dir = os.path.dirname(os.path.abspath(__file__))
|
|||||||
os.chdir(script_dir)
|
os.chdir(script_dir)
|
||||||
logger.info("工作目录已设置")
|
logger.info("工作目录已设置")
|
||||||
|
|
||||||
|
|
||||||
class ConfigManager:
|
class ConfigManager:
|
||||||
"""配置管理器"""
|
"""配置管理器"""
|
||||||
|
|
||||||
@@ -96,6 +97,7 @@ class ConfigManager:
|
|||||||
logger.error(f"加载环境变量失败: {e}")
|
logger.error(f"加载环境变量失败: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
class EULAManager:
|
class EULAManager:
|
||||||
"""EULA管理类"""
|
"""EULA管理类"""
|
||||||
|
|
||||||
@@ -134,7 +136,9 @@ class EULAManager:
|
|||||||
return
|
return
|
||||||
|
|
||||||
if attempts % 5 == 0:
|
if attempts % 5 == 0:
|
||||||
confirm_logger.critical(f"请修改 .env 文件中的 EULA_CONFIRMED=true (尝试 {attempts}/{MAX_EULA_CHECK_ATTEMPTS})")
|
confirm_logger.critical(
|
||||||
|
f"请修改 .env 文件中的 EULA_CONFIRMED=true (尝试 {attempts}/{MAX_EULA_CHECK_ATTEMPTS})"
|
||||||
|
)
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
confirm_logger.info("用户取消,程序退出")
|
confirm_logger.info("用户取消,程序退出")
|
||||||
@@ -148,16 +152,14 @@ class EULAManager:
|
|||||||
confirm_logger.error("EULA确认超时,程序退出")
|
confirm_logger.error("EULA确认超时,程序退出")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
class TaskManager:
|
class TaskManager:
|
||||||
"""任务管理器"""
|
"""任务管理器"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def cancel_pending_tasks(loop, timeout=SHUTDOWN_TIMEOUT):
|
async def cancel_pending_tasks(loop, timeout=SHUTDOWN_TIMEOUT):
|
||||||
"""取消所有待处理的任务"""
|
"""取消所有待处理的任务"""
|
||||||
remaining_tasks = [
|
remaining_tasks = [t for t in asyncio.all_tasks(loop) if t is not asyncio.current_task(loop) and not t.done()]
|
||||||
t for t in asyncio.all_tasks(loop)
|
|
||||||
if t is not asyncio.current_task(loop) and not t.done()
|
|
||||||
]
|
|
||||||
|
|
||||||
if not remaining_tasks:
|
if not remaining_tasks:
|
||||||
logger.info("没有待取消的任务")
|
logger.info("没有待取消的任务")
|
||||||
@@ -171,10 +173,7 @@ class TaskManager:
|
|||||||
|
|
||||||
# 等待任务完成
|
# 等待任务完成
|
||||||
try:
|
try:
|
||||||
results = await asyncio.wait_for(
|
results = await asyncio.wait_for(asyncio.gather(*remaining_tasks, return_exceptions=True), timeout=timeout)
|
||||||
asyncio.gather(*remaining_tasks, return_exceptions=True),
|
|
||||||
timeout=timeout
|
|
||||||
)
|
|
||||||
|
|
||||||
# 检查任务结果
|
# 检查任务结果
|
||||||
for i, result in enumerate(results):
|
for i, result in enumerate(results):
|
||||||
@@ -195,6 +194,7 @@ class TaskManager:
|
|||||||
"""停止所有异步任务"""
|
"""停止所有异步任务"""
|
||||||
try:
|
try:
|
||||||
from src.manager.async_task_manager import async_task_manager
|
from src.manager.async_task_manager import async_task_manager
|
||||||
|
|
||||||
await async_task_manager.stop_and_wait_all_tasks()
|
await async_task_manager.stop_and_wait_all_tasks()
|
||||||
return True
|
return True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@@ -204,6 +204,7 @@ class TaskManager:
|
|||||||
logger.error(f"停止异步任务失败: {e}")
|
logger.error(f"停止异步任务失败: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
class ShutdownManager:
|
class ShutdownManager:
|
||||||
"""关闭管理器"""
|
"""关闭管理器"""
|
||||||
|
|
||||||
@@ -236,6 +237,7 @@ class ShutdownManager:
|
|||||||
logger.error(f"麦麦关闭失败: {e}", exc_info=True)
|
logger.error(f"麦麦关闭失败: {e}", exc_info=True)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def create_event_loop_context():
|
async def create_event_loop_context():
|
||||||
"""创建事件循环的上下文管理器"""
|
"""创建事件循环的上下文管理器"""
|
||||||
@@ -260,6 +262,7 @@ async def create_event_loop_context():
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"关闭事件循环失败: {e}")
|
logger.error(f"关闭事件循环失败: {e}")
|
||||||
|
|
||||||
|
|
||||||
class DatabaseManager:
|
class DatabaseManager:
|
||||||
"""数据库连接管理器"""
|
"""数据库连接管理器"""
|
||||||
|
|
||||||
@@ -278,7 +281,9 @@ class DatabaseManager:
|
|||||||
# 使用线程执行器运行潜在的阻塞操作
|
# 使用线程执行器运行潜在的阻塞操作
|
||||||
await asyncio.to_thread(initialize_sql_database, global_config.database)
|
await asyncio.to_thread(initialize_sql_database, global_config.database)
|
||||||
elapsed_time = time.time() - start_time
|
elapsed_time = time.time() - start_time
|
||||||
logger.info(f"数据库连接初始化成功,使用 {global_config.database.database_type} 数据库,耗时: {elapsed_time:.2f}秒")
|
logger.info(
|
||||||
|
f"数据库连接初始化成功,使用 {global_config.database.database_type} 数据库,耗时: {elapsed_time:.2f}秒"
|
||||||
|
)
|
||||||
|
|
||||||
return self
|
return self
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -291,6 +296,7 @@ class DatabaseManager:
|
|||||||
logger.error(f"数据库操作发生异常: {exc_val}")
|
logger.error(f"数据库操作发生异常: {exc_val}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
class ConfigurationValidator:
|
class ConfigurationValidator:
|
||||||
"""配置验证器"""
|
"""配置验证器"""
|
||||||
|
|
||||||
@@ -328,6 +334,7 @@ class ConfigurationValidator:
|
|||||||
logger.error(f"配置验证失败: {e}")
|
logger.error(f"配置验证失败: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
class EasterEgg:
|
class EasterEgg:
|
||||||
"""彩蛋功能"""
|
"""彩蛋功能"""
|
||||||
|
|
||||||
@@ -347,6 +354,7 @@ class EasterEgg:
|
|||||||
rainbow_text += rainbow_colors[i % len(rainbow_colors)] + char
|
rainbow_text += rainbow_colors[i % len(rainbow_colors)] + char
|
||||||
logger.info(rainbow_text)
|
logger.info(rainbow_text)
|
||||||
|
|
||||||
|
|
||||||
class MaiBotMain:
|
class MaiBotMain:
|
||||||
"""麦麦机器人主程序类"""
|
"""麦麦机器人主程序类"""
|
||||||
|
|
||||||
@@ -375,6 +383,7 @@ class MaiBotMain:
|
|||||||
try:
|
try:
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
from src.common.database.sqlalchemy_models import initialize_database as init_db
|
from src.common.database.sqlalchemy_models import initialize_database as init_db
|
||||||
|
|
||||||
await init_db()
|
await init_db()
|
||||||
elapsed_time = time.time() - start_time
|
elapsed_time = time.time() - start_time
|
||||||
logger.info(f"数据库表结构初始化完成,耗时: {elapsed_time:.2f}秒")
|
logger.info(f"数据库表结构初始化完成,耗时: {elapsed_time:.2f}秒")
|
||||||
@@ -385,6 +394,7 @@ class MaiBotMain:
|
|||||||
def create_main_system(self):
|
def create_main_system(self):
|
||||||
"""创建MainSystem实例"""
|
"""创建MainSystem实例"""
|
||||||
from src.main import MainSystem
|
from src.main import MainSystem
|
||||||
|
|
||||||
self.main_system = MainSystem()
|
self.main_system = MainSystem()
|
||||||
return self.main_system
|
return self.main_system
|
||||||
|
|
||||||
@@ -411,11 +421,13 @@ class MaiBotMain:
|
|||||||
|
|
||||||
# 初始化知识库
|
# 初始化知识库
|
||||||
from src.chat.knowledge.knowledge_lib import initialize_lpmm_knowledge
|
from src.chat.knowledge.knowledge_lib import initialize_lpmm_knowledge
|
||||||
|
|
||||||
initialize_lpmm_knowledge()
|
initialize_lpmm_knowledge()
|
||||||
|
|
||||||
# 显示彩蛋
|
# 显示彩蛋
|
||||||
EasterEgg.show()
|
EasterEgg.show()
|
||||||
|
|
||||||
|
|
||||||
async def wait_for_user_input():
|
async def wait_for_user_input():
|
||||||
"""等待用户输入(异步方式)"""
|
"""等待用户输入(异步方式)"""
|
||||||
try:
|
try:
|
||||||
@@ -432,6 +444,7 @@ async def wait_for_user_input():
|
|||||||
logger.error(f"等待用户输入时发生错误: {e}")
|
logger.error(f"等待用户输入时发生错误: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
async def main_async():
|
async def main_async():
|
||||||
"""主异步函数"""
|
"""主异步函数"""
|
||||||
exit_code = 0
|
exit_code = 0
|
||||||
@@ -455,10 +468,7 @@ async def main_async():
|
|||||||
user_input_done = asyncio.create_task(wait_for_user_input())
|
user_input_done = asyncio.create_task(wait_for_user_input())
|
||||||
|
|
||||||
# 使用wait等待任意一个任务完成
|
# 使用wait等待任意一个任务完成
|
||||||
done, pending = await asyncio.wait(
|
done, pending = await asyncio.wait([main_task, user_input_done], return_when=asyncio.FIRST_COMPLETED)
|
||||||
[main_task, user_input_done],
|
|
||||||
return_when=asyncio.FIRST_COMPLETED
|
|
||||||
)
|
|
||||||
|
|
||||||
# 如果用户输入任务完成(用户按了Ctrl+C),取消主任务
|
# 如果用户输入任务完成(用户按了Ctrl+C),取消主任务
|
||||||
if user_input_done in done and main_task not in done:
|
if user_input_done in done and main_task not in done:
|
||||||
@@ -482,6 +492,7 @@ async def main_async():
|
|||||||
|
|
||||||
return exit_code
|
return exit_code
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
exit_code = 0
|
exit_code = 0
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -12,10 +12,13 @@ logger = get_logger("HTTP消息API")
|
|||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
@router.get("/messages/recent")
|
@router.get("/messages/recent")
|
||||||
async def get_message_stats(
|
async def get_message_stats(
|
||||||
days: int = Query(1, ge=1, description="指定查询过去多少天的数据"),
|
days: int = Query(1, ge=1, description="指定查询过去多少天的数据"),
|
||||||
message_type: Literal["all", "sent", "received"] = Query("all", description="筛选消息类型: 'sent' (BOT发送的), 'received' (BOT接收的), or 'all' (全部)")
|
message_type: Literal["all", "sent", "received"] = Query(
|
||||||
|
"all", description="筛选消息类型: 'sent' (BOT发送的), 'received' (BOT接收的), or 'all' (全部)"
|
||||||
|
),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
获取BOT在指定天数内的消息统计数据。
|
获取BOT在指定天数内的消息统计数据。
|
||||||
@@ -45,7 +48,7 @@ async def get_message_stats(
|
|||||||
"message_type": message_type,
|
"message_type": message_type,
|
||||||
"sent_count": sent_count,
|
"sent_count": sent_count,
|
||||||
"received_count": received_count,
|
"received_count": received_count,
|
||||||
"total_count": len(messages)
|
"total_count": len(messages),
|
||||||
}
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -76,10 +79,7 @@ async def get_message_stats_by_chat(
|
|||||||
user_id = msg.get("user_id")
|
user_id = msg.get("user_id")
|
||||||
|
|
||||||
if chat_id not in stats:
|
if chat_id not in stats:
|
||||||
stats[chat_id] = {
|
stats[chat_id] = {"total_stats": {"total": 0}, "user_stats": {}}
|
||||||
"total_stats": {"total": 0},
|
|
||||||
"user_stats": {}
|
|
||||||
}
|
|
||||||
|
|
||||||
stats[chat_id]["total_stats"]["total"] += 1
|
stats[chat_id]["total_stats"]["total"] += 1
|
||||||
|
|
||||||
@@ -116,10 +116,7 @@ async def get_message_stats_by_chat(
|
|||||||
for user_id, count in data["user_stats"].items():
|
for user_id, count in data["user_stats"].items():
|
||||||
person_id = person_api.get_person_id("qq", user_id)
|
person_id = person_api.get_person_id("qq", user_id)
|
||||||
nickname = await person_api.get_person_value(person_id, "nickname", "未知用户")
|
nickname = await person_api.get_person_value(person_id, "nickname", "未知用户")
|
||||||
formatted_data["user_stats"][user_id] = {
|
formatted_data["user_stats"][user_id] = {"nickname": nickname, "count": count}
|
||||||
"nickname": nickname,
|
|
||||||
"count": count
|
|
||||||
}
|
|
||||||
|
|
||||||
formatted_stats[chat_id] = formatted_data
|
formatted_stats[chat_id] = formatted_data
|
||||||
return formatted_stats
|
return formatted_stats
|
||||||
@@ -129,6 +126,7 @@ async def get_message_stats_by_chat(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
@router.get("/messages/bot_stats_by_chat")
|
@router.get("/messages/bot_stats_by_chat")
|
||||||
async def get_bot_message_stats_by_chat(
|
async def get_bot_message_stats_by_chat(
|
||||||
days: int = Query(1, ge=1, description="指定查询过去多少天的数据"),
|
days: int = Query(1, ge=1, description="指定查询过去多少天的数据"),
|
||||||
@@ -165,10 +163,7 @@ async def get_bot_message_stats_by_chat(
|
|||||||
elif stream.user_info and stream.user_info.user_nickname:
|
elif stream.user_info and stream.user_info.user_nickname:
|
||||||
chat_name = stream.user_info.user_nickname
|
chat_name = stream.user_info.user_nickname
|
||||||
|
|
||||||
formatted_stats[chat_id] = {
|
formatted_stats[chat_id] = {"chat_name": chat_name, "count": count}
|
||||||
"chat_name": chat_name,
|
|
||||||
"count": count
|
|
||||||
}
|
|
||||||
return formatted_stats
|
return formatted_stats
|
||||||
|
|
||||||
return stats
|
return stats
|
||||||
|
|||||||
@@ -313,7 +313,9 @@ class EnergyManager:
|
|||||||
|
|
||||||
# 确保 score 是 float 类型
|
# 确保 score 是 float 类型
|
||||||
if not isinstance(score, int | float):
|
if not isinstance(score, int | float):
|
||||||
logger.warning(f"计算器 {calculator.__class__.__name__} 返回了非数值类型: {type(score)},跳过此组件")
|
logger.warning(
|
||||||
|
f"计算器 {calculator.__class__.__name__} 返回了非数值类型: {type(score)},跳过此组件"
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
component_scores[calculator.__class__.__name__] = float(score)
|
component_scores[calculator.__class__.__name__] = float(score)
|
||||||
|
|||||||
@@ -429,7 +429,9 @@ class BotInterestManager:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"❌ 计算相似度分数失败: {e}")
|
logger.error(f"❌ 计算相似度分数失败: {e}")
|
||||||
|
|
||||||
async def calculate_interest_match(self, message_text: str, keywords: list[str] | None = None) -> InterestMatchResult:
|
async def calculate_interest_match(
|
||||||
|
self, message_text: str, keywords: list[str] | None = None
|
||||||
|
) -> InterestMatchResult:
|
||||||
"""计算消息与机器人兴趣的匹配度"""
|
"""计算消息与机器人兴趣的匹配度"""
|
||||||
if not self.current_interests or not self._initialized:
|
if not self.current_interests or not self._initialized:
|
||||||
raise RuntimeError("❌ 兴趣标签系统未初始化")
|
raise RuntimeError("❌ 兴趣标签系统未初始化")
|
||||||
|
|||||||
@@ -79,7 +79,9 @@ class InterestManager:
|
|||||||
|
|
||||||
# 如果已有组件在运行,先清理并替换
|
# 如果已有组件在运行,先清理并替换
|
||||||
if self._current_calculator:
|
if self._current_calculator:
|
||||||
logger.info(f"替换现有兴趣值计算组件: {self._current_calculator.component_name} -> {calculator.component_name}")
|
logger.info(
|
||||||
|
f"替换现有兴趣值计算组件: {self._current_calculator.component_name} -> {calculator.component_name}"
|
||||||
|
)
|
||||||
await self._current_calculator.cleanup()
|
await self._current_calculator.cleanup()
|
||||||
else:
|
else:
|
||||||
logger.info(f"注册新的兴趣值计算组件: {calculator.component_name}")
|
logger.info(f"注册新的兴趣值计算组件: {calculator.component_name}")
|
||||||
@@ -114,7 +116,7 @@ class InterestManager:
|
|||||||
success=False,
|
success=False,
|
||||||
message_id=getattr(message, "message_id", ""),
|
message_id=getattr(message, "message_id", ""),
|
||||||
interest_value=0.3,
|
interest_value=0.3,
|
||||||
error_message="没有可用的兴趣值计算组件"
|
error_message="没有可用的兴趣值计算组件",
|
||||||
)
|
)
|
||||||
|
|
||||||
# 使用 create_task 异步执行计算
|
# 使用 create_task 异步执行计算
|
||||||
@@ -133,7 +135,7 @@ class InterestManager:
|
|||||||
interest_value=0.5, # 固定默认兴趣值
|
interest_value=0.5, # 固定默认兴趣值
|
||||||
should_reply=False,
|
should_reply=False,
|
||||||
should_act=False,
|
should_act=False,
|
||||||
error_message=f"计算超时({timeout}s),使用默认值"
|
error_message=f"计算超时({timeout}s),使用默认值",
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# 发生异常,返回默认结果
|
# 发生异常,返回默认结果
|
||||||
@@ -142,7 +144,7 @@ class InterestManager:
|
|||||||
success=False,
|
success=False,
|
||||||
message_id=getattr(message, "message_id", ""),
|
message_id=getattr(message, "message_id", ""),
|
||||||
interest_value=0.3,
|
interest_value=0.3,
|
||||||
error_message=f"计算异常: {e!s}"
|
error_message=f"计算异常: {e!s}",
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _async_calculate(self, message: "DatabaseMessages") -> InterestCalculationResult:
|
async def _async_calculate(self, message: "DatabaseMessages") -> InterestCalculationResult:
|
||||||
@@ -171,7 +173,7 @@ class InterestManager:
|
|||||||
message_id=getattr(message, "message_id", ""),
|
message_id=getattr(message, "message_id", ""),
|
||||||
interest_value=0.0,
|
interest_value=0.0,
|
||||||
error_message=f"计算异常: {e!s}",
|
error_message=f"计算异常: {e!s}",
|
||||||
calculation_time=time.time() - start_time
|
calculation_time=time.time() - start_time,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _calculation_worker(self):
|
async def _calculation_worker(self):
|
||||||
@@ -179,10 +181,7 @@ class InterestManager:
|
|||||||
while not self._shutdown_event.is_set():
|
while not self._shutdown_event.is_set():
|
||||||
try:
|
try:
|
||||||
# 等待计算任务或关闭信号
|
# 等待计算任务或关闭信号
|
||||||
await asyncio.wait_for(
|
await asyncio.wait_for(self._calculation_queue.get(), timeout=1.0)
|
||||||
self._calculation_queue.get(),
|
|
||||||
timeout=1.0
|
|
||||||
)
|
|
||||||
|
|
||||||
# 处理计算任务
|
# 处理计算任务
|
||||||
# 这里可以实现批量处理逻辑
|
# 这里可以实现批量处理逻辑
|
||||||
@@ -210,7 +209,7 @@ class InterestManager:
|
|||||||
"failed_calculations": self._failed_calculations,
|
"failed_calculations": self._failed_calculations,
|
||||||
"success_rate": success_rate,
|
"success_rate": success_rate,
|
||||||
"last_calculation_time": self._last_calculation_time,
|
"last_calculation_time": self._last_calculation_time,
|
||||||
"current_calculator": self._current_calculator.component_name if self._current_calculator else None
|
"current_calculator": self._current_calculator.component_name if self._current_calculator else None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ logger = get_logger(__name__)
|
|||||||
@dataclass
|
@dataclass
|
||||||
class HippocampusSampleConfig:
|
class HippocampusSampleConfig:
|
||||||
"""海马体采样配置"""
|
"""海马体采样配置"""
|
||||||
|
|
||||||
# 双峰分布参数
|
# 双峰分布参数
|
||||||
recent_mean_hours: float = 12.0 # 近期分布均值(小时)
|
recent_mean_hours: float = 12.0 # 近期分布均值(小时)
|
||||||
recent_std_hours: float = 8.0 # 近期分布标准差(小时)
|
recent_std_hours: float = 8.0 # 近期分布标准差(小时)
|
||||||
@@ -84,12 +85,10 @@ class HippocampusSampler:
|
|||||||
try:
|
try:
|
||||||
# 初始化LLM模型
|
# 初始化LLM模型
|
||||||
from src.config.config import model_config
|
from src.config.config import model_config
|
||||||
|
|
||||||
task_config = getattr(model_config.model_task_config, "utils", None)
|
task_config = getattr(model_config.model_task_config, "utils", None)
|
||||||
if task_config:
|
if task_config:
|
||||||
self.memory_builder_model = LLMRequest(
|
self.memory_builder_model = LLMRequest(model_set=task_config, request_type="memory.hippocampus_build")
|
||||||
model_set=task_config,
|
|
||||||
request_type="memory.hippocampus_build"
|
|
||||||
)
|
|
||||||
asyncio.create_task(self.start_background_sampling())
|
asyncio.create_task(self.start_background_sampling())
|
||||||
logger.info("✅ 海马体采样器初始化成功")
|
logger.info("✅ 海马体采样器初始化成功")
|
||||||
else:
|
else:
|
||||||
@@ -107,14 +106,10 @@ class HippocampusSampler:
|
|||||||
|
|
||||||
# 生成两个正态分布的小时偏移
|
# 生成两个正态分布的小时偏移
|
||||||
recent_offsets = np.random.normal(
|
recent_offsets = np.random.normal(
|
||||||
loc=self.config.recent_mean_hours,
|
loc=self.config.recent_mean_hours, scale=self.config.recent_std_hours, size=recent_samples
|
||||||
scale=self.config.recent_std_hours,
|
|
||||||
size=recent_samples
|
|
||||||
)
|
)
|
||||||
distant_offsets = np.random.normal(
|
distant_offsets = np.random.normal(
|
||||||
loc=self.config.distant_mean_hours,
|
loc=self.config.distant_mean_hours, scale=self.config.distant_std_hours, size=distant_samples
|
||||||
scale=self.config.distant_std_hours,
|
|
||||||
size=distant_samples
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 合并两个分布的偏移
|
# 合并两个分布的偏移
|
||||||
@@ -122,10 +117,7 @@ class HippocampusSampler:
|
|||||||
|
|
||||||
# 转换为时间戳(使用绝对值确保时间点在过去)
|
# 转换为时间戳(使用绝对值确保时间点在过去)
|
||||||
base_time = datetime.now()
|
base_time = datetime.now()
|
||||||
timestamps = [
|
timestamps = [base_time - timedelta(hours=abs(offset)) for offset in all_offsets]
|
||||||
base_time - timedelta(hours=abs(offset))
|
|
||||||
for offset in all_offsets
|
|
||||||
]
|
|
||||||
|
|
||||||
# 按时间排序(从最早到最近)
|
# 按时间排序(从最早到最近)
|
||||||
return sorted(timestamps)
|
return sorted(timestamps)
|
||||||
@@ -171,7 +163,8 @@ class HippocampusSampler:
|
|||||||
if messages and len(messages) >= 2: # 至少需要2条消息
|
if messages and len(messages) >= 2: # 至少需要2条消息
|
||||||
# 过滤掉已经记忆过的消息
|
# 过滤掉已经记忆过的消息
|
||||||
filtered_messages = [
|
filtered_messages = [
|
||||||
msg for msg in messages
|
msg
|
||||||
|
for msg in messages
|
||||||
if msg.get("memorized_times", 0) < 2 # 最多记忆2次
|
if msg.get("memorized_times", 0) < 2 # 最多记忆2次
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -229,7 +222,7 @@ class HippocampusSampler:
|
|||||||
conversation_text=input_text,
|
conversation_text=input_text,
|
||||||
context=context,
|
context=context,
|
||||||
timestamp=time.time(),
|
timestamp=time.time(),
|
||||||
bypass_interval=True # 海马体采样器绕过构建间隔限制
|
bypass_interval=True, # 海马体采样器绕过构建间隔限制
|
||||||
)
|
)
|
||||||
|
|
||||||
if memories:
|
if memories:
|
||||||
@@ -367,7 +360,7 @@ class HippocampusSampler:
|
|||||||
max_concurrent = min(5, len(time_samples)) # 提高并发数到5
|
max_concurrent = min(5, len(time_samples)) # 提高并发数到5
|
||||||
|
|
||||||
for i in range(0, len(time_samples), max_concurrent):
|
for i in range(0, len(time_samples), max_concurrent):
|
||||||
batch = time_samples[i:i + max_concurrent]
|
batch = time_samples[i : i + max_concurrent]
|
||||||
tasks = []
|
tasks = []
|
||||||
|
|
||||||
# 创建并发收集任务
|
# 创建并发收集任务
|
||||||
@@ -392,7 +385,9 @@ class HippocampusSampler:
|
|||||||
|
|
||||||
return collected_messages
|
return collected_messages
|
||||||
|
|
||||||
async def _fuse_and_deduplicate_messages(self, collected_messages: list[list[dict[str, Any]]]) -> list[list[dict[str, Any]]]:
|
async def _fuse_and_deduplicate_messages(
|
||||||
|
self, collected_messages: list[list[dict[str, Any]]]
|
||||||
|
) -> list[list[dict[str, Any]]]:
|
||||||
"""融合和去重消息样本"""
|
"""融合和去重消息样本"""
|
||||||
if not collected_messages:
|
if not collected_messages:
|
||||||
return []
|
return []
|
||||||
@@ -416,7 +411,7 @@ class HippocampusSampler:
|
|||||||
chat_id = message.get("chat_id", "")
|
chat_id = message.get("chat_id", "")
|
||||||
|
|
||||||
# 简单哈希:内容前50字符 + 时间戳(精确到分钟) + 聊天ID
|
# 简单哈希:内容前50字符 + 时间戳(精确到分钟) + 聊天ID
|
||||||
hash_key = f"{content[:50]}_{int(timestamp//60)}_{chat_id}"
|
hash_key = f"{content[:50]}_{int(timestamp // 60)}_{chat_id}"
|
||||||
|
|
||||||
if hash_key not in seen_hashes and len(content.strip()) > 10:
|
if hash_key not in seen_hashes and len(content.strip()) > 10:
|
||||||
seen_hashes.add(hash_key)
|
seen_hashes.add(hash_key)
|
||||||
@@ -448,7 +443,9 @@ class HippocampusSampler:
|
|||||||
# 返回原始消息组作为备选
|
# 返回原始消息组作为备选
|
||||||
return collected_messages[:5] # 限制返回数量
|
return collected_messages[:5] # 限制返回数量
|
||||||
|
|
||||||
def _merge_adjacent_messages(self, messages: list[dict[str, Any]], time_gap: int = 1800) -> list[list[dict[str, Any]]]:
|
def _merge_adjacent_messages(
|
||||||
|
self, messages: list[dict[str, Any]], time_gap: int = 1800
|
||||||
|
) -> list[list[dict[str, Any]]]:
|
||||||
"""合并时间间隔内的消息"""
|
"""合并时间间隔内的消息"""
|
||||||
if not messages:
|
if not messages:
|
||||||
return []
|
return []
|
||||||
@@ -479,7 +476,9 @@ class HippocampusSampler:
|
|||||||
|
|
||||||
return result_groups
|
return result_groups
|
||||||
|
|
||||||
async def _build_batch_memory(self, fused_messages: list[list[dict[str, Any]]], time_samples: list[datetime]) -> dict[str, Any]:
|
async def _build_batch_memory(
|
||||||
|
self, fused_messages: list[list[dict[str, Any]]], time_samples: list[datetime]
|
||||||
|
) -> dict[str, Any]:
|
||||||
"""批量构建记忆"""
|
"""批量构建记忆"""
|
||||||
if not fused_messages:
|
if not fused_messages:
|
||||||
return {"memory_count": 0, "memories": []}
|
return {"memory_count": 0, "memories": []}
|
||||||
@@ -513,10 +512,7 @@ class HippocampusSampler:
|
|||||||
|
|
||||||
# 一次性构建记忆
|
# 一次性构建记忆
|
||||||
memories = await self.memory_system.build_memory_from_conversation(
|
memories = await self.memory_system.build_memory_from_conversation(
|
||||||
conversation_text=batch_input_text,
|
conversation_text=batch_input_text, context=batch_context, timestamp=time.time(), bypass_interval=True
|
||||||
context=batch_context,
|
|
||||||
timestamp=time.time(),
|
|
||||||
bypass_interval=True
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if memories:
|
if memories:
|
||||||
@@ -545,11 +541,7 @@ class HippocampusSampler:
|
|||||||
if len(self.last_sample_results) > 10:
|
if len(self.last_sample_results) > 10:
|
||||||
self.last_sample_results.pop(0)
|
self.last_sample_results.pop(0)
|
||||||
|
|
||||||
return {
|
return {"memory_count": total_memory_count, "memories": total_memories, "result": result}
|
||||||
"memory_count": total_memory_count,
|
|
||||||
"memories": total_memories,
|
|
||||||
"result": result
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"批量构建记忆失败: {e}")
|
logger.error(f"批量构建记忆失败: {e}")
|
||||||
@@ -601,11 +593,7 @@ class HippocampusSampler:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"单独构建失败: {e}")
|
logger.debug(f"单独构建失败: {e}")
|
||||||
|
|
||||||
return {
|
return {"memory_count": total_count, "memories": total_memories, "fallback_mode": True}
|
||||||
"memory_count": total_count,
|
|
||||||
"memories": total_memories,
|
|
||||||
"fallback_mode": True
|
|
||||||
}
|
|
||||||
|
|
||||||
async def process_sample_timestamp(self, target_timestamp: float) -> str | None:
|
async def process_sample_timestamp(self, target_timestamp: float) -> str | None:
|
||||||
"""处理单个时间戳采样(保留作为备选方法)"""
|
"""处理单个时间戳采样(保留作为备选方法)"""
|
||||||
@@ -696,7 +684,9 @@ class HippocampusSampler:
|
|||||||
"performance_metrics": {
|
"performance_metrics": {
|
||||||
"avg_messages_per_sample": f"{recent_avg_messages:.1f}",
|
"avg_messages_per_sample": f"{recent_avg_messages:.1f}",
|
||||||
"avg_memories_per_sample": f"{recent_avg_memory_count:.1f}",
|
"avg_memories_per_sample": f"{recent_avg_memory_count:.1f}",
|
||||||
"fusion_efficiency": f"{(recent_avg_messages/max(recent_avg_memory_count, 1)):.1f}x" if recent_avg_messages > 0 else "N/A"
|
"fusion_efficiency": f"{(recent_avg_messages / max(recent_avg_memory_count, 1)):.1f}x"
|
||||||
|
if recent_avg_messages > 0
|
||||||
|
else "N/A",
|
||||||
},
|
},
|
||||||
"config": {
|
"config": {
|
||||||
"sample_interval": self.config.sample_interval,
|
"sample_interval": self.config.sample_interval,
|
||||||
|
|||||||
@@ -15,6 +15,7 @@
|
|||||||
|
|
||||||
返回: 适合直接嵌入提示词的大段文本;若无有效记忆返回空串。
|
返回: 适合直接嵌入提示词的大段文本;若无有效记忆返回空串。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import time
|
import time
|
||||||
|
|||||||
@@ -24,9 +24,12 @@ from src.chat.memory_system.memory_query_planner import MemoryQueryPlanner
|
|||||||
# 记忆采样模式枚举
|
# 记忆采样模式枚举
|
||||||
class MemorySamplingMode(Enum):
|
class MemorySamplingMode(Enum):
|
||||||
"""记忆采样模式"""
|
"""记忆采样模式"""
|
||||||
|
|
||||||
HIPPOCAMPUS = "hippocampus" # 海马体模式:定时任务采样
|
HIPPOCAMPUS = "hippocampus" # 海马体模式:定时任务采样
|
||||||
IMMEDIATE = "immediate" # 即时模式:回复后立即采样
|
IMMEDIATE = "immediate" # 即时模式:回复后立即采样
|
||||||
ALL = "all" # 所有模式:同时使用海马体和即时采样
|
ALL = "all" # 所有模式:同时使用海马体和即时采样
|
||||||
|
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.config.config import global_config, model_config
|
from src.config.config import global_config, model_config
|
||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
@@ -165,7 +168,6 @@ class MemorySystem:
|
|||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
"""异步初始化记忆系统"""
|
"""异步初始化记忆系统"""
|
||||||
try:
|
try:
|
||||||
|
|
||||||
# 初始化LLM模型
|
# 初始化LLM模型
|
||||||
fallback_task = getattr(self.llm_model, "model_for_task", None) if self.llm_model else None
|
fallback_task = getattr(self.llm_model, "model_for_task", None) if self.llm_model else None
|
||||||
|
|
||||||
@@ -264,6 +266,7 @@ class MemorySystem:
|
|||||||
if global_config.memory.enable_hippocampus_sampling:
|
if global_config.memory.enable_hippocampus_sampling:
|
||||||
try:
|
try:
|
||||||
from .hippocampus_sampler import initialize_hippocampus_sampler
|
from .hippocampus_sampler import initialize_hippocampus_sampler
|
||||||
|
|
||||||
self.hippocampus_sampler = await initialize_hippocampus_sampler(self)
|
self.hippocampus_sampler = await initialize_hippocampus_sampler(self)
|
||||||
logger.info("✅ 海马体采样器初始化成功")
|
logger.info("✅ 海马体采样器初始化成功")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -321,7 +324,11 @@ class MemorySystem:
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
async def build_memory_from_conversation(
|
async def build_memory_from_conversation(
|
||||||
self, conversation_text: str, context: dict[str, Any], timestamp: float | None = None, bypass_interval: bool = False
|
self,
|
||||||
|
conversation_text: str,
|
||||||
|
context: dict[str, Any],
|
||||||
|
timestamp: float | None = None,
|
||||||
|
bypass_interval: bool = False,
|
||||||
) -> list[MemoryChunk]:
|
) -> list[MemoryChunk]:
|
||||||
"""从对话中构建记忆
|
"""从对话中构建记忆
|
||||||
|
|
||||||
@@ -560,7 +567,6 @@ class MemorySystem:
|
|||||||
sampling_mode = getattr(global_config.memory, "memory_sampling_mode", "precision")
|
sampling_mode = getattr(global_config.memory, "memory_sampling_mode", "precision")
|
||||||
current_mode = MemorySamplingMode(sampling_mode)
|
current_mode = MemorySamplingMode(sampling_mode)
|
||||||
|
|
||||||
|
|
||||||
context["__sampling_mode"] = current_mode.value
|
context["__sampling_mode"] = current_mode.value
|
||||||
logger.debug(f"使用记忆采样模式: {current_mode.value}")
|
logger.debug(f"使用记忆采样模式: {current_mode.value}")
|
||||||
|
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ logger = get_logger("adaptive_stream_manager")
|
|||||||
|
|
||||||
class StreamPriority(Enum):
|
class StreamPriority(Enum):
|
||||||
"""流优先级"""
|
"""流优先级"""
|
||||||
|
|
||||||
LOW = 1
|
LOW = 1
|
||||||
NORMAL = 2
|
NORMAL = 2
|
||||||
HIGH = 3
|
HIGH = 3
|
||||||
@@ -26,6 +27,7 @@ class StreamPriority(Enum):
|
|||||||
@dataclass
|
@dataclass
|
||||||
class SystemMetrics:
|
class SystemMetrics:
|
||||||
"""系统指标"""
|
"""系统指标"""
|
||||||
|
|
||||||
cpu_usage: float = 0.0
|
cpu_usage: float = 0.0
|
||||||
memory_usage: float = 0.0
|
memory_usage: float = 0.0
|
||||||
active_coroutines: int = 0
|
active_coroutines: int = 0
|
||||||
@@ -36,6 +38,7 @@ class SystemMetrics:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class StreamMetrics:
|
class StreamMetrics:
|
||||||
"""流指标"""
|
"""流指标"""
|
||||||
|
|
||||||
stream_id: str
|
stream_id: str
|
||||||
priority: StreamPriority
|
priority: StreamPriority
|
||||||
message_rate: float = 0.0 # 消息速率(消息/分钟)
|
message_rate: float = 0.0 # 消息速率(消息/分钟)
|
||||||
@@ -139,10 +142,7 @@ class AdaptiveStreamManager:
|
|||||||
logger.info("自适应流管理器已停止")
|
logger.info("自适应流管理器已停止")
|
||||||
|
|
||||||
async def acquire_stream_slot(
|
async def acquire_stream_slot(
|
||||||
self,
|
self, stream_id: str, priority: StreamPriority = StreamPriority.NORMAL, force: bool = False
|
||||||
stream_id: str,
|
|
||||||
priority: StreamPriority = StreamPriority.NORMAL,
|
|
||||||
force: bool = False
|
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
获取流处理槽位
|
获取流处理槽位
|
||||||
@@ -165,10 +165,7 @@ class AdaptiveStreamManager:
|
|||||||
|
|
||||||
# 更新流指标
|
# 更新流指标
|
||||||
if stream_id not in self.stream_metrics:
|
if stream_id not in self.stream_metrics:
|
||||||
self.stream_metrics[stream_id] = StreamMetrics(
|
self.stream_metrics[stream_id] = StreamMetrics(stream_id=stream_id, priority=priority)
|
||||||
stream_id=stream_id,
|
|
||||||
priority=priority
|
|
||||||
)
|
|
||||||
self.stream_metrics[stream_id].last_activity = current_time
|
self.stream_metrics[stream_id].last_activity = current_time
|
||||||
|
|
||||||
# 检查是否已经活跃
|
# 检查是否已经活跃
|
||||||
@@ -271,8 +268,10 @@ class AdaptiveStreamManager:
|
|||||||
|
|
||||||
# 如果最近有活跃且响应时间较长,可能需要强制分发
|
# 如果最近有活跃且响应时间较长,可能需要强制分发
|
||||||
current_time = time.time()
|
current_time = time.time()
|
||||||
if (current_time - metrics.last_activity < 300 and # 5分钟内有活动
|
if (
|
||||||
metrics.response_time > 5.0): # 响应时间超过5秒
|
current_time - metrics.last_activity < 300 # 5分钟内有活动
|
||||||
|
and metrics.response_time > 5.0
|
||||||
|
): # 响应时间超过5秒
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return False
|
return False
|
||||||
@@ -324,26 +323,20 @@ class AdaptiveStreamManager:
|
|||||||
memory_usage=memory_usage,
|
memory_usage=memory_usage,
|
||||||
active_coroutines=active_coroutines,
|
active_coroutines=active_coroutines,
|
||||||
event_loop_lag=event_loop_lag,
|
event_loop_lag=event_loop_lag,
|
||||||
timestamp=time.time()
|
timestamp=time.time(),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.system_metrics.append(metrics)
|
self.system_metrics.append(metrics)
|
||||||
|
|
||||||
# 保持指标窗口大小
|
# 保持指标窗口大小
|
||||||
cutoff_time = time.time() - self.metrics_window
|
cutoff_time = time.time() - self.metrics_window
|
||||||
self.system_metrics = [
|
self.system_metrics = [m for m in self.system_metrics if m.timestamp > cutoff_time]
|
||||||
m for m in self.system_metrics
|
|
||||||
if m.timestamp > cutoff_time
|
|
||||||
]
|
|
||||||
|
|
||||||
# 更新统计信息
|
# 更新统计信息
|
||||||
self.stats["avg_concurrent_streams"] = (
|
self.stats["avg_concurrent_streams"] = (
|
||||||
self.stats["avg_concurrent_streams"] * 0.9 + len(self.active_streams) * 0.1
|
self.stats["avg_concurrent_streams"] * 0.9 + len(self.active_streams) * 0.1
|
||||||
)
|
)
|
||||||
self.stats["peak_concurrent_streams"] = max(
|
self.stats["peak_concurrent_streams"] = max(self.stats["peak_concurrent_streams"], len(self.active_streams))
|
||||||
self.stats["peak_concurrent_streams"],
|
|
||||||
len(self.active_streams)
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"收集系统指标失败: {e}")
|
logger.error(f"收集系统指标失败: {e}")
|
||||||
@@ -445,14 +438,16 @@ class AdaptiveStreamManager:
|
|||||||
def get_stats(self) -> dict:
|
def get_stats(self) -> dict:
|
||||||
"""获取统计信息"""
|
"""获取统计信息"""
|
||||||
stats = self.stats.copy()
|
stats = self.stats.copy()
|
||||||
stats.update({
|
stats.update(
|
||||||
|
{
|
||||||
"current_limit": self.current_limit,
|
"current_limit": self.current_limit,
|
||||||
"active_streams": len(self.active_streams),
|
"active_streams": len(self.active_streams),
|
||||||
"pending_streams": len(self.pending_streams),
|
"pending_streams": len(self.pending_streams),
|
||||||
"is_running": self.is_running,
|
"is_running": self.is_running,
|
||||||
"system_cpu": self.system_metrics[-1].cpu_usage if self.system_metrics else 0,
|
"system_cpu": self.system_metrics[-1].cpu_usage if self.system_metrics else 0,
|
||||||
"system_memory": self.system_metrics[-1].memory_usage if self.system_metrics else 0,
|
"system_memory": self.system_metrics[-1].memory_usage if self.system_metrics else 0,
|
||||||
})
|
}
|
||||||
|
)
|
||||||
|
|
||||||
# 计算接受率
|
# 计算接受率
|
||||||
if stats["total_requests"] > 0:
|
if stats["total_requests"] > 0:
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ logger = get_logger("batch_database_writer")
|
|||||||
@dataclass
|
@dataclass
|
||||||
class StreamUpdatePayload:
|
class StreamUpdatePayload:
|
||||||
"""流更新数据结构"""
|
"""流更新数据结构"""
|
||||||
|
|
||||||
stream_id: str
|
stream_id: str
|
||||||
update_data: dict[str, Any]
|
update_data: dict[str, Any]
|
||||||
priority: int = 0 # 优先级,数字越大优先级越高
|
priority: int = 0 # 优先级,数字越大优先级越高
|
||||||
@@ -95,12 +96,7 @@ class BatchDatabaseWriter:
|
|||||||
|
|
||||||
logger.info("批量数据库写入器已停止")
|
logger.info("批量数据库写入器已停止")
|
||||||
|
|
||||||
async def schedule_stream_update(
|
async def schedule_stream_update(self, stream_id: str, update_data: dict[str, Any], priority: int = 0) -> bool:
|
||||||
self,
|
|
||||||
stream_id: str,
|
|
||||||
update_data: dict[str, Any],
|
|
||||||
priority: int = 0
|
|
||||||
) -> bool:
|
|
||||||
"""
|
"""
|
||||||
调度流更新
|
调度流更新
|
||||||
|
|
||||||
@@ -119,11 +115,7 @@ class BatchDatabaseWriter:
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
# 创建更新载荷
|
# 创建更新载荷
|
||||||
payload = StreamUpdatePayload(
|
payload = StreamUpdatePayload(stream_id=stream_id, update_data=update_data, priority=priority)
|
||||||
stream_id=stream_id,
|
|
||||||
update_data=update_data,
|
|
||||||
priority=priority
|
|
||||||
)
|
|
||||||
|
|
||||||
# 非阻塞方式加入队列
|
# 非阻塞方式加入队列
|
||||||
try:
|
try:
|
||||||
@@ -178,10 +170,7 @@ class BatchDatabaseWriter:
|
|||||||
if remaining_time == 0:
|
if remaining_time == 0:
|
||||||
break
|
break
|
||||||
|
|
||||||
payload = await asyncio.wait_for(
|
payload = await asyncio.wait_for(self.write_queue.get(), timeout=remaining_time)
|
||||||
self.write_queue.get(),
|
|
||||||
timeout=remaining_time
|
|
||||||
)
|
|
||||||
batch.append(payload)
|
batch.append(payload)
|
||||||
|
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
@@ -203,7 +192,10 @@ class BatchDatabaseWriter:
|
|||||||
# 合并同一流ID的更新(保留最新的)
|
# 合并同一流ID的更新(保留最新的)
|
||||||
merged_updates = {}
|
merged_updates = {}
|
||||||
for payload in batch:
|
for payload in batch:
|
||||||
if payload.stream_id not in merged_updates or payload.timestamp > merged_updates[payload.stream_id].timestamp:
|
if (
|
||||||
|
payload.stream_id not in merged_updates
|
||||||
|
or payload.timestamp > merged_updates[payload.stream_id].timestamp
|
||||||
|
):
|
||||||
merged_updates[payload.stream_id] = payload
|
merged_updates[payload.stream_id] = payload
|
||||||
|
|
||||||
# 批量写入
|
# 批量写入
|
||||||
@@ -211,9 +203,7 @@ class BatchDatabaseWriter:
|
|||||||
|
|
||||||
# 更新统计
|
# 更新统计
|
||||||
self.stats["batch_writes"] += 1
|
self.stats["batch_writes"] += 1
|
||||||
self.stats["avg_batch_size"] = (
|
self.stats["avg_batch_size"] = self.stats["avg_batch_size"] * 0.9 + len(batch) * 0.1 # 滑动平均
|
||||||
self.stats["avg_batch_size"] * 0.9 + len(batch) * 0.1
|
|
||||||
) # 滑动平均
|
|
||||||
self.stats["last_flush_time"] = start_time
|
self.stats["last_flush_time"] = start_time
|
||||||
|
|
||||||
logger.debug(f"批量写入完成: {len(batch)} 个更新,耗时 {time.time() - start_time:.3f}s")
|
logger.debug(f"批量写入完成: {len(batch)} 个更新,耗时 {time.time() - start_time:.3f}s")
|
||||||
@@ -238,31 +228,22 @@ class BatchDatabaseWriter:
|
|||||||
# 根据数据库类型选择不同的插入/更新策略
|
# 根据数据库类型选择不同的插入/更新策略
|
||||||
if global_config.database.database_type == "sqlite":
|
if global_config.database.database_type == "sqlite":
|
||||||
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
||||||
stmt = sqlite_insert(ChatStreams).values(
|
|
||||||
stream_id=stream_id, **update_data
|
stmt = sqlite_insert(ChatStreams).values(stream_id=stream_id, **update_data)
|
||||||
)
|
stmt = stmt.on_conflict_do_update(index_elements=["stream_id"], set_=update_data)
|
||||||
stmt = stmt.on_conflict_do_update(
|
|
||||||
index_elements=["stream_id"],
|
|
||||||
set_=update_data
|
|
||||||
)
|
|
||||||
elif global_config.database.database_type == "mysql":
|
elif global_config.database.database_type == "mysql":
|
||||||
from sqlalchemy.dialects.mysql import insert as mysql_insert
|
from sqlalchemy.dialects.mysql import insert as mysql_insert
|
||||||
stmt = mysql_insert(ChatStreams).values(
|
|
||||||
stream_id=stream_id, **update_data
|
stmt = mysql_insert(ChatStreams).values(stream_id=stream_id, **update_data)
|
||||||
)
|
|
||||||
stmt = stmt.on_duplicate_key_update(
|
stmt = stmt.on_duplicate_key_update(
|
||||||
**{key: value for key, value in update_data.items() if key != "stream_id"}
|
**{key: value for key, value in update_data.items() if key != "stream_id"}
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# 默认使用SQLite语法
|
# 默认使用SQLite语法
|
||||||
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
||||||
stmt = sqlite_insert(ChatStreams).values(
|
|
||||||
stream_id=stream_id, **update_data
|
stmt = sqlite_insert(ChatStreams).values(stream_id=stream_id, **update_data)
|
||||||
)
|
stmt = stmt.on_conflict_do_update(index_elements=["stream_id"], set_=update_data)
|
||||||
stmt = stmt.on_conflict_do_update(
|
|
||||||
index_elements=["stream_id"],
|
|
||||||
set_=update_data
|
|
||||||
)
|
|
||||||
|
|
||||||
await session.execute(stmt)
|
await session.execute(stmt)
|
||||||
|
|
||||||
@@ -273,30 +254,21 @@ class BatchDatabaseWriter:
|
|||||||
async with get_db_session() as session:
|
async with get_db_session() as session:
|
||||||
if global_config.database.database_type == "sqlite":
|
if global_config.database.database_type == "sqlite":
|
||||||
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
||||||
stmt = sqlite_insert(ChatStreams).values(
|
|
||||||
stream_id=stream_id, **update_data
|
stmt = sqlite_insert(ChatStreams).values(stream_id=stream_id, **update_data)
|
||||||
)
|
stmt = stmt.on_conflict_do_update(index_elements=["stream_id"], set_=update_data)
|
||||||
stmt = stmt.on_conflict_do_update(
|
|
||||||
index_elements=["stream_id"],
|
|
||||||
set_=update_data
|
|
||||||
)
|
|
||||||
elif global_config.database.database_type == "mysql":
|
elif global_config.database.database_type == "mysql":
|
||||||
from sqlalchemy.dialects.mysql import insert as mysql_insert
|
from sqlalchemy.dialects.mysql import insert as mysql_insert
|
||||||
stmt = mysql_insert(ChatStreams).values(
|
|
||||||
stream_id=stream_id, **update_data
|
stmt = mysql_insert(ChatStreams).values(stream_id=stream_id, **update_data)
|
||||||
)
|
|
||||||
stmt = stmt.on_duplicate_key_update(
|
stmt = stmt.on_duplicate_key_update(
|
||||||
**{key: value for key, value in update_data.items() if key != "stream_id"}
|
**{key: value for key, value in update_data.items() if key != "stream_id"}
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
||||||
stmt = sqlite_insert(ChatStreams).values(
|
|
||||||
stream_id=stream_id, **update_data
|
stmt = sqlite_insert(ChatStreams).values(stream_id=stream_id, **update_data)
|
||||||
)
|
stmt = stmt.on_conflict_do_update(index_elements=["stream_id"], set_=update_data)
|
||||||
stmt = stmt.on_conflict_do_update(
|
|
||||||
index_elements=["stream_id"],
|
|
||||||
set_=update_data
|
|
||||||
)
|
|
||||||
|
|
||||||
await session.execute(stmt)
|
await session.execute(stmt)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|||||||
@@ -273,8 +273,10 @@ class SingleStreamContextManager:
|
|||||||
message.should_reply = result.should_reply
|
message.should_reply = result.should_reply
|
||||||
message.should_act = result.should_act
|
message.should_act = result.should_act
|
||||||
|
|
||||||
logger.debug(f"消息 {message.message_id} 兴趣值已更新: {result.interest_value:.3f}, "
|
logger.debug(
|
||||||
f"should_reply: {result.should_reply}, should_act: {result.should_act}")
|
f"消息 {message.message_id} 兴趣值已更新: {result.interest_value:.3f}, "
|
||||||
|
f"should_reply: {result.should_reply}, should_act: {result.should_act}"
|
||||||
|
)
|
||||||
return result.interest_value
|
return result.interest_value
|
||||||
else:
|
else:
|
||||||
logger.warning(f"消息 {message.message_id} 兴趣值计算失败: {result.error_message}")
|
logger.warning(f"消息 {message.message_id} 兴趣值计算失败: {result.error_message}")
|
||||||
|
|||||||
@@ -79,7 +79,7 @@ class StreamLoopManager:
|
|||||||
logger.info(f"正在取消 {len(cancel_tasks)} 个流循环任务...")
|
logger.info(f"正在取消 {len(cancel_tasks)} 个流循环任务...")
|
||||||
await asyncio.gather(
|
await asyncio.gather(
|
||||||
*[self._wait_for_task_cancel(stream_id, task) for stream_id, task in cancel_tasks],
|
*[self._wait_for_task_cancel(stream_id, task) for stream_id, task in cancel_tasks],
|
||||||
return_exceptions=True
|
return_exceptions=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 取消所有活跃的 chatter 处理任务
|
# 取消所有活跃的 chatter 处理任务
|
||||||
@@ -115,6 +115,7 @@ class StreamLoopManager:
|
|||||||
# 使用自适应流管理器获取槽位
|
# 使用自适应流管理器获取槽位
|
||||||
try:
|
try:
|
||||||
from src.chat.message_manager.adaptive_stream_manager import get_adaptive_stream_manager
|
from src.chat.message_manager.adaptive_stream_manager import get_adaptive_stream_manager
|
||||||
|
|
||||||
adaptive_manager = get_adaptive_stream_manager()
|
adaptive_manager = get_adaptive_stream_manager()
|
||||||
|
|
||||||
if adaptive_manager.is_running:
|
if adaptive_manager.is_running:
|
||||||
@@ -123,9 +124,7 @@ class StreamLoopManager:
|
|||||||
|
|
||||||
# 获取处理槽位
|
# 获取处理槽位
|
||||||
slot_acquired = await adaptive_manager.acquire_stream_slot(
|
slot_acquired = await adaptive_manager.acquire_stream_slot(
|
||||||
stream_id=stream_id,
|
stream_id=stream_id, priority=priority, force=force
|
||||||
priority=priority,
|
|
||||||
force=force
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if slot_acquired:
|
if slot_acquired:
|
||||||
@@ -140,10 +139,7 @@ class StreamLoopManager:
|
|||||||
|
|
||||||
# 创建流循环任务
|
# 创建流循环任务
|
||||||
try:
|
try:
|
||||||
loop_task = asyncio.create_task(
|
loop_task = asyncio.create_task(self._stream_loop_worker(stream_id), name=f"stream_loop_{stream_id}")
|
||||||
self._stream_loop_worker(stream_id),
|
|
||||||
name=f"stream_loop_{stream_id}"
|
|
||||||
)
|
|
||||||
self.stream_loops[stream_id] = loop_task
|
self.stream_loops[stream_id] = loop_task
|
||||||
# 更新统计信息
|
# 更新统计信息
|
||||||
self.stats["active_streams"] += 1
|
self.stats["active_streams"] += 1
|
||||||
@@ -156,6 +152,7 @@ class StreamLoopManager:
|
|||||||
logger.error(f"启动流循环任务失败 {stream_id}: {e}")
|
logger.error(f"启动流循环任务失败 {stream_id}: {e}")
|
||||||
# 释放槽位
|
# 释放槽位
|
||||||
from src.chat.message_manager.adaptive_stream_manager import get_adaptive_stream_manager
|
from src.chat.message_manager.adaptive_stream_manager import get_adaptive_stream_manager
|
||||||
|
|
||||||
adaptive_manager = get_adaptive_stream_manager()
|
adaptive_manager = get_adaptive_stream_manager()
|
||||||
adaptive_manager.release_stream_slot(stream_id)
|
adaptive_manager.release_stream_slot(stream_id)
|
||||||
|
|
||||||
@@ -179,8 +176,8 @@ class StreamLoopManager:
|
|||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
from src.chat.message_manager.adaptive_stream_manager import StreamPriority
|
from src.chat.message_manager.adaptive_stream_manager import StreamPriority
|
||||||
return StreamPriority.NORMAL
|
|
||||||
|
|
||||||
|
return StreamPriority.NORMAL
|
||||||
|
|
||||||
async def stop_stream_loop(self, stream_id: str) -> bool:
|
async def stop_stream_loop(self, stream_id: str) -> bool:
|
||||||
"""停止指定流的循环任务
|
"""停止指定流的循环任务
|
||||||
@@ -244,11 +241,12 @@ class StreamLoopManager:
|
|||||||
# 3. 更新自适应管理器指标
|
# 3. 更新自适应管理器指标
|
||||||
try:
|
try:
|
||||||
from src.chat.message_manager.adaptive_stream_manager import get_adaptive_stream_manager
|
from src.chat.message_manager.adaptive_stream_manager import get_adaptive_stream_manager
|
||||||
|
|
||||||
adaptive_manager = get_adaptive_stream_manager()
|
adaptive_manager = get_adaptive_stream_manager()
|
||||||
adaptive_manager.update_stream_metrics(
|
adaptive_manager.update_stream_metrics(
|
||||||
stream_id,
|
stream_id,
|
||||||
message_rate=unread_count / 5.0 if unread_count > 0 else 0.0, # 简化计算
|
message_rate=unread_count / 5.0 if unread_count > 0 else 0.0, # 简化计算
|
||||||
last_activity=time.time()
|
last_activity=time.time(),
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"更新流指标失败: {e}")
|
logger.debug(f"更新流指标失败: {e}")
|
||||||
@@ -300,6 +298,7 @@ class StreamLoopManager:
|
|||||||
# 释放自适应管理器的槽位
|
# 释放自适应管理器的槽位
|
||||||
try:
|
try:
|
||||||
from src.chat.message_manager.adaptive_stream_manager import get_adaptive_stream_manager
|
from src.chat.message_manager.adaptive_stream_manager import get_adaptive_stream_manager
|
||||||
|
|
||||||
adaptive_manager = get_adaptive_stream_manager()
|
adaptive_manager = get_adaptive_stream_manager()
|
||||||
adaptive_manager.release_stream_slot(stream_id)
|
adaptive_manager.release_stream_slot(stream_id)
|
||||||
logger.debug(f"释放自适应流处理槽位: {stream_id}")
|
logger.debug(f"释放自适应流处理槽位: {stream_id}")
|
||||||
@@ -553,12 +552,12 @@ class StreamLoopManager:
|
|||||||
existing_task.cancel()
|
existing_task.cancel()
|
||||||
# 创建异步任务来等待取消完成,并添加异常处理
|
# 创建异步任务来等待取消完成,并添加异常处理
|
||||||
cancel_task = asyncio.create_task(
|
cancel_task = asyncio.create_task(
|
||||||
self._wait_for_task_cancel(stream_id, existing_task),
|
self._wait_for_task_cancel(stream_id, existing_task), name=f"cancel_existing_loop_{stream_id}"
|
||||||
name=f"cancel_existing_loop_{stream_id}"
|
|
||||||
)
|
)
|
||||||
# 为取消任务添加异常处理,避免孤儿任务
|
# 为取消任务添加异常处理,避免孤儿任务
|
||||||
cancel_task.add_done_callback(
|
cancel_task.add_done_callback(
|
||||||
lambda task: logger.debug(f"取消任务完成: {stream_id}") if not task.exception()
|
lambda task: logger.debug(f"取消任务完成: {stream_id}")
|
||||||
|
if not task.exception()
|
||||||
else logger.error(f"取消任务异常: {stream_id} - {task.exception()}")
|
else logger.error(f"取消任务异常: {stream_id} - {task.exception()}")
|
||||||
)
|
)
|
||||||
# 从字典中移除
|
# 从字典中移除
|
||||||
@@ -582,10 +581,7 @@ class StreamLoopManager:
|
|||||||
logger.info(f"流 {stream_id} 当前未读消息数: {unread_count}")
|
logger.info(f"流 {stream_id} 当前未读消息数: {unread_count}")
|
||||||
|
|
||||||
# 创建新的流循环任务
|
# 创建新的流循环任务
|
||||||
new_task = asyncio.create_task(
|
new_task = asyncio.create_task(self._stream_loop(stream_id), name=f"force_stream_loop_{stream_id}")
|
||||||
self._stream_loop(stream_id),
|
|
||||||
name=f"force_stream_loop_{stream_id}"
|
|
||||||
)
|
|
||||||
self.stream_loops[stream_id] = new_task
|
self.stream_loops[stream_id] = new_task
|
||||||
self.stats["total_loops"] += 1
|
self.stats["total_loops"] += 1
|
||||||
|
|
||||||
|
|||||||
@@ -59,6 +59,7 @@ class MessageManager:
|
|||||||
# 启动批量数据库写入器
|
# 启动批量数据库写入器
|
||||||
try:
|
try:
|
||||||
from src.chat.message_manager.batch_database_writer import init_batch_writer
|
from src.chat.message_manager.batch_database_writer import init_batch_writer
|
||||||
|
|
||||||
await init_batch_writer()
|
await init_batch_writer()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"启动批量数据库写入器失败: {e}")
|
logger.error(f"启动批量数据库写入器失败: {e}")
|
||||||
@@ -66,6 +67,7 @@ class MessageManager:
|
|||||||
# 启动流缓存管理器
|
# 启动流缓存管理器
|
||||||
try:
|
try:
|
||||||
from src.chat.message_manager.stream_cache_manager import init_stream_cache_manager
|
from src.chat.message_manager.stream_cache_manager import init_stream_cache_manager
|
||||||
|
|
||||||
await init_stream_cache_manager()
|
await init_stream_cache_manager()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"启动流缓存管理器失败: {e}")
|
logger.error(f"启动流缓存管理器失败: {e}")
|
||||||
@@ -73,6 +75,7 @@ class MessageManager:
|
|||||||
# 启动自适应流管理器
|
# 启动自适应流管理器
|
||||||
try:
|
try:
|
||||||
from src.chat.message_manager.adaptive_stream_manager import init_adaptive_stream_manager
|
from src.chat.message_manager.adaptive_stream_manager import init_adaptive_stream_manager
|
||||||
|
|
||||||
await init_adaptive_stream_manager()
|
await init_adaptive_stream_manager()
|
||||||
logger.info("🎯 自适应流管理器已启动")
|
logger.info("🎯 自适应流管理器已启动")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -97,6 +100,7 @@ class MessageManager:
|
|||||||
# 停止批量数据库写入器
|
# 停止批量数据库写入器
|
||||||
try:
|
try:
|
||||||
from src.chat.message_manager.batch_database_writer import shutdown_batch_writer
|
from src.chat.message_manager.batch_database_writer import shutdown_batch_writer
|
||||||
|
|
||||||
await shutdown_batch_writer()
|
await shutdown_batch_writer()
|
||||||
logger.info("📦 批量数据库写入器已停止")
|
logger.info("📦 批量数据库写入器已停止")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -105,6 +109,7 @@ class MessageManager:
|
|||||||
# 停止流缓存管理器
|
# 停止流缓存管理器
|
||||||
try:
|
try:
|
||||||
from src.chat.message_manager.stream_cache_manager import shutdown_stream_cache_manager
|
from src.chat.message_manager.stream_cache_manager import shutdown_stream_cache_manager
|
||||||
|
|
||||||
await shutdown_stream_cache_manager()
|
await shutdown_stream_cache_manager()
|
||||||
logger.info("🗄️ 流缓存管理器已停止")
|
logger.info("🗄️ 流缓存管理器已停止")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -113,6 +118,7 @@ class MessageManager:
|
|||||||
# 停止自适应流管理器
|
# 停止自适应流管理器
|
||||||
try:
|
try:
|
||||||
from src.chat.message_manager.adaptive_stream_manager import shutdown_adaptive_stream_manager
|
from src.chat.message_manager.adaptive_stream_manager import shutdown_adaptive_stream_manager
|
||||||
|
|
||||||
await shutdown_adaptive_stream_manager()
|
await shutdown_adaptive_stream_manager()
|
||||||
logger.info("🎯 自适应流管理器已停止")
|
logger.info("🎯 自适应流管理器已停止")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ logger = get_logger("stream_cache_manager")
|
|||||||
@dataclass
|
@dataclass
|
||||||
class StreamCacheStats:
|
class StreamCacheStats:
|
||||||
"""缓存统计信息"""
|
"""缓存统计信息"""
|
||||||
|
|
||||||
hot_cache_size: int = 0
|
hot_cache_size: int = 0
|
||||||
warm_storage_size: int = 0
|
warm_storage_size: int = 0
|
||||||
cold_storage_size: int = 0
|
cold_storage_size: int = 0
|
||||||
@@ -134,11 +135,7 @@ class TieredStreamCache:
|
|||||||
# 4. 缓存未命中,创建新流
|
# 4. 缓存未命中,创建新流
|
||||||
self.stats.cache_misses += 1
|
self.stats.cache_misses += 1
|
||||||
stream = create_optimized_chat_stream(
|
stream = create_optimized_chat_stream(
|
||||||
stream_id=stream_id,
|
stream_id=stream_id, platform=platform, user_info=user_info, group_info=group_info, data=data
|
||||||
platform=platform,
|
|
||||||
user_info=user_info,
|
|
||||||
group_info=group_info,
|
|
||||||
data=data
|
|
||||||
)
|
)
|
||||||
logger.debug(f"缓存未命中,创建新流: {stream_id}")
|
logger.debug(f"缓存未命中,创建新流: {stream_id}")
|
||||||
|
|
||||||
@@ -294,9 +291,9 @@ class TieredStreamCache:
|
|||||||
|
|
||||||
# 估算内存使用(粗略估计)
|
# 估算内存使用(粗略估计)
|
||||||
self.stats.total_memory_usage = (
|
self.stats.total_memory_usage = (
|
||||||
len(self.hot_cache) * 1024 + # 每个热流约1KB
|
len(self.hot_cache) * 1024 # 每个热流约1KB
|
||||||
len(self.warm_storage) * 512 + # 每个温流约512B
|
+ len(self.warm_storage) * 512 # 每个温流约512B
|
||||||
len(self.cold_storage) * 256 # 每个冷流约256B
|
+ len(self.cold_storage) * 256 # 每个冷流约256B
|
||||||
)
|
)
|
||||||
|
|
||||||
if sum(cleanup_stats.values()) > 0:
|
if sum(cleanup_stats.values()) > 0:
|
||||||
|
|||||||
@@ -561,7 +561,11 @@ class ChatBot:
|
|||||||
|
|
||||||
# 将兴趣度结果同步回原始消息,便于后续流程使用
|
# 将兴趣度结果同步回原始消息,便于后续流程使用
|
||||||
message.interest_value = getattr(db_message, "interest_value", getattr(message, "interest_value", 0.0))
|
message.interest_value = getattr(db_message, "interest_value", getattr(message, "interest_value", 0.0))
|
||||||
setattr(message, "should_reply", getattr(db_message, "should_reply", getattr(message, "should_reply", False)))
|
setattr(
|
||||||
|
message,
|
||||||
|
"should_reply",
|
||||||
|
getattr(db_message, "should_reply", getattr(message, "should_reply", False)),
|
||||||
|
)
|
||||||
setattr(message, "should_act", getattr(db_message, "should_act", getattr(message, "should_act", False)))
|
setattr(message, "should_act", getattr(db_message, "should_act", getattr(message, "should_act", False)))
|
||||||
|
|
||||||
# 存储消息到数据库,只进行一次写入
|
# 存储消息到数据库,只进行一次写入
|
||||||
|
|||||||
@@ -298,8 +298,10 @@ class ChatStream:
|
|||||||
db_message.should_reply = result.should_reply
|
db_message.should_reply = result.should_reply
|
||||||
db_message.should_act = result.should_act
|
db_message.should_act = result.should_act
|
||||||
|
|
||||||
logger.debug(f"消息 {db_message.message_id} 兴趣值已更新: {result.interest_value:.3f}, "
|
logger.debug(
|
||||||
f"should_reply: {result.should_reply}, should_act: {result.should_act}")
|
f"消息 {db_message.message_id} 兴趣值已更新: {result.interest_value:.3f}, "
|
||||||
|
f"should_reply: {result.should_reply}, should_act: {result.should_act}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
logger.warning(f"消息 {db_message.message_id} 兴趣值计算失败: {result.error_message}")
|
logger.warning(f"消息 {db_message.message_id} 兴趣值计算失败: {result.error_message}")
|
||||||
# 使用默认值
|
# 使用默认值
|
||||||
@@ -521,18 +523,17 @@ class ChatManager:
|
|||||||
# 优先使用缓存管理器(优化版本)
|
# 优先使用缓存管理器(优化版本)
|
||||||
try:
|
try:
|
||||||
from src.chat.message_manager.stream_cache_manager import get_stream_cache_manager
|
from src.chat.message_manager.stream_cache_manager import get_stream_cache_manager
|
||||||
|
|
||||||
cache_manager = get_stream_cache_manager()
|
cache_manager = get_stream_cache_manager()
|
||||||
|
|
||||||
if cache_manager.is_running:
|
if cache_manager.is_running:
|
||||||
optimized_stream = await cache_manager.get_or_create_stream(
|
optimized_stream = await cache_manager.get_or_create_stream(
|
||||||
stream_id=stream_id,
|
stream_id=stream_id, platform=platform, user_info=user_info, group_info=group_info
|
||||||
platform=platform,
|
|
||||||
user_info=user_info,
|
|
||||||
group_info=group_info
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 设置消息上下文
|
# 设置消息上下文
|
||||||
from .message import MessageRecv
|
from .message import MessageRecv
|
||||||
|
|
||||||
if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], MessageRecv):
|
if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], MessageRecv):
|
||||||
optimized_stream.set_context(self.last_messages[stream_id])
|
optimized_stream.set_context(self.last_messages[stream_id])
|
||||||
|
|
||||||
@@ -715,7 +716,7 @@ class ChatManager:
|
|||||||
success = await batch_writer.schedule_stream_update(
|
success = await batch_writer.schedule_stream_update(
|
||||||
stream_id=stream_data_dict["stream_id"],
|
stream_id=stream_data_dict["stream_id"],
|
||||||
update_data=ChatManager._prepare_stream_data(stream_data_dict),
|
update_data=ChatManager._prepare_stream_data(stream_data_dict),
|
||||||
priority=1 # 流更新的优先级
|
priority=1, # 流更新的优先级
|
||||||
)
|
)
|
||||||
if success:
|
if success:
|
||||||
stream.saved = True
|
stream.saved = True
|
||||||
@@ -738,7 +739,7 @@ class ChatManager:
|
|||||||
result = await batch_update(
|
result = await batch_update(
|
||||||
model_class=ChatStreams,
|
model_class=ChatStreams,
|
||||||
conditions={"stream_id": stream_data_dict["stream_id"]},
|
conditions={"stream_id": stream_data_dict["stream_id"]},
|
||||||
data=ChatManager._prepare_stream_data(stream_data_dict)
|
data=ChatManager._prepare_stream_data(stream_data_dict),
|
||||||
)
|
)
|
||||||
if result and result > 0:
|
if result and result > 0:
|
||||||
stream.saved = True
|
stream.saved = True
|
||||||
@@ -881,7 +882,7 @@ def _convert_to_original_stream(self, optimized_stream) -> "ChatStream":
|
|||||||
stream_id=optimized_stream.stream_id,
|
stream_id=optimized_stream.stream_id,
|
||||||
platform=optimized_stream.platform,
|
platform=optimized_stream.platform,
|
||||||
user_info=optimized_stream._get_effective_user_info(),
|
user_info=optimized_stream._get_effective_user_info(),
|
||||||
group_info=optimized_stream._get_effective_group_info()
|
group_info=optimized_stream._get_effective_group_info(),
|
||||||
)
|
)
|
||||||
|
|
||||||
# 复制状态
|
# 复制状态
|
||||||
@@ -909,7 +910,7 @@ def _convert_to_original_stream(self, optimized_stream) -> "ChatStream":
|
|||||||
stream_id=optimized_stream.stream_id,
|
stream_id=optimized_stream.stream_id,
|
||||||
platform=optimized_stream.platform,
|
platform=optimized_stream.platform,
|
||||||
user_info=optimized_stream._get_effective_user_info(),
|
user_info=optimized_stream._get_effective_user_info(),
|
||||||
group_info=optimized_stream._get_effective_group_info()
|
group_info=optimized_stream._get_effective_group_info(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -80,10 +80,7 @@ class OptimizedChatStream:
|
|||||||
):
|
):
|
||||||
# 共享的只读数据
|
# 共享的只读数据
|
||||||
self._shared_context = SharedContext(
|
self._shared_context = SharedContext(
|
||||||
stream_id=stream_id,
|
stream_id=stream_id, platform=platform, user_info=user_info, group_info=group_info
|
||||||
platform=platform,
|
|
||||||
user_info=user_info,
|
|
||||||
group_info=group_info
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 本地修改数据
|
# 本地修改数据
|
||||||
@@ -269,14 +266,13 @@ class OptimizedChatStream:
|
|||||||
self._stream_context = StreamContext(
|
self._stream_context = StreamContext(
|
||||||
stream_id=self.stream_id,
|
stream_id=self.stream_id,
|
||||||
chat_type=ChatType.GROUP if self.group_info else ChatType.PRIVATE,
|
chat_type=ChatType.GROUP if self.group_info else ChatType.PRIVATE,
|
||||||
chat_mode=ChatMode.NORMAL
|
chat_mode=ChatMode.NORMAL,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 创建单流上下文管理器
|
# 创建单流上下文管理器
|
||||||
from src.chat.message_manager.context_manager import SingleStreamContextManager
|
from src.chat.message_manager.context_manager import SingleStreamContextManager
|
||||||
self._context_manager = SingleStreamContextManager(
|
|
||||||
stream_id=self.stream_id, context=self._stream_context
|
self._context_manager = SingleStreamContextManager(stream_id=self.stream_id, context=self._stream_context)
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def stream_context(self):
|
def stream_context(self):
|
||||||
@@ -331,9 +327,11 @@ class OptimizedChatStream:
|
|||||||
# 恢复stream_context信息
|
# 恢复stream_context信息
|
||||||
if "stream_context_chat_type" in data:
|
if "stream_context_chat_type" in data:
|
||||||
from src.plugin_system.base.component_types import ChatMode, ChatType
|
from src.plugin_system.base.component_types import ChatMode, ChatType
|
||||||
|
|
||||||
instance.stream_context.chat_type = ChatType(data["stream_context_chat_type"])
|
instance.stream_context.chat_type = ChatType(data["stream_context_chat_type"])
|
||||||
if "stream_context_chat_mode" in data:
|
if "stream_context_chat_mode" in data:
|
||||||
from src.plugin_system.base.component_types import ChatMode, ChatType
|
from src.plugin_system.base.component_types import ChatMode, ChatType
|
||||||
|
|
||||||
instance.stream_context.chat_mode = ChatMode(data["stream_context_chat_mode"])
|
instance.stream_context.chat_mode = ChatMode(data["stream_context_chat_mode"])
|
||||||
|
|
||||||
# 恢复interruption_count信息
|
# 恢复interruption_count信息
|
||||||
@@ -352,6 +350,7 @@ class OptimizedChatStream:
|
|||||||
if isinstance(actions, str):
|
if isinstance(actions, str):
|
||||||
try:
|
try:
|
||||||
import json
|
import json
|
||||||
|
|
||||||
actions = json.loads(actions)
|
actions = json.loads(actions)
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
logger.warning(f"无法解析actions JSON字符串: {actions}")
|
logger.warning(f"无法解析actions JSON字符串: {actions}")
|
||||||
@@ -458,7 +457,7 @@ class OptimizedChatStream:
|
|||||||
stream_id=self.stream_id,
|
stream_id=self.stream_id,
|
||||||
platform=self.platform,
|
platform=self.platform,
|
||||||
user_info=self._get_effective_user_info(),
|
user_info=self._get_effective_user_info(),
|
||||||
group_info=self._get_effective_group_info()
|
group_info=self._get_effective_group_info(),
|
||||||
)
|
)
|
||||||
|
|
||||||
# 复制本地修改(但不触发写时复制)
|
# 复制本地修改(但不触发写时复制)
|
||||||
@@ -482,9 +481,5 @@ def create_optimized_chat_stream(
|
|||||||
) -> OptimizedChatStream:
|
) -> OptimizedChatStream:
|
||||||
"""创建优化版聊天流实例"""
|
"""创建优化版聊天流实例"""
|
||||||
return OptimizedChatStream(
|
return OptimizedChatStream(
|
||||||
stream_id=stream_id,
|
stream_id=stream_id, platform=platform, user_info=user_info, group_info=group_info, data=data
|
||||||
platform=platform,
|
|
||||||
user_info=user_info,
|
|
||||||
group_info=group_info,
|
|
||||||
data=data
|
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -196,10 +196,11 @@ class ChatterActionManager:
|
|||||||
thinking_id=thinking_id or "",
|
thinking_id=thinking_id or "",
|
||||||
action_done=True,
|
action_done=True,
|
||||||
action_build_into_prompt=False,
|
action_build_into_prompt=False,
|
||||||
action_prompt_display=reason
|
action_prompt_display=reason,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
asyncio.create_task(database_api.store_action_info(
|
asyncio.create_task(
|
||||||
|
database_api.store_action_info(
|
||||||
chat_stream=chat_stream,
|
chat_stream=chat_stream,
|
||||||
action_build_into_prompt=False,
|
action_build_into_prompt=False,
|
||||||
action_prompt_display=reason,
|
action_prompt_display=reason,
|
||||||
@@ -207,7 +208,8 @@ class ChatterActionManager:
|
|||||||
thinking_id=thinking_id,
|
thinking_id=thinking_id,
|
||||||
action_data={"reason": reason},
|
action_data={"reason": reason},
|
||||||
action_name="no_reply",
|
action_name="no_reply",
|
||||||
))
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# 自动清空所有未读消息
|
# 自动清空所有未读消息
|
||||||
asyncio.create_task(self._clear_all_unread_messages(chat_stream.stream_id, "no_reply"))
|
asyncio.create_task(self._clear_all_unread_messages(chat_stream.stream_id, "no_reply"))
|
||||||
@@ -228,7 +230,9 @@ class ChatterActionManager:
|
|||||||
|
|
||||||
# 记录执行的动作到目标消息
|
# 记录执行的动作到目标消息
|
||||||
if success:
|
if success:
|
||||||
asyncio.create_task(self._record_action_to_message(chat_stream, action_name, target_message, action_data))
|
asyncio.create_task(
|
||||||
|
self._record_action_to_message(chat_stream, action_name, target_message, action_data)
|
||||||
|
)
|
||||||
# 自动清空所有未读消息
|
# 自动清空所有未读消息
|
||||||
if clear_unread_messages:
|
if clear_unread_messages:
|
||||||
asyncio.create_task(self._clear_all_unread_messages(chat_stream.stream_id, action_name))
|
asyncio.create_task(self._clear_all_unread_messages(chat_stream.stream_id, action_name))
|
||||||
@@ -496,7 +500,7 @@ class ChatterActionManager:
|
|||||||
thinking_id=thinking_id or "",
|
thinking_id=thinking_id or "",
|
||||||
action_done=True,
|
action_done=True,
|
||||||
action_build_into_prompt=False,
|
action_build_into_prompt=False,
|
||||||
action_prompt_display=action_prompt_display
|
action_prompt_display=action_prompt_display,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
await database_api.store_action_info(
|
await database_api.store_action_info(
|
||||||
@@ -618,9 +622,15 @@ class ChatterActionManager:
|
|||||||
self._pending_actions = [] # 清空队列
|
self._pending_actions = [] # 清空队列
|
||||||
logger.debug("已禁用批量存储模式")
|
logger.debug("已禁用批量存储模式")
|
||||||
|
|
||||||
def add_action_to_batch(self, action_name: str, action_data: dict, thinking_id: str = "",
|
def add_action_to_batch(
|
||||||
action_done: bool = True, action_build_into_prompt: bool = False,
|
self,
|
||||||
action_prompt_display: str = ""):
|
action_name: str,
|
||||||
|
action_data: dict,
|
||||||
|
thinking_id: str = "",
|
||||||
|
action_done: bool = True,
|
||||||
|
action_build_into_prompt: bool = False,
|
||||||
|
action_prompt_display: str = "",
|
||||||
|
):
|
||||||
"""添加动作到批量存储列表"""
|
"""添加动作到批量存储列表"""
|
||||||
if not self._batch_storage_enabled:
|
if not self._batch_storage_enabled:
|
||||||
return False
|
return False
|
||||||
@@ -632,7 +642,7 @@ class ChatterActionManager:
|
|||||||
"action_done": action_done,
|
"action_done": action_done,
|
||||||
"action_build_into_prompt": action_build_into_prompt,
|
"action_build_into_prompt": action_build_into_prompt,
|
||||||
"action_prompt_display": action_prompt_display,
|
"action_prompt_display": action_prompt_display,
|
||||||
"timestamp": time.time()
|
"timestamp": time.time(),
|
||||||
}
|
}
|
||||||
self._pending_actions.append(action_record)
|
self._pending_actions.append(action_record)
|
||||||
logger.debug(f"已添加动作到批量存储列表: {action_name} (当前待处理: {len(self._pending_actions)} 个)")
|
logger.debug(f"已添加动作到批量存储列表: {action_name} (当前待处理: {len(self._pending_actions)} 个)")
|
||||||
@@ -658,7 +668,7 @@ class ChatterActionManager:
|
|||||||
action_done=action_data.get("action_done", True),
|
action_done=action_data.get("action_done", True),
|
||||||
action_build_into_prompt=action_data.get("action_build_into_prompt", False),
|
action_build_into_prompt=action_data.get("action_build_into_prompt", False),
|
||||||
action_prompt_display=action_data.get("action_prompt_display", ""),
|
action_prompt_display=action_data.get("action_prompt_display", ""),
|
||||||
thinking_id=action_data.get("thinking_id", "")
|
thinking_id=action_data.get("thinking_id", ""),
|
||||||
)
|
)
|
||||||
if result:
|
if result:
|
||||||
stored_count += 1
|
stored_count += 1
|
||||||
|
|||||||
@@ -1275,12 +1275,32 @@ class DefaultReplyer:
|
|||||||
|
|
||||||
# 并行执行六个构建任务
|
# 并行执行六个构建任务
|
||||||
tasks = {
|
tasks = {
|
||||||
"expression_habits": asyncio.create_task(self._time_and_run_task(self.build_expression_habits(chat_talking_prompt_short, target), "expression_habits")),
|
"expression_habits": asyncio.create_task(
|
||||||
"relation_info": asyncio.create_task(self._time_and_run_task(self.build_relation_info(sender, target), "relation_info")),
|
self._time_and_run_task(
|
||||||
"memory_block": asyncio.create_task(self._time_and_run_task(self.build_memory_block(chat_talking_prompt_short, target), "memory_block")),
|
self.build_expression_habits(chat_talking_prompt_short, target), "expression_habits"
|
||||||
"tool_info": asyncio.create_task(self._time_and_run_task(self.build_tool_info(chat_talking_prompt_short, sender, target, enable_tool=enable_tool), "tool_info")),
|
)
|
||||||
"prompt_info": asyncio.create_task(self._time_and_run_task(self.get_prompt_info(chat_talking_prompt_short, sender, target), "prompt_info")),
|
),
|
||||||
"cross_context": asyncio.create_task(self._time_and_run_task(Prompt.build_cross_context(chat_id, global_config.personality.prompt_mode, target_user_info), "cross_context")),
|
"relation_info": asyncio.create_task(
|
||||||
|
self._time_and_run_task(self.build_relation_info(sender, target), "relation_info")
|
||||||
|
),
|
||||||
|
"memory_block": asyncio.create_task(
|
||||||
|
self._time_and_run_task(self.build_memory_block(chat_talking_prompt_short, target), "memory_block")
|
||||||
|
),
|
||||||
|
"tool_info": asyncio.create_task(
|
||||||
|
self._time_and_run_task(
|
||||||
|
self.build_tool_info(chat_talking_prompt_short, sender, target, enable_tool=enable_tool),
|
||||||
|
"tool_info",
|
||||||
|
)
|
||||||
|
),
|
||||||
|
"prompt_info": asyncio.create_task(
|
||||||
|
self._time_and_run_task(self.get_prompt_info(chat_talking_prompt_short, sender, target), "prompt_info")
|
||||||
|
),
|
||||||
|
"cross_context": asyncio.create_task(
|
||||||
|
self._time_and_run_task(
|
||||||
|
Prompt.build_cross_context(chat_id, global_config.personality.prompt_mode, target_user_info),
|
||||||
|
"cross_context",
|
||||||
|
)
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
# 设置超时
|
# 设置超时
|
||||||
@@ -1606,13 +1626,8 @@ class DefaultReplyer:
|
|||||||
chat_target_name = (
|
chat_target_name = (
|
||||||
self.chat_target_info.get("person_name") or self.chat_target_info.get("user_nickname") or "对方"
|
self.chat_target_info.get("person_name") or self.chat_target_info.get("user_nickname") or "对方"
|
||||||
)
|
)
|
||||||
await global_prompt_manager.format_prompt(
|
await global_prompt_manager.format_prompt("chat_target_private1", sender_name=chat_target_name)
|
||||||
"chat_target_private1", sender_name=chat_target_name
|
await global_prompt_manager.format_prompt("chat_target_private2", sender_name=chat_target_name)
|
||||||
)
|
|
||||||
await global_prompt_manager.format_prompt(
|
|
||||||
"chat_target_private2", sender_name=chat_target_name
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# 使用新的统一Prompt系统 - Expressor模式,创建PromptParameters
|
# 使用新的统一Prompt系统 - Expressor模式,创建PromptParameters
|
||||||
prompt_parameters = PromptParameters(
|
prompt_parameters = PromptParameters(
|
||||||
|
|||||||
@@ -121,13 +121,14 @@ class VideoAnalyzer:
|
|||||||
async def _analyze_batch(self, frames: list[tuple[str, float]], question: str | None) -> str:
|
async def _analyze_batch(self, frames: list[tuple[str, float]], question: str | None) -> str:
|
||||||
from src.llm_models.payload_content.message import MessageBuilder, RoleType
|
from src.llm_models.payload_content.message import MessageBuilder, RoleType
|
||||||
from src.llm_models.utils_model import RequestType
|
from src.llm_models.utils_model import RequestType
|
||||||
|
|
||||||
prompt = self.batch_analysis_prompt.format(
|
prompt = self.batch_analysis_prompt.format(
|
||||||
personality_core=self.personality_core, personality_side=self.personality_side
|
personality_core=self.personality_core, personality_side=self.personality_side
|
||||||
)
|
)
|
||||||
if question:
|
if question:
|
||||||
prompt += f"\n用户关注: {question}"
|
prompt += f"\n用户关注: {question}"
|
||||||
desc = [
|
desc = [
|
||||||
(f"第{i+1}帧 (时间: {ts:.2f}s)" if self.enable_frame_timing else f"第{i+1}帧")
|
(f"第{i + 1}帧 (时间: {ts:.2f}s)" if self.enable_frame_timing else f"第{i + 1}帧")
|
||||||
for i, (_b, ts) in enumerate(frames)
|
for i, (_b, ts) in enumerate(frames)
|
||||||
]
|
]
|
||||||
prompt += "\n帧列表: " + ", ".join(desc)
|
prompt += "\n帧列表: " + ", ".join(desc)
|
||||||
@@ -151,16 +152,16 @@ class VideoAnalyzer:
|
|||||||
async def _analyze_sequential(self, frames: list[tuple[str, float]], question: str | None) -> str:
|
async def _analyze_sequential(self, frames: list[tuple[str, float]], question: str | None) -> str:
|
||||||
results: list[str] = []
|
results: list[str] = []
|
||||||
for i, (b64, ts) in enumerate(frames):
|
for i, (b64, ts) in enumerate(frames):
|
||||||
prompt = f"分析第{i+1}帧" + (f" (时间: {ts:.2f}s)" if self.enable_frame_timing else "")
|
prompt = f"分析第{i + 1}帧" + (f" (时间: {ts:.2f}s)" if self.enable_frame_timing else "")
|
||||||
if question:
|
if question:
|
||||||
prompt += f"\n关注: {question}"
|
prompt += f"\n关注: {question}"
|
||||||
try:
|
try:
|
||||||
text, _ = await self.video_llm.generate_response_for_image(
|
text, _ = await self.video_llm.generate_response_for_image(
|
||||||
prompt=prompt, image_base64=b64, image_format="jpeg"
|
prompt=prompt, image_base64=b64, image_format="jpeg"
|
||||||
)
|
)
|
||||||
results.append(f"第{i+1}帧: {text}")
|
results.append(f"第{i + 1}帧: {text}")
|
||||||
except Exception as e: # pragma: no cover
|
except Exception as e: # pragma: no cover
|
||||||
results.append(f"第{i+1}帧: 失败 {e}")
|
results.append(f"第{i + 1}帧: 失败 {e}")
|
||||||
if i < len(frames) - 1:
|
if i < len(frames) - 1:
|
||||||
await asyncio.sleep(self.frame_analysis_delay)
|
await asyncio.sleep(self.frame_analysis_delay)
|
||||||
summary_prompt = "基于以下逐帧结果给出完整总结:\n\n" + "\n".join(results)
|
summary_prompt = "基于以下逐帧结果给出完整总结:\n\n" + "\n".join(results)
|
||||||
@@ -182,7 +183,9 @@ class VideoAnalyzer:
|
|||||||
mode = self.analysis_mode
|
mode = self.analysis_mode
|
||||||
if mode == "auto":
|
if mode == "auto":
|
||||||
mode = "batch" if len(frames) <= 20 else "sequential"
|
mode = "batch" if len(frames) <= 20 else "sequential"
|
||||||
text = await (self._analyze_batch(frames, question) if mode == "batch" else self._analyze_sequential(frames, question))
|
text = await (
|
||||||
|
self._analyze_batch(frames, question) if mode == "batch" else self._analyze_sequential(frames, question)
|
||||||
|
)
|
||||||
return True, text
|
return True, text
|
||||||
|
|
||||||
async def analyze_video_from_bytes(
|
async def analyze_video_from_bytes(
|
||||||
|
|||||||
@@ -220,7 +220,9 @@ class DatabaseMessages(BaseDataModel):
|
|||||||
"chat_info_user_cardname": self.chat_info.user_info.user_cardname,
|
"chat_info_user_cardname": self.chat_info.user_info.user_cardname,
|
||||||
}
|
}
|
||||||
|
|
||||||
def update_message_info(self, interest_value: float | None = None, actions: list | None = None, should_reply: bool | None = None):
|
def update_message_info(
|
||||||
|
self, interest_value: float | None = None, actions: list | None = None, should_reply: bool | None = None
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
更新消息信息
|
更新消息信息
|
||||||
|
|
||||||
|
|||||||
@@ -53,8 +53,6 @@ class StreamContext(BaseDataModel):
|
|||||||
priority_mode: str | None = None
|
priority_mode: str | None = None
|
||||||
priority_info: dict | None = None
|
priority_info: dict | None = None
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def add_action_to_message(self, message_id: str, action: str):
|
def add_action_to_message(self, message_id: str, action: str):
|
||||||
"""
|
"""
|
||||||
向指定消息添加执行的动作
|
向指定消息添加执行的动作
|
||||||
@@ -75,9 +73,6 @@ class StreamContext(BaseDataModel):
|
|||||||
message.add_action(action)
|
message.add_action(action)
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def mark_message_as_read(self, message_id: str):
|
def mark_message_as_read(self, message_id: str):
|
||||||
"""标记消息为已读"""
|
"""标记消息为已读"""
|
||||||
for msg in self.unread_messages:
|
for msg in self.unread_messages:
|
||||||
|
|||||||
@@ -78,7 +78,7 @@ class ConnectionPoolManager:
|
|||||||
"total_expired": 0,
|
"total_expired": 0,
|
||||||
"active_connections": 0,
|
"active_connections": 0,
|
||||||
"pool_hits": 0,
|
"pool_hits": 0,
|
||||||
"pool_misses": 0
|
"pool_misses": 0,
|
||||||
}
|
}
|
||||||
|
|
||||||
# 后台清理任务
|
# 后台清理任务
|
||||||
@@ -156,7 +156,9 @@ class ConnectionPoolManager:
|
|||||||
if connection_info:
|
if connection_info:
|
||||||
connection_info.mark_released()
|
connection_info.mark_released()
|
||||||
|
|
||||||
async def _get_reusable_connection(self, session_factory: async_sessionmaker[AsyncSession]) -> ConnectionInfo | None:
|
async def _get_reusable_connection(
|
||||||
|
self, session_factory: async_sessionmaker[AsyncSession]
|
||||||
|
) -> ConnectionInfo | None:
|
||||||
"""获取可复用的连接"""
|
"""获取可复用的连接"""
|
||||||
async with self._lock:
|
async with self._lock:
|
||||||
# 清理过期连接
|
# 清理过期连接
|
||||||
@@ -164,9 +166,7 @@ class ConnectionPoolManager:
|
|||||||
|
|
||||||
# 查找可复用的连接
|
# 查找可复用的连接
|
||||||
for connection_info in list(self._connections):
|
for connection_info in list(self._connections):
|
||||||
if (not connection_info.in_use and
|
if not connection_info.in_use and not connection_info.is_expired(self.max_lifetime, self.max_idle):
|
||||||
not connection_info.is_expired(self.max_lifetime, self.max_idle)):
|
|
||||||
|
|
||||||
# 验证连接是否仍然有效
|
# 验证连接是否仍然有效
|
||||||
try:
|
try:
|
||||||
# 执行一个简单的查询来验证连接
|
# 执行一个简单的查询来验证连接
|
||||||
@@ -191,8 +191,7 @@ class ConnectionPoolManager:
|
|||||||
expired_connections = []
|
expired_connections = []
|
||||||
|
|
||||||
for connection_info in list(self._connections):
|
for connection_info in list(self._connections):
|
||||||
if (connection_info.is_expired(self.max_lifetime, self.max_idle) and
|
if connection_info.is_expired(self.max_lifetime, self.max_idle) and not connection_info.in_use:
|
||||||
not connection_info.in_use):
|
|
||||||
expired_connections.append(connection_info)
|
expired_connections.append(connection_info)
|
||||||
|
|
||||||
for connection_info in expired_connections:
|
for connection_info in expired_connections:
|
||||||
@@ -238,7 +237,8 @@ class ConnectionPoolManager:
|
|||||||
"max_pool_size": self.max_pool_size,
|
"max_pool_size": self.max_pool_size,
|
||||||
"pool_efficiency": (
|
"pool_efficiency": (
|
||||||
self._stats["pool_hits"] / max(1, self._stats["pool_hits"] + self._stats["pool_misses"])
|
self._stats["pool_hits"] / max(1, self._stats["pool_hits"] + self._stats["pool_misses"])
|
||||||
) * 100
|
)
|
||||||
|
* 100,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ T = TypeVar("T")
|
|||||||
@dataclass
|
@dataclass
|
||||||
class BatchOperation:
|
class BatchOperation:
|
||||||
"""批量操作基础类"""
|
"""批量操作基础类"""
|
||||||
|
|
||||||
operation_type: str # 'select', 'insert', 'update', 'delete'
|
operation_type: str # 'select', 'insert', 'update', 'delete'
|
||||||
model_class: Any
|
model_class: Any
|
||||||
conditions: dict[str, Any]
|
conditions: dict[str, Any]
|
||||||
@@ -40,6 +41,7 @@ class BatchOperation:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class BatchResult:
|
class BatchResult:
|
||||||
"""批量操作结果"""
|
"""批量操作结果"""
|
||||||
|
|
||||||
success: bool
|
success: bool
|
||||||
data: Any = None
|
data: Any = None
|
||||||
error: str | None = None
|
error: str | None = None
|
||||||
@@ -48,10 +50,12 @@ class BatchResult:
|
|||||||
class DatabaseBatchScheduler:
|
class DatabaseBatchScheduler:
|
||||||
"""数据库批量调度器"""
|
"""数据库批量调度器"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
|
self,
|
||||||
batch_size: int = 50,
|
batch_size: int = 50,
|
||||||
max_wait_time: float = 0.1, # 100ms
|
max_wait_time: float = 0.1, # 100ms
|
||||||
max_queue_size: int = 1000):
|
max_queue_size: int = 1000,
|
||||||
|
):
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.max_wait_time = max_wait_time
|
self.max_wait_time = max_wait_time
|
||||||
self.max_queue_size = max_queue_size
|
self.max_queue_size = max_queue_size
|
||||||
@@ -65,12 +69,7 @@ class DatabaseBatchScheduler:
|
|||||||
self._lock = asyncio.Lock()
|
self._lock = asyncio.Lock()
|
||||||
|
|
||||||
# 统计信息
|
# 统计信息
|
||||||
self.stats = {
|
self.stats = {"total_operations": 0, "batched_operations": 0, "cache_hits": 0, "execution_time": 0.0}
|
||||||
"total_operations": 0,
|
|
||||||
"batched_operations": 0,
|
|
||||||
"cache_hits": 0,
|
|
||||||
"execution_time": 0.0
|
|
||||||
}
|
|
||||||
|
|
||||||
# 简单的结果缓存(用于频繁的查询)
|
# 简单的结果缓存(用于频繁的查询)
|
||||||
self._result_cache: dict[str, tuple[Any, float]] = {}
|
self._result_cache: dict[str, tuple[Any, float]] = {}
|
||||||
@@ -105,11 +104,7 @@ class DatabaseBatchScheduler:
|
|||||||
def _generate_cache_key(self, operation_type: str, model_class: Any, conditions: dict[str, Any]) -> str:
|
def _generate_cache_key(self, operation_type: str, model_class: Any, conditions: dict[str, Any]) -> str:
|
||||||
"""生成缓存键"""
|
"""生成缓存键"""
|
||||||
# 简单的缓存键生成,实际可以根据需要优化
|
# 简单的缓存键生成,实际可以根据需要优化
|
||||||
key_parts = [
|
key_parts = [operation_type, model_class.__name__, str(sorted(conditions.items()))]
|
||||||
operation_type,
|
|
||||||
model_class.__name__,
|
|
||||||
str(sorted(conditions.items()))
|
|
||||||
]
|
|
||||||
return "|".join(key_parts)
|
return "|".join(key_parts)
|
||||||
|
|
||||||
def _get_from_cache(self, cache_key: str) -> Any | None:
|
def _get_from_cache(self, cache_key: str) -> Any | None:
|
||||||
@@ -132,11 +127,7 @@ class DatabaseBatchScheduler:
|
|||||||
"""添加操作到队列"""
|
"""添加操作到队列"""
|
||||||
# 检查是否可以立即返回缓存结果
|
# 检查是否可以立即返回缓存结果
|
||||||
if operation.operation_type == "select":
|
if operation.operation_type == "select":
|
||||||
cache_key = self._generate_cache_key(
|
cache_key = self._generate_cache_key(operation.operation_type, operation.model_class, operation.conditions)
|
||||||
operation.operation_type,
|
|
||||||
operation.model_class,
|
|
||||||
operation.conditions
|
|
||||||
)
|
|
||||||
cached_result = self._get_from_cache(cache_key)
|
cached_result = self._get_from_cache(cache_key)
|
||||||
if cached_result is not None:
|
if cached_result is not None:
|
||||||
if operation.callback:
|
if operation.callback:
|
||||||
@@ -180,10 +171,7 @@ class DatabaseBatchScheduler:
|
|||||||
return
|
return
|
||||||
|
|
||||||
# 复制队列内容,避免长时间占用锁
|
# 复制队列内容,避免长时间占用锁
|
||||||
queues_copy = {
|
queues_copy = {key: deque(operations) for key, operations in self.operation_queues.items()}
|
||||||
key: deque(operations)
|
|
||||||
for key, operations in self.operation_queues.items()
|
|
||||||
}
|
|
||||||
# 清空原队列
|
# 清空原队列
|
||||||
for queue in self.operation_queues.values():
|
for queue in self.operation_queues.values():
|
||||||
queue.clear()
|
queue.clear()
|
||||||
@@ -240,9 +228,7 @@ class DatabaseBatchScheduler:
|
|||||||
# 缓存查询结果
|
# 缓存查询结果
|
||||||
if operation.operation_type == "select":
|
if operation.operation_type == "select":
|
||||||
cache_key = self._generate_cache_key(
|
cache_key = self._generate_cache_key(
|
||||||
operation.operation_type,
|
operation.operation_type, operation.model_class, operation.conditions
|
||||||
operation.model_class,
|
|
||||||
operation.conditions
|
|
||||||
)
|
)
|
||||||
self._set_cache(cache_key, result)
|
self._set_cache(cache_key, result)
|
||||||
|
|
||||||
@@ -287,12 +273,9 @@ class DatabaseBatchScheduler:
|
|||||||
else:
|
else:
|
||||||
# 需要根据条件过滤结果
|
# 需要根据条件过滤结果
|
||||||
op_result = [
|
op_result = [
|
||||||
item for item in data
|
item
|
||||||
if all(
|
for item in data
|
||||||
getattr(item, k) == v
|
if all(getattr(item, k) == v for k, v in op.conditions.items() if hasattr(item, k))
|
||||||
for k, v in op.conditions.items()
|
|
||||||
if hasattr(item, k)
|
|
||||||
)
|
|
||||||
]
|
]
|
||||||
results.append(op_result)
|
results.append(op_result)
|
||||||
|
|
||||||
@@ -429,7 +412,7 @@ class DatabaseBatchScheduler:
|
|||||||
**self.stats,
|
**self.stats,
|
||||||
"cache_size": len(self._result_cache),
|
"cache_size": len(self._result_cache),
|
||||||
"queue_sizes": {k: len(v) for k, v in self.operation_queues.items()},
|
"queue_sizes": {k: len(v) for k, v in self.operation_queues.items()},
|
||||||
"is_running": self._is_running
|
"is_running": self._is_running,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -452,43 +435,25 @@ async def get_batch_session():
|
|||||||
# 便捷函数
|
# 便捷函数
|
||||||
async def batch_select(model_class: Any, conditions: dict[str, Any]) -> Any:
|
async def batch_select(model_class: Any, conditions: dict[str, Any]) -> Any:
|
||||||
"""批量查询"""
|
"""批量查询"""
|
||||||
operation = BatchOperation(
|
operation = BatchOperation(operation_type="select", model_class=model_class, conditions=conditions)
|
||||||
operation_type="select",
|
|
||||||
model_class=model_class,
|
|
||||||
conditions=conditions
|
|
||||||
)
|
|
||||||
return await db_batch_scheduler.add_operation(operation)
|
return await db_batch_scheduler.add_operation(operation)
|
||||||
|
|
||||||
|
|
||||||
async def batch_insert(model_class: Any, data: dict[str, Any]) -> int:
|
async def batch_insert(model_class: Any, data: dict[str, Any]) -> int:
|
||||||
"""批量插入"""
|
"""批量插入"""
|
||||||
operation = BatchOperation(
|
operation = BatchOperation(operation_type="insert", model_class=model_class, conditions={}, data=data)
|
||||||
operation_type="insert",
|
|
||||||
model_class=model_class,
|
|
||||||
conditions={},
|
|
||||||
data=data
|
|
||||||
)
|
|
||||||
return await db_batch_scheduler.add_operation(operation)
|
return await db_batch_scheduler.add_operation(operation)
|
||||||
|
|
||||||
|
|
||||||
async def batch_update(model_class: Any, conditions: dict[str, Any], data: dict[str, Any]) -> int:
|
async def batch_update(model_class: Any, conditions: dict[str, Any], data: dict[str, Any]) -> int:
|
||||||
"""批量更新"""
|
"""批量更新"""
|
||||||
operation = BatchOperation(
|
operation = BatchOperation(operation_type="update", model_class=model_class, conditions=conditions, data=data)
|
||||||
operation_type="update",
|
|
||||||
model_class=model_class,
|
|
||||||
conditions=conditions,
|
|
||||||
data=data
|
|
||||||
)
|
|
||||||
return await db_batch_scheduler.add_operation(operation)
|
return await db_batch_scheduler.add_operation(operation)
|
||||||
|
|
||||||
|
|
||||||
async def batch_delete(model_class: Any, conditions: dict[str, Any]) -> int:
|
async def batch_delete(model_class: Any, conditions: dict[str, Any]) -> int:
|
||||||
"""批量删除"""
|
"""批量删除"""
|
||||||
operation = BatchOperation(
|
operation = BatchOperation(operation_type="delete", model_class=model_class, conditions=conditions)
|
||||||
operation_type="delete",
|
|
||||||
model_class=model_class,
|
|
||||||
conditions=conditions
|
|
||||||
)
|
|
||||||
return await db_batch_scheduler.add_operation(operation)
|
return await db_batch_scheduler.add_operation(operation)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -304,7 +304,6 @@ def load_log_config(): # sourcery skip: use-contextlib-suppress
|
|||||||
"library_log_levels": {"aiohttp": "WARNING"},
|
"library_log_levels": {"aiohttp": "WARNING"},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
# 误加的即刻线程启动已移除;真正的线程在 start_log_cleanup_task 中按午夜调度
|
# 误加的即刻线程启动已移除;真正的线程在 start_log_cleanup_task 中按午夜调度
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -37,7 +37,9 @@ class DatabaseConfig(ValidatedConfigBase):
|
|||||||
connection_timeout: int = Field(default=10, ge=1, description="连接超时时间")
|
connection_timeout: int = Field(default=10, ge=1, description="连接超时时间")
|
||||||
|
|
||||||
# 批量动作记录存储配置
|
# 批量动作记录存储配置
|
||||||
batch_action_storage_enabled: bool = Field(default=True, description="是否启用批量保存动作记录(开启后将多个动作一次性写入数据库,提升性能)")
|
batch_action_storage_enabled: bool = Field(
|
||||||
|
default=True, description="是否启用批量保存动作记录(开启后将多个动作一次性写入数据库,提升性能)"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class BotConfig(ValidatedConfigBase):
|
class BotConfig(ValidatedConfigBase):
|
||||||
@@ -355,7 +357,7 @@ class MemoryConfig(ValidatedConfigBase):
|
|||||||
# 双峰分布配置 [近期均值, 近期标准差, 近期权重, 远期均值, 远期标准差, 远期权重]
|
# 双峰分布配置 [近期均值, 近期标准差, 近期权重, 远期均值, 远期标准差, 远期权重]
|
||||||
hippocampus_distribution_config: list[float] = Field(
|
hippocampus_distribution_config: list[float] = Field(
|
||||||
default=[12.0, 8.0, 0.7, 48.0, 24.0, 0.3],
|
default=[12.0, 8.0, 0.7, 48.0, 24.0, 0.3],
|
||||||
description="海马体双峰分布配置:[近期均值(h), 近期标准差(h), 近期权重, 远期均值(h), 远期标准差(h), 远期权重]"
|
description="海马体双峰分布配置:[近期均值(h), 近期标准差(h), 近期权重, 远期均值(h), 远期标准差(h), 远期权重]",
|
||||||
)
|
)
|
||||||
|
|
||||||
# 自适应采样配置
|
# 自适应采样配置
|
||||||
@@ -690,7 +692,6 @@ class AffinityFlowConfig(ValidatedConfigBase):
|
|||||||
base_relationship_score: float = Field(default=0.5, description="基础人物关系分")
|
base_relationship_score: float = Field(default=0.5, description="基础人物关系分")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ProactiveThinkingConfig(ValidatedConfigBase):
|
class ProactiveThinkingConfig(ValidatedConfigBase):
|
||||||
"""主动思考(主动发起对话)功能配置"""
|
"""主动思考(主动发起对话)功能配置"""
|
||||||
|
|
||||||
|
|||||||
@@ -50,14 +50,16 @@ def _convert_messages_to_mcp(messages: list[Message]) -> list[dict[str, Any]]:
|
|||||||
for item in message.content:
|
for item in message.content:
|
||||||
if isinstance(item, tuple):
|
if isinstance(item, tuple):
|
||||||
# 图片内容
|
# 图片内容
|
||||||
content_parts.append({
|
content_parts.append(
|
||||||
|
{
|
||||||
"type": "image",
|
"type": "image",
|
||||||
"source": {
|
"source": {
|
||||||
"type": "base64",
|
"type": "base64",
|
||||||
"media_type": f"image/{item[0].lower()}",
|
"media_type": f"image/{item[0].lower()}",
|
||||||
"data": item[1],
|
"data": item[1],
|
||||||
},
|
},
|
||||||
})
|
}
|
||||||
|
)
|
||||||
elif isinstance(item, str):
|
elif isinstance(item, str):
|
||||||
# 文本内容
|
# 文本内容
|
||||||
content_parts.append({"type": "text", "text": item})
|
content_parts.append({"type": "text", "text": item})
|
||||||
@@ -138,9 +140,7 @@ async def _parse_sse_stream(
|
|||||||
async with session.post(url, json=payload, headers=headers) as response:
|
async with session.post(url, json=payload, headers=headers) as response:
|
||||||
if response.status != 200:
|
if response.status != 200:
|
||||||
error_text = await response.text()
|
error_text = await response.text()
|
||||||
raise RespNotOkException(
|
raise RespNotOkException(response.status, f"MCP SSE请求失败: {error_text}")
|
||||||
response.status, f"MCP SSE请求失败: {error_text}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 解析SSE流
|
# 解析SSE流
|
||||||
async for line in response.content:
|
async for line in response.content:
|
||||||
@@ -258,10 +258,7 @@ async def _parse_sse_stream(
|
|||||||
response.reasoning_content = reasoning_buffer.getvalue()
|
response.reasoning_content = reasoning_buffer.getvalue()
|
||||||
|
|
||||||
if tool_calls_buffer:
|
if tool_calls_buffer:
|
||||||
response.tool_calls = [
|
response.tool_calls = [ToolCall(call_id, func_name, args) for call_id, func_name, args in tool_calls_buffer]
|
||||||
ToolCall(call_id, func_name, args)
|
|
||||||
for call_id, func_name, args in tool_calls_buffer
|
|
||||||
]
|
|
||||||
|
|
||||||
# 关闭缓冲区
|
# 关闭缓冲区
|
||||||
content_buffer.close()
|
content_buffer.close()
|
||||||
@@ -351,9 +348,7 @@ class MCPSSEClient(BaseClient):
|
|||||||
url = f"{self.api_provider.base_url}/v1/messages"
|
url = f"{self.api_provider.base_url}/v1/messages"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response, usage_record = await _parse_sse_stream(
|
response, usage_record = await _parse_sse_stream(session, url, payload, headers, interrupt_flag)
|
||||||
session, url, payload, headers, interrupt_flag
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"MCP SSE请求失败: {e}")
|
logger.error(f"MCP SSE请求失败: {e}")
|
||||||
raise
|
raise
|
||||||
|
|||||||
@@ -414,9 +414,7 @@ class OpenaiClient(BaseClient):
|
|||||||
|
|
||||||
# 创建新的 AsyncOpenAI 实例
|
# 创建新的 AsyncOpenAI 实例
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"创建新的 AsyncOpenAI 客户端实例 "
|
f"创建新的 AsyncOpenAI 客户端实例 (base_url={self.api_provider.base_url}, config_hash={self._config_hash})"
|
||||||
f"(base_url={self.api_provider.base_url}, "
|
|
||||||
f"config_hash={self._config_hash})"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
client = AsyncOpenAI(
|
client = AsyncOpenAI(
|
||||||
|
|||||||
@@ -280,7 +280,9 @@ class _PromptProcessor:
|
|||||||
这有助于我判断你的输出是否被截断。请不要在 `{self.end_marker}` 前后添加任何其他文字或标点。
|
这有助于我判断你的输出是否被截断。请不要在 `{self.end_marker}` 前后添加任何其他文字或标点。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
async def prepare_prompt(self, prompt: str, model_info: ModelInfo, api_provider: APIProvider, task_name: str) -> str:
|
async def prepare_prompt(
|
||||||
|
self, prompt: str, model_info: ModelInfo, api_provider: APIProvider, task_name: str
|
||||||
|
) -> str:
|
||||||
"""
|
"""
|
||||||
为请求准备最终的提示词。
|
为请求准备最终的提示词。
|
||||||
|
|
||||||
|
|||||||
31
src/main.py
31
src/main.py
@@ -88,6 +88,7 @@ class MainSystem:
|
|||||||
|
|
||||||
def _setup_signal_handlers(self) -> None:
|
def _setup_signal_handlers(self) -> None:
|
||||||
"""设置信号处理器"""
|
"""设置信号处理器"""
|
||||||
|
|
||||||
def signal_handler(signum, frame):
|
def signal_handler(signum, frame):
|
||||||
if self._shutting_down:
|
if self._shutting_down:
|
||||||
logger.warning("系统已经在关闭过程中,忽略重复信号")
|
logger.warning("系统已经在关闭过程中,忽略重复信号")
|
||||||
@@ -132,6 +133,7 @@ class MainSystem:
|
|||||||
try:
|
try:
|
||||||
from src.plugin_system.apis.component_manage_api import get_components_info_by_type
|
from src.plugin_system.apis.component_manage_api import get_components_info_by_type
|
||||||
from src.plugin_system.base.component_types import ComponentType
|
from src.plugin_system.base.component_types import ComponentType
|
||||||
|
|
||||||
interest_calculators = get_components_info_by_type(ComponentType.INTEREST_CALCULATOR)
|
interest_calculators = get_components_info_by_type(ComponentType.INTEREST_CALCULATOR)
|
||||||
logger.info(f"通过组件注册表发现 {len(interest_calculators)} 个兴趣计算器组件")
|
logger.info(f"通过组件注册表发现 {len(interest_calculators)} 个兴趣计算器组件")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -143,6 +145,7 @@ class MainSystem:
|
|||||||
|
|
||||||
# 初始化兴趣度管理器
|
# 初始化兴趣度管理器
|
||||||
from src.chat.interest_system.interest_manager import get_interest_manager
|
from src.chat.interest_system.interest_manager import get_interest_manager
|
||||||
|
|
||||||
interest_manager = get_interest_manager()
|
interest_manager = get_interest_manager()
|
||||||
await interest_manager.initialize()
|
await interest_manager.initialize()
|
||||||
|
|
||||||
@@ -159,7 +162,10 @@ class MainSystem:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
from src.plugin_system.core.component_registry import component_registry
|
from src.plugin_system.core.component_registry import component_registry
|
||||||
component_class = component_registry.get_component_class(calc_name, ComponentType.INTEREST_CALCULATOR)
|
|
||||||
|
component_class = component_registry.get_component_class(
|
||||||
|
calc_name, ComponentType.INTEREST_CALCULATOR
|
||||||
|
)
|
||||||
|
|
||||||
if not component_class:
|
if not component_class:
|
||||||
logger.warning(f"无法找到 {calc_name} 的组件类")
|
logger.warning(f"无法找到 {calc_name} 的组件类")
|
||||||
@@ -208,6 +214,7 @@ class MainSystem:
|
|||||||
# 停止数据库服务
|
# 停止数据库服务
|
||||||
try:
|
try:
|
||||||
from src.common.database.database import stop_database
|
from src.common.database.database import stop_database
|
||||||
|
|
||||||
cleanup_tasks.append(("数据库服务", stop_database()))
|
cleanup_tasks.append(("数据库服务", stop_database()))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"准备停止数据库服务时出错: {e}")
|
logger.error(f"准备停止数据库服务时出错: {e}")
|
||||||
@@ -215,6 +222,7 @@ class MainSystem:
|
|||||||
# 停止消息管理器
|
# 停止消息管理器
|
||||||
try:
|
try:
|
||||||
from src.chat.message_manager import message_manager
|
from src.chat.message_manager import message_manager
|
||||||
|
|
||||||
cleanup_tasks.append(("消息管理器", message_manager.stop()))
|
cleanup_tasks.append(("消息管理器", message_manager.stop()))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"准备停止消息管理器时出错: {e}")
|
logger.error(f"准备停止消息管理器时出错: {e}")
|
||||||
@@ -222,6 +230,7 @@ class MainSystem:
|
|||||||
# 停止消息重组器
|
# 停止消息重组器
|
||||||
try:
|
try:
|
||||||
from src.utils.message_chunker import reassembler
|
from src.utils.message_chunker import reassembler
|
||||||
|
|
||||||
cleanup_tasks.append(("消息重组器", reassembler.stop_cleanup_task()))
|
cleanup_tasks.append(("消息重组器", reassembler.stop_cleanup_task()))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"准备停止消息重组器时出错: {e}")
|
logger.error(f"准备停止消息重组器时出错: {e}")
|
||||||
@@ -236,15 +245,18 @@ class MainSystem:
|
|||||||
# 触发停止事件
|
# 触发停止事件
|
||||||
try:
|
try:
|
||||||
from src.plugin_system.core.event_manager import event_manager
|
from src.plugin_system.core.event_manager import event_manager
|
||||||
cleanup_tasks.append(("插件系统停止事件",
|
|
||||||
event_manager.trigger_event(EventType.ON_STOP, permission_group="SYSTEM")))
|
cleanup_tasks.append(
|
||||||
|
("插件系统停止事件", event_manager.trigger_event(EventType.ON_STOP, permission_group="SYSTEM"))
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"准备触发停止事件时出错: {e}")
|
logger.error(f"准备触发停止事件时出错: {e}")
|
||||||
|
|
||||||
# 停止表情管理器
|
# 停止表情管理器
|
||||||
try:
|
try:
|
||||||
cleanup_tasks.append(("表情管理器",
|
cleanup_tasks.append(
|
||||||
asyncio.get_event_loop().run_in_executor(None, get_emoji_manager().shutdown)))
|
("表情管理器", asyncio.get_event_loop().run_in_executor(None, get_emoji_manager().shutdown))
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"准备停止表情管理器时出错: {e}")
|
logger.error(f"准备停止表情管理器时出错: {e}")
|
||||||
|
|
||||||
@@ -275,7 +287,7 @@ class MainSystem:
|
|||||||
try:
|
try:
|
||||||
results = await asyncio.wait_for(
|
results = await asyncio.wait_for(
|
||||||
asyncio.gather(*tasks, return_exceptions=True),
|
asyncio.gather(*tasks, return_exceptions=True),
|
||||||
timeout=30.0 # 30秒超时
|
timeout=30.0, # 30秒超时
|
||||||
)
|
)
|
||||||
|
|
||||||
# 记录结果
|
# 记录结果
|
||||||
@@ -389,6 +401,7 @@ MoFox_Bot(第三方修改版)
|
|||||||
# 注册API路由
|
# 注册API路由
|
||||||
try:
|
try:
|
||||||
from src.api.message_router import router as message_router
|
from src.api.message_router import router as message_router
|
||||||
|
|
||||||
self.server.register_router(message_router, prefix="/api")
|
self.server.register_router(message_router, prefix="/api")
|
||||||
logger.info("API路由注册成功")
|
logger.info("API路由注册成功")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -405,6 +418,7 @@ MoFox_Bot(第三方修改版)
|
|||||||
mcp_config = global_config.get("mcp_servers", [])
|
mcp_config = global_config.get("mcp_servers", [])
|
||||||
if mcp_config:
|
if mcp_config:
|
||||||
from src.plugin_system.utils.mcp_tool_provider import mcp_tool_provider
|
from src.plugin_system.utils.mcp_tool_provider import mcp_tool_provider
|
||||||
|
|
||||||
await mcp_tool_provider.initialize(mcp_config)
|
await mcp_tool_provider.initialize(mcp_config)
|
||||||
logger.info("MCP工具提供器初始化成功")
|
logger.info("MCP工具提供器初始化成功")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -445,6 +459,7 @@ MoFox_Bot(第三方修改版)
|
|||||||
# 初始化LPMM知识库
|
# 初始化LPMM知识库
|
||||||
try:
|
try:
|
||||||
from src.chat.knowledge.knowledge_lib import initialize_lpmm_knowledge
|
from src.chat.knowledge.knowledge_lib import initialize_lpmm_knowledge
|
||||||
|
|
||||||
initialize_lpmm_knowledge()
|
initialize_lpmm_knowledge()
|
||||||
logger.info("LPMM知识库初始化成功")
|
logger.info("LPMM知识库初始化成功")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -456,6 +471,7 @@ MoFox_Bot(第三方修改版)
|
|||||||
# 启动消息重组器
|
# 启动消息重组器
|
||||||
try:
|
try:
|
||||||
from src.utils.message_chunker import reassembler
|
from src.utils.message_chunker import reassembler
|
||||||
|
|
||||||
await reassembler.start_cleanup_task()
|
await reassembler.start_cleanup_task()
|
||||||
logger.info("消息重组器已启动")
|
logger.info("消息重组器已启动")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -464,6 +480,7 @@ MoFox_Bot(第三方修改版)
|
|||||||
# 启动消息管理器
|
# 启动消息管理器
|
||||||
try:
|
try:
|
||||||
from src.chat.message_manager import message_manager
|
from src.chat.message_manager import message_manager
|
||||||
|
|
||||||
await message_manager.start()
|
await message_manager.start()
|
||||||
logger.info("消息管理器已启动")
|
logger.info("消息管理器已启动")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -504,6 +521,7 @@ MoFox_Bot(第三方修改版)
|
|||||||
|
|
||||||
def _safe_init(self, component_name: str, init_func) -> callable:
|
def _safe_init(self, component_name: str, init_func) -> callable:
|
||||||
"""安全初始化组件,捕获异常"""
|
"""安全初始化组件,捕获异常"""
|
||||||
|
|
||||||
async def wrapper():
|
async def wrapper():
|
||||||
try:
|
try:
|
||||||
result = init_func()
|
result = init_func()
|
||||||
@@ -514,6 +532,7 @@ MoFox_Bot(第三方修改版)
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"{component_name}初始化失败: {e}")
|
logger.error(f"{component_name}初始化失败: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
async def schedule_tasks(self) -> None:
|
async def schedule_tasks(self) -> None:
|
||||||
|
|||||||
@@ -59,6 +59,7 @@ class ChatMood:
|
|||||||
"""异步初始化方法"""
|
"""异步初始化方法"""
|
||||||
if not self._initialized:
|
if not self._initialized:
|
||||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||||
|
|
||||||
chat_manager = get_chat_manager()
|
chat_manager = get_chat_manager()
|
||||||
self.chat_stream = await chat_manager.get_stream(self.chat_id)
|
self.chat_stream = await chat_manager.get_stream(self.chat_id)
|
||||||
|
|
||||||
|
|||||||
@@ -69,6 +69,7 @@ class RelationshipBuilder:
|
|||||||
if not self._log_prefix_initialized:
|
if not self._log_prefix_initialized:
|
||||||
try:
|
try:
|
||||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||||
|
|
||||||
chat_name = await get_chat_manager().get_stream_name(self.chat_id)
|
chat_name = await get_chat_manager().get_stream_name(self.chat_id)
|
||||||
self.log_prefix = f"[{chat_name}]"
|
self.log_prefix = f"[{chat_name}]"
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|||||||
@@ -85,6 +85,7 @@ class RelationshipFetcher:
|
|||||||
"""异步初始化log_prefix"""
|
"""异步初始化log_prefix"""
|
||||||
if not self._log_prefix_initialized:
|
if not self._log_prefix_initialized:
|
||||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||||
|
|
||||||
name = await get_chat_manager().get_stream_name(self.chat_id)
|
name = await get_chat_manager().get_stream_name(self.chat_id)
|
||||||
self.log_prefix = f"[{name}] 实时信息"
|
self.log_prefix = f"[{name}] 实时信息"
|
||||||
self._log_prefix_initialized = True
|
self._log_prefix_initialized = True
|
||||||
|
|||||||
@@ -57,6 +57,9 @@ async def get_replyer(
|
|||||||
raise ValueError("chat_stream 和 chat_id 不可均为空")
|
raise ValueError("chat_stream 和 chat_id 不可均为空")
|
||||||
try:
|
try:
|
||||||
logger.debug(f"[GeneratorAPI] 正在获取回复器,chat_id: {chat_id}, chat_stream: {'有' if chat_stream else '无'}")
|
logger.debug(f"[GeneratorAPI] 正在获取回复器,chat_id: {chat_id}, chat_stream: {'有' if chat_stream else '无'}")
|
||||||
|
# 动态导入避免循环依赖
|
||||||
|
from src.chat.replyer.replyer_manager import replyer_manager
|
||||||
|
|
||||||
return await replyer_manager.get_replyer(
|
return await replyer_manager.get_replyer(
|
||||||
chat_stream=chat_stream,
|
chat_stream=chat_stream,
|
||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
|
|||||||
@@ -39,6 +39,7 @@ def get_llm_available_tool_definitions():
|
|||||||
# 添加MCP工具
|
# 添加MCP工具
|
||||||
try:
|
try:
|
||||||
from src.plugin_system.utils.mcp_tool_provider import mcp_tool_provider
|
from src.plugin_system.utils.mcp_tool_provider import mcp_tool_provider
|
||||||
|
|
||||||
mcp_tools = mcp_tool_provider.get_mcp_tool_definitions()
|
mcp_tools = mcp_tool_provider.get_mcp_tool_definitions()
|
||||||
tool_definitions.extend(mcp_tools)
|
tool_definitions.extend(mcp_tools)
|
||||||
if mcp_tools:
|
if mcp_tools:
|
||||||
|
|||||||
@@ -86,7 +86,9 @@ class HandlerResultsCollection:
|
|||||||
|
|
||||||
|
|
||||||
class BaseEvent:
|
class BaseEvent:
|
||||||
def __init__(self, name: str, allowed_subscribers: list[str] | None = None, allowed_triggers: list[str] | None = None):
|
def __init__(
|
||||||
|
self, name: str, allowed_subscribers: list[str] | None = None, allowed_triggers: list[str] | None = None
|
||||||
|
):
|
||||||
self.name = name
|
self.name = name
|
||||||
self.enabled = True
|
self.enabled = True
|
||||||
self.allowed_subscribers = allowed_subscribers # 记录事件处理器名
|
self.allowed_subscribers = allowed_subscribers # 记录事件处理器名
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ class InterestCalculationResult:
|
|||||||
should_reply: bool = False,
|
should_reply: bool = False,
|
||||||
should_act: bool = False,
|
should_act: bool = False,
|
||||||
error_message: str | None = None,
|
error_message: str | None = None,
|
||||||
calculation_time: float = 0.0
|
calculation_time: float = 0.0,
|
||||||
):
|
):
|
||||||
self.success = success
|
self.success = success
|
||||||
self.message_id = message_id
|
self.message_id = message_id
|
||||||
@@ -51,17 +51,19 @@ class InterestCalculationResult:
|
|||||||
"should_act": self.should_act,
|
"should_act": self.should_act,
|
||||||
"error_message": self.error_message,
|
"error_message": self.error_message,
|
||||||
"calculation_time": self.calculation_time,
|
"calculation_time": self.calculation_time,
|
||||||
"timestamp": self.timestamp
|
"timestamp": self.timestamp,
|
||||||
}
|
}
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return (f"InterestCalculationResult("
|
return (
|
||||||
|
f"InterestCalculationResult("
|
||||||
f"success={self.success}, "
|
f"success={self.success}, "
|
||||||
f"message_id={self.message_id}, "
|
f"message_id={self.message_id}, "
|
||||||
f"interest_value={self.interest_value:.3f}, "
|
f"interest_value={self.interest_value:.3f}, "
|
||||||
f"should_take_action={self.should_take_action}, "
|
f"should_take_action={self.should_take_action}, "
|
||||||
f"should_reply={self.should_reply}, "
|
f"should_reply={self.should_reply}, "
|
||||||
f"should_act={self.should_act})")
|
f"should_act={self.should_act})"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class BaseInterestCalculator(ABC):
|
class BaseInterestCalculator(ABC):
|
||||||
@@ -144,7 +146,7 @@ class BaseInterestCalculator(ABC):
|
|||||||
"failed_calculations": self._failed_calculations,
|
"failed_calculations": self._failed_calculations,
|
||||||
"success_rate": 1.0 - (self._failed_calculations / max(1, self._total_calculations)),
|
"success_rate": 1.0 - (self._failed_calculations / max(1, self._total_calculations)),
|
||||||
"average_calculation_time": self._average_calculation_time,
|
"average_calculation_time": self._average_calculation_time,
|
||||||
"last_calculation_time": self._last_calculation_time
|
"last_calculation_time": self._last_calculation_time,
|
||||||
}
|
}
|
||||||
|
|
||||||
def _update_statistics(self, result: InterestCalculationResult):
|
def _update_statistics(self, result: InterestCalculationResult):
|
||||||
@@ -159,8 +161,7 @@ class BaseInterestCalculator(ABC):
|
|||||||
else:
|
else:
|
||||||
alpha = 0.1 # 指数移动平均的平滑因子
|
alpha = 0.1 # 指数移动平均的平滑因子
|
||||||
self._average_calculation_time = (
|
self._average_calculation_time = (
|
||||||
alpha * result.calculation_time +
|
alpha * result.calculation_time + (1 - alpha) * self._average_calculation_time
|
||||||
(1 - alpha) * self._average_calculation_time
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self._last_calculation_time = result.timestamp
|
self._last_calculation_time = result.timestamp
|
||||||
@@ -172,7 +173,7 @@ class BaseInterestCalculator(ABC):
|
|||||||
success=False,
|
success=False,
|
||||||
message_id=getattr(message, "message_id", ""),
|
message_id=getattr(message, "message_id", ""),
|
||||||
interest_value=0.0,
|
interest_value=0.0,
|
||||||
error_message="组件未启用"
|
error_message="组件未启用",
|
||||||
)
|
)
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
@@ -187,7 +188,7 @@ class BaseInterestCalculator(ABC):
|
|||||||
message_id=getattr(message, "message_id", ""),
|
message_id=getattr(message, "message_id", ""),
|
||||||
interest_value=0.0,
|
interest_value=0.0,
|
||||||
error_message=f"计算执行失败: {e!s}",
|
error_message=f"计算执行失败: {e!s}",
|
||||||
calculation_time=time.time() - start_time
|
calculation_time=time.time() - start_time,
|
||||||
)
|
)
|
||||||
self._update_statistics(result)
|
self._update_statistics(result)
|
||||||
return result
|
return result
|
||||||
@@ -214,7 +215,9 @@ class BaseInterestCalculator(ABC):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return (f"{self.__class__.__name__}("
|
return (
|
||||||
|
f"{self.__class__.__name__}("
|
||||||
f"name={self.component_name}, "
|
f"name={self.component_name}, "
|
||||||
f"version={self.component_version}, "
|
f"version={self.component_version}, "
|
||||||
f"enabled={self._enabled})")
|
f"enabled={self._enabled})"
|
||||||
|
)
|
||||||
|
|||||||
@@ -60,7 +60,9 @@ class BasePlugin(PluginBase):
|
|||||||
if hasattr(component_class, "get_interest_calculator_info"):
|
if hasattr(component_class, "get_interest_calculator_info"):
|
||||||
return component_class.get_interest_calculator_info()
|
return component_class.get_interest_calculator_info()
|
||||||
else:
|
else:
|
||||||
logger.warning(f"InterestCalculator类 {component_class.__name__} 缺少 get_interest_calculator_info 方法")
|
logger.warning(
|
||||||
|
f"InterestCalculator类 {component_class.__name__} 缺少 get_interest_calculator_info 方法"
|
||||||
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
elif component_type == ComponentType.PLUS_COMMAND:
|
elif component_type == ComponentType.PLUS_COMMAND:
|
||||||
@@ -96,6 +98,7 @@ class BasePlugin(PluginBase):
|
|||||||
对应类型的ComponentInfo对象
|
对应类型的ComponentInfo对象
|
||||||
"""
|
"""
|
||||||
return cls._get_component_info_from_class(component_class, component_type)
|
return cls._get_component_info_from_class(component_class, component_type)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_plugin_components(
|
def get_plugin_components(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ class PluginMetadata:
|
|||||||
"""
|
"""
|
||||||
插件元数据,用于存储插件的开发者信息和用户帮助信息。
|
插件元数据,用于存储插件的开发者信息和用户帮助信息。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
name: str # 插件名称 (供用户查看)
|
name: str # 插件名称 (供用户查看)
|
||||||
description: str # 插件功能描述
|
description: str # 插件功能描述
|
||||||
usage: str # 插件使用方法
|
usage: str # 插件使用方法
|
||||||
|
|||||||
@@ -319,7 +319,9 @@ class ComponentRegistry:
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
def _register_interest_calculator_component(
|
def _register_interest_calculator_component(
|
||||||
self, interest_calculator_info: "InterestCalculatorInfo", interest_calculator_class: type["BaseInterestCalculator"]
|
self,
|
||||||
|
interest_calculator_info: "InterestCalculatorInfo",
|
||||||
|
interest_calculator_class: type["BaseInterestCalculator"],
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""注册InterestCalculator组件到特定注册表"""
|
"""注册InterestCalculator组件到特定注册表"""
|
||||||
calculator_name = interest_calculator_info.name
|
calculator_name = interest_calculator_info.name
|
||||||
@@ -327,7 +329,9 @@ class ComponentRegistry:
|
|||||||
if not calculator_name:
|
if not calculator_name:
|
||||||
logger.error(f"InterestCalculator组件 {interest_calculator_class.__name__} 必须指定名称")
|
logger.error(f"InterestCalculator组件 {interest_calculator_class.__name__} 必须指定名称")
|
||||||
return False
|
return False
|
||||||
if not isinstance(interest_calculator_info, InterestCalculatorInfo) or not issubclass(interest_calculator_class, BaseInterestCalculator):
|
if not isinstance(interest_calculator_info, InterestCalculatorInfo) or not issubclass(
|
||||||
|
interest_calculator_class, BaseInterestCalculator
|
||||||
|
):
|
||||||
logger.error(f"注册失败: {calculator_name} 不是有效的InterestCalculator")
|
logger.error(f"注册失败: {calculator_name} 不是有效的InterestCalculator")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|||||||
@@ -67,6 +67,7 @@ class ToolExecutor:
|
|||||||
"""异步初始化log_prefix和chat_stream"""
|
"""异步初始化log_prefix和chat_stream"""
|
||||||
if not self._log_prefix_initialized:
|
if not self._log_prefix_initialized:
|
||||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||||
|
|
||||||
self.chat_stream = await get_chat_manager().get_stream(self.chat_id)
|
self.chat_stream = await get_chat_manager().get_stream(self.chat_id)
|
||||||
stream_name = await get_chat_manager().get_stream_name(self.chat_id)
|
stream_name = await get_chat_manager().get_stream_name(self.chat_id)
|
||||||
self.log_prefix = f"[{stream_name or self.chat_id}]"
|
self.log_prefix = f"[{stream_name or self.chat_id}]"
|
||||||
@@ -283,6 +284,7 @@ class ToolExecutor:
|
|||||||
# 检查是否是MCP工具
|
# 检查是否是MCP工具
|
||||||
try:
|
try:
|
||||||
from src.plugin_system.utils.mcp_tool_provider import mcp_tool_provider
|
from src.plugin_system.utils.mcp_tool_provider import mcp_tool_provider
|
||||||
|
|
||||||
if function_name in mcp_tool_provider.mcp_tools:
|
if function_name in mcp_tool_provider.mcp_tools:
|
||||||
logger.info(f"{self.log_prefix}执行MCP工具: {function_name}")
|
logger.info(f"{self.log_prefix}执行MCP工具: {function_name}")
|
||||||
result = await mcp_tool_provider.call_mcp_tool(function_name, function_args)
|
result = await mcp_tool_provider.call_mcp_tool(function_name, function_args)
|
||||||
|
|||||||
@@ -8,8 +8,5 @@ __plugin_meta__ = PluginMetadata(
|
|||||||
author="MoFox",
|
author="MoFox",
|
||||||
keywords=["chatter", "affinity", "conversation"],
|
keywords=["chatter", "affinity", "conversation"],
|
||||||
categories=["Chat", "AI"],
|
categories=["Chat", "AI"],
|
||||||
extra={
|
extra={"is_built_in": True},
|
||||||
"is_built_in": True
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -149,7 +149,6 @@ class AffinityChatter(BaseChatter):
|
|||||||
"""
|
"""
|
||||||
return self.planner.get_mood_stats()
|
return self.planner.get_mood_stats()
|
||||||
|
|
||||||
|
|
||||||
def reset_stats(self):
|
def reset_stats(self):
|
||||||
"""重置统计信息"""
|
"""重置统计信息"""
|
||||||
self.stats = {
|
self.stats = {
|
||||||
|
|||||||
@@ -111,9 +111,11 @@ class AffinityInterestCalculator(BaseInterestCalculator):
|
|||||||
+ mentioned_score * self.score_weights["mentioned"]
|
+ mentioned_score * self.score_weights["mentioned"]
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(f"[Affinity兴趣计算] 综合得分计算: {interest_match_score:.3f}*{self.score_weights['interest_match']} + "
|
logger.debug(
|
||||||
|
f"[Affinity兴趣计算] 综合得分计算: {interest_match_score:.3f}*{self.score_weights['interest_match']} + "
|
||||||
f"{relationship_score:.3f}*{self.score_weights['relationship']} + "
|
f"{relationship_score:.3f}*{self.score_weights['relationship']} + "
|
||||||
f"{mentioned_score:.3f}*{self.score_weights['mentioned']} = {total_score:.3f}")
|
f"{mentioned_score:.3f}*{self.score_weights['mentioned']} = {total_score:.3f}"
|
||||||
|
)
|
||||||
|
|
||||||
# 5. 考虑连续不回复的概率提升
|
# 5. 考虑连续不回复的概率提升
|
||||||
adjusted_score = self._apply_no_reply_boost(total_score)
|
adjusted_score = self._apply_no_reply_boost(total_score)
|
||||||
@@ -135,8 +137,10 @@ class AffinityInterestCalculator(BaseInterestCalculator):
|
|||||||
|
|
||||||
calculation_time = time.time() - start_time
|
calculation_time = time.time() - start_time
|
||||||
|
|
||||||
logger.debug(f"Affinity兴趣值计算完成 - 消息 {message_id}: {adjusted_score:.3f} "
|
logger.debug(
|
||||||
f"(匹配:{interest_match_score:.2f}, 关系:{relationship_score:.2f}, 提及:{mentioned_score:.2f})")
|
f"Affinity兴趣值计算完成 - 消息 {message_id}: {adjusted_score:.3f} "
|
||||||
|
f"(匹配:{interest_match_score:.2f}, 关系:{relationship_score:.2f}, 提及:{mentioned_score:.2f})"
|
||||||
|
)
|
||||||
|
|
||||||
return InterestCalculationResult(
|
return InterestCalculationResult(
|
||||||
success=True,
|
success=True,
|
||||||
@@ -145,16 +149,13 @@ class AffinityInterestCalculator(BaseInterestCalculator):
|
|||||||
should_take_action=should_take_action,
|
should_take_action=should_take_action,
|
||||||
should_reply=should_reply,
|
should_reply=should_reply,
|
||||||
should_act=should_take_action,
|
should_act=should_take_action,
|
||||||
calculation_time=calculation_time
|
calculation_time=calculation_time,
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Affinity兴趣值计算失败: {e}", exc_info=True)
|
logger.error(f"Affinity兴趣值计算失败: {e}", exc_info=True)
|
||||||
return InterestCalculationResult(
|
return InterestCalculationResult(
|
||||||
success=False,
|
success=False, message_id=getattr(message, "message_id", ""), interest_value=0.0, error_message=str(e)
|
||||||
message_id=getattr(message, "message_id", ""),
|
|
||||||
interest_value=0.0,
|
|
||||||
error_message=str(e)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _calculate_interest_match_score(self, content: str, keywords: list[str] | None = None) -> float:
|
async def _calculate_interest_match_score(self, content: str, keywords: list[str] | None = None) -> float:
|
||||||
|
|||||||
@@ -405,7 +405,6 @@ class ChatterPlanExecutor:
|
|||||||
# 移除执行时间列表以避免返回过大数据
|
# 移除执行时间列表以避免返回过大数据
|
||||||
stats.pop("execution_times", None)
|
stats.pop("execution_times", None)
|
||||||
|
|
||||||
|
|
||||||
return stats
|
return stats
|
||||||
|
|
||||||
def reset_stats(self):
|
def reset_stats(self):
|
||||||
@@ -434,12 +433,12 @@ class ChatterPlanExecutor:
|
|||||||
for i, time_val in enumerate(recent_times)
|
for i, time_val in enumerate(recent_times)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
async def _flush_action_manager_batch_storage(self, plan: Plan):
|
async def _flush_action_manager_batch_storage(self, plan: Plan):
|
||||||
"""使用 action_manager 的批量存储功能存储所有待处理的动作"""
|
"""使用 action_manager 的批量存储功能存储所有待处理的动作"""
|
||||||
try:
|
try:
|
||||||
# 通过 chat_id 获取真实的 chat_stream 对象
|
# 通过 chat_id 获取真实的 chat_stream 对象
|
||||||
from src.plugin_system.apis.chat_api import get_chat_manager
|
from src.plugin_system.apis.chat_api import get_chat_manager
|
||||||
|
|
||||||
chat_manager = get_chat_manager()
|
chat_manager = get_chat_manager()
|
||||||
chat_stream = await chat_manager.get_stream(plan.chat_id)
|
chat_stream = await chat_manager.get_stream(plan.chat_id)
|
||||||
|
|
||||||
@@ -455,4 +454,3 @@ class ChatterPlanExecutor:
|
|||||||
logger.error(f"批量存储动作记录时发生错误: {e}")
|
logger.error(f"批量存储动作记录时发生错误: {e}")
|
||||||
# 确保在出错时也禁用批量存储模式
|
# 确保在出错时也禁用批量存储模式
|
||||||
self.action_manager.disable_batch_storage()
|
self.action_manager.disable_batch_storage()
|
||||||
|
|
||||||
|
|||||||
@@ -64,7 +64,6 @@ class ChatterPlanFilter:
|
|||||||
|
|
||||||
llm_content, _ = await self.planner_llm.generate_response_async(prompt=prompt)
|
llm_content, _ = await self.planner_llm.generate_response_async(prompt=prompt)
|
||||||
|
|
||||||
|
|
||||||
if llm_content:
|
if llm_content:
|
||||||
if global_config.debug.show_prompt:
|
if global_config.debug.show_prompt:
|
||||||
logger.info(f"LLM规划器原始响应:{llm_content}")
|
logger.info(f"LLM规划器原始响应:{llm_content}")
|
||||||
|
|||||||
@@ -132,7 +132,6 @@ class ChatterActionPlanner:
|
|||||||
if message_should_act:
|
if message_should_act:
|
||||||
aggregate_should_act = True
|
aggregate_should_act = True
|
||||||
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"处理消息 {message.message_id} 失败: {e}")
|
logger.warning(f"处理消息 {message.message_id} 失败: {e}")
|
||||||
message.interest_value = 0.0
|
message.interest_value = 0.0
|
||||||
|
|||||||
@@ -39,6 +39,7 @@ class AffinityChatterPlugin(BasePlugin):
|
|||||||
try:
|
try:
|
||||||
# 延迟导入 AffinityChatter
|
# 延迟导入 AffinityChatter
|
||||||
from .affinity_chatter import AffinityChatter
|
from .affinity_chatter import AffinityChatter
|
||||||
|
|
||||||
components.append((AffinityChatter.get_chatter_info(), AffinityChatter))
|
components.append((AffinityChatter.get_chatter_info(), AffinityChatter))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"加载 AffinityChatter 时出错: {e}")
|
logger.error(f"加载 AffinityChatter 时出错: {e}")
|
||||||
@@ -46,9 +47,9 @@ class AffinityChatterPlugin(BasePlugin):
|
|||||||
try:
|
try:
|
||||||
# 延迟导入 AffinityInterestCalculator
|
# 延迟导入 AffinityInterestCalculator
|
||||||
from .affinity_interest_calculator import AffinityInterestCalculator
|
from .affinity_interest_calculator import AffinityInterestCalculator
|
||||||
|
|
||||||
components.append((AffinityInterestCalculator.get_interest_calculator_info(), AffinityInterestCalculator))
|
components.append((AffinityInterestCalculator.get_interest_calculator_info(), AffinityInterestCalculator))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"加载 AffinityInterestCalculator 时出错: {e}")
|
logger.error(f"加载 AffinityInterestCalculator 时出错: {e}")
|
||||||
|
|
||||||
return components
|
return components
|
||||||
|
|
||||||
|
|||||||
@@ -13,5 +13,5 @@ __plugin_meta__ = PluginMetadata(
|
|||||||
extra={
|
extra={
|
||||||
"is_built_in": True,
|
"is_built_in": True,
|
||||||
"plugin_type": "action_provider",
|
"plugin_type": "action_provider",
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -13,5 +13,5 @@ __plugin_meta__ = PluginMetadata(
|
|||||||
extra={
|
extra={
|
||||||
"is_built_in": False,
|
"is_built_in": False,
|
||||||
"plugin_type": "social",
|
"plugin_type": "social",
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -12,5 +12,5 @@ __plugin_meta__ = PluginMetadata(
|
|||||||
categories=["protocol"],
|
categories=["protocol"],
|
||||||
extra={
|
extra={
|
||||||
"is_built_in": False,
|
"is_built_in": False,
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
@@ -13,8 +13,8 @@ def create_router(plugin_config: dict):
|
|||||||
"""创建路由器实例"""
|
"""创建路由器实例"""
|
||||||
global router
|
global router
|
||||||
platform_name = config_api.get_plugin_config(plugin_config, "maibot_server.platform_name", "qq")
|
platform_name = config_api.get_plugin_config(plugin_config, "maibot_server.platform_name", "qq")
|
||||||
host = os.getenv("HOST","127.0.0.1")
|
host = os.getenv("HOST", "127.0.0.1")
|
||||||
port = os.getenv("PORT","8000")
|
port = os.getenv("PORT", "8000")
|
||||||
logger.debug(f"初始化MaiBot连接,使用地址:{host}:{port}")
|
logger.debug(f"初始化MaiBot连接,使用地址:{host}:{port}")
|
||||||
route_config = RouteConfig(
|
route_config = RouteConfig(
|
||||||
route_config={
|
route_config={
|
||||||
|
|||||||
@@ -12,5 +12,5 @@ __plugin_meta__ = PluginMetadata(
|
|||||||
extra={
|
extra={
|
||||||
"is_built_in": True,
|
"is_built_in": True,
|
||||||
"plugin_type": "permission",
|
"plugin_type": "permission",
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -13,5 +13,5 @@ __plugin_meta__ = PluginMetadata(
|
|||||||
extra={
|
extra={
|
||||||
"is_built_in": True,
|
"is_built_in": True,
|
||||||
"plugin_type": "plugin_management",
|
"plugin_type": "plugin_management",
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -8,10 +8,7 @@ __plugin_meta__ = PluginMetadata(
|
|||||||
author="MoFox-Studio",
|
author="MoFox-Studio",
|
||||||
license="GPL-v3.0-or-later",
|
license="GPL-v3.0-or-later",
|
||||||
repository_url="https://github.com/MoFox-Studio",
|
repository_url="https://github.com/MoFox-Studio",
|
||||||
keywords=["主动思考","自己发消息"],
|
keywords=["主动思考", "自己发消息"],
|
||||||
categories=["Chat", "Integration"],
|
categories=["Chat", "Integration"],
|
||||||
extra={
|
extra={"is_built_in": True, "plugin_type": "functional"},
|
||||||
"is_built_in": True,
|
|
||||||
"plugin_type": "functional"
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -63,7 +63,9 @@ class ColdStartTask(AsyncTask):
|
|||||||
logger.info(f"【冷启动】发现全新用户 {chat_id},准备发起第一次问候。")
|
logger.info(f"【冷启动】发现全新用户 {chat_id},准备发起第一次问候。")
|
||||||
elif stream.last_active_time < self.bot_start_time:
|
elif stream.last_active_time < self.bot_start_time:
|
||||||
should_wake_up = True
|
should_wake_up = True
|
||||||
logger.info(f"【冷启动】发现沉睡的聊天流 {chat_id} (最后活跃于 {datetime.fromtimestamp(stream.last_active_time)}),准备唤醒。")
|
logger.info(
|
||||||
|
f"【冷启动】发现沉睡的聊天流 {chat_id} (最后活跃于 {datetime.fromtimestamp(stream.last_active_time)}),准备唤醒。"
|
||||||
|
)
|
||||||
|
|
||||||
if should_wake_up:
|
if should_wake_up:
|
||||||
person_id = person_api.get_person_id(platform, user_id)
|
person_id = person_api.get_person_id(platform, user_id)
|
||||||
@@ -166,7 +168,9 @@ class ProactiveThinkingTask(AsyncTask):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
# 检查冷却时间
|
# 检查冷却时间
|
||||||
recent_messages = await message_api.get_recent_messages(chat_id=stream.stream_id, limit=1,limit_mode="latest")
|
recent_messages = await message_api.get_recent_messages(
|
||||||
|
chat_id=stream.stream_id, limit=1, limit_mode="latest"
|
||||||
|
)
|
||||||
last_message_time = recent_messages[0]["time"] if recent_messages else stream.create_time
|
last_message_time = recent_messages[0]["time"] if recent_messages else stream.create_time
|
||||||
time_since_last_active = time.time() - last_message_time
|
time_since_last_active = time.time() - last_message_time
|
||||||
if time_since_last_active > next_interval:
|
if time_since_last_active > next_interval:
|
||||||
|
|||||||
@@ -143,14 +143,16 @@ class ProactiveThinkerExecutor:
|
|||||||
else "今天没有日程安排。"
|
else "今天没有日程安排。"
|
||||||
)
|
)
|
||||||
|
|
||||||
recent_messages = await message_api.get_recent_messages(stream.stream_id,limit=50,limit_mode="latest",hours=12)
|
recent_messages = await message_api.get_recent_messages(
|
||||||
|
stream.stream_id, limit=50, limit_mode="latest", hours=12
|
||||||
|
)
|
||||||
recent_chat_history = (
|
recent_chat_history = (
|
||||||
await message_api.build_readable_messages_to_str(recent_messages) if recent_messages else "无"
|
await message_api.build_readable_messages_to_str(recent_messages) if recent_messages else "无"
|
||||||
)
|
)
|
||||||
|
|
||||||
action_history_list = await get_actions_by_timestamp_with_chat(
|
action_history_list = await get_actions_by_timestamp_with_chat(
|
||||||
chat_id=stream.stream_id,
|
chat_id=stream.stream_id,
|
||||||
timestamp_start=time.time() - 3600 * 24, #过去24小时
|
timestamp_start=time.time() - 3600 * 24, # 过去24小时
|
||||||
timestamp_end=time.time(),
|
timestamp_end=time.time(),
|
||||||
limit=7,
|
limit=7,
|
||||||
)
|
)
|
||||||
@@ -195,9 +197,7 @@ class ProactiveThinkerExecutor:
|
|||||||
person_id = person_api.get_person_id(user_info.platform, int(user_info.user_id))
|
person_id = person_api.get_person_id(user_info.platform, int(user_info.user_id))
|
||||||
person_info_manager = get_person_info_manager()
|
person_info_manager = get_person_info_manager()
|
||||||
person_info = await person_info_manager.get_values(person_id, ["user_id", "platform", "person_name"])
|
person_info = await person_info_manager.get_values(person_id, ["user_id", "platform", "person_name"])
|
||||||
cross_context_block = await Prompt.build_cross_context(
|
cross_context_block = await Prompt.build_cross_context(stream.stream_id, "s4u", person_info)
|
||||||
stream.stream_id, "s4u", person_info
|
|
||||||
)
|
|
||||||
|
|
||||||
# 获取关系信息
|
# 获取关系信息
|
||||||
short_impression = await person_info_manager.get_value(person_id, "short_impression") or "无"
|
short_impression = await person_info_manager.get_value(person_id, "short_impression") or "无"
|
||||||
|
|||||||
@@ -10,8 +10,5 @@ __plugin_meta__ = PluginMetadata(
|
|||||||
repository_url="https://github.com/MoFox-Studio",
|
repository_url="https://github.com/MoFox-Studio",
|
||||||
keywords=["emoji", "reaction", "like", "表情", "回应", "点赞"],
|
keywords=["emoji", "reaction", "like", "表情", "回应", "点赞"],
|
||||||
categories=["Chat", "Integration"],
|
categories=["Chat", "Integration"],
|
||||||
extra={
|
extra={"is_built_in": "true", "plugin_type": "functional"},
|
||||||
"is_built_in": "true",
|
|
||||||
"plugin_type": "functional"
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -548,7 +548,7 @@ class SetEmojiLikePlugin(BasePlugin):
|
|||||||
config_section_descriptions = {"plugin": "插件基本信息", "components": "插件组件"}
|
config_section_descriptions = {"plugin": "插件基本信息", "components": "插件组件"}
|
||||||
|
|
||||||
# 配置Schema定义
|
# 配置Schema定义
|
||||||
config_schema: ClassVar[dict ]= {
|
config_schema: ClassVar[dict] = {
|
||||||
"plugin": {
|
"plugin": {
|
||||||
"name": ConfigField(type=str, default="set_emoji_like", description="插件名称"),
|
"name": ConfigField(type=str, default="set_emoji_like", description="插件名称"),
|
||||||
"version": ConfigField(type=str, default="1.0.0", description="插件版本"),
|
"version": ConfigField(type=str, default="1.0.0", description="插件版本"),
|
||||||
|
|||||||
@@ -13,5 +13,5 @@ __plugin_meta__ = PluginMetadata(
|
|||||||
extra={
|
extra={
|
||||||
"is_built_in": True,
|
"is_built_in": True,
|
||||||
"plugin_type": "audio_processor",
|
"plugin_type": "audio_processor",
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -12,5 +12,5 @@ __plugin_meta__ = PluginMetadata(
|
|||||||
categories=["Tools"],
|
categories=["Tools"],
|
||||||
extra={
|
extra={
|
||||||
"is_built_in": True,
|
"is_built_in": True,
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -43,7 +43,9 @@ class SearXNGSearchEngine(BaseSearchEngine):
|
|||||||
|
|
||||||
api_keys = config_api.get_global_config("web_search.searxng_api_keys", None)
|
api_keys = config_api.get_global_config("web_search.searxng_api_keys", None)
|
||||||
if isinstance(api_keys, list):
|
if isinstance(api_keys, list):
|
||||||
self.api_keys: list[str | None] = [k.strip() if isinstance(k, str) and k.strip() else None for k in api_keys]
|
self.api_keys: list[str | None] = [
|
||||||
|
k.strip() if isinstance(k, str) and k.strip() else None for k in api_keys
|
||||||
|
]
|
||||||
else:
|
else:
|
||||||
self.api_keys = []
|
self.api_keys = []
|
||||||
|
|
||||||
@@ -51,9 +53,7 @@ class SearXNGSearchEngine(BaseSearchEngine):
|
|||||||
if self.api_keys and len(self.api_keys) < len(self.instances):
|
if self.api_keys and len(self.api_keys) < len(self.instances):
|
||||||
self.api_keys.extend([None] * (len(self.instances) - len(self.api_keys)))
|
self.api_keys.extend([None] * (len(self.instances) - len(self.api_keys)))
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(f"SearXNG 引擎配置: instances={self.instances}, api_keys={'yes' if any(self.api_keys) else 'no'}")
|
||||||
f"SearXNG 引擎配置: instances={self.instances}, api_keys={'yes' if any(self.api_keys) else 'no'}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def is_available(self) -> bool:
|
def is_available(self) -> bool:
|
||||||
return bool(self.instances)
|
return bool(self.instances)
|
||||||
|
|||||||
Reference in New Issue
Block a user