typing fix

This commit is contained in:
UnCLAS-Prommer
2025-07-17 00:10:41 +08:00
parent 6e838ccc74
commit 1aa2734d62
26 changed files with 329 additions and 293 deletions

27
bot.py
View File

@@ -146,7 +146,7 @@ def _calculate_file_hash(file_path: Path, file_type: str) -> str:
if not file_path.exists(): if not file_path.exists():
logger.error(f"{file_type} 文件不存在") logger.error(f"{file_type} 文件不存在")
raise FileNotFoundError(f"{file_type} 文件不存在") raise FileNotFoundError(f"{file_type} 文件不存在")
with open(file_path, "r", encoding="utf-8") as f: with open(file_path, "r", encoding="utf-8") as f:
content = f.read() content = f.read()
return hashlib.md5(content.encode("utf-8")).hexdigest() return hashlib.md5(content.encode("utf-8")).hexdigest()
@@ -154,21 +154,21 @@ def _calculate_file_hash(file_path: Path, file_type: str) -> str:
def _check_agreement_status(file_hash: str, confirm_file: Path, env_var: str) -> tuple[bool, bool]: def _check_agreement_status(file_hash: str, confirm_file: Path, env_var: str) -> tuple[bool, bool]:
"""检查协议确认状态 """检查协议确认状态
Returns: Returns:
tuple[bool, bool]: (已确认, 未更新) tuple[bool, bool]: (已确认, 未更新)
""" """
# 检查环境变量确认 # 检查环境变量确认
if file_hash == os.getenv(env_var): if file_hash == os.getenv(env_var):
return True, False return True, False
# 检查确认文件 # 检查确认文件
if confirm_file.exists(): if confirm_file.exists():
with open(confirm_file, "r", encoding="utf-8") as f: with open(confirm_file, "r", encoding="utf-8") as f:
confirmed_content = f.read() confirmed_content = f.read()
if file_hash == confirmed_content: if file_hash == confirmed_content:
return True, False return True, False
return False, True return False, True
@@ -178,7 +178,7 @@ def _prompt_user_confirmation(eula_hash: str, privacy_hash: str) -> None:
confirm_logger.critical( confirm_logger.critical(
f'输入"同意""confirmed"或设置环境变量"EULA_AGREE={eula_hash}""PRIVACY_AGREE={privacy_hash}"继续运行' f'输入"同意""confirmed"或设置环境变量"EULA_AGREE={eula_hash}""PRIVACY_AGREE={privacy_hash}"继续运行'
) )
while True: while True:
user_input = input().strip().lower() user_input = input().strip().lower()
if user_input in ["同意", "confirmed"]: if user_input in ["同意", "confirmed"]:
@@ -186,13 +186,12 @@ def _prompt_user_confirmation(eula_hash: str, privacy_hash: str) -> None:
confirm_logger.critical('请输入"同意""confirmed"以继续运行') confirm_logger.critical('请输入"同意""confirmed"以继续运行')
def _save_confirmations(eula_updated: bool, privacy_updated: bool, def _save_confirmations(eula_updated: bool, privacy_updated: bool, eula_hash: str, privacy_hash: str) -> None:
eula_hash: str, privacy_hash: str) -> None:
"""保存用户确认结果""" """保存用户确认结果"""
if eula_updated: if eula_updated:
logger.info(f"更新EULA确认文件{eula_hash}") logger.info(f"更新EULA确认文件{eula_hash}")
Path("eula.confirmed").write_text(eula_hash, encoding="utf-8") Path("eula.confirmed").write_text(eula_hash, encoding="utf-8")
if privacy_updated: if privacy_updated:
logger.info(f"更新隐私条款确认文件{privacy_hash}") logger.info(f"更新隐私条款确认文件{privacy_hash}")
Path("privacy.confirmed").write_text(privacy_hash, encoding="utf-8") Path("privacy.confirmed").write_text(privacy_hash, encoding="utf-8")
@@ -203,19 +202,17 @@ def check_eula():
# 计算文件哈希值 # 计算文件哈希值
eula_hash = _calculate_file_hash(Path("EULA.md"), "EULA.md") eula_hash = _calculate_file_hash(Path("EULA.md"), "EULA.md")
privacy_hash = _calculate_file_hash(Path("PRIVACY.md"), "PRIVACY.md") privacy_hash = _calculate_file_hash(Path("PRIVACY.md"), "PRIVACY.md")
# 检查确认状态 # 检查确认状态
eula_confirmed, eula_updated = _check_agreement_status( eula_confirmed, eula_updated = _check_agreement_status(eula_hash, Path("eula.confirmed"), "EULA_AGREE")
eula_hash, Path("eula.confirmed"), "EULA_AGREE"
)
privacy_confirmed, privacy_updated = _check_agreement_status( privacy_confirmed, privacy_updated = _check_agreement_status(
privacy_hash, Path("privacy.confirmed"), "PRIVACY_AGREE" privacy_hash, Path("privacy.confirmed"), "PRIVACY_AGREE"
) )
# 早期返回:如果都已确认且未更新 # 早期返回:如果都已确认且未更新
if eula_confirmed and privacy_confirmed: if eula_confirmed and privacy_confirmed:
return return
# 如果有更新,需要重新确认 # 如果有更新,需要重新确认
if eula_updated or privacy_updated: if eula_updated or privacy_updated:
_prompt_user_confirmation(eula_hash, privacy_hash) _prompt_user_confirmation(eula_hash, privacy_hash)
@@ -225,7 +222,7 @@ def check_eula():
def raw_main(): def raw_main():
# 利用 TZ 环境变量设定程序工作的时区 # 利用 TZ 环境变量设定程序工作的时区
if platform.system().lower() != "windows": if platform.system().lower() != "windows":
time.tzset() time.tzset() # type: ignore
check_eula() check_eula()
logger.info("检查EULA和隐私条款完成") logger.info("检查EULA和隐私条款完成")

View File

@@ -107,11 +107,12 @@ class ExpressionLearner:
last_active_time = expr.get("last_active_time", time.time()) last_active_time = expr.get("last_active_time", time.time())
# 查重同chat_id+type+situation+style # 查重同chat_id+type+situation+style
from src.common.database.database_model import Expression from src.common.database.database_model import Expression
query = Expression.select().where( query = Expression.select().where(
(Expression.chat_id == chat_id) & (Expression.chat_id == chat_id)
(Expression.type == type_str) & & (Expression.type == type_str)
(Expression.situation == situation) & & (Expression.situation == situation)
(Expression.style == style_val) & (Expression.style == style_val)
) )
if query.exists(): if query.exists():
expr_obj = query.get() expr_obj = query.get()
@@ -125,7 +126,7 @@ class ExpressionLearner:
count=count, count=count,
last_active_time=last_active_time, last_active_time=last_active_time,
chat_id=chat_id, chat_id=chat_id,
type=type_str type=type_str,
) )
logger.info(f"已迁移 {expr_file} 到数据库") logger.info(f"已迁移 {expr_file} 到数据库")
except Exception as e: except Exception as e:
@@ -149,24 +150,28 @@ class ExpressionLearner:
# 直接从数据库查询 # 直接从数据库查询
style_query = Expression.select().where((Expression.chat_id == chat_id) & (Expression.type == "style")) style_query = Expression.select().where((Expression.chat_id == chat_id) & (Expression.type == "style"))
for expr in style_query: for expr in style_query:
learnt_style_expressions.append({ learnt_style_expressions.append(
"situation": expr.situation, {
"style": expr.style, "situation": expr.situation,
"count": expr.count, "style": expr.style,
"last_active_time": expr.last_active_time, "count": expr.count,
"source_id": chat_id, "last_active_time": expr.last_active_time,
"type": "style" "source_id": chat_id,
}) "type": "style",
}
)
grammar_query = Expression.select().where((Expression.chat_id == chat_id) & (Expression.type == "grammar")) grammar_query = Expression.select().where((Expression.chat_id == chat_id) & (Expression.type == "grammar"))
for expr in grammar_query: for expr in grammar_query:
learnt_grammar_expressions.append({ learnt_grammar_expressions.append(
"situation": expr.situation, {
"style": expr.style, "situation": expr.situation,
"count": expr.count, "style": expr.style,
"last_active_time": expr.last_active_time, "count": expr.count,
"source_id": chat_id, "last_active_time": expr.last_active_time,
"type": "grammar" "source_id": chat_id,
}) "type": "grammar",
}
)
return learnt_style_expressions, learnt_grammar_expressions return learnt_style_expressions, learnt_grammar_expressions
def is_similar(self, s1: str, s2: str) -> bool: def is_similar(self, s1: str, s2: str) -> bool:
@@ -213,14 +218,16 @@ class ExpressionLearner:
logger.error(f"全局衰减{type}表达方式失败: {e}") logger.error(f"全局衰减{type}表达方式失败: {e}")
continue continue
learnt_style: Optional[List[Tuple[str, str, str]]] = []
learnt_grammar: Optional[List[Tuple[str, str, str]]] = []
# 学习新的表达方式(这里会进行局部衰减) # 学习新的表达方式(这里会进行局部衰减)
for _ in range(3): for _ in range(3):
learnt_style: Optional[List[Tuple[str, str, str]]] = await self.learn_and_store(type="style", num=25) learnt_style = await self.learn_and_store(type="style", num=25)
if not learnt_style: if not learnt_style:
return [], [] return [], []
for _ in range(1): for _ in range(1):
learnt_grammar: Optional[List[Tuple[str, str, str]]] = await self.learn_and_store(type="grammar", num=10) learnt_grammar = await self.learn_and_store(type="grammar", num=10)
if not learnt_grammar: if not learnt_grammar:
return [], [] return [], []
@@ -321,10 +328,10 @@ class ExpressionLearner:
for new_expr in expr_list: for new_expr in expr_list:
# 查找是否已存在相似表达方式 # 查找是否已存在相似表达方式
query = Expression.select().where( query = Expression.select().where(
(Expression.chat_id == chat_id) & (Expression.chat_id == chat_id)
(Expression.type == type) & & (Expression.type == type)
(Expression.situation == new_expr["situation"]) & & (Expression.situation == new_expr["situation"])
(Expression.style == new_expr["style"]) & (Expression.style == new_expr["style"])
) )
if query.exists(): if query.exists():
expr_obj = query.get() expr_obj = query.get()
@@ -342,13 +349,17 @@ class ExpressionLearner:
count=1, count=1,
last_active_time=current_time, last_active_time=current_time,
chat_id=chat_id, chat_id=chat_id,
type=type type=type,
) )
# 限制最大数量 # 限制最大数量
exprs = list(Expression.select().where((Expression.chat_id == chat_id) & (Expression.type == type)).order_by(Expression.count.asc())) exprs = list(
Expression.select()
.where((Expression.chat_id == chat_id) & (Expression.type == type))
.order_by(Expression.count.asc())
)
if len(exprs) > MAX_EXPRESSION_COUNT: if len(exprs) > MAX_EXPRESSION_COUNT:
# 删除count最小的多余表达方式 # 删除count最小的多余表达方式
for expr in exprs[:len(exprs) - MAX_EXPRESSION_COUNT]: for expr in exprs[: len(exprs) - MAX_EXPRESSION_COUNT]:
expr.delete_instance() expr.delete_instance()
return learnt_expressions return learnt_expressions

