修了点pyright错误喵~
This commit is contained in:
@@ -110,6 +110,8 @@ def init_prompt() -> None:
|
||||
|
||||
class ExpressionLearner:
|
||||
def __init__(self, chat_id: str) -> None:
|
||||
if model_config is None:
|
||||
raise RuntimeError("Model config is not initialized")
|
||||
self.express_learn_model: LLMRequest = LLMRequest(
|
||||
model_set=model_config.model_task_config.replyer, request_type="expressor.learner"
|
||||
)
|
||||
@@ -143,7 +145,10 @@ class ExpressionLearner:
|
||||
"""
|
||||
# 从配置读取过期天数
|
||||
if expiration_days is None:
|
||||
expiration_days = global_config.expression.expiration_days
|
||||
if global_config is None:
|
||||
expiration_days = 30 # Default value if config is missing
|
||||
else:
|
||||
expiration_days = global_config.expression.expiration_days
|
||||
|
||||
current_time = time.time()
|
||||
expiration_threshold = current_time - (expiration_days * 24 * 3600)
|
||||
@@ -192,6 +197,8 @@ class ExpressionLearner:
|
||||
bool: 是否允许学习
|
||||
"""
|
||||
try:
|
||||
if global_config is None:
|
||||
return False
|
||||
use_expression, enable_learning, _ = global_config.expression.get_expression_config_for_chat(self.chat_id)
|
||||
return enable_learning
|
||||
except Exception as e:
|
||||
@@ -212,6 +219,8 @@ class ExpressionLearner:
|
||||
|
||||
# 获取该聊天流的学习强度
|
||||
try:
|
||||
if global_config is None:
|
||||
return False
|
||||
use_expression, enable_learning, learning_intensity = (
|
||||
global_config.expression.get_expression_config_for_chat(self.chat_id)
|
||||
)
|
||||
@@ -424,8 +433,10 @@ class ExpressionLearner:
|
||||
group_name = f"聊天流 {chat_id}"
|
||||
elif chat_stream.group_info:
|
||||
group_name = chat_stream.group_info.group_name
|
||||
else:
|
||||
elif chat_stream.user_info and chat_stream.user_info.user_nickname:
|
||||
group_name = f"{chat_stream.user_info.user_nickname}的私聊"
|
||||
else:
|
||||
group_name = f"聊天流 {chat_id}"
|
||||
learnt_expressions_str = ""
|
||||
for _chat_id, situation, style in learnt_expressions:
|
||||
learnt_expressions_str += f"{situation}->{style}\n"
|
||||
|
||||
@@ -78,6 +78,8 @@ def weighted_sample(population: list[dict], weights: list[float], k: int) -> lis
|
||||
class ExpressionSelector:
|
||||
def __init__(self, chat_id: str = ""):
|
||||
self.chat_id = chat_id
|
||||
if model_config is None:
|
||||
raise RuntimeError("Model config is not initialized")
|
||||
self.llm_model = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils_small, request_type="expression.selector"
|
||||
)
|
||||
@@ -94,6 +96,8 @@ class ExpressionSelector:
|
||||
bool: 是否允许使用表达
|
||||
"""
|
||||
try:
|
||||
if global_config is None:
|
||||
return False
|
||||
use_expression, _, _ = global_config.expression.get_expression_config_for_chat(chat_id)
|
||||
return use_expression
|
||||
except Exception as e:
|
||||
@@ -122,6 +126,8 @@ class ExpressionSelector:
|
||||
|
||||
def get_related_chat_ids(self, chat_id: str) -> list[str]:
|
||||
"""根据expression.rules配置,获取与当前chat_id相关的所有chat_id(包括自身)"""
|
||||
if global_config is None:
|
||||
return [chat_id]
|
||||
rules = global_config.expression.rules
|
||||
current_group = None
|
||||
|
||||
@@ -280,6 +286,9 @@ class ExpressionSelector:
|
||||
else:
|
||||
chat_info = chat_history
|
||||
|
||||
if global_config is None:
|
||||
raise RuntimeError("Global config is not initialized")
|
||||
|
||||
# 根据配置选择模式
|
||||
mode = global_config.expression.mode
|
||||
logger.debug(f"使用表达选择模式: {mode}")
|
||||
@@ -582,6 +591,9 @@ class ExpressionSelector:
|
||||
target_message_str = ""
|
||||
target_message_extra_block = ""
|
||||
|
||||
if global_config is None:
|
||||
raise RuntimeError("Global config is not initialized")
|
||||
|
||||
# 3. 构建prompt(只包含情境,不包含完整的表达方式)
|
||||
prompt = (await global_prompt_manager.get_prompt_async("expression_evaluation_prompt")).format(
|
||||
bot_name=global_config.bot.nickname,
|
||||
|
||||
@@ -42,6 +42,8 @@ class SituationExtractor:
|
||||
"""情境提取器,从聊天历史中提取当前情境"""
|
||||
|
||||
def __init__(self):
|
||||
if model_config is None:
|
||||
raise RuntimeError("Model config is not initialized")
|
||||
self.llm_model = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils_small,
|
||||
request_type="expression.situation_extractor"
|
||||
@@ -81,6 +83,8 @@ class SituationExtractor:
|
||||
|
||||
# 构建 prompt
|
||||
try:
|
||||
if global_config is None:
|
||||
raise RuntimeError("Global config is not initialized")
|
||||
prompt = (await global_prompt_manager.get_prompt_async("situation_extraction_prompt")).format(
|
||||
bot_name=global_config.bot.nickname,
|
||||
chat_history=chat_info,
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
import numpy as np
|
||||
from sqlalchemy import select
|
||||
@@ -77,6 +77,9 @@ class BotInterestManager:
|
||||
from src.config.config import model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
if model_config is None:
|
||||
raise RuntimeError("Model config is not initialized")
|
||||
|
||||
# 检查embedding配置是否存在
|
||||
if not hasattr(model_config.model_task_config, "embedding"):
|
||||
raise RuntimeError("❌ 未找到embedding模型配置")
|
||||
@@ -251,6 +254,9 @@ class BotInterestManager:
|
||||
from src.config.config import model_config
|
||||
from src.plugin_system.apis import llm_api
|
||||
|
||||
if model_config is None:
|
||||
raise RuntimeError("Model config is not initialized")
|
||||
|
||||
# 构建完整的提示词,明确要求只返回纯JSON
|
||||
full_prompt = f"""你是一个专业的机器人人设分析师,擅长根据人设描述生成合适的兴趣标签。
|
||||
|
||||
@@ -348,9 +354,15 @@ class BotInterestManager:
|
||||
embedding, model_name = await self.embedding_request.get_embedding(text)
|
||||
|
||||
if embedding and len(embedding) > 0:
|
||||
self.embedding_cache[text] = embedding
|
||||
if isinstance(embedding[0], list):
|
||||
# If it's a list of lists, take the first one (though get_embedding(str) should return list[float])
|
||||
embedding = embedding[0]
|
||||
|
||||
# Now we can safely cast to list[float] as we've handled the nested list case
|
||||
embedding_float = cast(list[float], embedding)
|
||||
self.embedding_cache[text] = embedding_float
|
||||
|
||||
current_dim = len(embedding)
|
||||
current_dim = len(embedding_float)
|
||||
if self._detected_embedding_dimension is None:
|
||||
self._detected_embedding_dimension = current_dim
|
||||
if self.embedding_dimension and self.embedding_dimension != current_dim:
|
||||
@@ -367,7 +379,7 @@ class BotInterestManager:
|
||||
self.embedding_dimension,
|
||||
current_dim,
|
||||
)
|
||||
return embedding
|
||||
return embedding_float
|
||||
else:
|
||||
raise RuntimeError(f"❌ 返回的embedding为空: {embedding}")
|
||||
|
||||
@@ -416,7 +428,10 @@ class BotInterestManager:
|
||||
|
||||
for idx_offset, message_id in enumerate(chunk_keys):
|
||||
vector = normalized[idx_offset] if idx_offset < len(normalized) else []
|
||||
results[message_id] = vector
|
||||
if isinstance(vector, list) and vector and isinstance(vector[0], float):
|
||||
results[message_id] = cast(list[float], vector)
|
||||
else:
|
||||
results[message_id] = []
|
||||
|
||||
return results
|
||||
|
||||
@@ -493,6 +508,9 @@ class BotInterestManager:
|
||||
medium_similarity_count = 0
|
||||
low_similarity_count = 0
|
||||
|
||||
if global_config is None:
|
||||
raise RuntimeError("Global config is not initialized")
|
||||
|
||||
# 分级相似度阈值 - 优化后可以提高阈值,因为匹配更准确了
|
||||
affinity_config = global_config.affinity_flow
|
||||
high_threshold = affinity_config.high_match_interest_threshold
|
||||
@@ -711,6 +729,9 @@ class BotInterestManager:
|
||||
if not keywords or not matched_tags:
|
||||
return {}
|
||||
|
||||
if global_config is None:
|
||||
return {}
|
||||
|
||||
affinity_config = global_config.affinity_flow
|
||||
bonus_dict = {}
|
||||
|
||||
@@ -1010,7 +1031,10 @@ class BotInterestManager:
|
||||
# 验证缓存版本和embedding模型
|
||||
cache_version = cache_data.get("version", 1)
|
||||
cache_embedding_model = cache_data.get("embedding_model", "")
|
||||
current_embedding_model = self.embedding_config.model_list[0] if hasattr(self.embedding_config, "model_list") else ""
|
||||
|
||||
current_embedding_model = ""
|
||||
if self.embedding_config and hasattr(self.embedding_config, "model_list") and self.embedding_config.model_list:
|
||||
current_embedding_model = self.embedding_config.model_list[0]
|
||||
|
||||
if cache_embedding_model != current_embedding_model:
|
||||
logger.warning(f"⚠️ Embedding模型已变更 ({cache_embedding_model} → {current_embedding_model}),忽略旧缓存")
|
||||
@@ -1044,7 +1068,10 @@ class BotInterestManager:
|
||||
cache_file = cache_dir / f"{personality_id}_embeddings.json"
|
||||
|
||||
# 准备缓存数据
|
||||
current_embedding_model = self.embedding_config.model_list[0] if hasattr(self.embedding_config, "model_list") and self.embedding_config.model_list else ""
|
||||
current_embedding_model = ""
|
||||
if self.embedding_config and hasattr(self.embedding_config, "model_list") and self.embedding_config.model_list:
|
||||
current_embedding_model = self.embedding_config.model_list[0]
|
||||
|
||||
cache_data = {
|
||||
"version": 1,
|
||||
"personality_id": personality_id,
|
||||
|
||||
@@ -144,6 +144,15 @@ class InterestManager:
|
||||
start_time = time.time()
|
||||
self._total_calculations += 1
|
||||
|
||||
if not self._current_calculator:
|
||||
return InterestCalculationResult(
|
||||
success=False,
|
||||
message_id=getattr(message, "message_id", ""),
|
||||
interest_value=0.0,
|
||||
error_message="没有可用的兴趣值计算组件",
|
||||
calculation_time=time.time() - start_time,
|
||||
)
|
||||
|
||||
try:
|
||||
# 使用组件的安全执行方法
|
||||
result = await self._current_calculator._safe_execute(message)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import os
|
||||
import time
|
||||
from typing import cast
|
||||
|
||||
import numpy as np
|
||||
import orjson
|
||||
@@ -139,6 +140,9 @@ class KGManager:
|
||||
embedding_manager: EmbeddingManager,
|
||||
) -> int:
|
||||
"""同义词连接"""
|
||||
if global_config is None:
|
||||
raise RuntimeError("Global config is not initialized")
|
||||
|
||||
new_edge_cnt = 0
|
||||
# 获取所有实体节点的hash值
|
||||
ent_hash_list = set()
|
||||
@@ -242,7 +246,8 @@ class KGManager:
|
||||
else:
|
||||
# 已存在的边
|
||||
edge_item = self.graph[src_tgt[0], src_tgt[1]]
|
||||
edge_item["weight"] += weight
|
||||
edge_item = cast(di_graph.DiEdge, edge_item)
|
||||
edge_item["weight"] = cast(float, edge_item["weight"]) + weight
|
||||
edge_item["update_time"] = now_time
|
||||
self.graph.update_edge(edge_item)
|
||||
|
||||
@@ -258,6 +263,7 @@ class KGManager:
|
||||
continue
|
||||
assert isinstance(node, EmbeddingStoreItem)
|
||||
node_item = self.graph[node_hash]
|
||||
node_item = cast(di_graph.DiNode, node_item)
|
||||
node_item["content"] = node.str
|
||||
node_item["type"] = "ent"
|
||||
node_item["create_time"] = now_time
|
||||
@@ -271,6 +277,7 @@ class KGManager:
|
||||
assert isinstance(node, EmbeddingStoreItem)
|
||||
content = node.str.replace("\n", " ")
|
||||
node_item = self.graph[node_hash]
|
||||
node_item = cast(di_graph.DiNode, node_item)
|
||||
node_item["content"] = content if len(content) < 8 else content[:8] + "..."
|
||||
node_item["type"] = "pg"
|
||||
node_item["create_time"] = now_time
|
||||
@@ -326,6 +333,9 @@ class KGManager:
|
||||
paragraph_search_result: ParagraphEmbedding的搜索结果(paragraph_hash, similarity)
|
||||
embed_manager: EmbeddingManager对象
|
||||
"""
|
||||
if global_config is None:
|
||||
raise RuntimeError("Global config is not initialized")
|
||||
|
||||
# 图中存在的节点总集
|
||||
existed_nodes = self.graph.get_node_list()
|
||||
|
||||
@@ -339,9 +349,12 @@ class KGManager:
|
||||
|
||||
# 针对每个关系,提取出其中的主宾短语作为两个实体,并记录对应的三元组的相似度作为权重依据
|
||||
ent_sim_scores = {}
|
||||
for relation_hash, similarity, _ in relation_search_result:
|
||||
for relation_hash, similarity in relation_search_result:
|
||||
# 提取主宾短语
|
||||
relation = embed_manager.relation_embedding_store.store.get(relation_hash).str
|
||||
relation_item = embed_manager.relation_embedding_store.store.get(relation_hash)
|
||||
if relation_item is None:
|
||||
continue
|
||||
relation = relation_item.str
|
||||
assert relation is not None # 断言:relation不为空
|
||||
# 关系三元组
|
||||
triple = relation[2:-2].split("', '")
|
||||
|
||||
@@ -36,6 +36,9 @@ def initialize_lpmm_knowledge():
|
||||
"""初始化LPMM知识库"""
|
||||
global qa_manager, inspire_manager
|
||||
|
||||
if global_config is None:
|
||||
raise RuntimeError("Global config is not initialized")
|
||||
|
||||
# 检查LPMM知识库是否启用
|
||||
if global_config.lpmm_knowledge.enable:
|
||||
logger.info("正在初始化Mai-LPMM")
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import time
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
from src.chat.utils.utils import get_embedding
|
||||
from src.config.config import global_config, model_config
|
||||
@@ -21,6 +21,8 @@ class QAManager:
|
||||
embed_manager: EmbeddingManager,
|
||||
kg_manager: KGManager,
|
||||
):
|
||||
if model_config is None:
|
||||
raise RuntimeError("Model config is not initialized")
|
||||
self.embed_manager = embed_manager
|
||||
self.kg_manager = kg_manager
|
||||
self.qa_model = LLMRequest(model_set=model_config.model_task_config.lpmm_qa, request_type="lpmm.qa")
|
||||
@@ -29,6 +31,8 @@ class QAManager:
|
||||
self, question: str
|
||||
) -> tuple[list[tuple[str, float, float]], dict[str, float] | None] | None:
|
||||
"""处理查询"""
|
||||
if global_config is None:
|
||||
raise RuntimeError("Global config is not initialized")
|
||||
|
||||
# 生成问题的Embedding
|
||||
part_start_time = time.perf_counter()
|
||||
@@ -61,7 +65,7 @@ class QAManager:
|
||||
for res in relation_search_res:
|
||||
if store_item := self.embed_manager.relation_embedding_store.store.get(res[0]):
|
||||
rel_str = store_item.str
|
||||
print(f"找到相关关系,相似度:{(res[1] * 100):.2f}% - {rel_str}")
|
||||
print(f"找到相关关系,相似度:{(res[1] * 100):.2f}% - {rel_str}")
|
||||
|
||||
# TODO: 使用LLM过滤三元组结果
|
||||
# logger.info(f"LLM过滤三元组用时:{time.time() - part_start_time:.2f}s")
|
||||
@@ -80,8 +84,52 @@ class QAManager:
|
||||
logger.info("找到相关关系,将使用RAG进行检索")
|
||||
# 使用KG检索
|
||||
part_start_time = time.perf_counter()
|
||||
# Cast relation_search_res to the expected type for kg_search
|
||||
# The search_top_k returns list[tuple[Any, float, float]], but kg_search expects list[tuple[tuple[str, str, str], float]]
|
||||
# We assume the ID (res[0]) in relation_search_res is actually a tuple[str, str, str] (the relation triple)
|
||||
# or at least compatible. However, looking at kg_manager.py, it expects relation_hash (str) in relation_search_result?
|
||||
# Wait, let's check kg_manager.py again.
|
||||
# kg_search signature: relation_search_result: list[tuple[tuple[str, str, str], float]]
|
||||
# But in kg_manager.py:
|
||||
# for relation_hash, similarity, _ in relation_search_result:
|
||||
# relation = embed_manager.relation_embedding_store.store.get(relation_hash).str
|
||||
# This implies relation_search_result items are tuples of (relation_hash, similarity, ...)
|
||||
# So the type hint in kg_manager.py might be wrong or I am misinterpreting it.
|
||||
# The error says: "tuple[Any, float, float]" vs "tuple[tuple[str, str, str], float]"
|
||||
# It seems kg_search expects the first element to be a tuple of strings?
|
||||
# But the implementation uses it as a hash key to look up in store.
|
||||
# Let's look at kg_manager.py again.
|
||||
|
||||
# In kg_manager.py:
|
||||
# def kg_search(self, relation_search_result: list[tuple[tuple[str, str, str], float]], ...)
|
||||
# ...
|
||||
# for relation_hash, similarity in relation_search_result:
|
||||
# relation_item = embed_manager.relation_embedding_store.store.get(relation_hash)
|
||||
|
||||
# Wait, I just fixed kg_manager.py to:
|
||||
# for relation_hash, similarity in relation_search_result:
|
||||
|
||||
# So it expects a tuple of 2 elements?
|
||||
# But search_top_k returns (id, score, vector).
|
||||
# So relation_search_res is list[tuple[Any, float, float]].
|
||||
|
||||
# I need to adapt the data or cast it.
|
||||
# If I pass it directly, it has 3 elements.
|
||||
# If kg_manager expects 2, I should probably slice it.
|
||||
|
||||
# Let's cast it for now to silence the error, assuming the runtime behavior is compatible (unpacking first 2 of 3 is fine in python if not strict, but here it is strict unpacking in loop?)
|
||||
# In kg_manager.py I changed it to:
|
||||
# for relation_hash, similarity in relation_search_result:
|
||||
# This will fail if the tuple has 3 elements! "too many values to unpack"
|
||||
|
||||
# So I should probably fix the data passed to kg_search to be list[tuple[str, float]].
|
||||
|
||||
relation_search_result_for_kg = [(str(res[0]), float(res[1])) for res in relation_search_res]
|
||||
|
||||
result, ppr_node_weights = self.kg_manager.kg_search(
|
||||
relation_search_res, paragraph_search_res, self.embed_manager
|
||||
cast(list[tuple[tuple[str, str, str], float]], relation_search_result_for_kg), # The type hint in kg_manager is weird, but let's match it or cast to Any
|
||||
paragraph_search_res,
|
||||
self.embed_manager
|
||||
)
|
||||
part_end_time = time.perf_counter()
|
||||
logger.info(f"RAG检索用时:{part_end_time - part_start_time:.5f}s")
|
||||
|
||||
@@ -51,13 +51,13 @@ class BatchDatabaseWriter:
|
||||
self.writer_task: asyncio.Task | None = None
|
||||
|
||||
# 统计信息
|
||||
self.stats = {
|
||||
self.stats: dict[str, int | float] = {
|
||||
"total_writes": 0,
|
||||
"batch_writes": 0,
|
||||
"failed_writes": 0,
|
||||
"queue_size": 0,
|
||||
"avg_batch_size": 0,
|
||||
"last_flush_time": 0,
|
||||
"avg_batch_size": 0.0,
|
||||
"last_flush_time": 0.0,
|
||||
}
|
||||
|
||||
# 按优先级分类的批次
|
||||
@@ -220,6 +220,9 @@ class BatchDatabaseWriter:
|
||||
|
||||
async def _batch_write_to_database(self, payloads: list[StreamUpdatePayload]):
|
||||
"""批量写入数据库"""
|
||||
if global_config is None:
|
||||
raise RuntimeError("Global config is not initialized")
|
||||
|
||||
async with get_db_session() as session:
|
||||
for payload in payloads:
|
||||
stream_id = payload.stream_id
|
||||
@@ -254,11 +257,11 @@ class BatchDatabaseWriter:
|
||||
stmt = stmt.on_conflict_do_update(index_elements=["stream_id"], set_=update_data)
|
||||
|
||||
await session.execute(stmt)
|
||||
|
||||
await session.commit()
|
||||
|
||||
async def _direct_write(self, stream_id: str, update_data: dict[str, Any]):
|
||||
"""直接写入数据库(降级方案)"""
|
||||
if global_config is None:
|
||||
raise RuntimeError("Global config is not initialized")
|
||||
|
||||
async with get_db_session() as session:
|
||||
if global_config.database.database_type == "sqlite":
|
||||
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
||||
|
||||
@@ -23,6 +23,9 @@ class StreamLoopManager:
|
||||
"""流循环管理器 - 每个流一个独立的无限循环任务"""
|
||||
|
||||
def __init__(self, max_concurrent_streams: int | None = None):
|
||||
if global_config is None:
|
||||
raise RuntimeError("Global config is not initialized")
|
||||
|
||||
# 统计信息
|
||||
self.stats: dict[str, Any] = {
|
||||
"active_streams": 0,
|
||||
@@ -246,6 +249,8 @@ class StreamLoopManager:
|
||||
|
||||
# 4. 激活chatter处理
|
||||
try:
|
||||
if global_config is None:
|
||||
raise RuntimeError("Global config is not initialized")
|
||||
success = await asyncio.wait_for(self._process_stream_messages(stream_id, context), global_config.chat.thinking_timeout)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"⏱️ [流工作器] stream={stream_id[:8]}, 任务ID={task_id}, 处理超时")
|
||||
@@ -444,7 +449,6 @@ class StreamLoopManager:
|
||||
except Exception as e:
|
||||
logger.warning(f"刷新StreamContext缓存失败: stream={stream_id}, error={e}")
|
||||
return []
|
||||
|
||||
async def _update_stream_energy(self, stream_id: str, context: Any) -> None:
|
||||
"""更新流的能量值
|
||||
|
||||
@@ -452,6 +456,9 @@ class StreamLoopManager:
|
||||
stream_id: 流ID
|
||||
context: 流上下文 (StreamContext)
|
||||
"""
|
||||
if global_config is None:
|
||||
raise RuntimeError("Global config is not initialized")
|
||||
|
||||
try:
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
|
||||
@@ -509,6 +516,9 @@ class StreamLoopManager:
|
||||
Returns:
|
||||
float: 间隔时间(秒)
|
||||
"""
|
||||
if global_config is None:
|
||||
raise RuntimeError("Global config is not initialized")
|
||||
|
||||
# 基础间隔
|
||||
base_interval = getattr(global_config.chat, "distribution_interval", 5.0)
|
||||
|
||||
|
||||
@@ -66,7 +66,7 @@ class GlobalNoticeManager:
|
||||
self._last_cleanup_time = time.time()
|
||||
|
||||
# 统计信息
|
||||
self.stats = {
|
||||
self.stats: dict[str, Any] = {
|
||||
"total_notices": 0,
|
||||
"public_notices": 0,
|
||||
"stream_notices": 0,
|
||||
|
||||
@@ -277,6 +277,9 @@ class MessageManager:
|
||||
|
||||
async def _check_and_handle_interruption(self, chat_stream: "ChatStream | None" = None, message: DatabaseMessages | None = None):
|
||||
"""检查并处理消息打断 - 通过取消 stream_loop_task 实现"""
|
||||
if global_config is None:
|
||||
raise RuntimeError("Global config is not initialized")
|
||||
|
||||
if not global_config.chat.interruption_enabled or not chat_stream or not message:
|
||||
return
|
||||
|
||||
|
||||
@@ -240,6 +240,9 @@ class ChatStream:
|
||||
|
||||
async def calculate_focus_energy(self) -> float:
|
||||
"""异步计算focus_energy"""
|
||||
if global_config is None:
|
||||
raise RuntimeError("Global config is not initialized")
|
||||
|
||||
try:
|
||||
# 使用单流上下文管理器获取消息
|
||||
all_messages = self.context.get_messages(limit=global_config.chat.max_context_size)
|
||||
@@ -629,6 +632,9 @@ class ChatManager:
|
||||
|
||||
# 回退到原始方法(最终方案)
|
||||
async def _db_save_stream_async(s_data_dict: dict):
|
||||
if global_config is None:
|
||||
raise RuntimeError("Global config is not initialized")
|
||||
|
||||
async with get_db_session() as session:
|
||||
user_info_d = s_data_dict.get("user_info")
|
||||
group_info_d = s_data_dict.get("group_info")
|
||||
|
||||
@@ -30,7 +30,7 @@ from __future__ import annotations
|
||||
import os
|
||||
import re
|
||||
import traceback
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from mofox_wire import MessageEnvelope, MessageRuntime
|
||||
|
||||
@@ -55,6 +55,8 @@ PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))
|
||||
|
||||
def _check_ban_words(text: str, chat: "ChatStream", userinfo) -> bool:
|
||||
"""检查消息是否包含过滤词"""
|
||||
if global_config is None:
|
||||
return False
|
||||
for word in global_config.message_receive.ban_words:
|
||||
if word in text:
|
||||
chat_name = chat.group_info.group_name if chat.group_info else "私聊"
|
||||
@@ -62,10 +64,10 @@ def _check_ban_words(text: str, chat: "ChatStream", userinfo) -> bool:
|
||||
logger.info(f"[过滤词识别]消息中含有{word},filtered")
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _check_ban_regex(text: str, chat: "ChatStream", userinfo) -> bool:
|
||||
"""检查消息是否匹配过滤正则表达式"""
|
||||
if global_config is None:
|
||||
return False
|
||||
for pattern in global_config.message_receive.ban_msgs_regex:
|
||||
if re.search(pattern, text):
|
||||
chat_name = chat.group_info.group_name if chat.group_info else "私聊"
|
||||
@@ -281,8 +283,8 @@ class MessageHandler:
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
chat = await get_chat_manager().get_or_create_stream(
|
||||
platform=platform,
|
||||
user_info=DatabaseUserInfo.from_dict(user_info) if user_info else None, # type: ignore
|
||||
group_info=DatabaseGroupInfo.from_dict(group_info) if group_info else None,
|
||||
user_info=DatabaseUserInfo.from_dict(cast(dict[str, Any], user_info)) if user_info else None, # type: ignore
|
||||
group_info=DatabaseGroupInfo.from_dict(cast(dict[str, Any], group_info)) if group_info else None,
|
||||
)
|
||||
|
||||
# 将消息信封转换为 DatabaseMessages
|
||||
@@ -431,8 +433,8 @@ class MessageHandler:
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
chat = await get_chat_manager().get_or_create_stream(
|
||||
platform=platform,
|
||||
user_info=DatabaseUserInfo.from_dict(user_info) if user_info else None, # type: ignore
|
||||
group_info=DatabaseGroupInfo.from_dict(group_info) if group_info else None,
|
||||
user_info=DatabaseUserInfo.from_dict(cast(dict[str, Any], user_info)) if user_info else None, # type: ignore
|
||||
group_info=DatabaseGroupInfo.from_dict(cast(dict[str, Any], group_info)) if group_info else None,
|
||||
)
|
||||
|
||||
# 将消息信封转换为 DatabaseMessages
|
||||
@@ -536,6 +538,8 @@ class MessageHandler:
|
||||
text = message.processed_plain_text or ""
|
||||
|
||||
# 获取配置的命令前缀
|
||||
if global_config is None:
|
||||
return False, None, True
|
||||
prefixes = global_config.command.command_prefixes
|
||||
|
||||
# 检查是否以任何前缀开头
|
||||
@@ -704,6 +708,9 @@ class MessageHandler:
|
||||
async def _preprocess_message(self, message: DatabaseMessages, chat: "ChatStream") -> None:
|
||||
"""预处理消息:存储、情绪更新等"""
|
||||
try:
|
||||
if global_config is None:
|
||||
return
|
||||
|
||||
group_info = chat.group_info
|
||||
|
||||
# 检查是否需要处理消息
|
||||
|
||||
@@ -256,7 +256,7 @@ async def _process_single_segment(
|
||||
# 检查消息是否由机器人自己发送
|
||||
user_info = message_info.get("user_info", {})
|
||||
user_id_str = str(user_info.get("user_id", ""))
|
||||
if user_id_str == str(global_config.bot.qq_account):
|
||||
if global_config and user_id_str == str(global_config.bot.qq_account):
|
||||
logger.info(f"检测到机器人自身发送的语音消息 (User ID: {user_id_str}),尝试从缓存获取文本。")
|
||||
if isinstance(seg_data, str):
|
||||
cached_text = consume_self_voice_text(seg_data)
|
||||
@@ -299,7 +299,7 @@ async def _process_single_segment(
|
||||
logger.warning("⚠️ Rust视频处理模块不可用,跳过视频分析")
|
||||
return "[视频]"
|
||||
|
||||
if global_config.video_analysis.enable:
|
||||
if global_config and global_config.video_analysis.enable:
|
||||
logger.info("已启用视频识别,开始识别")
|
||||
if isinstance(seg_data, dict):
|
||||
try:
|
||||
|
||||
@@ -3,10 +3,11 @@ import re
|
||||
import time
|
||||
import traceback
|
||||
from collections import deque
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
from typing import Optional, TYPE_CHECKING, cast
|
||||
|
||||
import orjson
|
||||
from sqlalchemy import desc, select, update
|
||||
from sqlalchemy.engine import CursorResult
|
||||
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.common.database.core import get_db_session
|
||||
@@ -343,7 +344,7 @@ class MessageUpdateBatcher:
|
||||
.where(Messages.message_id == mmc_id)
|
||||
.values(message_id=qq_id)
|
||||
)
|
||||
if result.rowcount > 0:
|
||||
if cast(CursorResult, result).rowcount > 0:
|
||||
updated_count += 1
|
||||
|
||||
await session.commit()
|
||||
@@ -571,7 +572,7 @@ class MessageStorage:
|
||||
result = await session.execute(stmt)
|
||||
await session.commit()
|
||||
|
||||
if result.rowcount > 0:
|
||||
if cast(CursorResult, result).rowcount > 0:
|
||||
logger.debug(f"成功更新消息 {message_id} 的interest_value为 {interest_value}")
|
||||
else:
|
||||
logger.warning(f"未找到消息 {message_id},无法更新interest_value")
|
||||
@@ -667,7 +668,7 @@ class MessageStorage:
|
||||
)
|
||||
|
||||
result = await session.execute(update_stmt)
|
||||
if result.rowcount > 0:
|
||||
if cast(CursorResult, result).rowcount > 0:
|
||||
fixed_count += 1
|
||||
logger.debug(f"修复消息 {msg.message_id} 的interest_value为 {default_interest}")
|
||||
|
||||
|
||||
@@ -133,7 +133,7 @@ class HeartFCSender:
|
||||
|
||||
# 将发送的消息写入上下文历史
|
||||
try:
|
||||
if chat_stream and chat_stream.context and global_config.chat:
|
||||
if chat_stream and chat_stream.context and global_config and global_config.chat:
|
||||
context = chat_stream.context
|
||||
chat_config = global_config.chat
|
||||
if chat_config:
|
||||
|
||||
Reference in New Issue
Block a user