ruff
This commit is contained in:
39
bot.py
39
bot.py
@@ -33,6 +33,7 @@ script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
os.chdir(script_dir)
|
||||
logger.info("工作目录已设置")
|
||||
|
||||
|
||||
class ConfigManager:
|
||||
"""配置管理器"""
|
||||
|
||||
@@ -96,6 +97,7 @@ class ConfigManager:
|
||||
logger.error(f"加载环境变量失败: {e}")
|
||||
return False
|
||||
|
||||
|
||||
class EULAManager:
|
||||
"""EULA管理类"""
|
||||
|
||||
@@ -134,7 +136,9 @@ class EULAManager:
|
||||
return
|
||||
|
||||
if attempts % 5 == 0:
|
||||
confirm_logger.critical(f"请修改 .env 文件中的 EULA_CONFIRMED=true (尝试 {attempts}/{MAX_EULA_CHECK_ATTEMPTS})")
|
||||
confirm_logger.critical(
|
||||
f"请修改 .env 文件中的 EULA_CONFIRMED=true (尝试 {attempts}/{MAX_EULA_CHECK_ATTEMPTS})"
|
||||
)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
confirm_logger.info("用户取消,程序退出")
|
||||
@@ -148,16 +152,14 @@ class EULAManager:
|
||||
confirm_logger.error("EULA确认超时,程序退出")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
class TaskManager:
|
||||
"""任务管理器"""
|
||||
|
||||
@staticmethod
|
||||
async def cancel_pending_tasks(loop, timeout=SHUTDOWN_TIMEOUT):
|
||||
"""取消所有待处理的任务"""
|
||||
remaining_tasks = [
|
||||
t for t in asyncio.all_tasks(loop)
|
||||
if t is not asyncio.current_task(loop) and not t.done()
|
||||
]
|
||||
remaining_tasks = [t for t in asyncio.all_tasks(loop) if t is not asyncio.current_task(loop) and not t.done()]
|
||||
|
||||
if not remaining_tasks:
|
||||
logger.info("没有待取消的任务")
|
||||
@@ -171,10 +173,7 @@ class TaskManager:
|
||||
|
||||
# 等待任务完成
|
||||
try:
|
||||
results = await asyncio.wait_for(
|
||||
asyncio.gather(*remaining_tasks, return_exceptions=True),
|
||||
timeout=timeout
|
||||
)
|
||||
results = await asyncio.wait_for(asyncio.gather(*remaining_tasks, return_exceptions=True), timeout=timeout)
|
||||
|
||||
# 检查任务结果
|
||||
for i, result in enumerate(results):
|
||||
@@ -195,6 +194,7 @@ class TaskManager:
|
||||
"""停止所有异步任务"""
|
||||
try:
|
||||
from src.manager.async_task_manager import async_task_manager
|
||||
|
||||
await async_task_manager.stop_and_wait_all_tasks()
|
||||
return True
|
||||
except ImportError:
|
||||
@@ -204,6 +204,7 @@ class TaskManager:
|
||||
logger.error(f"停止异步任务失败: {e}")
|
||||
return False
|
||||
|
||||
|
||||
class ShutdownManager:
|
||||
"""关闭管理器"""
|
||||
|
||||
@@ -236,6 +237,7 @@ class ShutdownManager:
|
||||
logger.error(f"麦麦关闭失败: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def create_event_loop_context():
|
||||
"""创建事件循环的上下文管理器"""
|
||||
@@ -260,6 +262,7 @@ async def create_event_loop_context():
|
||||
except Exception as e:
|
||||
logger.error(f"关闭事件循环失败: {e}")
|
||||
|
||||
|
||||
class DatabaseManager:
|
||||
"""数据库连接管理器"""
|
||||
|
||||
@@ -278,7 +281,9 @@ class DatabaseManager:
|
||||
# 使用线程执行器运行潜在的阻塞操作
|
||||
await asyncio.to_thread(initialize_sql_database, global_config.database)
|
||||
elapsed_time = time.time() - start_time
|
||||
logger.info(f"数据库连接初始化成功,使用 {global_config.database.database_type} 数据库,耗时: {elapsed_time:.2f}秒")
|
||||
logger.info(
|
||||
f"数据库连接初始化成功,使用 {global_config.database.database_type} 数据库,耗时: {elapsed_time:.2f}秒"
|
||||
)
|
||||
|
||||
return self
|
||||
except Exception as e:
|
||||
@@ -291,6 +296,7 @@ class DatabaseManager:
|
||||
logger.error(f"数据库操作发生异常: {exc_val}")
|
||||
return False
|
||||
|
||||
|
||||
class ConfigurationValidator:
|
||||
"""配置验证器"""
|
||||
|
||||
@@ -328,6 +334,7 @@ class ConfigurationValidator:
|
||||
logger.error(f"配置验证失败: {e}")
|
||||
return False
|
||||
|
||||
|
||||
class EasterEgg:
|
||||
"""彩蛋功能"""
|
||||
|
||||
@@ -347,6 +354,7 @@ class EasterEgg:
|
||||
rainbow_text += rainbow_colors[i % len(rainbow_colors)] + char
|
||||
logger.info(rainbow_text)
|
||||
|
||||
|
||||
class MaiBotMain:
|
||||
"""麦麦机器人主程序类"""
|
||||
|
||||
@@ -375,6 +383,7 @@ class MaiBotMain:
|
||||
try:
|
||||
start_time = time.time()
|
||||
from src.common.database.sqlalchemy_models import initialize_database as init_db
|
||||
|
||||
await init_db()
|
||||
elapsed_time = time.time() - start_time
|
||||
logger.info(f"数据库表结构初始化完成,耗时: {elapsed_time:.2f}秒")
|
||||
@@ -385,6 +394,7 @@ class MaiBotMain:
|
||||
def create_main_system(self):
|
||||
"""创建MainSystem实例"""
|
||||
from src.main import MainSystem
|
||||
|
||||
self.main_system = MainSystem()
|
||||
return self.main_system
|
||||
|
||||
@@ -411,11 +421,13 @@ class MaiBotMain:
|
||||
|
||||
# 初始化知识库
|
||||
from src.chat.knowledge.knowledge_lib import initialize_lpmm_knowledge
|
||||
|
||||
initialize_lpmm_knowledge()
|
||||
|
||||
# 显示彩蛋
|
||||
EasterEgg.show()
|
||||
|
||||
|
||||
async def wait_for_user_input():
|
||||
"""等待用户输入(异步方式)"""
|
||||
try:
|
||||
@@ -432,6 +444,7 @@ async def wait_for_user_input():
|
||||
logger.error(f"等待用户输入时发生错误: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def main_async():
|
||||
"""主异步函数"""
|
||||
exit_code = 0
|
||||
@@ -455,10 +468,7 @@ async def main_async():
|
||||
user_input_done = asyncio.create_task(wait_for_user_input())
|
||||
|
||||
# 使用wait等待任意一个任务完成
|
||||
done, pending = await asyncio.wait(
|
||||
[main_task, user_input_done],
|
||||
return_when=asyncio.FIRST_COMPLETED
|
||||
)
|
||||
done, pending = await asyncio.wait([main_task, user_input_done], return_when=asyncio.FIRST_COMPLETED)
|
||||
|
||||
# 如果用户输入任务完成(用户按了Ctrl+C),取消主任务
|
||||
if user_input_done in done and main_task not in done:
|
||||
@@ -482,6 +492,7 @@ async def main_async():
|
||||
|
||||
return exit_code
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
exit_code = 0
|
||||
try:
|
||||
|
||||
@@ -12,10 +12,13 @@ logger = get_logger("HTTP消息API")
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/messages/recent")
|
||||
async def get_message_stats(
|
||||
days: int = Query(1, ge=1, description="指定查询过去多少天的数据"),
|
||||
message_type: Literal["all", "sent", "received"] = Query("all", description="筛选消息类型: 'sent' (BOT发送的), 'received' (BOT接收的), or 'all' (全部)")
|
||||
message_type: Literal["all", "sent", "received"] = Query(
|
||||
"all", description="筛选消息类型: 'sent' (BOT发送的), 'received' (BOT接收的), or 'all' (全部)"
|
||||
),
|
||||
):
|
||||
"""
|
||||
获取BOT在指定天数内的消息统计数据。
|
||||
@@ -45,7 +48,7 @@ async def get_message_stats(
|
||||
"message_type": message_type,
|
||||
"sent_count": sent_count,
|
||||
"received_count": received_count,
|
||||
"total_count": len(messages)
|
||||
"total_count": len(messages),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
@@ -76,10 +79,7 @@ async def get_message_stats_by_chat(
|
||||
user_id = msg.get("user_id")
|
||||
|
||||
if chat_id not in stats:
|
||||
stats[chat_id] = {
|
||||
"total_stats": {"total": 0},
|
||||
"user_stats": {}
|
||||
}
|
||||
stats[chat_id] = {"total_stats": {"total": 0}, "user_stats": {}}
|
||||
|
||||
stats[chat_id]["total_stats"]["total"] += 1
|
||||
|
||||
@@ -116,10 +116,7 @@ async def get_message_stats_by_chat(
|
||||
for user_id, count in data["user_stats"].items():
|
||||
person_id = person_api.get_person_id("qq", user_id)
|
||||
nickname = await person_api.get_person_value(person_id, "nickname", "未知用户")
|
||||
formatted_data["user_stats"][user_id] = {
|
||||
"nickname": nickname,
|
||||
"count": count
|
||||
}
|
||||
formatted_data["user_stats"][user_id] = {"nickname": nickname, "count": count}
|
||||
|
||||
formatted_stats[chat_id] = formatted_data
|
||||
return formatted_stats
|
||||
@@ -129,6 +126,7 @@ async def get_message_stats_by_chat(
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/messages/bot_stats_by_chat")
|
||||
async def get_bot_message_stats_by_chat(
|
||||
days: int = Query(1, ge=1, description="指定查询过去多少天的数据"),
|
||||
@@ -165,10 +163,7 @@ async def get_bot_message_stats_by_chat(
|
||||
elif stream.user_info and stream.user_info.user_nickname:
|
||||
chat_name = stream.user_info.user_nickname
|
||||
|
||||
formatted_stats[chat_id] = {
|
||||
"chat_name": chat_name,
|
||||
"count": count
|
||||
}
|
||||
formatted_stats[chat_id] = {"chat_name": chat_name, "count": count}
|
||||
return formatted_stats
|
||||
|
||||
return stats
|
||||
|
||||
@@ -313,7 +313,9 @@ class EnergyManager:
|
||||
|
||||
# 确保 score 是 float 类型
|
||||
if not isinstance(score, int | float):
|
||||
logger.warning(f"计算器 {calculator.__class__.__name__} 返回了非数值类型: {type(score)},跳过此组件")
|
||||
logger.warning(
|
||||
f"计算器 {calculator.__class__.__name__} 返回了非数值类型: {type(score)},跳过此组件"
|
||||
)
|
||||
continue
|
||||
|
||||
component_scores[calculator.__class__.__name__] = float(score)
|
||||
|
||||
@@ -527,7 +527,7 @@ class ExpressionLearnerManager:
|
||||
os.makedirs(directory, exist_ok=True)
|
||||
logger.debug(f"确保目录存在: {directory}")
|
||||
except Exception as e:
|
||||
logger.error(f"创建目录失败 {directory}: {e}")
|
||||
logger.error(f"创建目录失败 {directory}: {e}")
|
||||
|
||||
@staticmethod
|
||||
async def _auto_migrate_json_to_db():
|
||||
|
||||
@@ -429,7 +429,9 @@ class BotInterestManager:
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 计算相似度分数失败: {e}")
|
||||
|
||||
async def calculate_interest_match(self, message_text: str, keywords: list[str] | None = None) -> InterestMatchResult:
|
||||
async def calculate_interest_match(
|
||||
self, message_text: str, keywords: list[str] | None = None
|
||||
) -> InterestMatchResult:
|
||||
"""计算消息与机器人兴趣的匹配度"""
|
||||
if not self.current_interests or not self._initialized:
|
||||
raise RuntimeError("❌ 兴趣标签系统未初始化")
|
||||
|
||||
@@ -79,7 +79,9 @@ class InterestManager:
|
||||
|
||||
# 如果已有组件在运行,先清理并替换
|
||||
if self._current_calculator:
|
||||
logger.info(f"替换现有兴趣值计算组件: {self._current_calculator.component_name} -> {calculator.component_name}")
|
||||
logger.info(
|
||||
f"替换现有兴趣值计算组件: {self._current_calculator.component_name} -> {calculator.component_name}"
|
||||
)
|
||||
await self._current_calculator.cleanup()
|
||||
else:
|
||||
logger.info(f"注册新的兴趣值计算组件: {calculator.component_name}")
|
||||
@@ -114,7 +116,7 @@ class InterestManager:
|
||||
success=False,
|
||||
message_id=getattr(message, "message_id", ""),
|
||||
interest_value=0.3,
|
||||
error_message="没有可用的兴趣值计算组件"
|
||||
error_message="没有可用的兴趣值计算组件",
|
||||
)
|
||||
|
||||
# 使用 create_task 异步执行计算
|
||||
@@ -133,7 +135,7 @@ class InterestManager:
|
||||
interest_value=0.5, # 固定默认兴趣值
|
||||
should_reply=False,
|
||||
should_act=False,
|
||||
error_message=f"计算超时({timeout}s),使用默认值"
|
||||
error_message=f"计算超时({timeout}s),使用默认值",
|
||||
)
|
||||
except Exception as e:
|
||||
# 发生异常,返回默认结果
|
||||
@@ -142,7 +144,7 @@ class InterestManager:
|
||||
success=False,
|
||||
message_id=getattr(message, "message_id", ""),
|
||||
interest_value=0.3,
|
||||
error_message=f"计算异常: {e!s}"
|
||||
error_message=f"计算异常: {e!s}",
|
||||
)
|
||||
|
||||
async def _async_calculate(self, message: "DatabaseMessages") -> InterestCalculationResult:
|
||||
@@ -171,7 +173,7 @@ class InterestManager:
|
||||
message_id=getattr(message, "message_id", ""),
|
||||
interest_value=0.0,
|
||||
error_message=f"计算异常: {e!s}",
|
||||
calculation_time=time.time() - start_time
|
||||
calculation_time=time.time() - start_time,
|
||||
)
|
||||
|
||||
async def _calculation_worker(self):
|
||||
@@ -179,10 +181,7 @@ class InterestManager:
|
||||
while not self._shutdown_event.is_set():
|
||||
try:
|
||||
# 等待计算任务或关闭信号
|
||||
await asyncio.wait_for(
|
||||
self._calculation_queue.get(),
|
||||
timeout=1.0
|
||||
)
|
||||
await asyncio.wait_for(self._calculation_queue.get(), timeout=1.0)
|
||||
|
||||
# 处理计算任务
|
||||
# 这里可以实现批量处理逻辑
|
||||
@@ -210,7 +209,7 @@ class InterestManager:
|
||||
"failed_calculations": self._failed_calculations,
|
||||
"success_rate": success_rate,
|
||||
"last_calculation_time": self._last_calculation_time,
|
||||
"current_calculator": self._current_calculator.component_name if self._current_calculator else None
|
||||
"current_calculator": self._current_calculator.component_name if self._current_calculator else None,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -125,19 +125,19 @@ class OpenIE:
|
||||
def extract_entity_dict(self):
|
||||
"""提取实体列表"""
|
||||
ner_output_dict = {
|
||||
doc_item["idx"]: doc_item["extracted_entities"]
|
||||
for doc_item in self.docs
|
||||
if len(doc_item["extracted_entities"]) > 0
|
||||
}
|
||||
doc_item["idx"]: doc_item["extracted_entities"]
|
||||
for doc_item in self.docs
|
||||
if len(doc_item["extracted_entities"]) > 0
|
||||
}
|
||||
return ner_output_dict
|
||||
|
||||
def extract_triple_dict(self):
|
||||
"""提取三元组列表"""
|
||||
triple_output_dict = {
|
||||
doc_item["idx"]: doc_item["extracted_triples"]
|
||||
for doc_item in self.docs
|
||||
if len(doc_item["extracted_triples"]) > 0
|
||||
}
|
||||
doc_item["idx"]: doc_item["extracted_triples"]
|
||||
for doc_item in self.docs
|
||||
if len(doc_item["extracted_triples"]) > 0
|
||||
}
|
||||
return triple_output_dict
|
||||
|
||||
def extract_raw_paragraph_dict(self):
|
||||
|
||||
@@ -19,10 +19,10 @@ def dyn_select_top_k(
|
||||
for score_item in sorted_score:
|
||||
normalized_score.append(
|
||||
(
|
||||
score_item[0],
|
||||
score_item[1],
|
||||
(score_item[1] - min_score) / (max_score - min_score),
|
||||
)
|
||||
score_item[0],
|
||||
score_item[1],
|
||||
(score_item[1] - min_score) / (max_score - min_score),
|
||||
)
|
||||
)
|
||||
|
||||
# 寻找跳变点:score变化最大的位置
|
||||
|
||||
@@ -29,20 +29,21 @@ logger = get_logger(__name__)
|
||||
@dataclass
|
||||
class HippocampusSampleConfig:
|
||||
"""海马体采样配置"""
|
||||
|
||||
# 双峰分布参数
|
||||
recent_mean_hours: float = 12.0 # 近期分布均值(小时)
|
||||
recent_std_hours: float = 8.0 # 近期分布标准差(小时)
|
||||
recent_weight: float = 0.7 # 近期分布权重
|
||||
recent_std_hours: float = 8.0 # 近期分布标准差(小时)
|
||||
recent_weight: float = 0.7 # 近期分布权重
|
||||
|
||||
distant_mean_hours: float = 48.0 # 远期分布均值(小时)
|
||||
distant_std_hours: float = 24.0 # 远期分布标准差(小时)
|
||||
distant_weight: float = 0.3 # 远期分布权重
|
||||
distant_std_hours: float = 24.0 # 远期分布标准差(小时)
|
||||
distant_weight: float = 0.3 # 远期分布权重
|
||||
|
||||
# 采样参数
|
||||
total_samples: int = 50 # 总采样数
|
||||
sample_interval: int = 1800 # 采样间隔(秒)
|
||||
max_sample_length: int = 30 # 每次采样的最大消息数量
|
||||
batch_size: int = 5 # 批处理大小
|
||||
total_samples: int = 50 # 总采样数
|
||||
sample_interval: int = 1800 # 采样间隔(秒)
|
||||
max_sample_length: int = 30 # 每次采样的最大消息数量
|
||||
batch_size: int = 5 # 批处理大小
|
||||
|
||||
@classmethod
|
||||
def from_global_config(cls) -> "HippocampusSampleConfig":
|
||||
@@ -84,12 +85,10 @@ class HippocampusSampler:
|
||||
try:
|
||||
# 初始化LLM模型
|
||||
from src.config.config import model_config
|
||||
|
||||
task_config = getattr(model_config.model_task_config, "utils", None)
|
||||
if task_config:
|
||||
self.memory_builder_model = LLMRequest(
|
||||
model_set=task_config,
|
||||
request_type="memory.hippocampus_build"
|
||||
)
|
||||
self.memory_builder_model = LLMRequest(model_set=task_config, request_type="memory.hippocampus_build")
|
||||
asyncio.create_task(self.start_background_sampling())
|
||||
logger.info("✅ 海马体采样器初始化成功")
|
||||
else:
|
||||
@@ -107,14 +106,10 @@ class HippocampusSampler:
|
||||
|
||||
# 生成两个正态分布的小时偏移
|
||||
recent_offsets = np.random.normal(
|
||||
loc=self.config.recent_mean_hours,
|
||||
scale=self.config.recent_std_hours,
|
||||
size=recent_samples
|
||||
loc=self.config.recent_mean_hours, scale=self.config.recent_std_hours, size=recent_samples
|
||||
)
|
||||
distant_offsets = np.random.normal(
|
||||
loc=self.config.distant_mean_hours,
|
||||
scale=self.config.distant_std_hours,
|
||||
size=distant_samples
|
||||
loc=self.config.distant_mean_hours, scale=self.config.distant_std_hours, size=distant_samples
|
||||
)
|
||||
|
||||
# 合并两个分布的偏移
|
||||
@@ -122,10 +117,7 @@ class HippocampusSampler:
|
||||
|
||||
# 转换为时间戳(使用绝对值确保时间点在过去)
|
||||
base_time = datetime.now()
|
||||
timestamps = [
|
||||
base_time - timedelta(hours=abs(offset))
|
||||
for offset in all_offsets
|
||||
]
|
||||
timestamps = [base_time - timedelta(hours=abs(offset)) for offset in all_offsets]
|
||||
|
||||
# 按时间排序(从最早到最近)
|
||||
return sorted(timestamps)
|
||||
@@ -171,7 +163,8 @@ class HippocampusSampler:
|
||||
if messages and len(messages) >= 2: # 至少需要2条消息
|
||||
# 过滤掉已经记忆过的消息
|
||||
filtered_messages = [
|
||||
msg for msg in messages
|
||||
msg
|
||||
for msg in messages
|
||||
if msg.get("memorized_times", 0) < 2 # 最多记忆2次
|
||||
]
|
||||
|
||||
@@ -229,7 +222,7 @@ class HippocampusSampler:
|
||||
conversation_text=input_text,
|
||||
context=context,
|
||||
timestamp=time.time(),
|
||||
bypass_interval=True # 海马体采样器绕过构建间隔限制
|
||||
bypass_interval=True, # 海马体采样器绕过构建间隔限制
|
||||
)
|
||||
|
||||
if memories:
|
||||
@@ -367,7 +360,7 @@ class HippocampusSampler:
|
||||
max_concurrent = min(5, len(time_samples)) # 提高并发数到5
|
||||
|
||||
for i in range(0, len(time_samples), max_concurrent):
|
||||
batch = time_samples[i:i + max_concurrent]
|
||||
batch = time_samples[i : i + max_concurrent]
|
||||
tasks = []
|
||||
|
||||
# 创建并发收集任务
|
||||
@@ -392,7 +385,9 @@ class HippocampusSampler:
|
||||
|
||||
return collected_messages
|
||||
|
||||
async def _fuse_and_deduplicate_messages(self, collected_messages: list[list[dict[str, Any]]]) -> list[list[dict[str, Any]]]:
|
||||
async def _fuse_and_deduplicate_messages(
|
||||
self, collected_messages: list[list[dict[str, Any]]]
|
||||
) -> list[list[dict[str, Any]]]:
|
||||
"""融合和去重消息样本"""
|
||||
if not collected_messages:
|
||||
return []
|
||||
@@ -416,7 +411,7 @@ class HippocampusSampler:
|
||||
chat_id = message.get("chat_id", "")
|
||||
|
||||
# 简单哈希:内容前50字符 + 时间戳(精确到分钟) + 聊天ID
|
||||
hash_key = f"{content[:50]}_{int(timestamp//60)}_{chat_id}"
|
||||
hash_key = f"{content[:50]}_{int(timestamp // 60)}_{chat_id}"
|
||||
|
||||
if hash_key not in seen_hashes and len(content.strip()) > 10:
|
||||
seen_hashes.add(hash_key)
|
||||
@@ -448,7 +443,9 @@ class HippocampusSampler:
|
||||
# 返回原始消息组作为备选
|
||||
return collected_messages[:5] # 限制返回数量
|
||||
|
||||
def _merge_adjacent_messages(self, messages: list[dict[str, Any]], time_gap: int = 1800) -> list[list[dict[str, Any]]]:
|
||||
def _merge_adjacent_messages(
|
||||
self, messages: list[dict[str, Any]], time_gap: int = 1800
|
||||
) -> list[list[dict[str, Any]]]:
|
||||
"""合并时间间隔内的消息"""
|
||||
if not messages:
|
||||
return []
|
||||
@@ -479,7 +476,9 @@ class HippocampusSampler:
|
||||
|
||||
return result_groups
|
||||
|
||||
async def _build_batch_memory(self, fused_messages: list[list[dict[str, Any]]], time_samples: list[datetime]) -> dict[str, Any]:
|
||||
async def _build_batch_memory(
|
||||
self, fused_messages: list[list[dict[str, Any]]], time_samples: list[datetime]
|
||||
) -> dict[str, Any]:
|
||||
"""批量构建记忆"""
|
||||
if not fused_messages:
|
||||
return {"memory_count": 0, "memories": []}
|
||||
@@ -513,10 +512,7 @@ class HippocampusSampler:
|
||||
|
||||
# 一次性构建记忆
|
||||
memories = await self.memory_system.build_memory_from_conversation(
|
||||
conversation_text=batch_input_text,
|
||||
context=batch_context,
|
||||
timestamp=time.time(),
|
||||
bypass_interval=True
|
||||
conversation_text=batch_input_text, context=batch_context, timestamp=time.time(), bypass_interval=True
|
||||
)
|
||||
|
||||
if memories:
|
||||
@@ -545,11 +541,7 @@ class HippocampusSampler:
|
||||
if len(self.last_sample_results) > 10:
|
||||
self.last_sample_results.pop(0)
|
||||
|
||||
return {
|
||||
"memory_count": total_memory_count,
|
||||
"memories": total_memories,
|
||||
"result": result
|
||||
}
|
||||
return {"memory_count": total_memory_count, "memories": total_memories, "result": result}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"批量构建记忆失败: {e}")
|
||||
@@ -601,11 +593,7 @@ class HippocampusSampler:
|
||||
except Exception as e:
|
||||
logger.debug(f"单独构建失败: {e}")
|
||||
|
||||
return {
|
||||
"memory_count": total_count,
|
||||
"memories": total_memories,
|
||||
"fallback_mode": True
|
||||
}
|
||||
return {"memory_count": total_count, "memories": total_memories, "fallback_mode": True}
|
||||
|
||||
async def process_sample_timestamp(self, target_timestamp: float) -> str | None:
|
||||
"""处理单个时间戳采样(保留作为备选方法)"""
|
||||
@@ -696,7 +684,9 @@ class HippocampusSampler:
|
||||
"performance_metrics": {
|
||||
"avg_messages_per_sample": f"{recent_avg_messages:.1f}",
|
||||
"avg_memories_per_sample": f"{recent_avg_memory_count:.1f}",
|
||||
"fusion_efficiency": f"{(recent_avg_messages/max(recent_avg_memory_count, 1)):.1f}x" if recent_avg_messages > 0 else "N/A"
|
||||
"fusion_efficiency": f"{(recent_avg_messages / max(recent_avg_memory_count, 1)):.1f}x"
|
||||
if recent_avg_messages > 0
|
||||
else "N/A",
|
||||
},
|
||||
"config": {
|
||||
"sample_interval": self.config.sample_interval,
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
|
||||
返回: 适合直接嵌入提示词的大段文本;若无有效记忆返回空串。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
|
||||
@@ -24,9 +24,12 @@ from src.chat.memory_system.memory_query_planner import MemoryQueryPlanner
|
||||
# 记忆采样模式枚举
|
||||
class MemorySamplingMode(Enum):
|
||||
"""记忆采样模式"""
|
||||
|
||||
HIPPOCAMPUS = "hippocampus" # 海马体模式:定时任务采样
|
||||
IMMEDIATE = "immediate" # 即时模式:回复后立即采样
|
||||
ALL = "all" # 所有模式:同时使用海马体和即时采样
|
||||
IMMEDIATE = "immediate" # 即时模式:回复后立即采样
|
||||
ALL = "all" # 所有模式:同时使用海马体和即时采样
|
||||
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
@@ -165,7 +168,6 @@ class MemorySystem:
|
||||
async def initialize(self):
|
||||
"""异步初始化记忆系统"""
|
||||
try:
|
||||
|
||||
# 初始化LLM模型
|
||||
fallback_task = getattr(self.llm_model, "model_for_task", None) if self.llm_model else None
|
||||
|
||||
@@ -264,6 +266,7 @@ class MemorySystem:
|
||||
if global_config.memory.enable_hippocampus_sampling:
|
||||
try:
|
||||
from .hippocampus_sampler import initialize_hippocampus_sampler
|
||||
|
||||
self.hippocampus_sampler = await initialize_hippocampus_sampler(self)
|
||||
logger.info("✅ 海马体采样器初始化成功")
|
||||
except Exception as e:
|
||||
@@ -321,7 +324,11 @@ class MemorySystem:
|
||||
return []
|
||||
|
||||
async def build_memory_from_conversation(
|
||||
self, conversation_text: str, context: dict[str, Any], timestamp: float | None = None, bypass_interval: bool = False
|
||||
self,
|
||||
conversation_text: str,
|
||||
context: dict[str, Any],
|
||||
timestamp: float | None = None,
|
||||
bypass_interval: bool = False,
|
||||
) -> list[MemoryChunk]:
|
||||
"""从对话中构建记忆
|
||||
|
||||
@@ -560,7 +567,6 @@ class MemorySystem:
|
||||
sampling_mode = getattr(global_config.memory, "memory_sampling_mode", "precision")
|
||||
current_mode = MemorySamplingMode(sampling_mode)
|
||||
|
||||
|
||||
context["__sampling_mode"] = current_mode.value
|
||||
logger.debug(f"使用记忆采样模式: {current_mode.value}")
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@ logger = get_logger("adaptive_stream_manager")
|
||||
|
||||
class StreamPriority(Enum):
|
||||
"""流优先级"""
|
||||
|
||||
LOW = 1
|
||||
NORMAL = 2
|
||||
HIGH = 3
|
||||
@@ -26,6 +27,7 @@ class StreamPriority(Enum):
|
||||
@dataclass
|
||||
class SystemMetrics:
|
||||
"""系统指标"""
|
||||
|
||||
cpu_usage: float = 0.0
|
||||
memory_usage: float = 0.0
|
||||
active_coroutines: int = 0
|
||||
@@ -36,6 +38,7 @@ class SystemMetrics:
|
||||
@dataclass
|
||||
class StreamMetrics:
|
||||
"""流指标"""
|
||||
|
||||
stream_id: str
|
||||
priority: StreamPriority
|
||||
message_rate: float = 0.0 # 消息速率(消息/分钟)
|
||||
@@ -56,7 +59,7 @@ class AdaptiveStreamManager:
|
||||
metrics_window: float = 60.0, # 指标窗口时间
|
||||
adjustment_interval: float = 30.0, # 调整间隔
|
||||
cpu_threshold_high: float = 0.8, # CPU高负载阈值
|
||||
cpu_threshold_low: float = 0.3, # CPU低负载阈值
|
||||
cpu_threshold_low: float = 0.3, # CPU低负载阈值
|
||||
memory_threshold_high: float = 0.85, # 内存高负载阈值
|
||||
):
|
||||
self.base_concurrent_limit = base_concurrent_limit
|
||||
@@ -139,10 +142,7 @@ class AdaptiveStreamManager:
|
||||
logger.info("自适应流管理器已停止")
|
||||
|
||||
async def acquire_stream_slot(
|
||||
self,
|
||||
stream_id: str,
|
||||
priority: StreamPriority = StreamPriority.NORMAL,
|
||||
force: bool = False
|
||||
self, stream_id: str, priority: StreamPriority = StreamPriority.NORMAL, force: bool = False
|
||||
) -> bool:
|
||||
"""
|
||||
获取流处理槽位
|
||||
@@ -165,10 +165,7 @@ class AdaptiveStreamManager:
|
||||
|
||||
# 更新流指标
|
||||
if stream_id not in self.stream_metrics:
|
||||
self.stream_metrics[stream_id] = StreamMetrics(
|
||||
stream_id=stream_id,
|
||||
priority=priority
|
||||
)
|
||||
self.stream_metrics[stream_id] = StreamMetrics(stream_id=stream_id, priority=priority)
|
||||
self.stream_metrics[stream_id].last_activity = current_time
|
||||
|
||||
# 检查是否已经活跃
|
||||
@@ -271,8 +268,10 @@ class AdaptiveStreamManager:
|
||||
|
||||
# 如果最近有活跃且响应时间较长,可能需要强制分发
|
||||
current_time = time.time()
|
||||
if (current_time - metrics.last_activity < 300 and # 5分钟内有活动
|
||||
metrics.response_time > 5.0): # 响应时间超过5秒
|
||||
if (
|
||||
current_time - metrics.last_activity < 300 # 5分钟内有活动
|
||||
and metrics.response_time > 5.0
|
||||
): # 响应时间超过5秒
|
||||
return True
|
||||
|
||||
return False
|
||||
@@ -324,26 +323,20 @@ class AdaptiveStreamManager:
|
||||
memory_usage=memory_usage,
|
||||
active_coroutines=active_coroutines,
|
||||
event_loop_lag=event_loop_lag,
|
||||
timestamp=time.time()
|
||||
timestamp=time.time(),
|
||||
)
|
||||
|
||||
self.system_metrics.append(metrics)
|
||||
|
||||
# 保持指标窗口大小
|
||||
cutoff_time = time.time() - self.metrics_window
|
||||
self.system_metrics = [
|
||||
m for m in self.system_metrics
|
||||
if m.timestamp > cutoff_time
|
||||
]
|
||||
self.system_metrics = [m for m in self.system_metrics if m.timestamp > cutoff_time]
|
||||
|
||||
# 更新统计信息
|
||||
self.stats["avg_concurrent_streams"] = (
|
||||
self.stats["avg_concurrent_streams"] * 0.9 + len(self.active_streams) * 0.1
|
||||
)
|
||||
self.stats["peak_concurrent_streams"] = max(
|
||||
self.stats["peak_concurrent_streams"],
|
||||
len(self.active_streams)
|
||||
)
|
||||
self.stats["peak_concurrent_streams"] = max(self.stats["peak_concurrent_streams"], len(self.active_streams))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"收集系统指标失败: {e}")
|
||||
@@ -445,14 +438,16 @@ class AdaptiveStreamManager:
|
||||
def get_stats(self) -> dict:
|
||||
"""获取统计信息"""
|
||||
stats = self.stats.copy()
|
||||
stats.update({
|
||||
"current_limit": self.current_limit,
|
||||
"active_streams": len(self.active_streams),
|
||||
"pending_streams": len(self.pending_streams),
|
||||
"is_running": self.is_running,
|
||||
"system_cpu": self.system_metrics[-1].cpu_usage if self.system_metrics else 0,
|
||||
"system_memory": self.system_metrics[-1].memory_usage if self.system_metrics else 0,
|
||||
})
|
||||
stats.update(
|
||||
{
|
||||
"current_limit": self.current_limit,
|
||||
"active_streams": len(self.active_streams),
|
||||
"pending_streams": len(self.pending_streams),
|
||||
"is_running": self.is_running,
|
||||
"system_cpu": self.system_metrics[-1].cpu_usage if self.system_metrics else 0,
|
||||
"system_memory": self.system_metrics[-1].memory_usage if self.system_metrics else 0,
|
||||
}
|
||||
)
|
||||
|
||||
# 计算接受率
|
||||
if stats["total_requests"] > 0:
|
||||
|
||||
@@ -20,6 +20,7 @@ logger = get_logger("batch_database_writer")
|
||||
@dataclass
|
||||
class StreamUpdatePayload:
|
||||
"""流更新数据结构"""
|
||||
|
||||
stream_id: str
|
||||
update_data: dict[str, Any]
|
||||
priority: int = 0 # 优先级,数字越大优先级越高
|
||||
@@ -95,12 +96,7 @@ class BatchDatabaseWriter:
|
||||
|
||||
logger.info("批量数据库写入器已停止")
|
||||
|
||||
async def schedule_stream_update(
|
||||
self,
|
||||
stream_id: str,
|
||||
update_data: dict[str, Any],
|
||||
priority: int = 0
|
||||
) -> bool:
|
||||
async def schedule_stream_update(self, stream_id: str, update_data: dict[str, Any], priority: int = 0) -> bool:
|
||||
"""
|
||||
调度流更新
|
||||
|
||||
@@ -119,11 +115,7 @@ class BatchDatabaseWriter:
|
||||
return True
|
||||
|
||||
# 创建更新载荷
|
||||
payload = StreamUpdatePayload(
|
||||
stream_id=stream_id,
|
||||
update_data=update_data,
|
||||
priority=priority
|
||||
)
|
||||
payload = StreamUpdatePayload(stream_id=stream_id, update_data=update_data, priority=priority)
|
||||
|
||||
# 非阻塞方式加入队列
|
||||
try:
|
||||
@@ -178,10 +170,7 @@ class BatchDatabaseWriter:
|
||||
if remaining_time == 0:
|
||||
break
|
||||
|
||||
payload = await asyncio.wait_for(
|
||||
self.write_queue.get(),
|
||||
timeout=remaining_time
|
||||
)
|
||||
payload = await asyncio.wait_for(self.write_queue.get(), timeout=remaining_time)
|
||||
batch.append(payload)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
@@ -203,7 +192,10 @@ class BatchDatabaseWriter:
|
||||
# 合并同一流ID的更新(保留最新的)
|
||||
merged_updates = {}
|
||||
for payload in batch:
|
||||
if payload.stream_id not in merged_updates or payload.timestamp > merged_updates[payload.stream_id].timestamp:
|
||||
if (
|
||||
payload.stream_id not in merged_updates
|
||||
or payload.timestamp > merged_updates[payload.stream_id].timestamp
|
||||
):
|
||||
merged_updates[payload.stream_id] = payload
|
||||
|
||||
# 批量写入
|
||||
@@ -211,9 +203,7 @@ class BatchDatabaseWriter:
|
||||
|
||||
# 更新统计
|
||||
self.stats["batch_writes"] += 1
|
||||
self.stats["avg_batch_size"] = (
|
||||
self.stats["avg_batch_size"] * 0.9 + len(batch) * 0.1
|
||||
) # 滑动平均
|
||||
self.stats["avg_batch_size"] = self.stats["avg_batch_size"] * 0.9 + len(batch) * 0.1 # 滑动平均
|
||||
self.stats["last_flush_time"] = start_time
|
||||
|
||||
logger.debug(f"批量写入完成: {len(batch)} 个更新,耗时 {time.time() - start_time:.3f}s")
|
||||
@@ -238,31 +228,22 @@ class BatchDatabaseWriter:
|
||||
# 根据数据库类型选择不同的插入/更新策略
|
||||
if global_config.database.database_type == "sqlite":
|
||||
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
||||
stmt = sqlite_insert(ChatStreams).values(
|
||||
stream_id=stream_id, **update_data
|
||||
)
|
||||
stmt = stmt.on_conflict_do_update(
|
||||
index_elements=["stream_id"],
|
||||
set_=update_data
|
||||
)
|
||||
|
||||
stmt = sqlite_insert(ChatStreams).values(stream_id=stream_id, **update_data)
|
||||
stmt = stmt.on_conflict_do_update(index_elements=["stream_id"], set_=update_data)
|
||||
elif global_config.database.database_type == "mysql":
|
||||
from sqlalchemy.dialects.mysql import insert as mysql_insert
|
||||
stmt = mysql_insert(ChatStreams).values(
|
||||
stream_id=stream_id, **update_data
|
||||
)
|
||||
|
||||
stmt = mysql_insert(ChatStreams).values(stream_id=stream_id, **update_data)
|
||||
stmt = stmt.on_duplicate_key_update(
|
||||
**{key: value for key, value in update_data.items() if key != "stream_id"}
|
||||
)
|
||||
else:
|
||||
# 默认使用SQLite语法
|
||||
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
||||
stmt = sqlite_insert(ChatStreams).values(
|
||||
stream_id=stream_id, **update_data
|
||||
)
|
||||
stmt = stmt.on_conflict_do_update(
|
||||
index_elements=["stream_id"],
|
||||
set_=update_data
|
||||
)
|
||||
|
||||
stmt = sqlite_insert(ChatStreams).values(stream_id=stream_id, **update_data)
|
||||
stmt = stmt.on_conflict_do_update(index_elements=["stream_id"], set_=update_data)
|
||||
|
||||
await session.execute(stmt)
|
||||
|
||||
@@ -273,30 +254,21 @@ class BatchDatabaseWriter:
|
||||
async with get_db_session() as session:
|
||||
if global_config.database.database_type == "sqlite":
|
||||
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
||||
stmt = sqlite_insert(ChatStreams).values(
|
||||
stream_id=stream_id, **update_data
|
||||
)
|
||||
stmt = stmt.on_conflict_do_update(
|
||||
index_elements=["stream_id"],
|
||||
set_=update_data
|
||||
)
|
||||
|
||||
stmt = sqlite_insert(ChatStreams).values(stream_id=stream_id, **update_data)
|
||||
stmt = stmt.on_conflict_do_update(index_elements=["stream_id"], set_=update_data)
|
||||
elif global_config.database.database_type == "mysql":
|
||||
from sqlalchemy.dialects.mysql import insert as mysql_insert
|
||||
stmt = mysql_insert(ChatStreams).values(
|
||||
stream_id=stream_id, **update_data
|
||||
)
|
||||
|
||||
stmt = mysql_insert(ChatStreams).values(stream_id=stream_id, **update_data)
|
||||
stmt = stmt.on_duplicate_key_update(
|
||||
**{key: value for key, value in update_data.items() if key != "stream_id"}
|
||||
)
|
||||
else:
|
||||
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
||||
stmt = sqlite_insert(ChatStreams).values(
|
||||
stream_id=stream_id, **update_data
|
||||
)
|
||||
stmt = stmt.on_conflict_do_update(
|
||||
index_elements=["stream_id"],
|
||||
set_=update_data
|
||||
)
|
||||
|
||||
stmt = sqlite_insert(ChatStreams).values(stream_id=stream_id, **update_data)
|
||||
stmt = stmt.on_conflict_do_update(index_elements=["stream_id"], set_=update_data)
|
||||
|
||||
await session.execute(stmt)
|
||||
await session.commit()
|
||||
|
||||
@@ -273,8 +273,10 @@ class SingleStreamContextManager:
|
||||
message.should_reply = result.should_reply
|
||||
message.should_act = result.should_act
|
||||
|
||||
logger.debug(f"消息 {message.message_id} 兴趣值已更新: {result.interest_value:.3f}, "
|
||||
f"should_reply: {result.should_reply}, should_act: {result.should_act}")
|
||||
logger.debug(
|
||||
f"消息 {message.message_id} 兴趣值已更新: {result.interest_value:.3f}, "
|
||||
f"should_reply: {result.should_reply}, should_act: {result.should_act}"
|
||||
)
|
||||
return result.interest_value
|
||||
else:
|
||||
logger.warning(f"消息 {message.message_id} 兴趣值计算失败: {result.error_message}")
|
||||
|
||||
@@ -79,7 +79,7 @@ class StreamLoopManager:
|
||||
logger.info(f"正在取消 {len(cancel_tasks)} 个流循环任务...")
|
||||
await asyncio.gather(
|
||||
*[self._wait_for_task_cancel(stream_id, task) for stream_id, task in cancel_tasks],
|
||||
return_exceptions=True
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
# 取消所有活跃的 chatter 处理任务
|
||||
@@ -115,6 +115,7 @@ class StreamLoopManager:
|
||||
# 使用自适应流管理器获取槽位
|
||||
try:
|
||||
from src.chat.message_manager.adaptive_stream_manager import get_adaptive_stream_manager
|
||||
|
||||
adaptive_manager = get_adaptive_stream_manager()
|
||||
|
||||
if adaptive_manager.is_running:
|
||||
@@ -123,9 +124,7 @@ class StreamLoopManager:
|
||||
|
||||
# 获取处理槽位
|
||||
slot_acquired = await adaptive_manager.acquire_stream_slot(
|
||||
stream_id=stream_id,
|
||||
priority=priority,
|
||||
force=force
|
||||
stream_id=stream_id, priority=priority, force=force
|
||||
)
|
||||
|
||||
if slot_acquired:
|
||||
@@ -140,10 +139,7 @@ class StreamLoopManager:
|
||||
|
||||
# 创建流循环任务
|
||||
try:
|
||||
loop_task = asyncio.create_task(
|
||||
self._stream_loop_worker(stream_id),
|
||||
name=f"stream_loop_{stream_id}"
|
||||
)
|
||||
loop_task = asyncio.create_task(self._stream_loop_worker(stream_id), name=f"stream_loop_{stream_id}")
|
||||
self.stream_loops[stream_id] = loop_task
|
||||
# 更新统计信息
|
||||
self.stats["active_streams"] += 1
|
||||
@@ -156,6 +152,7 @@ class StreamLoopManager:
|
||||
logger.error(f"启动流循环任务失败 {stream_id}: {e}")
|
||||
# 释放槽位
|
||||
from src.chat.message_manager.adaptive_stream_manager import get_adaptive_stream_manager
|
||||
|
||||
adaptive_manager = get_adaptive_stream_manager()
|
||||
adaptive_manager.release_stream_slot(stream_id)
|
||||
|
||||
@@ -179,8 +176,8 @@ class StreamLoopManager:
|
||||
|
||||
except Exception:
|
||||
from src.chat.message_manager.adaptive_stream_manager import StreamPriority
|
||||
return StreamPriority.NORMAL
|
||||
|
||||
return StreamPriority.NORMAL
|
||||
|
||||
async def stop_stream_loop(self, stream_id: str) -> bool:
|
||||
"""停止指定流的循环任务
|
||||
@@ -244,11 +241,12 @@ class StreamLoopManager:
|
||||
# 3. 更新自适应管理器指标
|
||||
try:
|
||||
from src.chat.message_manager.adaptive_stream_manager import get_adaptive_stream_manager
|
||||
|
||||
adaptive_manager = get_adaptive_stream_manager()
|
||||
adaptive_manager.update_stream_metrics(
|
||||
stream_id,
|
||||
message_rate=unread_count / 5.0 if unread_count > 0 else 0.0, # 简化计算
|
||||
last_activity=time.time()
|
||||
last_activity=time.time(),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(f"更新流指标失败: {e}")
|
||||
@@ -300,6 +298,7 @@ class StreamLoopManager:
|
||||
# 释放自适应管理器的槽位
|
||||
try:
|
||||
from src.chat.message_manager.adaptive_stream_manager import get_adaptive_stream_manager
|
||||
|
||||
adaptive_manager = get_adaptive_stream_manager()
|
||||
adaptive_manager.release_stream_slot(stream_id)
|
||||
logger.debug(f"释放自适应流处理槽位: {stream_id}")
|
||||
@@ -553,12 +552,12 @@ class StreamLoopManager:
|
||||
existing_task.cancel()
|
||||
# 创建异步任务来等待取消完成,并添加异常处理
|
||||
cancel_task = asyncio.create_task(
|
||||
self._wait_for_task_cancel(stream_id, existing_task),
|
||||
name=f"cancel_existing_loop_{stream_id}"
|
||||
self._wait_for_task_cancel(stream_id, existing_task), name=f"cancel_existing_loop_{stream_id}"
|
||||
)
|
||||
# 为取消任务添加异常处理,避免孤儿任务
|
||||
cancel_task.add_done_callback(
|
||||
lambda task: logger.debug(f"取消任务完成: {stream_id}") if not task.exception()
|
||||
lambda task: logger.debug(f"取消任务完成: {stream_id}")
|
||||
if not task.exception()
|
||||
else logger.error(f"取消任务异常: {stream_id} - {task.exception()}")
|
||||
)
|
||||
# 从字典中移除
|
||||
@@ -582,10 +581,7 @@ class StreamLoopManager:
|
||||
logger.info(f"流 {stream_id} 当前未读消息数: {unread_count}")
|
||||
|
||||
# 创建新的流循环任务
|
||||
new_task = asyncio.create_task(
|
||||
self._stream_loop(stream_id),
|
||||
name=f"force_stream_loop_{stream_id}"
|
||||
)
|
||||
new_task = asyncio.create_task(self._stream_loop(stream_id), name=f"force_stream_loop_{stream_id}")
|
||||
self.stream_loops[stream_id] = new_task
|
||||
self.stats["total_loops"] += 1
|
||||
|
||||
|
||||
@@ -59,6 +59,7 @@ class MessageManager:
|
||||
# 启动批量数据库写入器
|
||||
try:
|
||||
from src.chat.message_manager.batch_database_writer import init_batch_writer
|
||||
|
||||
await init_batch_writer()
|
||||
except Exception as e:
|
||||
logger.error(f"启动批量数据库写入器失败: {e}")
|
||||
@@ -66,6 +67,7 @@ class MessageManager:
|
||||
# 启动流缓存管理器
|
||||
try:
|
||||
from src.chat.message_manager.stream_cache_manager import init_stream_cache_manager
|
||||
|
||||
await init_stream_cache_manager()
|
||||
except Exception as e:
|
||||
logger.error(f"启动流缓存管理器失败: {e}")
|
||||
@@ -73,6 +75,7 @@ class MessageManager:
|
||||
# 启动自适应流管理器
|
||||
try:
|
||||
from src.chat.message_manager.adaptive_stream_manager import init_adaptive_stream_manager
|
||||
|
||||
await init_adaptive_stream_manager()
|
||||
logger.info("🎯 自适应流管理器已启动")
|
||||
except Exception as e:
|
||||
@@ -97,6 +100,7 @@ class MessageManager:
|
||||
# 停止批量数据库写入器
|
||||
try:
|
||||
from src.chat.message_manager.batch_database_writer import shutdown_batch_writer
|
||||
|
||||
await shutdown_batch_writer()
|
||||
logger.info("📦 批量数据库写入器已停止")
|
||||
except Exception as e:
|
||||
@@ -105,6 +109,7 @@ class MessageManager:
|
||||
# 停止流缓存管理器
|
||||
try:
|
||||
from src.chat.message_manager.stream_cache_manager import shutdown_stream_cache_manager
|
||||
|
||||
await shutdown_stream_cache_manager()
|
||||
logger.info("🗄️ 流缓存管理器已停止")
|
||||
except Exception as e:
|
||||
@@ -113,6 +118,7 @@ class MessageManager:
|
||||
# 停止自适应流管理器
|
||||
try:
|
||||
from src.chat.message_manager.adaptive_stream_manager import shutdown_adaptive_stream_manager
|
||||
|
||||
await shutdown_adaptive_stream_manager()
|
||||
logger.info("🎯 自适应流管理器已停止")
|
||||
except Exception as e:
|
||||
|
||||
@@ -19,6 +19,7 @@ logger = get_logger("stream_cache_manager")
|
||||
@dataclass
|
||||
class StreamCacheStats:
|
||||
"""缓存统计信息"""
|
||||
|
||||
hot_cache_size: int = 0
|
||||
warm_storage_size: int = 0
|
||||
cold_storage_size: int = 0
|
||||
@@ -38,9 +39,9 @@ class TieredStreamCache:
|
||||
max_warm_size: int = 500,
|
||||
max_cold_size: int = 2000,
|
||||
cleanup_interval: float = 300.0, # 5分钟清理一次
|
||||
hot_timeout: float = 1800.0, # 30分钟未访问降级到warm
|
||||
warm_timeout: float = 7200.0, # 2小时未访问降级到cold
|
||||
cold_timeout: float = 86400.0, # 24小时未访问删除
|
||||
hot_timeout: float = 1800.0, # 30分钟未访问降级到warm
|
||||
warm_timeout: float = 7200.0, # 2小时未访问降级到cold
|
||||
cold_timeout: float = 86400.0, # 24小时未访问删除
|
||||
):
|
||||
self.max_hot_size = max_hot_size
|
||||
self.max_warm_size = max_warm_size
|
||||
@@ -52,8 +53,8 @@ class TieredStreamCache:
|
||||
|
||||
# 三层缓存存储
|
||||
self.hot_cache: OrderedDict[str, OptimizedChatStream] = OrderedDict() # 热数据(LRU)
|
||||
self.warm_storage: dict[str, tuple[OptimizedChatStream, float]] = {} # 温数据(最后访问时间)
|
||||
self.cold_storage: dict[str, tuple[OptimizedChatStream, float]] = {} # 冷数据(最后访问时间)
|
||||
self.warm_storage: dict[str, tuple[OptimizedChatStream, float]] = {} # 温数据(最后访问时间)
|
||||
self.cold_storage: dict[str, tuple[OptimizedChatStream, float]] = {} # 冷数据(最后访问时间)
|
||||
|
||||
# 统计信息
|
||||
self.stats = StreamCacheStats()
|
||||
@@ -134,11 +135,7 @@ class TieredStreamCache:
|
||||
# 4. 缓存未命中,创建新流
|
||||
self.stats.cache_misses += 1
|
||||
stream = create_optimized_chat_stream(
|
||||
stream_id=stream_id,
|
||||
platform=platform,
|
||||
user_info=user_info,
|
||||
group_info=group_info,
|
||||
data=data
|
||||
stream_id=stream_id, platform=platform, user_info=user_info, group_info=group_info, data=data
|
||||
)
|
||||
logger.debug(f"缓存未命中,创建新流: {stream_id}")
|
||||
|
||||
@@ -294,9 +291,9 @@ class TieredStreamCache:
|
||||
|
||||
# 估算内存使用(粗略估计)
|
||||
self.stats.total_memory_usage = (
|
||||
len(self.hot_cache) * 1024 + # 每个热流约1KB
|
||||
len(self.warm_storage) * 512 + # 每个温流约512B
|
||||
len(self.cold_storage) * 256 # 每个冷流约256B
|
||||
len(self.hot_cache) * 1024 # 每个热流约1KB
|
||||
+ len(self.warm_storage) * 512 # 每个温流约512B
|
||||
+ len(self.cold_storage) * 256 # 每个冷流约256B
|
||||
)
|
||||
|
||||
if sum(cleanup_stats.values()) > 0:
|
||||
|
||||
@@ -557,7 +557,11 @@ class ChatBot:
|
||||
|
||||
# 将兴趣度结果同步回原始消息,便于后续流程使用
|
||||
message.interest_value = getattr(db_message, "interest_value", getattr(message, "interest_value", 0.0))
|
||||
setattr(message, "should_reply", getattr(db_message, "should_reply", getattr(message, "should_reply", False)))
|
||||
setattr(
|
||||
message,
|
||||
"should_reply",
|
||||
getattr(db_message, "should_reply", getattr(message, "should_reply", False)),
|
||||
)
|
||||
setattr(message, "should_act", getattr(db_message, "should_act", getattr(message, "should_act", False)))
|
||||
|
||||
# 存储消息到数据库,只进行一次写入
|
||||
|
||||
@@ -298,8 +298,10 @@ class ChatStream:
|
||||
db_message.should_reply = result.should_reply
|
||||
db_message.should_act = result.should_act
|
||||
|
||||
logger.debug(f"消息 {db_message.message_id} 兴趣值已更新: {result.interest_value:.3f}, "
|
||||
f"should_reply: {result.should_reply}, should_act: {result.should_act}")
|
||||
logger.debug(
|
||||
f"消息 {db_message.message_id} 兴趣值已更新: {result.interest_value:.3f}, "
|
||||
f"should_reply: {result.should_reply}, should_act: {result.should_act}"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"消息 {db_message.message_id} 兴趣值计算失败: {result.error_message}")
|
||||
# 使用默认值
|
||||
@@ -521,18 +523,17 @@ class ChatManager:
|
||||
# 优先使用缓存管理器(优化版本)
|
||||
try:
|
||||
from src.chat.message_manager.stream_cache_manager import get_stream_cache_manager
|
||||
|
||||
cache_manager = get_stream_cache_manager()
|
||||
|
||||
if cache_manager.is_running:
|
||||
optimized_stream = await cache_manager.get_or_create_stream(
|
||||
stream_id=stream_id,
|
||||
platform=platform,
|
||||
user_info=user_info,
|
||||
group_info=group_info
|
||||
stream_id=stream_id, platform=platform, user_info=user_info, group_info=group_info
|
||||
)
|
||||
|
||||
# 设置消息上下文
|
||||
from .message import MessageRecv
|
||||
|
||||
if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], MessageRecv):
|
||||
optimized_stream.set_context(self.last_messages[stream_id])
|
||||
|
||||
@@ -715,7 +716,7 @@ class ChatManager:
|
||||
success = await batch_writer.schedule_stream_update(
|
||||
stream_id=stream_data_dict["stream_id"],
|
||||
update_data=ChatManager._prepare_stream_data(stream_data_dict),
|
||||
priority=1 # 流更新的优先级
|
||||
priority=1, # 流更新的优先级
|
||||
)
|
||||
if success:
|
||||
stream.saved = True
|
||||
@@ -738,7 +739,7 @@ class ChatManager:
|
||||
result = await batch_update(
|
||||
model_class=ChatStreams,
|
||||
conditions={"stream_id": stream_data_dict["stream_id"]},
|
||||
data=ChatManager._prepare_stream_data(stream_data_dict)
|
||||
data=ChatManager._prepare_stream_data(stream_data_dict),
|
||||
)
|
||||
if result and result > 0:
|
||||
stream.saved = True
|
||||
@@ -874,43 +875,43 @@ chat_manager = None
|
||||
|
||||
|
||||
def _convert_to_original_stream(self, optimized_stream) -> "ChatStream":
|
||||
"""将OptimizedChatStream转换为原始ChatStream以保持兼容性"""
|
||||
try:
|
||||
# 创建原始ChatStream实例
|
||||
original_stream = ChatStream(
|
||||
stream_id=optimized_stream.stream_id,
|
||||
platform=optimized_stream.platform,
|
||||
user_info=optimized_stream._get_effective_user_info(),
|
||||
group_info=optimized_stream._get_effective_group_info()
|
||||
)
|
||||
"""将OptimizedChatStream转换为原始ChatStream以保持兼容性"""
|
||||
try:
|
||||
# 创建原始ChatStream实例
|
||||
original_stream = ChatStream(
|
||||
stream_id=optimized_stream.stream_id,
|
||||
platform=optimized_stream.platform,
|
||||
user_info=optimized_stream._get_effective_user_info(),
|
||||
group_info=optimized_stream._get_effective_group_info(),
|
||||
)
|
||||
|
||||
# 复制状态
|
||||
original_stream.create_time = optimized_stream.create_time
|
||||
original_stream.last_active_time = optimized_stream.last_active_time
|
||||
original_stream.sleep_pressure = optimized_stream.sleep_pressure
|
||||
original_stream.base_interest_energy = optimized_stream.base_interest_energy
|
||||
original_stream._focus_energy = optimized_stream._focus_energy
|
||||
original_stream.no_reply_consecutive = optimized_stream.no_reply_consecutive
|
||||
original_stream.saved = optimized_stream.saved
|
||||
# 复制状态
|
||||
original_stream.create_time = optimized_stream.create_time
|
||||
original_stream.last_active_time = optimized_stream.last_active_time
|
||||
original_stream.sleep_pressure = optimized_stream.sleep_pressure
|
||||
original_stream.base_interest_energy = optimized_stream.base_interest_energy
|
||||
original_stream._focus_energy = optimized_stream._focus_energy
|
||||
original_stream.no_reply_consecutive = optimized_stream.no_reply_consecutive
|
||||
original_stream.saved = optimized_stream.saved
|
||||
|
||||
# 复制上下文信息(如果存在)
|
||||
if hasattr(optimized_stream, "_stream_context") and optimized_stream._stream_context:
|
||||
original_stream.stream_context = optimized_stream._stream_context
|
||||
# 复制上下文信息(如果存在)
|
||||
if hasattr(optimized_stream, "_stream_context") and optimized_stream._stream_context:
|
||||
original_stream.stream_context = optimized_stream._stream_context
|
||||
|
||||
if hasattr(optimized_stream, "_context_manager") and optimized_stream._context_manager:
|
||||
original_stream.context_manager = optimized_stream._context_manager
|
||||
if hasattr(optimized_stream, "_context_manager") and optimized_stream._context_manager:
|
||||
original_stream.context_manager = optimized_stream._context_manager
|
||||
|
||||
return original_stream
|
||||
return original_stream
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"转换OptimizedChatStream失败: {e}")
|
||||
# 如果转换失败,创建一个新的原始流
|
||||
return ChatStream(
|
||||
stream_id=optimized_stream.stream_id,
|
||||
platform=optimized_stream.platform,
|
||||
user_info=optimized_stream._get_effective_user_info(),
|
||||
group_info=optimized_stream._get_effective_group_info()
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"转换OptimizedChatStream失败: {e}")
|
||||
# 如果转换失败,创建一个新的原始流
|
||||
return ChatStream(
|
||||
stream_id=optimized_stream.stream_id,
|
||||
platform=optimized_stream.platform,
|
||||
user_info=optimized_stream._get_effective_user_info(),
|
||||
group_info=optimized_stream._get_effective_group_info(),
|
||||
)
|
||||
|
||||
|
||||
def get_chat_manager():
|
||||
|
||||
@@ -80,10 +80,7 @@ class OptimizedChatStream:
|
||||
):
|
||||
# 共享的只读数据
|
||||
self._shared_context = SharedContext(
|
||||
stream_id=stream_id,
|
||||
platform=platform,
|
||||
user_info=user_info,
|
||||
group_info=group_info
|
||||
stream_id=stream_id, platform=platform, user_info=user_info, group_info=group_info
|
||||
)
|
||||
|
||||
# 本地修改数据
|
||||
@@ -269,14 +266,13 @@ class OptimizedChatStream:
|
||||
self._stream_context = StreamContext(
|
||||
stream_id=self.stream_id,
|
||||
chat_type=ChatType.GROUP if self.group_info else ChatType.PRIVATE,
|
||||
chat_mode=ChatMode.NORMAL
|
||||
chat_mode=ChatMode.NORMAL,
|
||||
)
|
||||
|
||||
# 创建单流上下文管理器
|
||||
from src.chat.message_manager.context_manager import SingleStreamContextManager
|
||||
self._context_manager = SingleStreamContextManager(
|
||||
stream_id=self.stream_id, context=self._stream_context
|
||||
)
|
||||
|
||||
self._context_manager = SingleStreamContextManager(stream_id=self.stream_id, context=self._stream_context)
|
||||
|
||||
@property
|
||||
def stream_context(self):
|
||||
@@ -331,9 +327,11 @@ class OptimizedChatStream:
|
||||
# 恢复stream_context信息
|
||||
if "stream_context_chat_type" in data:
|
||||
from src.plugin_system.base.component_types import ChatMode, ChatType
|
||||
|
||||
instance.stream_context.chat_type = ChatType(data["stream_context_chat_type"])
|
||||
if "stream_context_chat_mode" in data:
|
||||
from src.plugin_system.base.component_types import ChatMode, ChatType
|
||||
|
||||
instance.stream_context.chat_mode = ChatMode(data["stream_context_chat_mode"])
|
||||
|
||||
# 恢复interruption_count信息
|
||||
@@ -352,6 +350,7 @@ class OptimizedChatStream:
|
||||
if isinstance(actions, str):
|
||||
try:
|
||||
import json
|
||||
|
||||
actions = json.loads(actions)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"无法解析actions JSON字符串: {actions}")
|
||||
@@ -458,7 +457,7 @@ class OptimizedChatStream:
|
||||
stream_id=self.stream_id,
|
||||
platform=self.platform,
|
||||
user_info=self._get_effective_user_info(),
|
||||
group_info=self._get_effective_group_info()
|
||||
group_info=self._get_effective_group_info(),
|
||||
)
|
||||
|
||||
# 复制本地修改(但不触发写时复制)
|
||||
@@ -482,9 +481,5 @@ def create_optimized_chat_stream(
|
||||
) -> OptimizedChatStream:
|
||||
"""创建优化版聊天流实例"""
|
||||
return OptimizedChatStream(
|
||||
stream_id=stream_id,
|
||||
platform=platform,
|
||||
user_info=user_info,
|
||||
group_info=group_info,
|
||||
data=data
|
||||
stream_id=stream_id, platform=platform, user_info=user_info, group_info=group_info, data=data
|
||||
)
|
||||
|
||||
@@ -196,18 +196,20 @@ class ChatterActionManager:
|
||||
thinking_id=thinking_id or "",
|
||||
action_done=True,
|
||||
action_build_into_prompt=False,
|
||||
action_prompt_display=reason
|
||||
action_prompt_display=reason,
|
||||
)
|
||||
else:
|
||||
asyncio.create_task(database_api.store_action_info(
|
||||
chat_stream=chat_stream,
|
||||
action_build_into_prompt=False,
|
||||
action_prompt_display=reason,
|
||||
action_done=True,
|
||||
thinking_id=thinking_id,
|
||||
action_data={"reason": reason},
|
||||
action_name="no_reply",
|
||||
))
|
||||
asyncio.create_task(
|
||||
database_api.store_action_info(
|
||||
chat_stream=chat_stream,
|
||||
action_build_into_prompt=False,
|
||||
action_prompt_display=reason,
|
||||
action_done=True,
|
||||
thinking_id=thinking_id,
|
||||
action_data={"reason": reason},
|
||||
action_name="no_reply",
|
||||
)
|
||||
)
|
||||
|
||||
# 自动清空所有未读消息
|
||||
asyncio.create_task(self._clear_all_unread_messages(chat_stream.stream_id, "no_reply"))
|
||||
@@ -228,7 +230,9 @@ class ChatterActionManager:
|
||||
|
||||
# 记录执行的动作到目标消息
|
||||
if success:
|
||||
asyncio.create_task(self._record_action_to_message(chat_stream, action_name, target_message, action_data))
|
||||
asyncio.create_task(
|
||||
self._record_action_to_message(chat_stream, action_name, target_message, action_data)
|
||||
)
|
||||
# 自动清空所有未读消息
|
||||
if clear_unread_messages:
|
||||
asyncio.create_task(self._clear_all_unread_messages(chat_stream.stream_id, action_name))
|
||||
@@ -496,7 +500,7 @@ class ChatterActionManager:
|
||||
thinking_id=thinking_id or "",
|
||||
action_done=True,
|
||||
action_build_into_prompt=False,
|
||||
action_prompt_display=action_prompt_display
|
||||
action_prompt_display=action_prompt_display,
|
||||
)
|
||||
else:
|
||||
await database_api.store_action_info(
|
||||
@@ -618,9 +622,15 @@ class ChatterActionManager:
|
||||
self._pending_actions = [] # 清空队列
|
||||
logger.debug("已禁用批量存储模式")
|
||||
|
||||
def add_action_to_batch(self, action_name: str, action_data: dict, thinking_id: str = "",
|
||||
action_done: bool = True, action_build_into_prompt: bool = False,
|
||||
action_prompt_display: str = ""):
|
||||
def add_action_to_batch(
|
||||
self,
|
||||
action_name: str,
|
||||
action_data: dict,
|
||||
thinking_id: str = "",
|
||||
action_done: bool = True,
|
||||
action_build_into_prompt: bool = False,
|
||||
action_prompt_display: str = "",
|
||||
):
|
||||
"""添加动作到批量存储列表"""
|
||||
if not self._batch_storage_enabled:
|
||||
return False
|
||||
@@ -632,7 +642,7 @@ class ChatterActionManager:
|
||||
"action_done": action_done,
|
||||
"action_build_into_prompt": action_build_into_prompt,
|
||||
"action_prompt_display": action_prompt_display,
|
||||
"timestamp": time.time()
|
||||
"timestamp": time.time(),
|
||||
}
|
||||
self._pending_actions.append(action_record)
|
||||
logger.debug(f"已添加动作到批量存储列表: {action_name} (当前待处理: {len(self._pending_actions)} 个)")
|
||||
@@ -658,7 +668,7 @@ class ChatterActionManager:
|
||||
action_done=action_data.get("action_done", True),
|
||||
action_build_into_prompt=action_data.get("action_build_into_prompt", False),
|
||||
action_prompt_display=action_data.get("action_prompt_display", ""),
|
||||
thinking_id=action_data.get("thinking_id", "")
|
||||
thinking_id=action_data.get("thinking_id", ""),
|
||||
)
|
||||
if result:
|
||||
stored_count += 1
|
||||
|
||||
@@ -589,7 +589,7 @@ class DefaultReplyer:
|
||||
# 获取记忆系统实例
|
||||
memory_system = get_memory_system()
|
||||
|
||||
# 使用统一记忆系统检索相关记忆
|
||||
# 使用统一记忆系统检索相关记忆
|
||||
enhanced_memories = await memory_system.retrieve_relevant_memories(
|
||||
query=target, user_id=memory_user_id, scope_id=stream.stream_id, context=memory_context, limit=10
|
||||
)
|
||||
@@ -1208,12 +1208,32 @@ class DefaultReplyer:
|
||||
|
||||
# 并行执行六个构建任务
|
||||
tasks = {
|
||||
"expression_habits": asyncio.create_task(self._time_and_run_task(self.build_expression_habits(chat_talking_prompt_short, target), "expression_habits")),
|
||||
"relation_info": asyncio.create_task(self._time_and_run_task(self.build_relation_info(sender, target), "relation_info")),
|
||||
"memory_block": asyncio.create_task(self._time_and_run_task(self.build_memory_block(chat_talking_prompt_short, target), "memory_block")),
|
||||
"tool_info": asyncio.create_task(self._time_and_run_task(self.build_tool_info(chat_talking_prompt_short, sender, target, enable_tool=enable_tool), "tool_info")),
|
||||
"prompt_info": asyncio.create_task(self._time_and_run_task(self.get_prompt_info(chat_talking_prompt_short, sender, target), "prompt_info")),
|
||||
"cross_context": asyncio.create_task(self._time_and_run_task(Prompt.build_cross_context(chat_id, global_config.personality.prompt_mode, target_user_info), "cross_context")),
|
||||
"expression_habits": asyncio.create_task(
|
||||
self._time_and_run_task(
|
||||
self.build_expression_habits(chat_talking_prompt_short, target), "expression_habits"
|
||||
)
|
||||
),
|
||||
"relation_info": asyncio.create_task(
|
||||
self._time_and_run_task(self.build_relation_info(sender, target), "relation_info")
|
||||
),
|
||||
"memory_block": asyncio.create_task(
|
||||
self._time_and_run_task(self.build_memory_block(chat_talking_prompt_short, target), "memory_block")
|
||||
),
|
||||
"tool_info": asyncio.create_task(
|
||||
self._time_and_run_task(
|
||||
self.build_tool_info(chat_talking_prompt_short, sender, target, enable_tool=enable_tool),
|
||||
"tool_info",
|
||||
)
|
||||
),
|
||||
"prompt_info": asyncio.create_task(
|
||||
self._time_and_run_task(self.get_prompt_info(chat_talking_prompt_short, sender, target), "prompt_info")
|
||||
),
|
||||
"cross_context": asyncio.create_task(
|
||||
self._time_and_run_task(
|
||||
Prompt.build_cross_context(chat_id, global_config.personality.prompt_mode, target_user_info),
|
||||
"cross_context",
|
||||
)
|
||||
),
|
||||
}
|
||||
|
||||
# 设置超时
|
||||
@@ -1512,13 +1532,8 @@ class DefaultReplyer:
|
||||
chat_target_name = (
|
||||
self.chat_target_info.get("person_name") or self.chat_target_info.get("user_nickname") or "对方"
|
||||
)
|
||||
await global_prompt_manager.format_prompt(
|
||||
"chat_target_private1", sender_name=chat_target_name
|
||||
)
|
||||
await global_prompt_manager.format_prompt(
|
||||
"chat_target_private2", sender_name=chat_target_name
|
||||
)
|
||||
|
||||
await global_prompt_manager.format_prompt("chat_target_private1", sender_name=chat_target_name)
|
||||
await global_prompt_manager.format_prompt("chat_target_private2", sender_name=chat_target_name)
|
||||
|
||||
# 使用新的统一Prompt系统 - Expressor模式,创建PromptParameters
|
||||
prompt_parameters = PromptParameters(
|
||||
|
||||
@@ -121,13 +121,14 @@ class VideoAnalyzer:
|
||||
async def _analyze_batch(self, frames: list[tuple[str, float]], question: str | None) -> str:
|
||||
from src.llm_models.payload_content.message import MessageBuilder, RoleType
|
||||
from src.llm_models.utils_model import RequestType
|
||||
|
||||
prompt = self.batch_analysis_prompt.format(
|
||||
personality_core=self.personality_core, personality_side=self.personality_side
|
||||
)
|
||||
if question:
|
||||
prompt += f"\n用户关注: {question}"
|
||||
desc = [
|
||||
(f"第{i+1}帧 (时间: {ts:.2f}s)" if self.enable_frame_timing else f"第{i+1}帧")
|
||||
(f"第{i + 1}帧 (时间: {ts:.2f}s)" if self.enable_frame_timing else f"第{i + 1}帧")
|
||||
for i, (_b, ts) in enumerate(frames)
|
||||
]
|
||||
prompt += "\n帧列表: " + ", ".join(desc)
|
||||
@@ -151,16 +152,16 @@ class VideoAnalyzer:
|
||||
async def _analyze_sequential(self, frames: list[tuple[str, float]], question: str | None) -> str:
|
||||
results: list[str] = []
|
||||
for i, (b64, ts) in enumerate(frames):
|
||||
prompt = f"分析第{i+1}帧" + (f" (时间: {ts:.2f}s)" if self.enable_frame_timing else "")
|
||||
prompt = f"分析第{i + 1}帧" + (f" (时间: {ts:.2f}s)" if self.enable_frame_timing else "")
|
||||
if question:
|
||||
prompt += f"\n关注: {question}"
|
||||
try:
|
||||
text, _ = await self.video_llm.generate_response_for_image(
|
||||
prompt=prompt, image_base64=b64, image_format="jpeg"
|
||||
)
|
||||
results.append(f"第{i+1}帧: {text}")
|
||||
results.append(f"第{i + 1}帧: {text}")
|
||||
except Exception as e: # pragma: no cover
|
||||
results.append(f"第{i+1}帧: 失败 {e}")
|
||||
results.append(f"第{i + 1}帧: 失败 {e}")
|
||||
if i < len(frames) - 1:
|
||||
await asyncio.sleep(self.frame_analysis_delay)
|
||||
summary_prompt = "基于以下逐帧结果给出完整总结:\n\n" + "\n".join(results)
|
||||
@@ -182,7 +183,9 @@ class VideoAnalyzer:
|
||||
mode = self.analysis_mode
|
||||
if mode == "auto":
|
||||
mode = "batch" if len(frames) <= 20 else "sequential"
|
||||
text = await (self._analyze_batch(frames, question) if mode == "batch" else self._analyze_sequential(frames, question))
|
||||
text = await (
|
||||
self._analyze_batch(frames, question) if mode == "batch" else self._analyze_sequential(frames, question)
|
||||
)
|
||||
return True, text
|
||||
|
||||
async def analyze_video_from_bytes(
|
||||
|
||||
@@ -220,7 +220,9 @@ class DatabaseMessages(BaseDataModel):
|
||||
"chat_info_user_cardname": self.chat_info.user_info.user_cardname,
|
||||
}
|
||||
|
||||
def update_message_info(self, interest_value: float | None = None, actions: list | None = None, should_reply: bool | None = None):
|
||||
def update_message_info(
|
||||
self, interest_value: float | None = None, actions: list | None = None, should_reply: bool | None = None
|
||||
):
|
||||
"""
|
||||
更新消息信息
|
||||
|
||||
|
||||
@@ -53,8 +53,6 @@ class StreamContext(BaseDataModel):
|
||||
priority_mode: str | None = None
|
||||
priority_info: dict | None = None
|
||||
|
||||
|
||||
|
||||
def add_action_to_message(self, message_id: str, action: str):
|
||||
"""
|
||||
向指定消息添加执行的动作
|
||||
@@ -75,9 +73,6 @@ class StreamContext(BaseDataModel):
|
||||
message.add_action(action)
|
||||
break
|
||||
|
||||
|
||||
|
||||
|
||||
def mark_message_as_read(self, message_id: str):
|
||||
"""标记消息为已读"""
|
||||
for msg in self.unread_messages:
|
||||
|
||||
@@ -78,7 +78,7 @@ class ConnectionPoolManager:
|
||||
"total_expired": 0,
|
||||
"active_connections": 0,
|
||||
"pool_hits": 0,
|
||||
"pool_misses": 0
|
||||
"pool_misses": 0,
|
||||
}
|
||||
|
||||
# 后台清理任务
|
||||
@@ -156,7 +156,9 @@ class ConnectionPoolManager:
|
||||
if connection_info:
|
||||
connection_info.mark_released()
|
||||
|
||||
async def _get_reusable_connection(self, session_factory: async_sessionmaker[AsyncSession]) -> ConnectionInfo | None:
|
||||
async def _get_reusable_connection(
|
||||
self, session_factory: async_sessionmaker[AsyncSession]
|
||||
) -> ConnectionInfo | None:
|
||||
"""获取可复用的连接"""
|
||||
async with self._lock:
|
||||
# 清理过期连接
|
||||
@@ -164,9 +166,7 @@ class ConnectionPoolManager:
|
||||
|
||||
# 查找可复用的连接
|
||||
for connection_info in list(self._connections):
|
||||
if (not connection_info.in_use and
|
||||
not connection_info.is_expired(self.max_lifetime, self.max_idle)):
|
||||
|
||||
if not connection_info.in_use and not connection_info.is_expired(self.max_lifetime, self.max_idle):
|
||||
# 验证连接是否仍然有效
|
||||
try:
|
||||
# 执行一个简单的查询来验证连接
|
||||
@@ -191,8 +191,7 @@ class ConnectionPoolManager:
|
||||
expired_connections = []
|
||||
|
||||
for connection_info in list(self._connections):
|
||||
if (connection_info.is_expired(self.max_lifetime, self.max_idle) and
|
||||
not connection_info.in_use):
|
||||
if connection_info.is_expired(self.max_lifetime, self.max_idle) and not connection_info.in_use:
|
||||
expired_connections.append(connection_info)
|
||||
|
||||
for connection_info in expired_connections:
|
||||
@@ -238,7 +237,8 @@ class ConnectionPoolManager:
|
||||
"max_pool_size": self.max_pool_size,
|
||||
"pool_efficiency": (
|
||||
self._stats["pool_hits"] / max(1, self._stats["pool_hits"] + self._stats["pool_misses"])
|
||||
) * 100
|
||||
)
|
||||
* 100,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -24,6 +24,7 @@ T = TypeVar("T")
|
||||
@dataclass
|
||||
class BatchOperation:
|
||||
"""批量操作基础类"""
|
||||
|
||||
operation_type: str # 'select', 'insert', 'update', 'delete'
|
||||
model_class: Any
|
||||
conditions: dict[str, Any]
|
||||
@@ -40,6 +41,7 @@ class BatchOperation:
|
||||
@dataclass
|
||||
class BatchResult:
|
||||
"""批量操作结果"""
|
||||
|
||||
success: bool
|
||||
data: Any = None
|
||||
error: str | None = None
|
||||
@@ -48,10 +50,12 @@ class BatchResult:
|
||||
class DatabaseBatchScheduler:
|
||||
"""数据库批量调度器"""
|
||||
|
||||
def __init__(self,
|
||||
batch_size: int = 50,
|
||||
max_wait_time: float = 0.1, # 100ms
|
||||
max_queue_size: int = 1000):
|
||||
def __init__(
|
||||
self,
|
||||
batch_size: int = 50,
|
||||
max_wait_time: float = 0.1, # 100ms
|
||||
max_queue_size: int = 1000,
|
||||
):
|
||||
self.batch_size = batch_size
|
||||
self.max_wait_time = max_wait_time
|
||||
self.max_queue_size = max_queue_size
|
||||
@@ -65,12 +69,7 @@ class DatabaseBatchScheduler:
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
# 统计信息
|
||||
self.stats = {
|
||||
"total_operations": 0,
|
||||
"batched_operations": 0,
|
||||
"cache_hits": 0,
|
||||
"execution_time": 0.0
|
||||
}
|
||||
self.stats = {"total_operations": 0, "batched_operations": 0, "cache_hits": 0, "execution_time": 0.0}
|
||||
|
||||
# 简单的结果缓存(用于频繁的查询)
|
||||
self._result_cache: dict[str, tuple[Any, float]] = {}
|
||||
@@ -105,11 +104,7 @@ class DatabaseBatchScheduler:
|
||||
def _generate_cache_key(self, operation_type: str, model_class: Any, conditions: dict[str, Any]) -> str:
|
||||
"""生成缓存键"""
|
||||
# 简单的缓存键生成,实际可以根据需要优化
|
||||
key_parts = [
|
||||
operation_type,
|
||||
model_class.__name__,
|
||||
str(sorted(conditions.items()))
|
||||
]
|
||||
key_parts = [operation_type, model_class.__name__, str(sorted(conditions.items()))]
|
||||
return "|".join(key_parts)
|
||||
|
||||
def _get_from_cache(self, cache_key: str) -> Any | None:
|
||||
@@ -132,11 +127,7 @@ class DatabaseBatchScheduler:
|
||||
"""添加操作到队列"""
|
||||
# 检查是否可以立即返回缓存结果
|
||||
if operation.operation_type == "select":
|
||||
cache_key = self._generate_cache_key(
|
||||
operation.operation_type,
|
||||
operation.model_class,
|
||||
operation.conditions
|
||||
)
|
||||
cache_key = self._generate_cache_key(operation.operation_type, operation.model_class, operation.conditions)
|
||||
cached_result = self._get_from_cache(cache_key)
|
||||
if cached_result is not None:
|
||||
if operation.callback:
|
||||
@@ -180,10 +171,7 @@ class DatabaseBatchScheduler:
|
||||
return
|
||||
|
||||
# 复制队列内容,避免长时间占用锁
|
||||
queues_copy = {
|
||||
key: deque(operations)
|
||||
for key, operations in self.operation_queues.items()
|
||||
}
|
||||
queues_copy = {key: deque(operations) for key, operations in self.operation_queues.items()}
|
||||
# 清空原队列
|
||||
for queue in self.operation_queues.values():
|
||||
queue.clear()
|
||||
@@ -240,9 +228,7 @@ class DatabaseBatchScheduler:
|
||||
# 缓存查询结果
|
||||
if operation.operation_type == "select":
|
||||
cache_key = self._generate_cache_key(
|
||||
operation.operation_type,
|
||||
operation.model_class,
|
||||
operation.conditions
|
||||
operation.operation_type, operation.model_class, operation.conditions
|
||||
)
|
||||
self._set_cache(cache_key, result)
|
||||
|
||||
@@ -287,12 +273,9 @@ class DatabaseBatchScheduler:
|
||||
else:
|
||||
# 需要根据条件过滤结果
|
||||
op_result = [
|
||||
item for item in data
|
||||
if all(
|
||||
getattr(item, k) == v
|
||||
for k, v in op.conditions.items()
|
||||
if hasattr(item, k)
|
||||
)
|
||||
item
|
||||
for item in data
|
||||
if all(getattr(item, k) == v for k, v in op.conditions.items() if hasattr(item, k))
|
||||
]
|
||||
results.append(op_result)
|
||||
|
||||
@@ -429,7 +412,7 @@ class DatabaseBatchScheduler:
|
||||
**self.stats,
|
||||
"cache_size": len(self._result_cache),
|
||||
"queue_sizes": {k: len(v) for k, v in self.operation_queues.items()},
|
||||
"is_running": self._is_running
|
||||
"is_running": self._is_running,
|
||||
}
|
||||
|
||||
|
||||
@@ -452,43 +435,25 @@ async def get_batch_session():
|
||||
# 便捷函数
|
||||
async def batch_select(model_class: Any, conditions: dict[str, Any]) -> Any:
|
||||
"""批量查询"""
|
||||
operation = BatchOperation(
|
||||
operation_type="select",
|
||||
model_class=model_class,
|
||||
conditions=conditions
|
||||
)
|
||||
operation = BatchOperation(operation_type="select", model_class=model_class, conditions=conditions)
|
||||
return await db_batch_scheduler.add_operation(operation)
|
||||
|
||||
|
||||
async def batch_insert(model_class: Any, data: dict[str, Any]) -> int:
|
||||
"""批量插入"""
|
||||
operation = BatchOperation(
|
||||
operation_type="insert",
|
||||
model_class=model_class,
|
||||
conditions={},
|
||||
data=data
|
||||
)
|
||||
operation = BatchOperation(operation_type="insert", model_class=model_class, conditions={}, data=data)
|
||||
return await db_batch_scheduler.add_operation(operation)
|
||||
|
||||
|
||||
async def batch_update(model_class: Any, conditions: dict[str, Any], data: dict[str, Any]) -> int:
|
||||
"""批量更新"""
|
||||
operation = BatchOperation(
|
||||
operation_type="update",
|
||||
model_class=model_class,
|
||||
conditions=conditions,
|
||||
data=data
|
||||
)
|
||||
operation = BatchOperation(operation_type="update", model_class=model_class, conditions=conditions, data=data)
|
||||
return await db_batch_scheduler.add_operation(operation)
|
||||
|
||||
|
||||
async def batch_delete(model_class: Any, conditions: dict[str, Any]) -> int:
|
||||
"""批量删除"""
|
||||
operation = BatchOperation(
|
||||
operation_type="delete",
|
||||
model_class=model_class,
|
||||
conditions=conditions
|
||||
)
|
||||
operation = BatchOperation(operation_type="delete", model_class=model_class, conditions=conditions)
|
||||
return await db_batch_scheduler.add_operation(operation)
|
||||
|
||||
|
||||
|
||||
@@ -304,8 +304,7 @@ def load_log_config(): # sourcery skip: use-contextlib-suppress
|
||||
"library_log_levels": {"aiohttp": "WARNING"},
|
||||
}
|
||||
|
||||
|
||||
# 误加的即刻线程启动已移除;真正的线程在 start_log_cleanup_task 中按午夜调度
|
||||
# 误加的即刻线程启动已移除;真正的线程在 start_log_cleanup_task 中按午夜调度
|
||||
|
||||
try:
|
||||
if config_path.exists():
|
||||
|
||||
@@ -37,7 +37,9 @@ class DatabaseConfig(ValidatedConfigBase):
|
||||
connection_timeout: int = Field(default=10, ge=1, description="连接超时时间")
|
||||
|
||||
# 批量动作记录存储配置
|
||||
batch_action_storage_enabled: bool = Field(default=True, description="是否启用批量保存动作记录(开启后将多个动作一次性写入数据库,提升性能)")
|
||||
batch_action_storage_enabled: bool = Field(
|
||||
default=True, description="是否启用批量保存动作记录(开启后将多个动作一次性写入数据库,提升性能)"
|
||||
)
|
||||
|
||||
|
||||
class BotConfig(ValidatedConfigBase):
|
||||
@@ -355,7 +357,7 @@ class MemoryConfig(ValidatedConfigBase):
|
||||
# 双峰分布配置 [近期均值, 近期标准差, 近期权重, 远期均值, 远期标准差, 远期权重]
|
||||
hippocampus_distribution_config: list[float] = Field(
|
||||
default=[12.0, 8.0, 0.7, 48.0, 24.0, 0.3],
|
||||
description="海马体双峰分布配置:[近期均值(h), 近期标准差(h), 近期权重, 远期均值(h), 远期标准差(h), 远期权重]"
|
||||
description="海马体双峰分布配置:[近期均值(h), 近期标准差(h), 近期权重, 远期均值(h), 远期标准差(h), 远期权重]",
|
||||
)
|
||||
|
||||
# 自适应采样配置
|
||||
@@ -690,7 +692,6 @@ class AffinityFlowConfig(ValidatedConfigBase):
|
||||
base_relationship_score: float = Field(default=0.5, description="基础人物关系分")
|
||||
|
||||
|
||||
|
||||
class ProactiveThinkingConfig(ValidatedConfigBase):
|
||||
"""主动思考(主动发起对话)功能配置"""
|
||||
|
||||
|
||||
@@ -189,11 +189,11 @@ class ClientRegistry:
|
||||
bool: 事件循环是否变化
|
||||
"""
|
||||
current_loop_id = self._get_current_loop_id()
|
||||
|
||||
|
||||
# 如果没有缓存的循环ID,说明是首次创建
|
||||
if provider_name not in self._event_loop_cache:
|
||||
return False
|
||||
|
||||
|
||||
# 比较当前循环ID与缓存的循环ID
|
||||
cached_loop_id = self._event_loop_cache[provider_name]
|
||||
return current_loop_id != cached_loop_id
|
||||
@@ -208,7 +208,7 @@ class ClientRegistry:
|
||||
BaseClient: 注册的API客户端实例
|
||||
"""
|
||||
provider_name = api_provider.name
|
||||
|
||||
|
||||
# 如果强制创建新实例,直接创建不使用缓存
|
||||
if force_new:
|
||||
if client_class := self.client_registry.get(api_provider.client_type):
|
||||
@@ -224,7 +224,7 @@ class ClientRegistry:
|
||||
# 事件循环已变化,需要重新创建实例
|
||||
logger.debug(f"检测到事件循环变化,为 {provider_name} 重新创建客户端实例")
|
||||
self._loop_change_count += 1
|
||||
|
||||
|
||||
# 移除旧实例
|
||||
if provider_name in self.client_instance_cache:
|
||||
del self.client_instance_cache[provider_name]
|
||||
@@ -237,7 +237,7 @@ class ClientRegistry:
|
||||
self._event_loop_cache[provider_name] = self._get_current_loop_id()
|
||||
else:
|
||||
raise KeyError(f"'{api_provider.client_type}' 类型的 Client 未注册")
|
||||
|
||||
|
||||
return self.client_instance_cache[provider_name]
|
||||
|
||||
def get_cache_stats(self) -> dict:
|
||||
|
||||
@@ -50,14 +50,16 @@ def _convert_messages_to_mcp(messages: list[Message]) -> list[dict[str, Any]]:
|
||||
for item in message.content:
|
||||
if isinstance(item, tuple):
|
||||
# 图片内容
|
||||
content_parts.append({
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": f"image/{item[0].lower()}",
|
||||
"data": item[1],
|
||||
},
|
||||
})
|
||||
content_parts.append(
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": f"image/{item[0].lower()}",
|
||||
"data": item[1],
|
||||
},
|
||||
}
|
||||
)
|
||||
elif isinstance(item, str):
|
||||
# 文本内容
|
||||
content_parts.append({"type": "text", "text": item})
|
||||
@@ -138,9 +140,7 @@ async def _parse_sse_stream(
|
||||
async with session.post(url, json=payload, headers=headers) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
raise RespNotOkException(
|
||||
response.status, f"MCP SSE请求失败: {error_text}"
|
||||
)
|
||||
raise RespNotOkException(response.status, f"MCP SSE请求失败: {error_text}")
|
||||
|
||||
# 解析SSE流
|
||||
async for line in response.content:
|
||||
@@ -258,10 +258,7 @@ async def _parse_sse_stream(
|
||||
response.reasoning_content = reasoning_buffer.getvalue()
|
||||
|
||||
if tool_calls_buffer:
|
||||
response.tool_calls = [
|
||||
ToolCall(call_id, func_name, args)
|
||||
for call_id, func_name, args in tool_calls_buffer
|
||||
]
|
||||
response.tool_calls = [ToolCall(call_id, func_name, args) for call_id, func_name, args in tool_calls_buffer]
|
||||
|
||||
# 关闭缓冲区
|
||||
content_buffer.close()
|
||||
@@ -351,9 +348,7 @@ class MCPSSEClient(BaseClient):
|
||||
url = f"{self.api_provider.base_url}/v1/messages"
|
||||
|
||||
try:
|
||||
response, usage_record = await _parse_sse_stream(
|
||||
session, url, payload, headers, interrupt_flag
|
||||
)
|
||||
response, usage_record = await _parse_sse_stream(session, url, payload, headers, interrupt_flag)
|
||||
except Exception as e:
|
||||
logger.error(f"MCP SSE请求失败: {e}")
|
||||
raise
|
||||
|
||||
@@ -378,7 +378,7 @@ class OpenaiClient(BaseClient):
|
||||
# 类级别的全局缓存:所有 OpenaiClient 实例共享
|
||||
_global_client_cache: dict[int, AsyncOpenAI] = {}
|
||||
"""全局 AsyncOpenAI 客户端缓存:config_hash -> AsyncOpenAI 实例"""
|
||||
|
||||
|
||||
def __init__(self, api_provider: APIProvider):
|
||||
super().__init__(api_provider)
|
||||
self._config_hash = self._calculate_config_hash()
|
||||
@@ -396,33 +396,31 @@ class OpenaiClient(BaseClient):
|
||||
def _create_client(self) -> AsyncOpenAI:
|
||||
"""
|
||||
获取或创建 OpenAI 客户端实例(全局缓存)
|
||||
|
||||
|
||||
多个 OpenaiClient 实例如果配置相同(base_url + api_key + timeout),
|
||||
将共享同一个 AsyncOpenAI 客户端实例,最大化连接池复用。
|
||||
"""
|
||||
# 检查全局缓存
|
||||
if self._config_hash in self._global_client_cache:
|
||||
return self._global_client_cache[self._config_hash]
|
||||
|
||||
|
||||
# 创建新的 AsyncOpenAI 实例
|
||||
logger.debug(
|
||||
f"创建新的 AsyncOpenAI 客户端实例 "
|
||||
f"(base_url={self.api_provider.base_url}, "
|
||||
f"config_hash={self._config_hash})"
|
||||
f"创建新的 AsyncOpenAI 客户端实例 (base_url={self.api_provider.base_url}, config_hash={self._config_hash})"
|
||||
)
|
||||
|
||||
|
||||
client = AsyncOpenAI(
|
||||
base_url=self.api_provider.base_url,
|
||||
api_key=self.api_provider.get_api_key(),
|
||||
max_retries=0,
|
||||
timeout=self.api_provider.timeout,
|
||||
)
|
||||
|
||||
|
||||
# 存入全局缓存
|
||||
self._global_client_cache[self._config_hash] = client
|
||||
|
||||
|
||||
return client
|
||||
|
||||
|
||||
@classmethod
|
||||
def get_cache_stats(cls) -> dict:
|
||||
"""获取全局缓存统计信息"""
|
||||
|
||||
@@ -280,7 +280,9 @@ class _PromptProcessor:
|
||||
这有助于我判断你的输出是否被截断。请不要在 `{self.end_marker}` 前后添加任何其他文字或标点。
|
||||
"""
|
||||
|
||||
async def prepare_prompt(self, prompt: str, model_info: ModelInfo, api_provider: APIProvider, task_name: str) -> str:
|
||||
async def prepare_prompt(
|
||||
self, prompt: str, model_info: ModelInfo, api_provider: APIProvider, task_name: str
|
||||
) -> str:
|
||||
"""
|
||||
为请求准备最终的提示词。
|
||||
|
||||
|
||||
31
src/main.py
31
src/main.py
@@ -88,6 +88,7 @@ class MainSystem:
|
||||
|
||||
def _setup_signal_handlers(self) -> None:
|
||||
"""设置信号处理器"""
|
||||
|
||||
def signal_handler(signum, frame):
|
||||
if self._shutting_down:
|
||||
logger.warning("系统已经在关闭过程中,忽略重复信号")
|
||||
@@ -132,6 +133,7 @@ class MainSystem:
|
||||
try:
|
||||
from src.plugin_system.apis.component_manage_api import get_components_info_by_type
|
||||
from src.plugin_system.base.component_types import ComponentType
|
||||
|
||||
interest_calculators = get_components_info_by_type(ComponentType.INTEREST_CALCULATOR)
|
||||
logger.info(f"通过组件注册表发现 {len(interest_calculators)} 个兴趣计算器组件")
|
||||
except Exception as e:
|
||||
@@ -143,6 +145,7 @@ class MainSystem:
|
||||
|
||||
# 初始化兴趣度管理器
|
||||
from src.chat.interest_system.interest_manager import get_interest_manager
|
||||
|
||||
interest_manager = get_interest_manager()
|
||||
await interest_manager.initialize()
|
||||
|
||||
@@ -159,7 +162,10 @@ class MainSystem:
|
||||
|
||||
try:
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
component_class = component_registry.get_component_class(calc_name, ComponentType.INTEREST_CALCULATOR)
|
||||
|
||||
component_class = component_registry.get_component_class(
|
||||
calc_name, ComponentType.INTEREST_CALCULATOR
|
||||
)
|
||||
|
||||
if not component_class:
|
||||
logger.warning(f"无法找到 {calc_name} 的组件类")
|
||||
@@ -208,6 +214,7 @@ class MainSystem:
|
||||
# 停止数据库服务
|
||||
try:
|
||||
from src.common.database.database import stop_database
|
||||
|
||||
cleanup_tasks.append(("数据库服务", stop_database()))
|
||||
except Exception as e:
|
||||
logger.error(f"准备停止数据库服务时出错: {e}")
|
||||
@@ -215,6 +222,7 @@ class MainSystem:
|
||||
# 停止消息管理器
|
||||
try:
|
||||
from src.chat.message_manager import message_manager
|
||||
|
||||
cleanup_tasks.append(("消息管理器", message_manager.stop()))
|
||||
except Exception as e:
|
||||
logger.error(f"准备停止消息管理器时出错: {e}")
|
||||
@@ -222,6 +230,7 @@ class MainSystem:
|
||||
# 停止消息重组器
|
||||
try:
|
||||
from src.utils.message_chunker import reassembler
|
||||
|
||||
cleanup_tasks.append(("消息重组器", reassembler.stop_cleanup_task()))
|
||||
except Exception as e:
|
||||
logger.error(f"准备停止消息重组器时出错: {e}")
|
||||
@@ -236,15 +245,18 @@ class MainSystem:
|
||||
# 触发停止事件
|
||||
try:
|
||||
from src.plugin_system.core.event_manager import event_manager
|
||||
cleanup_tasks.append(("插件系统停止事件",
|
||||
event_manager.trigger_event(EventType.ON_STOP, permission_group="SYSTEM")))
|
||||
|
||||
cleanup_tasks.append(
|
||||
("插件系统停止事件", event_manager.trigger_event(EventType.ON_STOP, permission_group="SYSTEM"))
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"准备触发停止事件时出错: {e}")
|
||||
|
||||
# 停止表情管理器
|
||||
try:
|
||||
cleanup_tasks.append(("表情管理器",
|
||||
asyncio.get_event_loop().run_in_executor(None, get_emoji_manager().shutdown)))
|
||||
cleanup_tasks.append(
|
||||
("表情管理器", asyncio.get_event_loop().run_in_executor(None, get_emoji_manager().shutdown))
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"准备停止表情管理器时出错: {e}")
|
||||
|
||||
@@ -275,7 +287,7 @@ class MainSystem:
|
||||
try:
|
||||
results = await asyncio.wait_for(
|
||||
asyncio.gather(*tasks, return_exceptions=True),
|
||||
timeout=30.0 # 30秒超时
|
||||
timeout=30.0, # 30秒超时
|
||||
)
|
||||
|
||||
# 记录结果
|
||||
@@ -389,6 +401,7 @@ MoFox_Bot(第三方修改版)
|
||||
# 注册API路由
|
||||
try:
|
||||
from src.api.message_router import router as message_router
|
||||
|
||||
self.server.register_router(message_router, prefix="/api")
|
||||
logger.info("API路由注册成功")
|
||||
except Exception as e:
|
||||
@@ -405,6 +418,7 @@ MoFox_Bot(第三方修改版)
|
||||
mcp_config = global_config.get("mcp_servers", [])
|
||||
if mcp_config:
|
||||
from src.plugin_system.utils.mcp_tool_provider import mcp_tool_provider
|
||||
|
||||
await mcp_tool_provider.initialize(mcp_config)
|
||||
logger.info("MCP工具提供器初始化成功")
|
||||
except Exception as e:
|
||||
@@ -445,6 +459,7 @@ MoFox_Bot(第三方修改版)
|
||||
# 初始化LPMM知识库
|
||||
try:
|
||||
from src.chat.knowledge.knowledge_lib import initialize_lpmm_knowledge
|
||||
|
||||
initialize_lpmm_knowledge()
|
||||
logger.info("LPMM知识库初始化成功")
|
||||
except Exception as e:
|
||||
@@ -456,6 +471,7 @@ MoFox_Bot(第三方修改版)
|
||||
# 启动消息重组器
|
||||
try:
|
||||
from src.utils.message_chunker import reassembler
|
||||
|
||||
await reassembler.start_cleanup_task()
|
||||
logger.info("消息重组器已启动")
|
||||
except Exception as e:
|
||||
@@ -464,6 +480,7 @@ MoFox_Bot(第三方修改版)
|
||||
# 启动消息管理器
|
||||
try:
|
||||
from src.chat.message_manager import message_manager
|
||||
|
||||
await message_manager.start()
|
||||
logger.info("消息管理器已启动")
|
||||
except Exception as e:
|
||||
@@ -504,6 +521,7 @@ MoFox_Bot(第三方修改版)
|
||||
|
||||
def _safe_init(self, component_name: str, init_func) -> callable:
|
||||
"""安全初始化组件,捕获异常"""
|
||||
|
||||
async def wrapper():
|
||||
try:
|
||||
result = init_func()
|
||||
@@ -514,6 +532,7 @@ MoFox_Bot(第三方修改版)
|
||||
except Exception as e:
|
||||
logger.error(f"{component_name}初始化失败: {e}")
|
||||
return False
|
||||
|
||||
return wrapper
|
||||
|
||||
async def schedule_tasks(self) -> None:
|
||||
|
||||
@@ -59,6 +59,7 @@ class ChatMood:
|
||||
"""异步初始化方法"""
|
||||
if not self._initialized:
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
|
||||
chat_manager = get_chat_manager()
|
||||
self.chat_stream = await chat_manager.get_stream(self.chat_id)
|
||||
|
||||
|
||||
@@ -69,6 +69,7 @@ class RelationshipBuilder:
|
||||
if not self._log_prefix_initialized:
|
||||
try:
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
|
||||
chat_name = await get_chat_manager().get_stream_name(self.chat_id)
|
||||
self.log_prefix = f"[{chat_name}]"
|
||||
except Exception:
|
||||
|
||||
@@ -85,6 +85,7 @@ class RelationshipFetcher:
|
||||
"""异步初始化log_prefix"""
|
||||
if not self._log_prefix_initialized:
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
|
||||
name = await get_chat_manager().get_stream_name(self.chat_id)
|
||||
self.log_prefix = f"[{name}] 实时信息"
|
||||
self._log_prefix_initialized = True
|
||||
|
||||
@@ -59,6 +59,7 @@ async def get_replyer(
|
||||
logger.debug(f"[GeneratorAPI] 正在获取回复器,chat_id: {chat_id}, chat_stream: {'有' if chat_stream else '无'}")
|
||||
# 动态导入避免循环依赖
|
||||
from src.chat.replyer.replyer_manager import replyer_manager
|
||||
|
||||
return await replyer_manager.get_replyer(
|
||||
chat_stream=chat_stream,
|
||||
chat_id=chat_id,
|
||||
|
||||
@@ -39,6 +39,7 @@ def get_llm_available_tool_definitions():
|
||||
# 添加MCP工具
|
||||
try:
|
||||
from src.plugin_system.utils.mcp_tool_provider import mcp_tool_provider
|
||||
|
||||
mcp_tools = mcp_tool_provider.get_mcp_tool_definitions()
|
||||
tool_definitions.extend(mcp_tools)
|
||||
if mcp_tools:
|
||||
|
||||
@@ -86,7 +86,9 @@ class HandlerResultsCollection:
|
||||
|
||||
|
||||
class BaseEvent:
|
||||
def __init__(self, name: str, allowed_subscribers: list[str] | None = None, allowed_triggers: list[str] | None = None):
|
||||
def __init__(
|
||||
self, name: str, allowed_subscribers: list[str] | None = None, allowed_triggers: list[str] | None = None
|
||||
):
|
||||
self.name = name
|
||||
self.enabled = True
|
||||
self.allowed_subscribers = allowed_subscribers # 记录事件处理器名
|
||||
|
||||
@@ -28,7 +28,7 @@ class InterestCalculationResult:
|
||||
should_reply: bool = False,
|
||||
should_act: bool = False,
|
||||
error_message: str | None = None,
|
||||
calculation_time: float = 0.0
|
||||
calculation_time: float = 0.0,
|
||||
):
|
||||
self.success = success
|
||||
self.message_id = message_id
|
||||
@@ -51,17 +51,19 @@ class InterestCalculationResult:
|
||||
"should_act": self.should_act,
|
||||
"error_message": self.error_message,
|
||||
"calculation_time": self.calculation_time,
|
||||
"timestamp": self.timestamp
|
||||
"timestamp": self.timestamp,
|
||||
}
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"InterestCalculationResult("
|
||||
f"success={self.success}, "
|
||||
f"message_id={self.message_id}, "
|
||||
f"interest_value={self.interest_value:.3f}, "
|
||||
f"should_take_action={self.should_take_action}, "
|
||||
f"should_reply={self.should_reply}, "
|
||||
f"should_act={self.should_act})")
|
||||
return (
|
||||
f"InterestCalculationResult("
|
||||
f"success={self.success}, "
|
||||
f"message_id={self.message_id}, "
|
||||
f"interest_value={self.interest_value:.3f}, "
|
||||
f"should_take_action={self.should_take_action}, "
|
||||
f"should_reply={self.should_reply}, "
|
||||
f"should_act={self.should_act})"
|
||||
)
|
||||
|
||||
|
||||
class BaseInterestCalculator(ABC):
|
||||
@@ -144,7 +146,7 @@ class BaseInterestCalculator(ABC):
|
||||
"failed_calculations": self._failed_calculations,
|
||||
"success_rate": 1.0 - (self._failed_calculations / max(1, self._total_calculations)),
|
||||
"average_calculation_time": self._average_calculation_time,
|
||||
"last_calculation_time": self._last_calculation_time
|
||||
"last_calculation_time": self._last_calculation_time,
|
||||
}
|
||||
|
||||
def _update_statistics(self, result: InterestCalculationResult):
|
||||
@@ -159,8 +161,7 @@ class BaseInterestCalculator(ABC):
|
||||
else:
|
||||
alpha = 0.1 # 指数移动平均的平滑因子
|
||||
self._average_calculation_time = (
|
||||
alpha * result.calculation_time +
|
||||
(1 - alpha) * self._average_calculation_time
|
||||
alpha * result.calculation_time + (1 - alpha) * self._average_calculation_time
|
||||
)
|
||||
|
||||
self._last_calculation_time = result.timestamp
|
||||
@@ -172,7 +173,7 @@ class BaseInterestCalculator(ABC):
|
||||
success=False,
|
||||
message_id=getattr(message, "message_id", ""),
|
||||
interest_value=0.0,
|
||||
error_message="组件未启用"
|
||||
error_message="组件未启用",
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
@@ -187,7 +188,7 @@ class BaseInterestCalculator(ABC):
|
||||
message_id=getattr(message, "message_id", ""),
|
||||
interest_value=0.0,
|
||||
error_message=f"计算执行失败: {e!s}",
|
||||
calculation_time=time.time() - start_time
|
||||
calculation_time=time.time() - start_time,
|
||||
)
|
||||
self._update_statistics(result)
|
||||
return result
|
||||
@@ -214,7 +215,9 @@ class BaseInterestCalculator(ABC):
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"{self.__class__.__name__}("
|
||||
f"name={self.component_name}, "
|
||||
f"version={self.component_version}, "
|
||||
f"enabled={self._enabled})")
|
||||
return (
|
||||
f"{self.__class__.__name__}("
|
||||
f"name={self.component_name}, "
|
||||
f"version={self.component_version}, "
|
||||
f"enabled={self._enabled})"
|
||||
)
|
||||
|
||||
@@ -60,7 +60,9 @@ class BasePlugin(PluginBase):
|
||||
if hasattr(component_class, "get_interest_calculator_info"):
|
||||
return component_class.get_interest_calculator_info()
|
||||
else:
|
||||
logger.warning(f"InterestCalculator类 {component_class.__name__} 缺少 get_interest_calculator_info 方法")
|
||||
logger.warning(
|
||||
f"InterestCalculator类 {component_class.__name__} 缺少 get_interest_calculator_info 方法"
|
||||
)
|
||||
return None
|
||||
|
||||
elif component_type == ComponentType.PLUS_COMMAND:
|
||||
@@ -96,6 +98,7 @@ class BasePlugin(PluginBase):
|
||||
对应类型的ComponentInfo对象
|
||||
"""
|
||||
return cls._get_component_info_from_class(component_class, component_type)
|
||||
|
||||
@abstractmethod
|
||||
def get_plugin_components(
|
||||
self,
|
||||
|
||||
@@ -7,6 +7,7 @@ class PluginMetadata:
|
||||
"""
|
||||
插件元数据,用于存储插件的开发者信息和用户帮助信息。
|
||||
"""
|
||||
|
||||
name: str # 插件名称 (供用户查看)
|
||||
description: str # 插件功能描述
|
||||
usage: str # 插件使用方法
|
||||
|
||||
@@ -319,7 +319,9 @@ class ComponentRegistry:
|
||||
return True
|
||||
|
||||
def _register_interest_calculator_component(
|
||||
self, interest_calculator_info: "InterestCalculatorInfo", interest_calculator_class: type["BaseInterestCalculator"]
|
||||
self,
|
||||
interest_calculator_info: "InterestCalculatorInfo",
|
||||
interest_calculator_class: type["BaseInterestCalculator"],
|
||||
) -> bool:
|
||||
"""注册InterestCalculator组件到特定注册表"""
|
||||
calculator_name = interest_calculator_info.name
|
||||
@@ -327,7 +329,9 @@ class ComponentRegistry:
|
||||
if not calculator_name:
|
||||
logger.error(f"InterestCalculator组件 {interest_calculator_class.__name__} 必须指定名称")
|
||||
return False
|
||||
if not isinstance(interest_calculator_info, InterestCalculatorInfo) or not issubclass(interest_calculator_class, BaseInterestCalculator):
|
||||
if not isinstance(interest_calculator_info, InterestCalculatorInfo) or not issubclass(
|
||||
interest_calculator_class, BaseInterestCalculator
|
||||
):
|
||||
logger.error(f"注册失败: {calculator_name} 不是有效的InterestCalculator")
|
||||
return False
|
||||
|
||||
|
||||
@@ -67,6 +67,7 @@ class ToolExecutor:
|
||||
"""异步初始化log_prefix和chat_stream"""
|
||||
if not self._log_prefix_initialized:
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
|
||||
self.chat_stream = await get_chat_manager().get_stream(self.chat_id)
|
||||
stream_name = await get_chat_manager().get_stream_name(self.chat_id)
|
||||
self.log_prefix = f"[{stream_name or self.chat_id}]"
|
||||
@@ -283,6 +284,7 @@ class ToolExecutor:
|
||||
# 检查是否是MCP工具
|
||||
try:
|
||||
from src.plugin_system.utils.mcp_tool_provider import mcp_tool_provider
|
||||
|
||||
if function_name in mcp_tool_provider.mcp_tools:
|
||||
logger.info(f"{self.log_prefix}执行MCP工具: {function_name}")
|
||||
result = await mcp_tool_provider.call_mcp_tool(function_name, function_args)
|
||||
|
||||
@@ -8,8 +8,5 @@ __plugin_meta__ = PluginMetadata(
|
||||
author="MoFox",
|
||||
keywords=["chatter", "affinity", "conversation"],
|
||||
categories=["Chat", "AI"],
|
||||
extra={
|
||||
"is_built_in": True
|
||||
}
|
||||
extra={"is_built_in": True},
|
||||
)
|
||||
|
||||
|
||||
@@ -149,7 +149,6 @@ class AffinityChatter(BaseChatter):
|
||||
"""
|
||||
return self.planner.get_mood_stats()
|
||||
|
||||
|
||||
def reset_stats(self):
|
||||
"""重置统计信息"""
|
||||
self.stats = {
|
||||
|
||||
@@ -111,9 +111,11 @@ class AffinityInterestCalculator(BaseInterestCalculator):
|
||||
+ mentioned_score * self.score_weights["mentioned"]
|
||||
)
|
||||
|
||||
logger.debug(f"[Affinity兴趣计算] 综合得分计算: {interest_match_score:.3f}*{self.score_weights['interest_match']} + "
|
||||
f"{relationship_score:.3f}*{self.score_weights['relationship']} + "
|
||||
f"{mentioned_score:.3f}*{self.score_weights['mentioned']} = {total_score:.3f}")
|
||||
logger.debug(
|
||||
f"[Affinity兴趣计算] 综合得分计算: {interest_match_score:.3f}*{self.score_weights['interest_match']} + "
|
||||
f"{relationship_score:.3f}*{self.score_weights['relationship']} + "
|
||||
f"{mentioned_score:.3f}*{self.score_weights['mentioned']} = {total_score:.3f}"
|
||||
)
|
||||
|
||||
# 5. 考虑连续不回复的概率提升
|
||||
adjusted_score = self._apply_no_reply_boost(total_score)
|
||||
@@ -135,8 +137,10 @@ class AffinityInterestCalculator(BaseInterestCalculator):
|
||||
|
||||
calculation_time = time.time() - start_time
|
||||
|
||||
logger.debug(f"Affinity兴趣值计算完成 - 消息 {message_id}: {adjusted_score:.3f} "
|
||||
f"(匹配:{interest_match_score:.2f}, 关系:{relationship_score:.2f}, 提及:{mentioned_score:.2f})")
|
||||
logger.debug(
|
||||
f"Affinity兴趣值计算完成 - 消息 {message_id}: {adjusted_score:.3f} "
|
||||
f"(匹配:{interest_match_score:.2f}, 关系:{relationship_score:.2f}, 提及:{mentioned_score:.2f})"
|
||||
)
|
||||
|
||||
return InterestCalculationResult(
|
||||
success=True,
|
||||
@@ -145,16 +149,13 @@ class AffinityInterestCalculator(BaseInterestCalculator):
|
||||
should_take_action=should_take_action,
|
||||
should_reply=should_reply,
|
||||
should_act=should_take_action,
|
||||
calculation_time=calculation_time
|
||||
calculation_time=calculation_time,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Affinity兴趣值计算失败: {e}", exc_info=True)
|
||||
return InterestCalculationResult(
|
||||
success=False,
|
||||
message_id=getattr(message, "message_id", ""),
|
||||
interest_value=0.0,
|
||||
error_message=str(e)
|
||||
success=False, message_id=getattr(message, "message_id", ""), interest_value=0.0, error_message=str(e)
|
||||
)
|
||||
|
||||
async def _calculate_interest_match_score(self, content: str, keywords: list[str] | None = None) -> float:
|
||||
|
||||
@@ -405,7 +405,6 @@ class ChatterPlanExecutor:
|
||||
# 移除执行时间列表以避免返回过大数据
|
||||
stats.pop("execution_times", None)
|
||||
|
||||
|
||||
return stats
|
||||
|
||||
def reset_stats(self):
|
||||
@@ -434,12 +433,12 @@ class ChatterPlanExecutor:
|
||||
for i, time_val in enumerate(recent_times)
|
||||
]
|
||||
|
||||
|
||||
async def _flush_action_manager_batch_storage(self, plan: Plan):
|
||||
"""使用 action_manager 的批量存储功能存储所有待处理的动作"""
|
||||
try:
|
||||
# 通过 chat_id 获取真实的 chat_stream 对象
|
||||
from src.plugin_system.apis.chat_api import get_chat_manager
|
||||
|
||||
chat_manager = get_chat_manager()
|
||||
chat_stream = await chat_manager.get_stream(plan.chat_id)
|
||||
|
||||
@@ -455,4 +454,3 @@ class ChatterPlanExecutor:
|
||||
logger.error(f"批量存储动作记录时发生错误: {e}")
|
||||
# 确保在出错时也禁用批量存储模式
|
||||
self.action_manager.disable_batch_storage()
|
||||
|
||||
|
||||
@@ -64,7 +64,6 @@ class ChatterPlanFilter:
|
||||
|
||||
llm_content, _ = await self.planner_llm.generate_response_async(prompt=prompt)
|
||||
|
||||
|
||||
if llm_content:
|
||||
if global_config.debug.show_prompt:
|
||||
logger.info(f"LLM规划器原始响应:{llm_content}")
|
||||
|
||||
@@ -132,7 +132,6 @@ class ChatterActionPlanner:
|
||||
if message_should_act:
|
||||
aggregate_should_act = True
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"处理消息 {message.message_id} 失败: {e}")
|
||||
message.interest_value = 0.0
|
||||
|
||||
@@ -39,6 +39,7 @@ class AffinityChatterPlugin(BasePlugin):
|
||||
try:
|
||||
# 延迟导入 AffinityChatter
|
||||
from .affinity_chatter import AffinityChatter
|
||||
|
||||
components.append((AffinityChatter.get_chatter_info(), AffinityChatter))
|
||||
except Exception as e:
|
||||
logger.error(f"加载 AffinityChatter 时出错: {e}")
|
||||
@@ -46,9 +47,9 @@ class AffinityChatterPlugin(BasePlugin):
|
||||
try:
|
||||
# 延迟导入 AffinityInterestCalculator
|
||||
from .affinity_interest_calculator import AffinityInterestCalculator
|
||||
|
||||
components.append((AffinityInterestCalculator.get_interest_calculator_info(), AffinityInterestCalculator))
|
||||
except Exception as e:
|
||||
logger.error(f"加载 AffinityInterestCalculator 时出错: {e}")
|
||||
|
||||
return components
|
||||
|
||||
|
||||
@@ -13,5 +13,5 @@ __plugin_meta__ = PluginMetadata(
|
||||
extra={
|
||||
"is_built_in": True,
|
||||
"plugin_type": "action_provider",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
@@ -13,5 +13,5 @@ __plugin_meta__ = PluginMetadata(
|
||||
extra={
|
||||
"is_built_in": False,
|
||||
"plugin_type": "social",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
@@ -12,5 +12,5 @@ __plugin_meta__ = PluginMetadata(
|
||||
categories=["protocol"],
|
||||
extra={
|
||||
"is_built_in": False,
|
||||
}
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@@ -13,8 +13,8 @@ def create_router(plugin_config: dict):
|
||||
"""创建路由器实例"""
|
||||
global router
|
||||
platform_name = config_api.get_plugin_config(plugin_config, "maibot_server.platform_name", "qq")
|
||||
host = os.getenv("HOST","127.0.0.1")
|
||||
port = os.getenv("PORT","8000")
|
||||
host = os.getenv("HOST", "127.0.0.1")
|
||||
port = os.getenv("PORT", "8000")
|
||||
logger.debug(f"初始化MaiBot连接,使用地址:{host}:{port}")
|
||||
route_config = RouteConfig(
|
||||
route_config={
|
||||
|
||||
@@ -12,5 +12,5 @@ __plugin_meta__ = PluginMetadata(
|
||||
extra={
|
||||
"is_built_in": True,
|
||||
"plugin_type": "permission",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
@@ -13,5 +13,5 @@ __plugin_meta__ = PluginMetadata(
|
||||
extra={
|
||||
"is_built_in": True,
|
||||
"plugin_type": "plugin_management",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
@@ -8,10 +8,7 @@ __plugin_meta__ = PluginMetadata(
|
||||
author="MoFox-Studio",
|
||||
license="GPL-v3.0-or-later",
|
||||
repository_url="https://github.com/MoFox-Studio",
|
||||
keywords=["主动思考","自己发消息"],
|
||||
keywords=["主动思考", "自己发消息"],
|
||||
categories=["Chat", "Integration"],
|
||||
extra={
|
||||
"is_built_in": True,
|
||||
"plugin_type": "functional"
|
||||
}
|
||||
extra={"is_built_in": True, "plugin_type": "functional"},
|
||||
)
|
||||
|
||||
@@ -63,7 +63,9 @@ class ColdStartTask(AsyncTask):
|
||||
logger.info(f"【冷启动】发现全新用户 {chat_id},准备发起第一次问候。")
|
||||
elif stream.last_active_time < self.bot_start_time:
|
||||
should_wake_up = True
|
||||
logger.info(f"【冷启动】发现沉睡的聊天流 {chat_id} (最后活跃于 {datetime.fromtimestamp(stream.last_active_time)}),准备唤醒。")
|
||||
logger.info(
|
||||
f"【冷启动】发现沉睡的聊天流 {chat_id} (最后活跃于 {datetime.fromtimestamp(stream.last_active_time)}),准备唤醒。"
|
||||
)
|
||||
|
||||
if should_wake_up:
|
||||
person_id = person_api.get_person_id(platform, user_id)
|
||||
@@ -166,7 +168,9 @@ class ProactiveThinkingTask(AsyncTask):
|
||||
continue
|
||||
|
||||
# 检查冷却时间
|
||||
recent_messages = await message_api.get_recent_messages(chat_id=stream.stream_id, limit=1,limit_mode="latest")
|
||||
recent_messages = await message_api.get_recent_messages(
|
||||
chat_id=stream.stream_id, limit=1, limit_mode="latest"
|
||||
)
|
||||
last_message_time = recent_messages[0]["time"] if recent_messages else stream.create_time
|
||||
time_since_last_active = time.time() - last_message_time
|
||||
if time_since_last_active > next_interval:
|
||||
@@ -209,7 +213,7 @@ class ProactiveThinkingTask(AsyncTask):
|
||||
logger.info("日常唤醒任务被正常取消。")
|
||||
break
|
||||
except Exception as e:
|
||||
traceback.print_exc() # 打印完整的堆栈跟踪
|
||||
traceback.print_exc() # 打印完整的堆栈跟踪
|
||||
logger.error(f"【日常唤醒】任务出现错误,将在60秒后重试: {e}", exc_info=True)
|
||||
await asyncio.sleep(60)
|
||||
|
||||
|
||||
@@ -143,14 +143,16 @@ class ProactiveThinkerExecutor:
|
||||
else "今天没有日程安排。"
|
||||
)
|
||||
|
||||
recent_messages = await message_api.get_recent_messages(stream.stream_id,limit=50,limit_mode="latest",hours=12)
|
||||
recent_messages = await message_api.get_recent_messages(
|
||||
stream.stream_id, limit=50, limit_mode="latest", hours=12
|
||||
)
|
||||
recent_chat_history = (
|
||||
await message_api.build_readable_messages_to_str(recent_messages) if recent_messages else "无"
|
||||
)
|
||||
|
||||
action_history_list = await get_actions_by_timestamp_with_chat(
|
||||
chat_id=stream.stream_id,
|
||||
timestamp_start=time.time() - 3600 * 24, #过去24小时
|
||||
timestamp_start=time.time() - 3600 * 24, # 过去24小时
|
||||
timestamp_end=time.time(),
|
||||
limit=7,
|
||||
)
|
||||
@@ -195,11 +197,9 @@ class ProactiveThinkerExecutor:
|
||||
person_id = person_api.get_person_id(user_info.platform, int(user_info.user_id))
|
||||
person_info_manager = get_person_info_manager()
|
||||
person_info = await person_info_manager.get_values(person_id, ["user_id", "platform", "person_name"])
|
||||
cross_context_block = await Prompt.build_cross_context(
|
||||
stream.stream_id, "s4u", person_info
|
||||
)
|
||||
cross_context_block = await Prompt.build_cross_context(stream.stream_id, "s4u", person_info)
|
||||
|
||||
# 获取关系信息
|
||||
# 获取关系信息
|
||||
short_impression = await person_info_manager.get_value(person_id, "short_impression") or "无"
|
||||
impression = await person_info_manager.get_value(person_id, "impression") or "无"
|
||||
attitude = await person_info_manager.get_value(person_id, "attitude") or 50
|
||||
|
||||
@@ -10,8 +10,5 @@ __plugin_meta__ = PluginMetadata(
|
||||
repository_url="https://github.com/MoFox-Studio",
|
||||
keywords=["emoji", "reaction", "like", "表情", "回应", "点赞"],
|
||||
categories=["Chat", "Integration"],
|
||||
extra={
|
||||
"is_built_in": "true",
|
||||
"plugin_type": "functional"
|
||||
}
|
||||
extra={"is_built_in": "true", "plugin_type": "functional"},
|
||||
)
|
||||
|
||||
@@ -548,7 +548,7 @@ class SetEmojiLikePlugin(BasePlugin):
|
||||
config_section_descriptions = {"plugin": "插件基本信息", "components": "插件组件"}
|
||||
|
||||
# 配置Schema定义
|
||||
config_schema: ClassVar[dict ]= {
|
||||
config_schema: ClassVar[dict] = {
|
||||
"plugin": {
|
||||
"name": ConfigField(type=str, default="set_emoji_like", description="插件名称"),
|
||||
"version": ConfigField(type=str, default="1.0.0", description="插件版本"),
|
||||
|
||||
@@ -13,5 +13,5 @@ __plugin_meta__ = PluginMetadata(
|
||||
extra={
|
||||
"is_built_in": True,
|
||||
"plugin_type": "audio_processor",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
@@ -12,5 +12,5 @@ __plugin_meta__ = PluginMetadata(
|
||||
categories=["Tools"],
|
||||
extra={
|
||||
"is_built_in": True,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
@@ -43,7 +43,9 @@ class SearXNGSearchEngine(BaseSearchEngine):
|
||||
|
||||
api_keys = config_api.get_global_config("web_search.searxng_api_keys", None)
|
||||
if isinstance(api_keys, list):
|
||||
self.api_keys: list[str | None] = [k.strip() if isinstance(k, str) and k.strip() else None for k in api_keys]
|
||||
self.api_keys: list[str | None] = [
|
||||
k.strip() if isinstance(k, str) and k.strip() else None for k in api_keys
|
||||
]
|
||||
else:
|
||||
self.api_keys = []
|
||||
|
||||
@@ -51,9 +53,7 @@ class SearXNGSearchEngine(BaseSearchEngine):
|
||||
if self.api_keys and len(self.api_keys) < len(self.instances):
|
||||
self.api_keys.extend([None] * (len(self.instances) - len(self.api_keys)))
|
||||
|
||||
logger.debug(
|
||||
f"SearXNG 引擎配置: instances={self.instances}, api_keys={'yes' if any(self.api_keys) else 'no'}"
|
||||
)
|
||||
logger.debug(f"SearXNG 引擎配置: instances={self.instances}, api_keys={'yes' if any(self.api_keys) else 'no'}")
|
||||
|
||||
def is_available(self) -> bool:
|
||||
return bool(self.instances)
|
||||
|
||||
Reference in New Issue
Block a user