View File

@@ -9,51 +9,49 @@ from src.common.logger import get_logger
import traceback import traceback
from src.config.config import global_config from src.config.config import global_config
from src.common.database.database_model import Memory # Peewee Models导入 from src.common.database.database_model import Memory # Peewee Models导入
logger = get_logger(__name__) logger = get_logger(__name__)
class MemoryItem: class MemoryItem:
def __init__(self,memory_id:str,chat_id:str,memory_text:str,keywords:list[str]): def __init__(self, memory_id: str, chat_id: str, memory_text: str, keywords: list[str]):
self.memory_id = memory_id self.memory_id = memory_id
self.chat_id = chat_id self.chat_id = chat_id
self.memory_text:str = memory_text self.memory_text: str = memory_text
self.keywords:list[str] = keywords self.keywords: list[str] = keywords
self.create_time:float = time.time() self.create_time: float = time.time()
self.last_view_time:float = time.time() self.last_view_time: float = time.time()
class MemoryManager: class MemoryManager:
def __init__(self): def __init__(self):
# self.memory_items:list[MemoryItem] = [] # self.memory_items:list[MemoryItem] = []
pass pass
class InstantMemory: class InstantMemory:
def __init__(self,chat_id): def __init__(self, chat_id):
self.chat_id = chat_id self.chat_id = chat_id
self.last_view_time = time.time() self.last_view_time = time.time()
self.summary_model = LLMRequest( self.summary_model = LLMRequest(
model=global_config.model.memory, model=global_config.model.memory,
temperature=0.5, temperature=0.5,
request_type="memory.summary", request_type="memory.summary",
) )
async def if_need_build(self,text): async def if_need_build(self, text):
prompt = f""" prompt = f"""
请判断以下内容中是否有值得记忆的信息如果有请输出1否则输出0 请判断以下内容中是否有值得记忆的信息如果有请输出1否则输出0
{text} {text}
请只输出1或0就好 请只输出1或0就好
""" """
try: try:
response,_ = await self.summary_model.generate_response_async(prompt) response, _ = await self.summary_model.generate_response_async(prompt)
print(prompt) print(prompt)
print(response) print(response)
if "1" in response: if "1" in response:
return True return True
else: else:
@@ -61,8 +59,8 @@ class InstantMemory:
except Exception as e: except Exception as e:
logger.error(f"判断是否需要记忆出现错误:{str(e)} {traceback.format_exc()}") logger.error(f"判断是否需要记忆出现错误:{str(e)} {traceback.format_exc()}")
return False return False
async def build_memory(self,text): async def build_memory(self, text):
prompt = f""" prompt = f"""
以下内容中存在值得记忆的信息,请你从中总结出一段值得记忆的信息,并输出 以下内容中存在值得记忆的信息,请你从中总结出一段值得记忆的信息,并输出
{text} {text}
@@ -73,7 +71,7 @@ class InstantMemory:
}} }}
""" """
try: try:
response,_ = await self.summary_model.generate_response_async(prompt) response, _ = await self.summary_model.generate_response_async(prompt)
print(prompt) print(prompt)
print(response) print(response)
if not response: if not response:
@@ -81,53 +79,53 @@ class InstantMemory:
try: try:
repaired = repair_json(response) repaired = repair_json(response)
result = json.loads(repaired) result = json.loads(repaired)
memory_text = result.get('memory_text', '') memory_text = result.get("memory_text", "")
keywords = result.get('keywords', '') keywords = result.get("keywords", "")
if isinstance(keywords, str): if isinstance(keywords, str):
keywords_list = [k.strip() for k in keywords.split('/') if k.strip()] keywords_list = [k.strip() for k in keywords.split("/") if k.strip()]
elif isinstance(keywords, list): elif isinstance(keywords, list):
keywords_list = keywords keywords_list = keywords
else: else:
keywords_list = [] keywords_list = []
return {'memory_text': memory_text, 'keywords': keywords_list} return {"memory_text": memory_text, "keywords": keywords_list}
except Exception as parse_e: except Exception as parse_e:
logger.error(f"解析记忆json失败{str(parse_e)} {traceback.format_exc()}") logger.error(f"解析记忆json失败{str(parse_e)} {traceback.format_exc()}")
return None return None
except Exception as e: except Exception as e:
logger.error(f"构建记忆出现错误:{str(e)} {traceback.format_exc()}") logger.error(f"构建记忆出现错误:{str(e)} {traceback.format_exc()}")
return None return None
async def create_and_store_memory(self,text): async def create_and_store_memory(self, text):
if_need = await self.if_need_build(text) if_need = await self.if_need_build(text)
if if_need: if if_need:
logger.info(f"需要记忆:{text}") logger.info(f"需要记忆:{text}")
memory = await self.build_memory(text) memory = await self.build_memory(text)
if memory and memory.get('memory_text'): if memory and memory.get("memory_text"):
memory_id = f"{self.chat_id}_{time.time()}" memory_id = f"{self.chat_id}_{time.time()}"
memory_item = MemoryItem( memory_item = MemoryItem(
memory_id=memory_id, memory_id=memory_id,
chat_id=self.chat_id, chat_id=self.chat_id,
memory_text=memory['memory_text'], memory_text=memory["memory_text"],
keywords=memory.get('keywords', []) keywords=memory.get("keywords", []),
) )
await self.store_memory(memory_item) await self.store_memory(memory_item)
else: else:
logger.info(f"不需要记忆:{text}") logger.info(f"不需要记忆:{text}")
async def store_memory(self,memory_item:MemoryItem): async def store_memory(self, memory_item: MemoryItem):
memory = Memory( memory = Memory(
memory_id=memory_item.memory_id, memory_id=memory_item.memory_id,
chat_id=memory_item.chat_id, chat_id=memory_item.chat_id,
memory_text=memory_item.memory_text, memory_text=memory_item.memory_text,
keywords=memory_item.keywords, keywords=memory_item.keywords,
create_time=memory_item.create_time, create_time=memory_item.create_time,
last_view_time=memory_item.last_view_time last_view_time=memory_item.last_view_time,
) )
memory.save() memory.save()
async def get_memory(self,target:str): async def get_memory(self, target: str):
from json_repair import repair_json from json_repair import repair_json
prompt = f""" prompt = f"""
请根据以下发言内容,判断是否需要提取记忆 请根据以下发言内容,判断是否需要提取记忆
{target} {target}
@@ -144,7 +142,7 @@ class InstantMemory:
请只输出json格式不要输出其他多余内容 请只输出json格式不要输出其他多余内容
""" """
try: try:
response,_ = await self.summary_model.generate_response_async(prompt) response, _ = await self.summary_model.generate_response_async(prompt)
print(prompt) print(prompt)
print(response) print(response)
if not response: if not response:
@@ -153,15 +151,15 @@ class InstantMemory:
repaired = repair_json(response) repaired = repair_json(response)
result = json.loads(repaired) result = json.loads(repaired)
# 解析keywords # 解析keywords
keywords = result.get('keywords', '') keywords = result.get("keywords", "")
if isinstance(keywords, str): if isinstance(keywords, str):
keywords_list = [k.strip() for k in keywords.split('/') if k.strip()] keywords_list = [k.strip() for k in keywords.split("/") if k.strip()]
elif isinstance(keywords, list): elif isinstance(keywords, list):
keywords_list = keywords keywords_list = keywords
else: else:
keywords_list = [] keywords_list = []
# 解析time为时间段 # 解析time为时间段
time_str = result.get('time', '').strip() time_str = result.get("time", "").strip()
start_time, end_time = self._parse_time_range(time_str) start_time, end_time = self._parse_time_range(time_str)
logger.info(f"start_time: {start_time}, end_time: {end_time}") logger.info(f"start_time: {start_time}, end_time: {end_time}")
# 检索包含关键词的记忆 # 检索包含关键词的记忆
@@ -170,16 +168,15 @@ class InstantMemory:
start_ts = start_time.timestamp() start_ts = start_time.timestamp()
end_ts = end_time.timestamp() end_ts = end_time.timestamp()
query = Memory.select().where( query = Memory.select().where(
(Memory.chat_id == self.chat_id) & (Memory.chat_id == self.chat_id)
(Memory.create_time >= start_ts) & & (Memory.create_time >= start_ts) # type: ignore
(Memory.create_time < end_ts) & (Memory.create_time < end_ts) # type: ignore
) )
else: else:
query = Memory.select().where(Memory.chat_id == self.chat_id) query = Memory.select().where(Memory.chat_id == self.chat_id)
for mem in query: for mem in query:
#对每条记忆 # 对每条记忆
mem_keywords = mem.keywords or [] mem_keywords = mem.keywords or []
parsed = ast.literal_eval(mem_keywords) parsed = ast.literal_eval(mem_keywords)
if isinstance(parsed, list): if isinstance(parsed, list):
@@ -212,6 +209,7 @@ class InstantMemory:
- 空字符串:返回(None, None) - 空字符串:返回(None, None)
""" """
from datetime import datetime, timedelta from datetime import datetime, timedelta
now = datetime.now() now = datetime.now()
if not time_str: if not time_str:
return 0, now return 0, now
@@ -251,8 +249,8 @@ class InstantMemory:
if m: if m:
months = int(m.group(1)) months = int(m.group(1))
# 近似每月30天 # 近似每月30天
start = (now - timedelta(days=months*30)).replace(hour=0, minute=0, second=0, microsecond=0) start = (now - timedelta(days=months * 30)).replace(hour=0, minute=0, second=0, microsecond=0)
end = start + timedelta(days=1) end = start + timedelta(days=1)
return start, end return start, end
# 其他无法解析 # 其他无法解析
return 0, now return 0, now

View File

@@ -30,7 +30,7 @@ class ChatMessageContext:
def get_template_name(self) -> Optional[str]: def get_template_name(self) -> Optional[str]:
"""获取模板名称""" """获取模板名称"""
if self.message.message_info.template_info and not self.message.message_info.template_info.template_default: if self.message.message_info.template_info and not self.message.message_info.template_info.template_default:
return self.message.message_info.template_info.template_name return self.message.message_info.template_info.template_name # type: ignore
return None return None
def get_last_message(self) -> "MessageRecv": def get_last_message(self) -> "MessageRecv":

View File

@@ -107,9 +107,9 @@ class MessageRecv(Message):
self.is_picid = False self.is_picid = False
self.has_picid = False self.has_picid = False
self.is_mentioned = None self.is_mentioned = None
self.is_command = False self.is_command = False
self.priority_mode = "interest" self.priority_mode = "interest"
self.priority_info = None self.priority_info = None
self.interest_value: float = None # type: ignore self.interest_value: float = None # type: ignore
@@ -181,6 +181,7 @@ class MessageRecv(Message):
logger.error(f"处理消息段失败: {str(e)}, 类型: {segment.type}, 数据: {segment.data}") logger.error(f"处理消息段失败: {str(e)}, 类型: {segment.type}, 数据: {segment.data}")
return f"[处理失败的{segment.type}消息]" return f"[处理失败的{segment.type}消息]"
@dataclass @dataclass
class MessageRecvS4U(MessageRecv): class MessageRecvS4U(MessageRecv):
def __init__(self, message_dict: dict[str, Any]): def __init__(self, message_dict: dict[str, Any]):
@@ -194,10 +195,10 @@ class MessageRecvS4U(MessageRecv):
self.superchat_price = None self.superchat_price = None
self.superchat_message_text = None self.superchat_message_text = None
self.is_screen = False self.is_screen = False
async def process(self) -> None: async def process(self) -> None:
self.processed_plain_text = await self._process_message_segments(self.message_segment) self.processed_plain_text = await self._process_message_segments(self.message_segment)
async def _process_single_segment(self, segment: Seg) -> str: async def _process_single_segment(self, segment: Seg) -> str:
"""处理单个消息段 """处理单个消息段
@@ -252,7 +253,7 @@ class MessageRecvS4U(MessageRecv):
elif segment.type == "gift": elif segment.type == "gift":
self.is_gift = True self.is_gift = True
# 解析gift_info格式为"名称:数量" # 解析gift_info格式为"名称:数量"
name, count = segment.data.split(":", 1) name, count = segment.data.split(":", 1) # type: ignore
self.gift_info = segment.data self.gift_info = segment.data
self.gift_name = name.strip() self.gift_name = name.strip()
self.gift_count = int(count.strip()) self.gift_count = int(count.strip())
@@ -260,13 +261,15 @@ class MessageRecvS4U(MessageRecv):
elif segment.type == "superchat": elif segment.type == "superchat":
self.is_superchat = True self.is_superchat = True
self.superchat_info = segment.data self.superchat_info = segment.data
price,message_text = segment.data.split(":", 1) price, message_text = segment.data.split(":", 1) # type: ignore
self.superchat_price = price.strip() self.superchat_price = price.strip()
self.superchat_message_text = message_text.strip() self.superchat_message_text = message_text.strip()
self.processed_plain_text = str(self.superchat_message_text) self.processed_plain_text = str(self.superchat_message_text)
self.processed_plain_text += f"(注意:这是一条超级弹幕信息,价值{self.superchat_price}元,请你认真回复)" self.processed_plain_text += (
f"(注意:这是一条超级弹幕信息,价值{self.superchat_price}元,请你认真回复)"
)
return self.processed_plain_text return self.processed_plain_text
elif segment.type == "screen": elif segment.type == "screen":
self.is_screen = True self.is_screen = True

View File

@@ -80,7 +80,7 @@ class ActionManager:
chat_stream: ChatStream, chat_stream: ChatStream,
log_prefix: str, log_prefix: str,
shutting_down: bool = False, shutting_down: bool = False,
action_message: dict = None, action_message: Optional[dict] = None,
) -> Optional[BaseAction]: ) -> Optional[BaseAction]:
""" """
创建动作处理器实例 创建动作处理器实例

View File

@@ -252,7 +252,7 @@ def _build_readable_messages_internal(
pic_id_mapping: Optional[Dict[str, str]] = None, pic_id_mapping: Optional[Dict[str, str]] = None,
pic_counter: int = 1, pic_counter: int = 1,
show_pic: bool = True, show_pic: bool = True,
message_id_list: List[Dict[str, Any]] = None, message_id_list: Optional[List[Dict[str, Any]]] = None,
) -> Tuple[str, List[Tuple[float, str, str]], Dict[str, str], int]: ) -> Tuple[str, List[Tuple[float, str, str]], Dict[str, str], int]:
""" """
内部辅助函数,构建可读消息字符串和原始消息详情列表。 内部辅助函数,构建可读消息字符串和原始消息详情列表。
@@ -615,7 +615,7 @@ def build_readable_actions(actions: List[Dict[str, Any]]) -> str:
for action in actions: for action in actions:
action_time = action.get("time", current_time) action_time = action.get("time", current_time)
action_name = action.get("action_name", "未知动作") action_name = action.get("action_name", "未知动作")
if action_name == "no_action" or action_name == "no_reply": if action_name in ["no_action", "no_reply"]:
continue continue
action_prompt_display = action.get("action_prompt_display", "无具体内容") action_prompt_display = action.get("action_prompt_display", "无具体内容")
@@ -697,7 +697,7 @@ def build_readable_messages(
truncate: bool = False, truncate: bool = False,
show_actions: bool = False, show_actions: bool = False,
show_pic: bool = True, show_pic: bool = True,
message_id_list: List[Dict[str, Any]] = None, message_id_list: Optional[List[Dict[str, Any]]] = None,
) -> str: # sourcery skip: extract-method ) -> str: # sourcery skip: extract-method
""" """
将消息列表转换为可读的文本格式。 将消息列表转换为可读的文本格式。

View File

@@ -1211,7 +1211,7 @@ class StatisticOutputTask(AsyncTask):
f.write(html_template) f.write(html_template)
def _generate_focus_tab(self, stat: dict[str, Any]) -> str: def _generate_focus_tab(self, stat: dict[str, Any]) -> str:
# sourcery skip: for-append-to-extend, list-comprehension, use-any # sourcery skip: for-append-to-extend, list-comprehension, use-any, use-named-expression, use-next
"""生成Focus统计独立分页的HTML内容""" """生成Focus统计独立分页的HTML内容"""
# 为每个时间段准备Focus数据 # 为每个时间段准备Focus数据
@@ -1559,6 +1559,7 @@ class StatisticOutputTask(AsyncTask):
""" """
def _generate_versions_tab(self, stat: dict[str, Any]) -> str: def _generate_versions_tab(self, stat: dict[str, Any]) -> str:
# sourcery skip: use-named-expression, use-next
"""生成版本对比独立分页的HTML内容""" """生成版本对比独立分页的HTML内容"""
# 为每个时间段准备版本对比数据 # 为每个时间段准备版本对比数据
@@ -2306,13 +2307,13 @@ class AsyncStatisticOutputTask(AsyncTask):
# 复用 StatisticOutputTask 的所有方法 # 复用 StatisticOutputTask 的所有方法
def _collect_all_statistics(self, now: datetime): def _collect_all_statistics(self, now: datetime):
return StatisticOutputTask._collect_all_statistics(self, now) return StatisticOutputTask._collect_all_statistics(self, now) # type: ignore
def _statistic_console_output(self, stats: Dict[str, Any], now: datetime): def _statistic_console_output(self, stats: Dict[str, Any], now: datetime):
return StatisticOutputTask._statistic_console_output(self, stats, now) return StatisticOutputTask._statistic_console_output(self, stats, now) # type: ignore
def _generate_html_report(self, stats: dict[str, Any], now: datetime): def _generate_html_report(self, stats: dict[str, Any], now: datetime):
return StatisticOutputTask._generate_html_report(self, stats, now) return StatisticOutputTask._generate_html_report(self, stats, now) # type: ignore
# 其他需要的方法也可以类似复用... # 其他需要的方法也可以类似复用...
@staticmethod @staticmethod
@@ -2324,10 +2325,10 @@ class AsyncStatisticOutputTask(AsyncTask):
return StatisticOutputTask._collect_online_time_for_period(collect_period, now) return StatisticOutputTask._collect_online_time_for_period(collect_period, now)
def _collect_message_count_for_period(self, collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]: def _collect_message_count_for_period(self, collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]:
return StatisticOutputTask._collect_message_count_for_period(self, collect_period) return StatisticOutputTask._collect_message_count_for_period(self, collect_period) # type: ignore
def _collect_focus_statistics_for_period(self, collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]: def _collect_focus_statistics_for_period(self, collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]:
return StatisticOutputTask._collect_focus_statistics_for_period(self, collect_period) return StatisticOutputTask._collect_focus_statistics_for_period(self, collect_period) # type: ignore
def _process_focus_file_data( def _process_focus_file_data(
self, self,
@@ -2336,10 +2337,10 @@ class AsyncStatisticOutputTask(AsyncTask):
collect_period: List[Tuple[str, datetime]], collect_period: List[Tuple[str, datetime]],
file_time: datetime, file_time: datetime,
): ):
return StatisticOutputTask._process_focus_file_data(self, cycles_data, stats, collect_period, file_time) return StatisticOutputTask._process_focus_file_data(self, cycles_data, stats, collect_period, file_time) # type: ignore
def _calculate_focus_averages(self, stats: Dict[str, Any]): def _calculate_focus_averages(self, stats: Dict[str, Any]):
return StatisticOutputTask._calculate_focus_averages(self, stats) return StatisticOutputTask._calculate_focus_averages(self, stats) # type: ignore
@staticmethod @staticmethod
def _format_total_stat(stats: Dict[str, Any]) -> str: def _format_total_stat(stats: Dict[str, Any]) -> str:
@@ -2347,31 +2348,31 @@ class AsyncStatisticOutputTask(AsyncTask):
@staticmethod @staticmethod
def _format_model_classified_stat(stats: Dict[str, Any]) -> str: def _format_model_classified_stat(stats: Dict[str, Any]) -> str:
return StatisticOutputTask._format_model_classified_stat(stats) return StatisticOutputTask._format_model_classified_stat(stats) # type: ignore
def _format_chat_stat(self, stats: Dict[str, Any]) -> str: def _format_chat_stat(self, stats: Dict[str, Any]) -> str:
return StatisticOutputTask._format_chat_stat(self, stats) return StatisticOutputTask._format_chat_stat(self, stats) # type: ignore
def _format_focus_stat(self, stats: Dict[str, Any]) -> str: def _format_focus_stat(self, stats: Dict[str, Any]) -> str:
return StatisticOutputTask._format_focus_stat(self, stats) return StatisticOutputTask._format_focus_stat(self, stats) # type: ignore
def _generate_chart_data(self, stat: dict[str, Any]) -> dict: def _generate_chart_data(self, stat: dict[str, Any]) -> dict:
return StatisticOutputTask._generate_chart_data(self, stat) return StatisticOutputTask._generate_chart_data(self, stat) # type: ignore
def _collect_interval_data(self, now: datetime, hours: int, interval_minutes: int) -> dict: def _collect_interval_data(self, now: datetime, hours: int, interval_minutes: int) -> dict:
return StatisticOutputTask._collect_interval_data(self, now, hours, interval_minutes) return StatisticOutputTask._collect_interval_data(self, now, hours, interval_minutes) # type: ignore
def _generate_chart_tab(self, chart_data: dict) -> str: def _generate_chart_tab(self, chart_data: dict) -> str:
return StatisticOutputTask._generate_chart_tab(self, chart_data) return StatisticOutputTask._generate_chart_tab(self, chart_data) # type: ignore
def _get_chat_display_name_from_id(self, chat_id: str) -> str: def _get_chat_display_name_from_id(self, chat_id: str) -> str:
return StatisticOutputTask._get_chat_display_name_from_id(self, chat_id) return StatisticOutputTask._get_chat_display_name_from_id(self, chat_id) # type: ignore
def _generate_focus_tab(self, stat: dict[str, Any]) -> str: def _generate_focus_tab(self, stat: dict[str, Any]) -> str:
return StatisticOutputTask._generate_focus_tab(self, stat) return StatisticOutputTask._generate_focus_tab(self, stat) # type: ignore
def _generate_versions_tab(self, stat: dict[str, Any]) -> str: def _generate_versions_tab(self, stat: dict[str, Any]) -> str:
return StatisticOutputTask._generate_versions_tab(self, stat) return StatisticOutputTask._generate_versions_tab(self, stat) # type: ignore
def _convert_defaultdict_to_dict(self, data): def _convert_defaultdict_to_dict(self, data):
return StatisticOutputTask._convert_defaultdict_to_dict(self, data) return StatisticOutputTask._convert_defaultdict_to_dict(self, data) # type: ignore

View File

@@ -2,14 +2,13 @@ import importlib
import asyncio import asyncio
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Dict, Optional from typing import Dict, Optional, Any
from rich.traceback import install from rich.traceback import install
from dataclasses import dataclass from dataclasses import dataclass
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.config import global_config from src.config.config import global_config
from src.chat.message_receive.chat_stream import ChatStream, GroupInfo from src.chat.message_receive.chat_stream import ChatStream, GroupInfo
from src.chat.message_receive.message import MessageRecv
from src.person_info.person_info import PersonInfoManager, get_person_info_manager from src.person_info.person_info import PersonInfoManager, get_person_info_manager
install(extra_lines=3) install(extra_lines=3)
@@ -54,7 +53,7 @@ class WillingInfo:
interested_rate (float): 兴趣度 interested_rate (float): 兴趣度
""" """
message: MessageRecv message: Dict[str, Any] # 原始消息数据
chat: ChatStream chat: ChatStream
person_info_manager: PersonInfoManager person_info_manager: PersonInfoManager
chat_id: str chat_id: str

View File

@@ -65,7 +65,7 @@ class ChatStreams(BaseModel):
# user_cardname 可能为空字符串或不存在,设置 null=True 更具灵活性。 # user_cardname 可能为空字符串或不存在,设置 null=True 更具灵活性。
user_cardname = TextField(null=True) user_cardname = TextField(null=True)
class Meta: class Meta: # type: ignore
# 如果 BaseModel.Meta.database 已设置,则此模型将继承该数据库配置。 # 如果 BaseModel.Meta.database 已设置,则此模型将继承该数据库配置。
# 如果不使用带有数据库实例的 BaseModel或者想覆盖它 # 如果不使用带有数据库实例的 BaseModel或者想覆盖它
# 请取消注释并在下面设置数据库实例: # 请取消注释并在下面设置数据库实例:
@@ -89,7 +89,7 @@ class LLMUsage(BaseModel):
status = TextField() status = TextField()
timestamp = DateTimeField(index=True) # 更改为 DateTimeField 并添加索引 timestamp = DateTimeField(index=True) # 更改为 DateTimeField 并添加索引
class Meta: class Meta: # type: ignore
# 如果 BaseModel.Meta.database 已设置,则此模型将继承该数据库配置。 # 如果 BaseModel.Meta.database 已设置,则此模型将继承该数据库配置。
# database = db # database = db
table_name = "llm_usage" table_name = "llm_usage"
@@ -112,7 +112,7 @@ class Emoji(BaseModel):
usage_count = IntegerField(default=0) # 使用次数(被使用的次数) usage_count = IntegerField(default=0) # 使用次数(被使用的次数)
last_used_time = FloatField(null=True) # 上次使用时间 last_used_time = FloatField(null=True) # 上次使用时间
class Meta: class Meta: # type: ignore
# database = db # 继承自 BaseModel # database = db # 继承自 BaseModel
table_name = "emoji" table_name = "emoji"
@@ -162,7 +162,8 @@ class Messages(BaseModel):
is_emoji = BooleanField(default=False) is_emoji = BooleanField(default=False)
is_picid = BooleanField(default=False) is_picid = BooleanField(default=False)
is_command = BooleanField(default=False) is_command = BooleanField(default=False)
class Meta:
class Meta: # type: ignore
# database = db # 继承自 BaseModel # database = db # 继承自 BaseModel
table_name = "messages" table_name = "messages"
@@ -186,7 +187,7 @@ class ActionRecords(BaseModel):
chat_info_stream_id = TextField() chat_info_stream_id = TextField()
chat_info_platform = TextField() chat_info_platform = TextField()
class Meta: class Meta: # type: ignore
# database = db # 继承自 BaseModel # database = db # 继承自 BaseModel
table_name = "action_records" table_name = "action_records"
@@ -206,7 +207,7 @@ class Images(BaseModel):
type = TextField() # 图像类型,例如 "emoji" type = TextField() # 图像类型,例如 "emoji"
vlm_processed = BooleanField(default=False) # 是否已经过VLM处理 vlm_processed = BooleanField(default=False) # 是否已经过VLM处理
class Meta: class Meta: # type: ignore
table_name = "images" table_name = "images"
@@ -220,7 +221,7 @@ class ImageDescriptions(BaseModel):
description = TextField() # 图像的描述 description = TextField() # 图像的描述
timestamp = FloatField() # 时间戳 timestamp = FloatField() # 时间戳
class Meta: class Meta: # type: ignore
# database = db # 继承自 BaseModel # database = db # 继承自 BaseModel
table_name = "image_descriptions" table_name = "image_descriptions"
@@ -236,7 +237,7 @@ class OnlineTime(BaseModel):
start_timestamp = DateTimeField(default=datetime.datetime.now) start_timestamp = DateTimeField(default=datetime.datetime.now)
end_timestamp = DateTimeField(index=True) end_timestamp = DateTimeField(index=True)
class Meta: class Meta: # type: ignore
# database = db # 继承自 BaseModel # database = db # 继承自 BaseModel
table_name = "online_time" table_name = "online_time"
@@ -263,10 +264,11 @@ class PersonInfo(BaseModel):
last_know = FloatField(null=True) # 最后一次印象总结时间 last_know = FloatField(null=True) # 最后一次印象总结时间
attitude = IntegerField(null=True, default=50) # 态度0-100从非常厌恶到十分喜欢 attitude = IntegerField(null=True, default=50) # 态度0-100从非常厌恶到十分喜欢
class Meta: class Meta: # type: ignore
# database = db # 继承自 BaseModel # database = db # 继承自 BaseModel
table_name = "person_info" table_name = "person_info"
class Memory(BaseModel): class Memory(BaseModel):
memory_id = TextField(index=True) memory_id = TextField(index=True)
chat_id = TextField(null=True) chat_id = TextField(null=True)
@@ -274,10 +276,11 @@ class Memory(BaseModel):
keywords = TextField(null=True) keywords = TextField(null=True)
create_time = FloatField(null=True) create_time = FloatField(null=True)
last_view_time = FloatField(null=True) last_view_time = FloatField(null=True)
class Meta: class Meta: # type: ignore
table_name = "memory" table_name = "memory"
class Knowledges(BaseModel): class Knowledges(BaseModel):
""" """
用于存储知识库条目的模型。 用于存储知识库条目的模型。
@@ -287,10 +290,11 @@ class Knowledges(BaseModel):
embedding = TextField() # 知识内容的嵌入向量,存储为 JSON 字符串的浮点数列表 embedding = TextField() # 知识内容的嵌入向量,存储为 JSON 字符串的浮点数列表
# 可以添加其他元数据字段,如 source, create_time 等 # 可以添加其他元数据字段,如 source, create_time 等
class Meta: class Meta: # type: ignore
# database = db # 继承自 BaseModel # database = db # 继承自 BaseModel
table_name = "knowledges" table_name = "knowledges"
class Expression(BaseModel): class Expression(BaseModel):
""" """
用于存储表达风格的模型。 用于存储表达风格的模型。
@@ -302,10 +306,11 @@ class Expression(BaseModel):
last_active_time = FloatField() last_active_time = FloatField()
chat_id = TextField(index=True) chat_id = TextField(index=True)
type = TextField() type = TextField()
class Meta: class Meta: # type: ignore
table_name = "expression" table_name = "expression"
class ThinkingLog(BaseModel): class ThinkingLog(BaseModel):
chat_id = TextField(index=True) chat_id = TextField(index=True)
trigger_text = TextField(null=True) trigger_text = TextField(null=True)
@@ -326,7 +331,7 @@ class ThinkingLog(BaseModel):
# And: import datetime # And: import datetime
created_at = DateTimeField(default=datetime.datetime.now) created_at = DateTimeField(default=datetime.datetime.now)
class Meta: class Meta: # type: ignore
table_name = "thinking_logs" table_name = "thinking_logs"
@@ -341,7 +346,7 @@ class GraphNodes(BaseModel):
created_time = FloatField() # 创建时间戳 created_time = FloatField() # 创建时间戳
last_modified = FloatField() # 最后修改时间戳 last_modified = FloatField() # 最后修改时间戳
class Meta: class Meta: # type: ignore
table_name = "graph_nodes" table_name = "graph_nodes"
@@ -357,7 +362,7 @@ class GraphEdges(BaseModel):
created_time = FloatField() # 创建时间戳 created_time = FloatField() # 创建时间戳
last_modified = FloatField() # 最后修改时间戳 last_modified = FloatField() # 最后修改时间戳
class Meta: class Meta: # type: ignore
table_name = "graph_edges" table_name = "graph_edges"

View File

@@ -7,13 +7,13 @@ from datetime import datetime
def get_key_comment(toml_table, key): def get_key_comment(toml_table, key):
# 获取key的注释如果有 # 获取key的注释如果有
if hasattr(toml_table, 'trivia') and hasattr(toml_table.trivia, 'comment'): if hasattr(toml_table, "trivia") and hasattr(toml_table.trivia, "comment"):
return toml_table.trivia.comment return toml_table.trivia.comment
if hasattr(toml_table, 'value') and isinstance(toml_table.value, dict): if hasattr(toml_table, "value") and isinstance(toml_table.value, dict):
item = toml_table.value.get(key) item = toml_table.value.get(key)
if item is not None and hasattr(item, 'trivia'): if item is not None and hasattr(item, "trivia"):
return item.trivia.comment return item.trivia.comment
if hasattr(toml_table, 'keys'): if hasattr(toml_table, "keys"):
for k in toml_table.keys(): for k in toml_table.keys():
if isinstance(k, KeyType) and k.key == key: if isinstance(k, KeyType) and k.key == key:
return k.trivia.comment return k.trivia.comment
@@ -36,16 +36,16 @@ def compare_dicts(new, old, path=None, new_comments=None, old_comments=None, log
continue continue
if key not in old: if key not in old:
comment = get_key_comment(new, key) comment = get_key_comment(new, key)
logs.append(f"新增: {'.'.join(path+[str(key)])} 注释: {comment if comment else ''}") logs.append(f"新增: {'.'.join(path + [str(key)])} 注释: {comment if comment else ''}")
elif isinstance(new[key], (dict, Table)) and isinstance(old.get(key), (dict, Table)): elif isinstance(new[key], (dict, Table)) and isinstance(old.get(key), (dict, Table)):
compare_dicts(new[key], old[key], path+[str(key)], new_comments, old_comments, logs) compare_dicts(new[key], old[key], path + [str(key)], new_comments, old_comments, logs)
# 删减项 # 删减项
for key in old: for key in old:
if key == "version": if key == "version":
continue continue
if key not in new: if key not in new:
comment = get_key_comment(old, key) comment = get_key_comment(old, key)
logs.append(f"删减: {'.'.join(path+[str(key)])} 注释: {comment if comment else ''}") logs.append(f"删减: {'.'.join(path + [str(key)])} 注释: {comment if comment else ''}")
return logs return logs
@@ -95,7 +95,7 @@ def update_config():
if old_version and new_version and old_version == new_version: if old_version and new_version and old_version == new_version:
print(f"检测到版本号相同 (v{old_version}),跳过更新") print(f"检测到版本号相同 (v{old_version}),跳过更新")
# 如果version相同恢复旧配置文件并返回 # 如果version相同恢复旧配置文件并返回
shutil.move(old_backup_path, old_config_path) shutil.move(old_backup_path, old_config_path) # type: ignore
return return
else: else:
print(f"检测到版本号不同: 旧版本 v{old_version} -> 新版本 v{new_version}") print(f"检测到版本号不同: 旧版本 v{old_version} -> 新版本 v{new_version}")

View File

@@ -53,13 +53,13 @@ MMC_VERSION = "0.9.0-snapshot.2"
def get_key_comment(toml_table, key): def get_key_comment(toml_table, key):
# 获取key的注释如果有 # 获取key的注释如果有
if hasattr(toml_table, 'trivia') and hasattr(toml_table.trivia, 'comment'): if hasattr(toml_table, "trivia") and hasattr(toml_table.trivia, "comment"):
return toml_table.trivia.comment return toml_table.trivia.comment
if hasattr(toml_table, 'value') and isinstance(toml_table.value, dict): if hasattr(toml_table, "value") and isinstance(toml_table.value, dict):
item = toml_table.value.get(key) item = toml_table.value.get(key)
if item is not None and hasattr(item, 'trivia'): if item is not None and hasattr(item, "trivia"):
return item.trivia.comment return item.trivia.comment
if hasattr(toml_table, 'keys'): if hasattr(toml_table, "keys"):
for k in toml_table.keys(): for k in toml_table.keys():
if isinstance(k, KeyType) and k.key == key: if isinstance(k, KeyType) and k.key == key:
return k.trivia.comment return k.trivia.comment
@@ -78,16 +78,16 @@ def compare_dicts(new, old, path=None, logs=None):
continue continue
if key not in old: if key not in old:
comment = get_key_comment(new, key) comment = get_key_comment(new, key)
logs.append(f"新增: {'.'.join(path+[str(key)])} 注释: {comment if comment else ''}") logs.append(f"新增: {'.'.join(path + [str(key)])} 注释: {comment if comment else ''}")
elif isinstance(new[key], (dict, Table)) and isinstance(old.get(key), (dict, Table)): elif isinstance(new[key], (dict, Table)) and isinstance(old.get(key), (dict, Table)):
compare_dicts(new[key], old[key], path+[str(key)], logs) compare_dicts(new[key], old[key], path + [str(key)], logs)
# 删减项 # 删减项
for key in old: for key in old:
if key == "version": if key == "version":
continue continue
if key not in new: if key not in new:
comment = get_key_comment(old, key) comment = get_key_comment(old, key)
logs.append(f"删减: {'.'.join(path+[str(key)])} 注释: {comment if comment else ''}") logs.append(f"删减: {'.'.join(path + [str(key)])} 注释: {comment if comment else ''}")
return logs return logs
@@ -99,6 +99,7 @@ def get_value_by_path(d, path):
return None return None
return d return d
def set_value_by_path(d, path, value): def set_value_by_path(d, path, value):
for k in path[:-1]: for k in path[:-1]:
if k not in d or not isinstance(d[k], dict): if k not in d or not isinstance(d[k], dict):
@@ -106,6 +107,7 @@ def set_value_by_path(d, path, value):
d = d[k] d = d[k]
d[path[-1]] = value d[path[-1]] = value
def compare_default_values(new, old, path=None, logs=None, changes=None): def compare_default_values(new, old, path=None, logs=None, changes=None):
# 递归比较两个dict找出默认值变化项 # 递归比较两个dict找出默认值变化项
if path is None: if path is None:
@@ -119,12 +121,14 @@ def compare_default_values(new, old, path=None, logs=None, changes=None):
continue continue
if key in old: if key in old:
if isinstance(new[key], (dict, Table)) and isinstance(old[key], (dict, Table)): if isinstance(new[key], (dict, Table)) and isinstance(old[key], (dict, Table)):
compare_default_values(new[key], old[key], path+[str(key)], logs, changes) compare_default_values(new[key], old[key], path + [str(key)], logs, changes)
else: else:
# 只要值发生变化就记录 # 只要值发生变化就记录
if new[key] != old[key]: if new[key] != old[key]:
logs.append(f"默认值变化: {'.'.join(path+[str(key)])} 旧默认值: {old[key]} 新默认值: {new[key]}") logs.append(
changes.append((path+[str(key)], old[key], new[key])) f"默认值变化: {'.'.join(path + [str(key)])} 旧默认值: {old[key]} 新默认值: {new[key]}"
)
changes.append((path + [str(key)], old[key], new[key]))
return logs, changes return logs, changes
@@ -148,8 +152,8 @@ def update_config():
return None return None
with open(toml_path, "r", encoding="utf-8") as f: with open(toml_path, "r", encoding="utf-8") as f:
doc = tomlkit.load(f) doc = tomlkit.load(f)
if "inner" in doc and "version" in doc["inner"]: if "inner" in doc and "version" in doc["inner"]: # type: ignore
return doc["inner"]["version"] return doc["inner"]["version"] # type: ignore
return None return None
template_version = get_version_from_toml(template_path) template_version = get_version_from_toml(template_path)
@@ -186,7 +190,9 @@ def update_config():
old_value = get_value_by_path(old_config, path) old_value = get_value_by_path(old_config, path)
if old_value == old_default: if old_value == old_default:
set_value_by_path(old_config, path, new_default) set_value_by_path(old_config, path, new_default)
logger.info(f"已自动将配置 {'.'.join(path)} 的值从旧默认值 {old_default} 更新为新默认值 {new_default}") logger.info(
f"已自动将配置 {'.'.join(path)} 的值从旧默认值 {old_default} 更新为新默认值 {new_default}"
)
else: else:
logger.info("未检测到模板默认值变动") logger.info("未检测到模板默认值变动")
# 保存旧配置的变更(后续合并逻辑会用到 old_config # 保存旧配置的变更(后续合并逻辑会用到 old_config
@@ -229,7 +235,9 @@ def update_config():
logger.info(f"检测到配置文件版本号相同 (v{old_version}),跳过更新") logger.info(f"检测到配置文件版本号相同 (v{old_version}),跳过更新")
return return
else: else:
logger.info(f"\n----------------------------------------\n检测到版本号不同: 旧版本 v{old_version} -> 新版本 v{new_version}\n----------------------------------------") logger.info(
f"\n----------------------------------------\n检测到版本号不同: 旧版本 v{old_version} -> 新版本 v{new_version}\n----------------------------------------"
)
else: else:
logger.info("已有配置文件未检测到版本号,可能是旧版本。将进行更新") logger.info("已有配置文件未检测到版本号,可能是旧版本。将进行更新")
@@ -321,6 +329,7 @@ class Config(ConfigBase):
debug: DebugConfig debug: DebugConfig
custom_prompt: CustomPromptConfig custom_prompt: CustomPromptConfig
def load_config(config_path: str) -> Config: def load_config(config_path: str) -> Config:
""" """
加载配置文件 加载配置文件

View File

@@ -39,7 +39,7 @@ class LLMRequestOff:
} }
# 发送请求到完整的 chat/completions 端点 # 发送请求到完整的 chat/completions 端点
api_url = f"{self.base_url.rstrip('/')}/chat/completions" api_url = f"{self.base_url.rstrip('/')}/chat/completions" # type: ignore
logger.info(f"Request URL: {api_url}") # 记录请求的 URL logger.info(f"Request URL: {api_url}") # 记录请求的 URL
max_retries = 3 max_retries = 3
@@ -89,7 +89,7 @@ class LLMRequestOff:
} }
# 发送请求到完整的 chat/completions 端点 # 发送请求到完整的 chat/completions 端点
api_url = f"{self.base_url.rstrip('/')}/chat/completions" api_url = f"{self.base_url.rstrip('/')}/chat/completions" # type: ignore
logger.info(f"Request URL: {api_url}") # 记录请求的 URL logger.info(f"Request URL: {api_url}") # 记录请求的 URL
max_retries = 3 max_retries = 3

View File

@@ -83,8 +83,8 @@ class PersonalityEvaluatorDirect:
def __init__(self): def __init__(self):
self.personality_traits = {"开放性": 0, "严谨性": 0, "外向性": 0, "宜人性": 0, "神经质": 0} self.personality_traits = {"开放性": 0, "严谨性": 0, "外向性": 0, "宜人性": 0, "神经质": 0}
self.scenarios = [] self.scenarios = []
self.final_scores = {"开放性": 0, "严谨性": 0, "外向性": 0, "宜人性": 0, "神经质": 0} self.final_scores: Dict[str, float] = {"开放性": 0, "严谨性": 0, "外向性": 0, "宜人性": 0, "神经质": 0}
self.dimension_counts = {trait: 0 for trait in self.final_scores.keys()} self.dimension_counts = {trait: 0 for trait in self.final_scores}
# 为每个人格特质获取对应的场景 # 为每个人格特质获取对应的场景
for trait in PERSONALITY_SCENES: for trait in PERSONALITY_SCENES:
@@ -119,8 +119,7 @@ class PersonalityEvaluatorDirect:
# 构建维度描述 # 构建维度描述
dimension_descriptions = [] dimension_descriptions = []
for dim in dimensions: for dim in dimensions:
desc = FACTOR_DESCRIPTIONS.get(dim, "") if desc := FACTOR_DESCRIPTIONS.get(dim, ""):
if desc:
dimension_descriptions.append(f"- {dim}{desc}") dimension_descriptions.append(f"- {dim}{desc}")
dimensions_text = "\n".join(dimension_descriptions) dimensions_text = "\n".join(dimension_descriptions)

View File

@@ -153,14 +153,14 @@ class MainSystem:
while True: while True:
await asyncio.sleep(global_config.memory.memory_build_interval) await asyncio.sleep(global_config.memory.memory_build_interval)
logger.info("正在进行记忆构建") logger.info("正在进行记忆构建")
await self.hippocampus_manager.build_memory() await self.hippocampus_manager.build_memory() # type: ignore
async def forget_memory_task(self): async def forget_memory_task(self):
"""记忆遗忘任务""" """记忆遗忘任务"""
while True: while True:
await asyncio.sleep(global_config.memory.forget_memory_interval) await asyncio.sleep(global_config.memory.forget_memory_interval)
logger.info("[记忆遗忘] 开始遗忘记忆...") logger.info("[记忆遗忘] 开始遗忘记忆...")
await self.hippocampus_manager.forget_memory(percentage=global_config.memory.memory_forget_percentage) await self.hippocampus_manager.forget_memory(percentage=global_config.memory.memory_forget_percentage) # type: ignore
logger.info("[记忆遗忘] 记忆遗忘完成") logger.info("[记忆遗忘] 记忆遗忘完成")
async def consolidate_memory_task(self): async def consolidate_memory_task(self):
@@ -168,7 +168,7 @@ class MainSystem:
while True: while True:
await asyncio.sleep(global_config.memory.consolidate_memory_interval) await asyncio.sleep(global_config.memory.consolidate_memory_interval)
logger.info("[记忆整合] 开始整合记忆...") logger.info("[记忆整合] 开始整合记忆...")
await self.hippocampus_manager.consolidate_memory() await self.hippocampus_manager.consolidate_memory() # type: ignore
logger.info("[记忆整合] 记忆整合完成") logger.info("[记忆整合] 记忆整合完成")
@staticmethod @staticmethod

View File

@@ -49,6 +49,9 @@ class ChatMood:
chat_manager = get_chat_manager() chat_manager = get_chat_manager()
self.chat_stream = chat_manager.get_stream(self.chat_id) self.chat_stream = chat_manager.get_stream(self.chat_id)
if not self.chat_stream:
raise ValueError(f"Chat stream for chat_id {chat_id} not found")
self.log_prefix = f"[{self.chat_stream.group_info.group_name if self.chat_stream.group_info else self.chat_stream.user_info.user_nickname}]" self.log_prefix = f"[{self.chat_stream.group_info.group_name if self.chat_stream.group_info else self.chat_stream.user_info.user_nickname}]"

View File

@@ -26,7 +26,7 @@ SEGMENT_CLEANUP_CONFIG = {
"cleanup_interval_hours": 0.5, # 清理间隔(小时) "cleanup_interval_hours": 0.5, # 清理间隔(小时)
} }
MAX_MESSAGE_COUNT = 80 / global_config.relationship.relation_frequency MAX_MESSAGE_COUNT = int(80 / global_config.relationship.relation_frequency)
class RelationshipBuilder: class RelationshipBuilder:

View File

@@ -61,7 +61,7 @@ __all__ = [
"ConfigField", "ConfigField",
# 工具函数 # 工具函数
"ManifestValidator", "ManifestValidator",
"ManifestGenerator", # "ManifestGenerator",
"validate_plugin_manifest", # "validate_plugin_manifest",
"generate_plugin_manifest", # "generate_plugin_manifest",
] ]

View File

@@ -111,7 +111,7 @@ async def _send_to_target(
is_head=True, is_head=True,
is_emoji=(message_type == "emoji"), is_emoji=(message_type == "emoji"),
thinking_start_time=current_time, thinking_start_time=current_time,
reply_to = reply_to_platform_id reply_to=reply_to_platform_id,
) )
# 发送消息 # 发送消息
@@ -137,6 +137,7 @@ async def _send_to_target(
async def _find_reply_message(target_stream, reply_to: str) -> Optional[MessageRecv]: async def _find_reply_message(target_stream, reply_to: str) -> Optional[MessageRecv]:
# sourcery skip: inline-variable, use-named-expression
"""查找要回复的消息 """查找要回复的消息
Args: Args:
@@ -184,14 +185,11 @@ async def _find_reply_message(target_stream, reply_to: str) -> Optional[MessageR
# 检查是否有 回复<aaa:bbb> 字段 # 检查是否有 回复<aaa:bbb> 字段
reply_pattern = r"回复<([^:<>]+):([^:<>]+)>" reply_pattern = r"回复<([^:<>]+):([^:<>]+)>"
match = re.search(reply_pattern, translate_text) if match := re.search(reply_pattern, translate_text):
if match:
aaa = match.group(1) aaa = match.group(1)
bbb = match.group(2) bbb = match.group(2)
reply_person_id = get_person_info_manager().get_person_id(platform, bbb) reply_person_id = get_person_info_manager().get_person_id(platform, bbb)
reply_person_name = await get_person_info_manager().get_value(reply_person_id, "person_name") reply_person_name = await get_person_info_manager().get_value(reply_person_id, "person_name") or aaa
if not reply_person_name:
reply_person_name = aaa
# 在内容前加上回复信息 # 在内容前加上回复信息
translate_text = re.sub(reply_pattern, f"回复 {reply_person_name}", translate_text, count=1) translate_text = re.sub(reply_pattern, f"回复 {reply_person_name}", translate_text, count=1)
@@ -206,9 +204,7 @@ async def _find_reply_message(target_stream, reply_to: str) -> Optional[MessageR
aaa = m.group(1) aaa = m.group(1)
bbb = m.group(2) bbb = m.group(2)
at_person_id = get_person_info_manager().get_person_id(platform, bbb) at_person_id = get_person_info_manager().get_person_id(platform, bbb)
at_person_name = await get_person_info_manager().get_value(at_person_id, "person_name") at_person_name = await get_person_info_manager().get_value(at_person_id, "person_name") or aaa
if not at_person_name:
at_person_name = aaa
new_content += f"@{at_person_name}" new_content += f"@{at_person_name}"
last_end = m.end() last_end = m.end()
new_content += translate_text[last_end:] new_content += translate_text[last_end:]
@@ -370,7 +366,14 @@ async def custom_to_stream(
bool: 是否发送成功 bool: 是否发送成功
""" """
return await _send_to_target( return await _send_to_target(
message_type, content, stream_id, display_message, typing, reply_to, storage_message, show_log message_type,
content,
stream_id,
display_message,
typing,
reply_to,
storage_message=storage_message,
show_log=show_log,
) )
@@ -396,7 +399,7 @@ async def text_to_group(
""" """
stream_id = get_chat_manager().get_stream_id(platform, group_id, True) stream_id = get_chat_manager().get_stream_id(platform, group_id, True)
return await _send_to_target("text", text, stream_id, "", typing, reply_to, storage_message) return await _send_to_target("text", text, stream_id, "", typing, reply_to, storage_message=storage_message)
async def text_to_user( async def text_to_user(
@@ -420,7 +423,7 @@ async def text_to_user(
bool: 是否发送成功 bool: 是否发送成功
""" """
stream_id = get_chat_manager().get_stream_id(platform, user_id, False) stream_id = get_chat_manager().get_stream_id(platform, user_id, False)
return await _send_to_target("text", text, stream_id, "", typing, reply_to, storage_message) return await _send_to_target("text", text, stream_id, "", typing, reply_to, storage_message=storage_message)
async def emoji_to_group(emoji_base64: str, group_id: str, platform: str = "qq", storage_message: bool = True) -> bool: async def emoji_to_group(emoji_base64: str, group_id: str, platform: str = "qq", storage_message: bool = True) -> bool:
@@ -543,7 +546,9 @@ async def custom_to_group(
bool: 是否发送成功 bool: 是否发送成功
""" """
stream_id = get_chat_manager().get_stream_id(platform, group_id, True) stream_id = get_chat_manager().get_stream_id(platform, group_id, True)
return await _send_to_target(message_type, content, stream_id, display_message, typing, reply_to, storage_message) return await _send_to_target(
message_type, content, stream_id, display_message, typing, reply_to, storage_message=storage_message
)
async def custom_to_user( async def custom_to_user(
@@ -571,7 +576,9 @@ async def custom_to_user(
bool: 是否发送成功 bool: 是否发送成功
""" """
stream_id = get_chat_manager().get_stream_id(platform, user_id, False) stream_id = get_chat_manager().get_stream_id(platform, user_id, False)
return await _send_to_target(message_type, content, stream_id, display_message, typing, reply_to, storage_message) return await _send_to_target(
message_type, content, stream_id, display_message, typing, reply_to, storage_message=storage_message
)
async def custom_message( async def custom_message(
@@ -611,4 +618,6 @@ async def custom_message(
await send_api.custom_message("audio", audio_base64, "123456", True, reply_to="张三:你好") await send_api.custom_message("audio", audio_base64, "123456", True, reply_to="张三:你好")
""" """
stream_id = get_chat_manager().get_stream_id(platform, target_id, is_group) stream_id = get_chat_manager().get_stream_id(platform, target_id, is_group)
return await _send_to_target(message_type, content, stream_id, display_message, typing, reply_to, storage_message) return await _send_to_target(
message_type, content, stream_id, display_message, typing, reply_to, storage_message=storage_message
)

View File

@@ -38,7 +38,7 @@ class BaseAction(ABC):
chat_stream: ChatStream, chat_stream: ChatStream,
log_prefix: str = "", log_prefix: str = "",
plugin_config: Optional[dict] = None, plugin_config: Optional[dict] = None,
action_message: dict = None, action_message: Optional[dict] = None,
**kwargs, **kwargs,
): ):
"""初始化Action组件 """初始化Action组件
@@ -63,7 +63,7 @@ class BaseAction(ABC):
self.cycle_timers = cycle_timers self.cycle_timers = cycle_timers
self.thinking_id = thinking_id self.thinking_id = thinking_id
self.log_prefix = log_prefix self.log_prefix = log_prefix
# 保存插件配置 # 保存插件配置
self.plugin_config = plugin_config or {} self.plugin_config = plugin_config or {}
@@ -92,10 +92,10 @@ class BaseAction(ABC):
self.chat_stream = chat_stream or kwargs.get("chat_stream") self.chat_stream = chat_stream or kwargs.get("chat_stream")
self.chat_id = self.chat_stream.stream_id self.chat_id = self.chat_stream.stream_id
self.platform = getattr(self.chat_stream, "platform", None) self.platform = getattr(self.chat_stream, "platform", None)
# 初始化基础信息(带类型注解) # 初始化基础信息(带类型注解)
self.action_message = action_message self.action_message = action_message
self.group_id = None self.group_id = None
self.group_name = None self.group_name = None
self.user_id = None self.user_id = None
@@ -103,15 +103,17 @@ class BaseAction(ABC):
self.is_group = False self.is_group = False
self.target_id = None self.target_id = None
self.has_action_message = False self.has_action_message = False
if self.action_message: if self.action_message:
self.has_action_message = True self.has_action_message = True
else:
self.action_message = {}
if self.has_action_message: if self.has_action_message:
if self.action_name != "no_reply": if self.action_name != "no_reply":
self.group_id = str(self.action_message.get("chat_info_group_id", None)) self.group_id = str(self.action_message.get("chat_info_group_id", None))
self.group_name = self.action_message.get("chat_info_group_name", None) self.group_name = self.action_message.get("chat_info_group_name", None)
self.user_id = str(self.action_message.get("user_id", None)) self.user_id = str(self.action_message.get("user_id", None))
self.user_nickname = self.action_message.get("user_nickname", None) self.user_nickname = self.action_message.get("user_nickname", None)
if self.group_id: if self.group_id:
@@ -132,8 +134,6 @@ class BaseAction(ABC):
self.is_group = False self.is_group = False
self.target_id = self.user_id self.target_id = self.user_id
logger.debug(f"{self.log_prefix} Action组件初始化完成") logger.debug(f"{self.log_prefix} Action组件初始化完成")
logger.info( logger.info(
f"{self.log_prefix} 聊天信息: 类型={'群聊' if self.is_group else '私聊'}, 平台={self.platform}, 目标={self.target_id}" f"{self.log_prefix} 聊天信息: 类型={'群聊' if self.is_group else '私聊'}, 平台={self.platform}, 目标={self.target_id}"
@@ -199,7 +199,9 @@ class BaseAction(ABC):
logger.error(f"{self.log_prefix} 等待新消息时发生错误: {e}") logger.error(f"{self.log_prefix} 等待新消息时发生错误: {e}")
return False, f"等待新消息失败: {str(e)}" return False, f"等待新消息失败: {str(e)}"
async def send_text(self, content: str, reply_to: str = "", reply_to_platform_id: str = "", typing: bool = False) -> bool: async def send_text(
self, content: str, reply_to: str = "", reply_to_platform_id: str = "", typing: bool = False
) -> bool:
"""发送文本消息 """发送文本消息
Args: Args:
@@ -299,7 +301,7 @@ class BaseAction(ABC):
) )
async def send_command( async def send_command(
self, command_name: str, args: dict = None, display_message: str = None, storage_message: bool = True self, command_name: str, args: Optional[dict] = None, display_message: str = "", storage_message: bool = True
) -> bool: ) -> bool:
"""发送命令消息 """发送命令消息

View File

@@ -135,7 +135,7 @@ class BaseCommand(ABC):
) )
async def send_command( async def send_command(
self, command_name: str, args: dict = None, display_message: str = "", storage_message: bool = True self, command_name: str, args: Optional[dict] = None, display_message: str = "", storage_message: bool = True
) -> bool: ) -> bool:
"""发送命令消息 """发送命令消息

View File

@@ -346,67 +346,67 @@ class ComponentRegistry:
# === 状态管理方法 === # === 状态管理方法 ===
def enable_component(self, component_name: str, component_type: ComponentType = None) -> bool: # def enable_component(self, component_name: str, component_type: ComponentType = None) -> bool:
# -------------------------------- NEED REFACTORING -------------------------------- # # -------------------------------- NEED REFACTORING --------------------------------
# -------------------------------- LOGIC ERROR ------------------------------------- # # -------------------------------- LOGIC ERROR -------------------------------------
"""启用组件,支持命名空间解析""" # """启用组件,支持命名空间解析"""
# 首先尝试找到正确的命名空间化名称 # # 首先尝试找到正确的命名空间化名称
component_info = self.get_component_info(component_name, component_type) # component_info = self.get_component_info(component_name, component_type)
if not component_info: # if not component_info:
return False # return False
# 根据组件类型构造正确的命名空间化名称 # # 根据组件类型构造正确的命名空间化名称
if component_info.component_type == ComponentType.ACTION: # if component_info.component_type == ComponentType.ACTION:
namespaced_name = f"action.{component_name}" if "." not in component_name else component_name # namespaced_name = f"action.{component_name}" if "." not in component_name else component_name
elif component_info.component_type == ComponentType.COMMAND: # elif component_info.component_type == ComponentType.COMMAND:
namespaced_name = f"command.{component_name}" if "." not in component_name else component_name # namespaced_name = f"command.{component_name}" if "." not in component_name else component_name
else: # else:
namespaced_name = ( # namespaced_name = (
f"{component_info.component_type.value}.{component_name}" # f"{component_info.component_type.value}.{component_name}"
if "." not in component_name # if "." not in component_name
else component_name # else component_name
) # )
if namespaced_name in self._components: # if namespaced_name in self._components:
self._components[namespaced_name].enabled = True # self._components[namespaced_name].enabled = True
# 如果是Action更新默认动作集 # # 如果是Action更新默认动作集
# ---- HERE ---- # # ---- HERE ----
# if isinstance(component_info, ActionInfo): # # if isinstance(component_info, ActionInfo):
# self._action_descriptions[component_name] = component_info.description # # self._action_descriptions[component_name] = component_info.description
logger.debug(f"已启用组件: {component_name} -> {namespaced_name}") # logger.debug(f"已启用组件: {component_name} -> {namespaced_name}")
return True # return True
return False # return False
def disable_component(self, component_name: str, component_type: ComponentType = None) -> bool: # def disable_component(self, component_name: str, component_type: ComponentType = None) -> bool:
# -------------------------------- NEED REFACTORING -------------------------------- # # -------------------------------- NEED REFACTORING --------------------------------
# -------------------------------- LOGIC ERROR ------------------------------------- # # -------------------------------- LOGIC ERROR -------------------------------------
"""禁用组件,支持命名空间解析""" # """禁用组件,支持命名空间解析"""
# 首先尝试找到正确的命名空间化名称 # # 首先尝试找到正确的命名空间化名称
component_info = self.get_component_info(component_name, component_type) # component_info = self.get_component_info(component_name, component_type)
if not component_info: # if not component_info:
return False # return False
# 根据组件类型构造正确的命名空间化名称 # # 根据组件类型构造正确的命名空间化名称
if component_info.component_type == ComponentType.ACTION: # if component_info.component_type == ComponentType.ACTION:
namespaced_name = f"action.{component_name}" if "." not in component_name else component_name # namespaced_name = f"action.{component_name}" if "." not in component_name else component_name
elif component_info.component_type == ComponentType.COMMAND: # elif component_info.component_type == ComponentType.COMMAND:
namespaced_name = f"command.{component_name}" if "." not in component_name else component_name # namespaced_name = f"command.{component_name}" if "." not in component_name else component_name
else: # else:
namespaced_name = ( # namespaced_name = (
f"{component_info.component_type.value}.{component_name}" # f"{component_info.component_type.value}.{component_name}"
if "." not in component_name # if "." not in component_name
else component_name # else component_name
) # )
if namespaced_name in self._components: # if namespaced_name in self._components:
self._components[namespaced_name].enabled = False # self._components[namespaced_name].enabled = False
# 如果是Action从默认动作集中移除 # # 如果是Action从默认动作集中移除
# ---- HERE ---- # # ---- HERE ----
# if component_name in self._action_descriptions: # # if component_name in self._action_descriptions:
# del self._action_descriptions[component_name] # # del self._action_descriptions[component_name]
logger.debug(f"已禁用组件: {component_name} -> {namespaced_name}") # logger.debug(f"已禁用组件: {component_name} -> {namespaced_name}")
return True # return True
return False # return False
def get_registry_stats(self) -> Dict[str, Any]: def get_registry_stats(self) -> Dict[str, Any]:
"""获取注册中心统计信息""" """获取注册中心统计信息"""

View File

@@ -7,7 +7,7 @@
import subprocess import subprocess
import sys import sys
import importlib import importlib
from typing import List, Dict, Tuple from typing import List, Dict, Tuple, Any
from src.common.logger import get_logger from src.common.logger import get_logger
from src.plugin_system.base.component_types import PythonDependency from src.plugin_system.base.component_types import PythonDependency
@@ -176,7 +176,7 @@ class DependencyManager:
logger.error(f"生成requirements文件失败: {str(e)}") logger.error(f"生成requirements文件失败: {str(e)}")
return False return False
def get_install_summary(self) -> Dict[str, any]: def get_install_summary(self) -> Dict[str, Any]:
"""获取安装摘要""" """获取安装摘要"""
return { return {
"install_log": self.install_log.copy(), "install_log": self.install_log.copy(),

View File

@@ -197,29 +197,29 @@ class PluginManager:
"""获取所有启用的插件信息""" """获取所有启用的插件信息"""
return list(component_registry.get_enabled_plugins().values()) return list(component_registry.get_enabled_plugins().values())
def enable_plugin(self, plugin_name: str) -> bool: # def enable_plugin(self, plugin_name: str) -> bool:
# -------------------------------- NEED REFACTORING -------------------------------- # # -------------------------------- NEED REFACTORING --------------------------------
"""启用插件""" # """启用插件"""
if plugin_info := component_registry.get_plugin_info(plugin_name): # if plugin_info := component_registry.get_plugin_info(plugin_name):
plugin_info.enabled = True # plugin_info.enabled = True
# 启用插件的所有组件 # # 启用插件的所有组件
for component in plugin_info.components: # for component in plugin_info.components:
component_registry.enable_component(component.name) # component_registry.enable_component(component.name)
logger.debug(f"已启用插件: {plugin_name}") # logger.debug(f"已启用插件: {plugin_name}")
return True # return True
return False # return False
def disable_plugin(self, plugin_name: str) -> bool: # def disable_plugin(self, plugin_name: str) -> bool:
# -------------------------------- NEED REFACTORING -------------------------------- # # -------------------------------- NEED REFACTORING --------------------------------
"""禁用插件""" # """禁用插件"""
if plugin_info := component_registry.get_plugin_info(plugin_name): # if plugin_info := component_registry.get_plugin_info(plugin_name):
plugin_info.enabled = False # plugin_info.enabled = False
# 禁用插件的所有组件 # # 禁用插件的所有组件
for component in plugin_info.components: # for component in plugin_info.components:
component_registry.disable_component(component.name) # component_registry.disable_component(component.name)
logger.debug(f"已禁用插件: {plugin_name}") # logger.debug(f"已禁用插件: {plugin_name}")
return True # return True
return False # return False
def get_plugin_instance(self, plugin_name: str) -> Optional["PluginBase"]: def get_plugin_instance(self, plugin_name: str) -> Optional["PluginBase"]:
"""获取插件实例 """获取插件实例

View File

@@ -28,10 +28,10 @@ class CompareNumbersTool(BaseTool):
Returns: Returns:
dict: 工具执行结果 dict: 工具执行结果
""" """
try: num1: int | float = function_args.get("num1") # type: ignore
num1 = function_args.get("num1") num2: int | float = function_args.get("num2") # type: ignore
num2 = function_args.get("num2")
try:
if num1 > num2: if num1 > num2:
result = f"{num1} 大于 {num2}" result = f"{num1} 大于 {num2}"
elif num1 < num2: elif num1 < num2:

View File

@@ -68,10 +68,10 @@ class RenamePersonTool(BaseTool):
) )
result = await person_info_manager.qv_person_name( result = await person_info_manager.qv_person_name(
person_id=person_id, person_id=person_id,
user_nickname=user_nickname, user_nickname=user_nickname, # type: ignore
user_cardname=user_cardname, user_cardname=user_cardname, # type: ignore
user_avatar=user_avatar, user_avatar=user_avatar, # type: ignore
request=request_context, request=request_context, # type: ignore
) )
# 3. 处理结果 # 3. 处理结果