typing fix

This commit is contained in:
UnCLAS-Prommer
2025-07-17 00:10:41 +08:00
parent 6e838ccc74
commit 1aa2734d62
26 changed files with 329 additions and 293 deletions

27
bot.py
View File

@@ -146,7 +146,7 @@ def _calculate_file_hash(file_path: Path, file_type: str) -> str:
if not file_path.exists():
logger.error(f"{file_type} 文件不存在")
raise FileNotFoundError(f"{file_type} 文件不存在")
with open(file_path, "r", encoding="utf-8") as f:
content = f.read()
return hashlib.md5(content.encode("utf-8")).hexdigest()
@@ -154,21 +154,21 @@ def _calculate_file_hash(file_path: Path, file_type: str) -> str:
def _check_agreement_status(file_hash: str, confirm_file: Path, env_var: str) -> tuple[bool, bool]:
"""检查协议确认状态
Returns:
tuple[bool, bool]: (已确认, 未更新)
"""
# 检查环境变量确认
if file_hash == os.getenv(env_var):
return True, False
# 检查确认文件
if confirm_file.exists():
with open(confirm_file, "r", encoding="utf-8") as f:
confirmed_content = f.read()
if file_hash == confirmed_content:
return True, False
return False, True
@@ -178,7 +178,7 @@ def _prompt_user_confirmation(eula_hash: str, privacy_hash: str) -> None:
confirm_logger.critical(
f'输入"同意""confirmed"或设置环境变量"EULA_AGREE={eula_hash}""PRIVACY_AGREE={privacy_hash}"继续运行'
)
while True:
user_input = input().strip().lower()
if user_input in ["同意", "confirmed"]:
@@ -186,13 +186,12 @@ def _prompt_user_confirmation(eula_hash: str, privacy_hash: str) -> None:
confirm_logger.critical('请输入"同意""confirmed"以继续运行')
def _save_confirmations(eula_updated: bool, privacy_updated: bool,
eula_hash: str, privacy_hash: str) -> None:
def _save_confirmations(eula_updated: bool, privacy_updated: bool, eula_hash: str, privacy_hash: str) -> None:
"""保存用户确认结果"""
if eula_updated:
logger.info(f"更新EULA确认文件{eula_hash}")
Path("eula.confirmed").write_text(eula_hash, encoding="utf-8")
if privacy_updated:
logger.info(f"更新隐私条款确认文件{privacy_hash}")
Path("privacy.confirmed").write_text(privacy_hash, encoding="utf-8")
@@ -203,19 +202,17 @@ def check_eula():
# 计算文件哈希值
eula_hash = _calculate_file_hash(Path("EULA.md"), "EULA.md")
privacy_hash = _calculate_file_hash(Path("PRIVACY.md"), "PRIVACY.md")
# 检查确认状态
eula_confirmed, eula_updated = _check_agreement_status(
eula_hash, Path("eula.confirmed"), "EULA_AGREE"
)
eula_confirmed, eula_updated = _check_agreement_status(eula_hash, Path("eula.confirmed"), "EULA_AGREE")
privacy_confirmed, privacy_updated = _check_agreement_status(
privacy_hash, Path("privacy.confirmed"), "PRIVACY_AGREE"
)
# 早期返回:如果都已确认且未更新
if eula_confirmed and privacy_confirmed:
return
# 如果有更新,需要重新确认
if eula_updated or privacy_updated:
_prompt_user_confirmation(eula_hash, privacy_hash)
@@ -225,7 +222,7 @@ def check_eula():
def raw_main():
# 利用 TZ 环境变量设定程序工作的时区
if platform.system().lower() != "windows":
time.tzset()
time.tzset() # type: ignore
check_eula()
logger.info("检查EULA和隐私条款完成")

View File

@@ -107,11 +107,12 @@ class ExpressionLearner:
last_active_time = expr.get("last_active_time", time.time())
# 查重同chat_id+type+situation+style
from src.common.database.database_model import Expression
query = Expression.select().where(
(Expression.chat_id == chat_id) &
(Expression.type == type_str) &
(Expression.situation == situation) &
(Expression.style == style_val)
(Expression.chat_id == chat_id)
& (Expression.type == type_str)
& (Expression.situation == situation)
& (Expression.style == style_val)
)
if query.exists():
expr_obj = query.get()
@@ -125,7 +126,7 @@ class ExpressionLearner:
count=count,
last_active_time=last_active_time,
chat_id=chat_id,
type=type_str
type=type_str,
)
logger.info(f"已迁移 {expr_file} 到数据库")
except Exception as e:
@@ -149,24 +150,28 @@ class ExpressionLearner:
# 直接从数据库查询
style_query = Expression.select().where((Expression.chat_id == chat_id) & (Expression.type == "style"))
for expr in style_query:
learnt_style_expressions.append({
"situation": expr.situation,
"style": expr.style,
"count": expr.count,
"last_active_time": expr.last_active_time,
"source_id": chat_id,
"type": "style"
})
learnt_style_expressions.append(
{
"situation": expr.situation,
"style": expr.style,
"count": expr.count,
"last_active_time": expr.last_active_time,
"source_id": chat_id,
"type": "style",
}
)
grammar_query = Expression.select().where((Expression.chat_id == chat_id) & (Expression.type == "grammar"))
for expr in grammar_query:
learnt_grammar_expressions.append({
"situation": expr.situation,
"style": expr.style,
"count": expr.count,
"last_active_time": expr.last_active_time,
"source_id": chat_id,
"type": "grammar"
})
learnt_grammar_expressions.append(
{
"situation": expr.situation,
"style": expr.style,
"count": expr.count,
"last_active_time": expr.last_active_time,
"source_id": chat_id,
"type": "grammar",
}
)
return learnt_style_expressions, learnt_grammar_expressions
def is_similar(self, s1: str, s2: str) -> bool:
@@ -213,14 +218,16 @@ class ExpressionLearner:
logger.error(f"全局衰减{type}表达方式失败: {e}")
continue
learnt_style: Optional[List[Tuple[str, str, str]]] = []
learnt_grammar: Optional[List[Tuple[str, str, str]]] = []
# 学习新的表达方式(这里会进行局部衰减)
for _ in range(3):
learnt_style: Optional[List[Tuple[str, str, str]]] = await self.learn_and_store(type="style", num=25)
learnt_style = await self.learn_and_store(type="style", num=25)
if not learnt_style:
return [], []
for _ in range(1):
learnt_grammar: Optional[List[Tuple[str, str, str]]] = await self.learn_and_store(type="grammar", num=10)
learnt_grammar = await self.learn_and_store(type="grammar", num=10)
if not learnt_grammar:
return [], []
@@ -321,10 +328,10 @@ class ExpressionLearner:
for new_expr in expr_list:
# 查找是否已存在相似表达方式
query = Expression.select().where(
(Expression.chat_id == chat_id) &
(Expression.type == type) &
(Expression.situation == new_expr["situation"]) &
(Expression.style == new_expr["style"])
(Expression.chat_id == chat_id)
& (Expression.type == type)
& (Expression.situation == new_expr["situation"])
& (Expression.style == new_expr["style"])
)
if query.exists():
expr_obj = query.get()
@@ -342,13 +349,17 @@ class ExpressionLearner:
count=1,
last_active_time=current_time,
chat_id=chat_id,
type=type
type=type,
)
# 限制最大数量
exprs = list(Expression.select().where((Expression.chat_id == chat_id) & (Expression.type == type)).order_by(Expression.count.asc()))
exprs = list(
Expression.select()
.where((Expression.chat_id == chat_id) & (Expression.type == type))
.order_by(Expression.count.asc())
)
if len(exprs) > MAX_EXPRESSION_COUNT:
# 删除count最小的多余表达方式
for expr in exprs[:len(exprs) - MAX_EXPRESSION_COUNT]:
for expr in exprs[: len(exprs) - MAX_EXPRESSION_COUNT]:
expr.delete_instance()
return learnt_expressions

View File

@@ -9,51 +9,49 @@ from src.common.logger import get_logger
import traceback
from src.config.config import global_config
from src.common.database.database_model import Memory # Peewee Models导入
from src.common.database.database_model import Memory # Peewee Models导入
logger = get_logger(__name__)
class MemoryItem:
def __init__(self,memory_id:str,chat_id:str,memory_text:str,keywords:list[str]):
def __init__(self, memory_id: str, chat_id: str, memory_text: str, keywords: list[str]):
self.memory_id = memory_id
self.chat_id = chat_id
self.memory_text:str = memory_text
self.keywords:list[str] = keywords
self.create_time:float = time.time()
self.last_view_time:float = time.time()
self.memory_text: str = memory_text
self.keywords: list[str] = keywords
self.create_time: float = time.time()
self.last_view_time: float = time.time()
class MemoryManager:
def __init__(self):
# self.memory_items:list[MemoryItem] = []
pass
class InstantMemory:
def __init__(self,chat_id):
self.chat_id = chat_id
def __init__(self, chat_id):
self.chat_id = chat_id
self.last_view_time = time.time()
self.summary_model = LLMRequest(
model=global_config.model.memory,
temperature=0.5,
request_type="memory.summary",
)
async def if_need_build(self,text):
async def if_need_build(self, text):
prompt = f"""
请判断以下内容中是否有值得记忆的信息如果有请输出1否则输出0
{text}
请只输出1或0就好
"""
try:
response,_ = await self.summary_model.generate_response_async(prompt)
response, _ = await self.summary_model.generate_response_async(prompt)
print(prompt)
print(response)
if "1" in response:
return True
else:
@@ -61,8 +59,8 @@ class InstantMemory:
except Exception as e:
logger.error(f"判断是否需要记忆出现错误:{str(e)} {traceback.format_exc()}")
return False
async def build_memory(self,text):
async def build_memory(self, text):
prompt = f"""
以下内容中存在值得记忆的信息,请你从中总结出一段值得记忆的信息,并输出
{text}
@@ -73,7 +71,7 @@ class InstantMemory:
}}
"""
try:
response,_ = await self.summary_model.generate_response_async(prompt)
response, _ = await self.summary_model.generate_response_async(prompt)
print(prompt)
print(response)
if not response:
@@ -81,53 +79,53 @@ class InstantMemory:
try:
repaired = repair_json(response)
result = json.loads(repaired)
memory_text = result.get('memory_text', '')
keywords = result.get('keywords', '')
memory_text = result.get("memory_text", "")
keywords = result.get("keywords", "")
if isinstance(keywords, str):
keywords_list = [k.strip() for k in keywords.split('/') if k.strip()]
keywords_list = [k.strip() for k in keywords.split("/") if k.strip()]
elif isinstance(keywords, list):
keywords_list = keywords
else:
keywords_list = []
return {'memory_text': memory_text, 'keywords': keywords_list}
return {"memory_text": memory_text, "keywords": keywords_list}
except Exception as parse_e:
logger.error(f"解析记忆json失败{str(parse_e)} {traceback.format_exc()}")
return None
except Exception as e:
logger.error(f"构建记忆出现错误:{str(e)} {traceback.format_exc()}")
return None
async def create_and_store_memory(self,text):
async def create_and_store_memory(self, text):
if_need = await self.if_need_build(text)
if if_need:
logger.info(f"需要记忆:{text}")
memory = await self.build_memory(text)
if memory and memory.get('memory_text'):
memory = await self.build_memory(text)
if memory and memory.get("memory_text"):
memory_id = f"{self.chat_id}_{time.time()}"
memory_item = MemoryItem(
memory_id=memory_id,
chat_id=self.chat_id,
memory_text=memory['memory_text'],
keywords=memory.get('keywords', [])
memory_text=memory["memory_text"],
keywords=memory.get("keywords", []),
)
await self.store_memory(memory_item)
else:
logger.info(f"不需要记忆:{text}")
async def store_memory(self,memory_item:MemoryItem):
async def store_memory(self, memory_item: MemoryItem):
memory = Memory(
memory_id=memory_item.memory_id,
chat_id=memory_item.chat_id,
memory_text=memory_item.memory_text,
keywords=memory_item.keywords,
create_time=memory_item.create_time,
last_view_time=memory_item.last_view_time
last_view_time=memory_item.last_view_time,
)
memory.save()
async def get_memory(self,target:str):
async def get_memory(self, target: str):
from json_repair import repair_json
prompt = f"""
请根据以下发言内容,判断是否需要提取记忆
{target}
@@ -144,7 +142,7 @@ class InstantMemory:
请只输出json格式不要输出其他多余内容
"""
try:
response,_ = await self.summary_model.generate_response_async(prompt)
response, _ = await self.summary_model.generate_response_async(prompt)
print(prompt)
print(response)
if not response:
@@ -153,15 +151,15 @@ class InstantMemory:
repaired = repair_json(response)
result = json.loads(repaired)
# 解析keywords
keywords = result.get('keywords', '')
keywords = result.get("keywords", "")
if isinstance(keywords, str):
keywords_list = [k.strip() for k in keywords.split('/') if k.strip()]
keywords_list = [k.strip() for k in keywords.split("/") if k.strip()]
elif isinstance(keywords, list):
keywords_list = keywords
else:
keywords_list = []
# 解析time为时间段
time_str = result.get('time', '').strip()
time_str = result.get("time", "").strip()
start_time, end_time = self._parse_time_range(time_str)
logger.info(f"start_time: {start_time}, end_time: {end_time}")
# 检索包含关键词的记忆
@@ -170,16 +168,15 @@ class InstantMemory:
start_ts = start_time.timestamp()
end_ts = end_time.timestamp()
query = Memory.select().where(
(Memory.chat_id == self.chat_id) &
(Memory.create_time >= start_ts) &
(Memory.create_time < end_ts)
(Memory.chat_id == self.chat_id)
& (Memory.create_time >= start_ts) # type: ignore
& (Memory.create_time < end_ts) # type: ignore
)
else:
query = Memory.select().where(Memory.chat_id == self.chat_id)
for mem in query:
#对每条记忆
# 对每条记忆
mem_keywords = mem.keywords or []
parsed = ast.literal_eval(mem_keywords)
if isinstance(parsed, list):
@@ -212,6 +209,7 @@ class InstantMemory:
- 空字符串:返回(None, None)
"""
from datetime import datetime, timedelta
now = datetime.now()
if not time_str:
return 0, now
@@ -251,8 +249,8 @@ class InstantMemory:
if m:
months = int(m.group(1))
# 近似每月30天
start = (now - timedelta(days=months*30)).replace(hour=0, minute=0, second=0, microsecond=0)
start = (now - timedelta(days=months * 30)).replace(hour=0, minute=0, second=0, microsecond=0)
end = start + timedelta(days=1)
return start, end
# 其他无法解析
return 0, now
return 0, now

View File

@@ -30,7 +30,7 @@ class ChatMessageContext:
def get_template_name(self) -> Optional[str]:
"""获取模板名称"""
if self.message.message_info.template_info and not self.message.message_info.template_info.template_default:
return self.message.message_info.template_info.template_name
return self.message.message_info.template_info.template_name # type: ignore
return None
def get_last_message(self) -> "MessageRecv":

View File

@@ -107,9 +107,9 @@ class MessageRecv(Message):
self.is_picid = False
self.has_picid = False
self.is_mentioned = None
self.is_command = False
self.priority_mode = "interest"
self.priority_info = None
self.interest_value: float = None # type: ignore
@@ -181,6 +181,7 @@ class MessageRecv(Message):
logger.error(f"处理消息段失败: {str(e)}, 类型: {segment.type}, 数据: {segment.data}")
return f"[处理失败的{segment.type}消息]"
@dataclass
class MessageRecvS4U(MessageRecv):
def __init__(self, message_dict: dict[str, Any]):
@@ -194,10 +195,10 @@ class MessageRecvS4U(MessageRecv):
self.superchat_price = None
self.superchat_message_text = None
self.is_screen = False
async def process(self) -> None:
self.processed_plain_text = await self._process_message_segments(self.message_segment)
async def _process_single_segment(self, segment: Seg) -> str:
"""处理单个消息段
@@ -252,7 +253,7 @@ class MessageRecvS4U(MessageRecv):
elif segment.type == "gift":
self.is_gift = True
# 解析gift_info格式为"名称:数量"
name, count = segment.data.split(":", 1)
name, count = segment.data.split(":", 1) # type: ignore
self.gift_info = segment.data
self.gift_name = name.strip()
self.gift_count = int(count.strip())
@@ -260,13 +261,15 @@ class MessageRecvS4U(MessageRecv):
elif segment.type == "superchat":
self.is_superchat = True
self.superchat_info = segment.data
price,message_text = segment.data.split(":", 1)
price, message_text = segment.data.split(":", 1) # type: ignore
self.superchat_price = price.strip()
self.superchat_message_text = message_text.strip()
self.processed_plain_text = str(self.superchat_message_text)
self.processed_plain_text += f"(注意:这是一条超级弹幕信息,价值{self.superchat_price}元,请你认真回复)"
self.processed_plain_text += (
f"(注意:这是一条超级弹幕信息,价值{self.superchat_price}元,请你认真回复)"
)
return self.processed_plain_text
elif segment.type == "screen":
self.is_screen = True

View File

@@ -80,7 +80,7 @@ class ActionManager:
chat_stream: ChatStream,
log_prefix: str,
shutting_down: bool = False,
action_message: dict = None,
action_message: Optional[dict] = None,
) -> Optional[BaseAction]:
"""
创建动作处理器实例

View File

@@ -252,7 +252,7 @@ def _build_readable_messages_internal(
pic_id_mapping: Optional[Dict[str, str]] = None,
pic_counter: int = 1,
show_pic: bool = True,
message_id_list: List[Dict[str, Any]] = None,
message_id_list: Optional[List[Dict[str, Any]]] = None,
) -> Tuple[str, List[Tuple[float, str, str]], Dict[str, str], int]:
"""
内部辅助函数,构建可读消息字符串和原始消息详情列表。
@@ -615,7 +615,7 @@ def build_readable_actions(actions: List[Dict[str, Any]]) -> str:
for action in actions:
action_time = action.get("time", current_time)
action_name = action.get("action_name", "未知动作")
if action_name == "no_action" or action_name == "no_reply":
if action_name in ["no_action", "no_reply"]:
continue
action_prompt_display = action.get("action_prompt_display", "无具体内容")
@@ -697,7 +697,7 @@ def build_readable_messages(
truncate: bool = False,
show_actions: bool = False,
show_pic: bool = True,
message_id_list: List[Dict[str, Any]] = None,
message_id_list: Optional[List[Dict[str, Any]]] = None,
) -> str: # sourcery skip: extract-method
"""
将消息列表转换为可读的文本格式。

View File

@@ -1211,7 +1211,7 @@ class StatisticOutputTask(AsyncTask):
f.write(html_template)
def _generate_focus_tab(self, stat: dict[str, Any]) -> str:
# sourcery skip: for-append-to-extend, list-comprehension, use-any
# sourcery skip: for-append-to-extend, list-comprehension, use-any, use-named-expression, use-next
"""生成Focus统计独立分页的HTML内容"""
# 为每个时间段准备Focus数据
@@ -1559,6 +1559,7 @@ class StatisticOutputTask(AsyncTask):
"""
def _generate_versions_tab(self, stat: dict[str, Any]) -> str:
# sourcery skip: use-named-expression, use-next
"""生成版本对比独立分页的HTML内容"""
# 为每个时间段准备版本对比数据
@@ -2306,13 +2307,13 @@ class AsyncStatisticOutputTask(AsyncTask):
# 复用 StatisticOutputTask 的所有方法
def _collect_all_statistics(self, now: datetime):
return StatisticOutputTask._collect_all_statistics(self, now)
return StatisticOutputTask._collect_all_statistics(self, now) # type: ignore
def _statistic_console_output(self, stats: Dict[str, Any], now: datetime):
return StatisticOutputTask._statistic_console_output(self, stats, now)
return StatisticOutputTask._statistic_console_output(self, stats, now) # type: ignore
def _generate_html_report(self, stats: dict[str, Any], now: datetime):
return StatisticOutputTask._generate_html_report(self, stats, now)
return StatisticOutputTask._generate_html_report(self, stats, now) # type: ignore
# 其他需要的方法也可以类似复用...
@staticmethod
@@ -2324,10 +2325,10 @@ class AsyncStatisticOutputTask(AsyncTask):
return StatisticOutputTask._collect_online_time_for_period(collect_period, now)
def _collect_message_count_for_period(self, collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]:
return StatisticOutputTask._collect_message_count_for_period(self, collect_period)
return StatisticOutputTask._collect_message_count_for_period(self, collect_period) # type: ignore
def _collect_focus_statistics_for_period(self, collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]:
return StatisticOutputTask._collect_focus_statistics_for_period(self, collect_period)
return StatisticOutputTask._collect_focus_statistics_for_period(self, collect_period) # type: ignore
def _process_focus_file_data(
self,
@@ -2336,10 +2337,10 @@ class AsyncStatisticOutputTask(AsyncTask):
collect_period: List[Tuple[str, datetime]],
file_time: datetime,
):
return StatisticOutputTask._process_focus_file_data(self, cycles_data, stats, collect_period, file_time)
return StatisticOutputTask._process_focus_file_data(self, cycles_data, stats, collect_period, file_time) # type: ignore
def _calculate_focus_averages(self, stats: Dict[str, Any]):
return StatisticOutputTask._calculate_focus_averages(self, stats)
return StatisticOutputTask._calculate_focus_averages(self, stats) # type: ignore
@staticmethod
def _format_total_stat(stats: Dict[str, Any]) -> str:
@@ -2347,31 +2348,31 @@ class AsyncStatisticOutputTask(AsyncTask):
@staticmethod
def _format_model_classified_stat(stats: Dict[str, Any]) -> str:
return StatisticOutputTask._format_model_classified_stat(stats)
return StatisticOutputTask._format_model_classified_stat(stats) # type: ignore
def _format_chat_stat(self, stats: Dict[str, Any]) -> str:
return StatisticOutputTask._format_chat_stat(self, stats)
return StatisticOutputTask._format_chat_stat(self, stats) # type: ignore
def _format_focus_stat(self, stats: Dict[str, Any]) -> str:
return StatisticOutputTask._format_focus_stat(self, stats)
return StatisticOutputTask._format_focus_stat(self, stats) # type: ignore
def _generate_chart_data(self, stat: dict[str, Any]) -> dict:
return StatisticOutputTask._generate_chart_data(self, stat)
return StatisticOutputTask._generate_chart_data(self, stat) # type: ignore
def _collect_interval_data(self, now: datetime, hours: int, interval_minutes: int) -> dict:
return StatisticOutputTask._collect_interval_data(self, now, hours, interval_minutes)
return StatisticOutputTask._collect_interval_data(self, now, hours, interval_minutes) # type: ignore
def _generate_chart_tab(self, chart_data: dict) -> str:
return StatisticOutputTask._generate_chart_tab(self, chart_data)
return StatisticOutputTask._generate_chart_tab(self, chart_data) # type: ignore
def _get_chat_display_name_from_id(self, chat_id: str) -> str:
return StatisticOutputTask._get_chat_display_name_from_id(self, chat_id)
return StatisticOutputTask._get_chat_display_name_from_id(self, chat_id) # type: ignore
def _generate_focus_tab(self, stat: dict[str, Any]) -> str:
return StatisticOutputTask._generate_focus_tab(self, stat)
return StatisticOutputTask._generate_focus_tab(self, stat) # type: ignore
def _generate_versions_tab(self, stat: dict[str, Any]) -> str:
return StatisticOutputTask._generate_versions_tab(self, stat)
return StatisticOutputTask._generate_versions_tab(self, stat) # type: ignore
def _convert_defaultdict_to_dict(self, data):
return StatisticOutputTask._convert_defaultdict_to_dict(self, data)
return StatisticOutputTask._convert_defaultdict_to_dict(self, data) # type: ignore

View File

@@ -2,14 +2,13 @@ import importlib
import asyncio
from abc import ABC, abstractmethod
from typing import Dict, Optional
from typing import Dict, Optional, Any
from rich.traceback import install
from dataclasses import dataclass
from src.common.logger import get_logger
from src.config.config import global_config
from src.chat.message_receive.chat_stream import ChatStream, GroupInfo
from src.chat.message_receive.message import MessageRecv
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
install(extra_lines=3)
@@ -54,7 +53,7 @@ class WillingInfo:
interested_rate (float): 兴趣度
"""
message: MessageRecv
message: Dict[str, Any] # 原始消息数据
chat: ChatStream
person_info_manager: PersonInfoManager
chat_id: str

View File

@@ -65,7 +65,7 @@ class ChatStreams(BaseModel):
# user_cardname 可能为空字符串或不存在,设置 null=True 更具灵活性。
user_cardname = TextField(null=True)
class Meta:
class Meta: # type: ignore
# 如果 BaseModel.Meta.database 已设置,则此模型将继承该数据库配置。
# 如果不使用带有数据库实例的 BaseModel或者想覆盖它
# 请取消注释并在下面设置数据库实例:
@@ -89,7 +89,7 @@ class LLMUsage(BaseModel):
status = TextField()
timestamp = DateTimeField(index=True) # 更改为 DateTimeField 并添加索引
class Meta:
class Meta: # type: ignore
# 如果 BaseModel.Meta.database 已设置,则此模型将继承该数据库配置。
# database = db
table_name = "llm_usage"
@@ -112,7 +112,7 @@ class Emoji(BaseModel):
usage_count = IntegerField(default=0) # 使用次数(被使用的次数)
last_used_time = FloatField(null=True) # 上次使用时间
class Meta:
class Meta: # type: ignore
# database = db # 继承自 BaseModel
table_name = "emoji"
@@ -162,7 +162,8 @@ class Messages(BaseModel):
is_emoji = BooleanField(default=False)
is_picid = BooleanField(default=False)
is_command = BooleanField(default=False)
class Meta:
class Meta: # type: ignore
# database = db # 继承自 BaseModel
table_name = "messages"
@@ -186,7 +187,7 @@ class ActionRecords(BaseModel):
chat_info_stream_id = TextField()
chat_info_platform = TextField()
class Meta:
class Meta: # type: ignore
# database = db # 继承自 BaseModel
table_name = "action_records"
@@ -206,7 +207,7 @@ class Images(BaseModel):
type = TextField() # 图像类型,例如 "emoji"
vlm_processed = BooleanField(default=False) # 是否已经过VLM处理
class Meta:
class Meta: # type: ignore
table_name = "images"
@@ -220,7 +221,7 @@ class ImageDescriptions(BaseModel):
description = TextField() # 图像的描述
timestamp = FloatField() # 时间戳
class Meta:
class Meta: # type: ignore
# database = db # 继承自 BaseModel
table_name = "image_descriptions"
@@ -236,7 +237,7 @@ class OnlineTime(BaseModel):
start_timestamp = DateTimeField(default=datetime.datetime.now)
end_timestamp = DateTimeField(index=True)
class Meta:
class Meta: # type: ignore
# database = db # 继承自 BaseModel
table_name = "online_time"
@@ -263,10 +264,11 @@ class PersonInfo(BaseModel):
last_know = FloatField(null=True) # 最后一次印象总结时间
attitude = IntegerField(null=True, default=50) # 态度0-100从非常厌恶到十分喜欢
class Meta:
class Meta: # type: ignore
# database = db # 继承自 BaseModel
table_name = "person_info"
class Memory(BaseModel):
memory_id = TextField(index=True)
chat_id = TextField(null=True)
@@ -274,10 +276,11 @@ class Memory(BaseModel):
keywords = TextField(null=True)
create_time = FloatField(null=True)
last_view_time = FloatField(null=True)
class Meta:
class Meta: # type: ignore
table_name = "memory"
class Knowledges(BaseModel):
"""
用于存储知识库条目的模型。
@@ -287,10 +290,11 @@ class Knowledges(BaseModel):
embedding = TextField() # 知识内容的嵌入向量,存储为 JSON 字符串的浮点数列表
# 可以添加其他元数据字段,如 source, create_time 等
class Meta:
class Meta: # type: ignore
# database = db # 继承自 BaseModel
table_name = "knowledges"
class Expression(BaseModel):
"""
用于存储表达风格的模型。
@@ -302,10 +306,11 @@ class Expression(BaseModel):
last_active_time = FloatField()
chat_id = TextField(index=True)
type = TextField()
class Meta:
class Meta: # type: ignore
table_name = "expression"
class ThinkingLog(BaseModel):
chat_id = TextField(index=True)
trigger_text = TextField(null=True)
@@ -326,7 +331,7 @@ class ThinkingLog(BaseModel):
# And: import datetime
created_at = DateTimeField(default=datetime.datetime.now)
class Meta:
class Meta: # type: ignore
table_name = "thinking_logs"
@@ -341,7 +346,7 @@ class GraphNodes(BaseModel):
created_time = FloatField() # 创建时间戳
last_modified = FloatField() # 最后修改时间戳
class Meta:
class Meta: # type: ignore
table_name = "graph_nodes"
@@ -357,7 +362,7 @@ class GraphEdges(BaseModel):
created_time = FloatField() # 创建时间戳
last_modified = FloatField() # 最后修改时间戳
class Meta:
class Meta: # type: ignore
table_name = "graph_edges"

View File

@@ -7,13 +7,13 @@ from datetime import datetime
def get_key_comment(toml_table, key):
# 获取key的注释如果有
if hasattr(toml_table, 'trivia') and hasattr(toml_table.trivia, 'comment'):
if hasattr(toml_table, "trivia") and hasattr(toml_table.trivia, "comment"):
return toml_table.trivia.comment
if hasattr(toml_table, 'value') and isinstance(toml_table.value, dict):
if hasattr(toml_table, "value") and isinstance(toml_table.value, dict):
item = toml_table.value.get(key)
if item is not None and hasattr(item, 'trivia'):
if item is not None and hasattr(item, "trivia"):
return item.trivia.comment
if hasattr(toml_table, 'keys'):
if hasattr(toml_table, "keys"):
for k in toml_table.keys():
if isinstance(k, KeyType) and k.key == key:
return k.trivia.comment
@@ -36,16 +36,16 @@ def compare_dicts(new, old, path=None, new_comments=None, old_comments=None, log
continue
if key not in old:
comment = get_key_comment(new, key)
logs.append(f"新增: {'.'.join(path+[str(key)])} 注释: {comment if comment else ''}")
logs.append(f"新增: {'.'.join(path + [str(key)])} 注释: {comment if comment else ''}")
elif isinstance(new[key], (dict, Table)) and isinstance(old.get(key), (dict, Table)):
compare_dicts(new[key], old[key], path+[str(key)], new_comments, old_comments, logs)
compare_dicts(new[key], old[key], path + [str(key)], new_comments, old_comments, logs)
# 删减项
for key in old:
if key == "version":
continue
if key not in new:
comment = get_key_comment(old, key)
logs.append(f"删减: {'.'.join(path+[str(key)])} 注释: {comment if comment else ''}")
logs.append(f"删减: {'.'.join(path + [str(key)])} 注释: {comment if comment else ''}")
return logs
@@ -95,7 +95,7 @@ def update_config():
if old_version and new_version and old_version == new_version:
print(f"检测到版本号相同 (v{old_version}),跳过更新")
# 如果version相同恢复旧配置文件并返回
shutil.move(old_backup_path, old_config_path)
shutil.move(old_backup_path, old_config_path) # type: ignore
return
else:
print(f"检测到版本号不同: 旧版本 v{old_version} -> 新版本 v{new_version}")

View File

@@ -53,13 +53,13 @@ MMC_VERSION = "0.9.0-snapshot.2"
def get_key_comment(toml_table, key):
# 获取key的注释如果有
if hasattr(toml_table, 'trivia') and hasattr(toml_table.trivia, 'comment'):
if hasattr(toml_table, "trivia") and hasattr(toml_table.trivia, "comment"):
return toml_table.trivia.comment
if hasattr(toml_table, 'value') and isinstance(toml_table.value, dict):
if hasattr(toml_table, "value") and isinstance(toml_table.value, dict):
item = toml_table.value.get(key)
if item is not None and hasattr(item, 'trivia'):
if item is not None and hasattr(item, "trivia"):
return item.trivia.comment
if hasattr(toml_table, 'keys'):
if hasattr(toml_table, "keys"):
for k in toml_table.keys():
if isinstance(k, KeyType) and k.key == key:
return k.trivia.comment
@@ -78,16 +78,16 @@ def compare_dicts(new, old, path=None, logs=None):
continue
if key not in old:
comment = get_key_comment(new, key)
logs.append(f"新增: {'.'.join(path+[str(key)])} 注释: {comment if comment else ''}")
logs.append(f"新增: {'.'.join(path + [str(key)])} 注释: {comment if comment else ''}")
elif isinstance(new[key], (dict, Table)) and isinstance(old.get(key), (dict, Table)):
compare_dicts(new[key], old[key], path+[str(key)], logs)
compare_dicts(new[key], old[key], path + [str(key)], logs)
# 删减项
for key in old:
if key == "version":
continue
if key not in new:
comment = get_key_comment(old, key)
logs.append(f"删减: {'.'.join(path+[str(key)])} 注释: {comment if comment else ''}")
logs.append(f"删减: {'.'.join(path + [str(key)])} 注释: {comment if comment else ''}")
return logs
@@ -99,6 +99,7 @@ def get_value_by_path(d, path):
return None
return d
def set_value_by_path(d, path, value):
for k in path[:-1]:
if k not in d or not isinstance(d[k], dict):
@@ -106,6 +107,7 @@ def set_value_by_path(d, path, value):
d = d[k]
d[path[-1]] = value
def compare_default_values(new, old, path=None, logs=None, changes=None):
# 递归比较两个dict找出默认值变化项
if path is None:
@@ -119,12 +121,14 @@ def compare_default_values(new, old, path=None, logs=None, changes=None):
continue
if key in old:
if isinstance(new[key], (dict, Table)) and isinstance(old[key], (dict, Table)):
compare_default_values(new[key], old[key], path+[str(key)], logs, changes)
compare_default_values(new[key], old[key], path + [str(key)], logs, changes)
else:
# 只要值发生变化就记录
if new[key] != old[key]:
logs.append(f"默认值变化: {'.'.join(path+[str(key)])} 旧默认值: {old[key]} 新默认值: {new[key]}")
changes.append((path+[str(key)], old[key], new[key]))
logs.append(
f"默认值变化: {'.'.join(path + [str(key)])} 旧默认值: {old[key]} 新默认值: {new[key]}"
)
changes.append((path + [str(key)], old[key], new[key]))
return logs, changes
@@ -148,8 +152,8 @@ def update_config():
return None
with open(toml_path, "r", encoding="utf-8") as f:
doc = tomlkit.load(f)
if "inner" in doc and "version" in doc["inner"]:
return doc["inner"]["version"]
if "inner" in doc and "version" in doc["inner"]: # type: ignore
return doc["inner"]["version"] # type: ignore
return None
template_version = get_version_from_toml(template_path)
@@ -186,7 +190,9 @@ def update_config():
old_value = get_value_by_path(old_config, path)
if old_value == old_default:
set_value_by_path(old_config, path, new_default)
logger.info(f"已自动将配置 {'.'.join(path)} 的值从旧默认值 {old_default} 更新为新默认值 {new_default}")
logger.info(
f"已自动将配置 {'.'.join(path)} 的值从旧默认值 {old_default} 更新为新默认值 {new_default}"
)
else:
logger.info("未检测到模板默认值变动")
# 保存旧配置的变更(后续合并逻辑会用到 old_config
@@ -229,7 +235,9 @@ def update_config():
logger.info(f"检测到配置文件版本号相同 (v{old_version}),跳过更新")
return
else:
logger.info(f"\n----------------------------------------\n检测到版本号不同: 旧版本 v{old_version} -> 新版本 v{new_version}\n----------------------------------------")
logger.info(
f"\n----------------------------------------\n检测到版本号不同: 旧版本 v{old_version} -> 新版本 v{new_version}\n----------------------------------------"
)
else:
logger.info("已有配置文件未检测到版本号,可能是旧版本。将进行更新")
@@ -321,6 +329,7 @@ class Config(ConfigBase):
debug: DebugConfig
custom_prompt: CustomPromptConfig
def load_config(config_path: str) -> Config:
"""
加载配置文件

View File

@@ -39,7 +39,7 @@ class LLMRequestOff:
}
# 发送请求到完整的 chat/completions 端点
api_url = f"{self.base_url.rstrip('/')}/chat/completions"
api_url = f"{self.base_url.rstrip('/')}/chat/completions" # type: ignore
logger.info(f"Request URL: {api_url}") # 记录请求的 URL
max_retries = 3
@@ -89,7 +89,7 @@ class LLMRequestOff:
}
# 发送请求到完整的 chat/completions 端点
api_url = f"{self.base_url.rstrip('/')}/chat/completions"
api_url = f"{self.base_url.rstrip('/')}/chat/completions" # type: ignore
logger.info(f"Request URL: {api_url}") # 记录请求的 URL
max_retries = 3

View File

@@ -83,8 +83,8 @@ class PersonalityEvaluatorDirect:
def __init__(self):
self.personality_traits = {"开放性": 0, "严谨性": 0, "外向性": 0, "宜人性": 0, "神经质": 0}
self.scenarios = []
self.final_scores = {"开放性": 0, "严谨性": 0, "外向性": 0, "宜人性": 0, "神经质": 0}
self.dimension_counts = {trait: 0 for trait in self.final_scores.keys()}
self.final_scores: Dict[str, float] = {"开放性": 0, "严谨性": 0, "外向性": 0, "宜人性": 0, "神经质": 0}
self.dimension_counts = {trait: 0 for trait in self.final_scores}
# 为每个人格特质获取对应的场景
for trait in PERSONALITY_SCENES:
@@ -119,8 +119,7 @@ class PersonalityEvaluatorDirect:
# 构建维度描述
dimension_descriptions = []
for dim in dimensions:
desc = FACTOR_DESCRIPTIONS.get(dim, "")
if desc:
if desc := FACTOR_DESCRIPTIONS.get(dim, ""):
dimension_descriptions.append(f"- {dim}{desc}")
dimensions_text = "\n".join(dimension_descriptions)

View File

@@ -153,14 +153,14 @@ class MainSystem:
while True:
await asyncio.sleep(global_config.memory.memory_build_interval)
logger.info("正在进行记忆构建")
await self.hippocampus_manager.build_memory()
await self.hippocampus_manager.build_memory() # type: ignore
async def forget_memory_task(self):
"""记忆遗忘任务"""
while True:
await asyncio.sleep(global_config.memory.forget_memory_interval)
logger.info("[记忆遗忘] 开始遗忘记忆...")
await self.hippocampus_manager.forget_memory(percentage=global_config.memory.memory_forget_percentage)
await self.hippocampus_manager.forget_memory(percentage=global_config.memory.memory_forget_percentage) # type: ignore
logger.info("[记忆遗忘] 记忆遗忘完成")
async def consolidate_memory_task(self):
@@ -168,7 +168,7 @@ class MainSystem:
while True:
await asyncio.sleep(global_config.memory.consolidate_memory_interval)
logger.info("[记忆整合] 开始整合记忆...")
await self.hippocampus_manager.consolidate_memory()
await self.hippocampus_manager.consolidate_memory() # type: ignore
logger.info("[记忆整合] 记忆整合完成")
@staticmethod

View File

@@ -49,6 +49,9 @@ class ChatMood:
chat_manager = get_chat_manager()
self.chat_stream = chat_manager.get_stream(self.chat_id)
if not self.chat_stream:
raise ValueError(f"Chat stream for chat_id {chat_id} not found")
self.log_prefix = f"[{self.chat_stream.group_info.group_name if self.chat_stream.group_info else self.chat_stream.user_info.user_nickname}]"

View File

@@ -26,7 +26,7 @@ SEGMENT_CLEANUP_CONFIG = {
"cleanup_interval_hours": 0.5, # 清理间隔(小时)
}
MAX_MESSAGE_COUNT = 80 / global_config.relationship.relation_frequency
MAX_MESSAGE_COUNT = int(80 / global_config.relationship.relation_frequency)
class RelationshipBuilder:

View File

@@ -61,7 +61,7 @@ __all__ = [
"ConfigField",
# 工具函数
"ManifestValidator",
"ManifestGenerator",
"validate_plugin_manifest",
"generate_plugin_manifest",
# "ManifestGenerator",
# "validate_plugin_manifest",
# "generate_plugin_manifest",
]

View File

@@ -111,7 +111,7 @@ async def _send_to_target(
is_head=True,
is_emoji=(message_type == "emoji"),
thinking_start_time=current_time,
reply_to = reply_to_platform_id
reply_to=reply_to_platform_id,
)
# 发送消息
@@ -137,6 +137,7 @@ async def _send_to_target(
async def _find_reply_message(target_stream, reply_to: str) -> Optional[MessageRecv]:
# sourcery skip: inline-variable, use-named-expression
"""查找要回复的消息
Args:
@@ -184,14 +185,11 @@ async def _find_reply_message(target_stream, reply_to: str) -> Optional[MessageR
# 检查是否有 回复<aaa:bbb> 字段
reply_pattern = r"回复<([^:<>]+):([^:<>]+)>"
match = re.search(reply_pattern, translate_text)
if match:
if match := re.search(reply_pattern, translate_text):
aaa = match.group(1)
bbb = match.group(2)
reply_person_id = get_person_info_manager().get_person_id(platform, bbb)
reply_person_name = await get_person_info_manager().get_value(reply_person_id, "person_name")
if not reply_person_name:
reply_person_name = aaa
reply_person_name = await get_person_info_manager().get_value(reply_person_id, "person_name") or aaa
# 在内容前加上回复信息
translate_text = re.sub(reply_pattern, f"回复 {reply_person_name}", translate_text, count=1)
@@ -206,9 +204,7 @@ async def _find_reply_message(target_stream, reply_to: str) -> Optional[MessageR
aaa = m.group(1)
bbb = m.group(2)
at_person_id = get_person_info_manager().get_person_id(platform, bbb)
at_person_name = await get_person_info_manager().get_value(at_person_id, "person_name")
if not at_person_name:
at_person_name = aaa
at_person_name = await get_person_info_manager().get_value(at_person_id, "person_name") or aaa
new_content += f"@{at_person_name}"
last_end = m.end()
new_content += translate_text[last_end:]
@@ -370,7 +366,14 @@ async def custom_to_stream(
bool: 是否发送成功
"""
return await _send_to_target(
message_type, content, stream_id, display_message, typing, reply_to, storage_message, show_log
message_type,
content,
stream_id,
display_message,
typing,
reply_to,
storage_message=storage_message,
show_log=show_log,
)
@@ -396,7 +399,7 @@ async def text_to_group(
"""
stream_id = get_chat_manager().get_stream_id(platform, group_id, True)
return await _send_to_target("text", text, stream_id, "", typing, reply_to, storage_message)
return await _send_to_target("text", text, stream_id, "", typing, reply_to, storage_message=storage_message)
async def text_to_user(
@@ -420,7 +423,7 @@ async def text_to_user(
bool: 是否发送成功
"""
stream_id = get_chat_manager().get_stream_id(platform, user_id, False)
return await _send_to_target("text", text, stream_id, "", typing, reply_to, storage_message)
return await _send_to_target("text", text, stream_id, "", typing, reply_to, storage_message=storage_message)
async def emoji_to_group(emoji_base64: str, group_id: str, platform: str = "qq", storage_message: bool = True) -> bool:
@@ -543,7 +546,9 @@ async def custom_to_group(
bool: 是否发送成功
"""
stream_id = get_chat_manager().get_stream_id(platform, group_id, True)
return await _send_to_target(message_type, content, stream_id, display_message, typing, reply_to, storage_message)
return await _send_to_target(
message_type, content, stream_id, display_message, typing, reply_to, storage_message=storage_message
)
async def custom_to_user(
@@ -571,7 +576,9 @@ async def custom_to_user(
bool: 是否发送成功
"""
stream_id = get_chat_manager().get_stream_id(platform, user_id, False)
return await _send_to_target(message_type, content, stream_id, display_message, typing, reply_to, storage_message)
return await _send_to_target(
message_type, content, stream_id, display_message, typing, reply_to, storage_message=storage_message
)
async def custom_message(
@@ -611,4 +618,6 @@ async def custom_message(
await send_api.custom_message("audio", audio_base64, "123456", True, reply_to="张三:你好")
"""
stream_id = get_chat_manager().get_stream_id(platform, target_id, is_group)
return await _send_to_target(message_type, content, stream_id, display_message, typing, reply_to, storage_message)
return await _send_to_target(
message_type, content, stream_id, display_message, typing, reply_to, storage_message=storage_message
)

View File

@@ -38,7 +38,7 @@ class BaseAction(ABC):
chat_stream: ChatStream,
log_prefix: str = "",
plugin_config: Optional[dict] = None,
action_message: dict = None,
action_message: Optional[dict] = None,
**kwargs,
):
"""初始化Action组件
@@ -63,7 +63,7 @@ class BaseAction(ABC):
self.cycle_timers = cycle_timers
self.thinking_id = thinking_id
self.log_prefix = log_prefix
# 保存插件配置
self.plugin_config = plugin_config or {}
@@ -92,10 +92,10 @@ class BaseAction(ABC):
self.chat_stream = chat_stream or kwargs.get("chat_stream")
self.chat_id = self.chat_stream.stream_id
self.platform = getattr(self.chat_stream, "platform", None)
# 初始化基础信息(带类型注解)
self.action_message = action_message
self.group_id = None
self.group_name = None
self.user_id = None
@@ -103,15 +103,17 @@ class BaseAction(ABC):
self.is_group = False
self.target_id = None
self.has_action_message = False
if self.action_message:
self.has_action_message = True
else:
self.action_message = {}
if self.has_action_message:
if self.action_name != "no_reply":
self.group_id = str(self.action_message.get("chat_info_group_id", None))
self.group_name = self.action_message.get("chat_info_group_name", None)
self.user_id = str(self.action_message.get("user_id", None))
self.user_nickname = self.action_message.get("user_nickname", None)
if self.group_id:
@@ -132,8 +134,6 @@ class BaseAction(ABC):
self.is_group = False
self.target_id = self.user_id
logger.debug(f"{self.log_prefix} Action组件初始化完成")
logger.info(
f"{self.log_prefix} 聊天信息: 类型={'群聊' if self.is_group else '私聊'}, 平台={self.platform}, 目标={self.target_id}"
@@ -199,7 +199,9 @@ class BaseAction(ABC):
logger.error(f"{self.log_prefix} 等待新消息时发生错误: {e}")
return False, f"等待新消息失败: {str(e)}"
async def send_text(self, content: str, reply_to: str = "", reply_to_platform_id: str = "", typing: bool = False) -> bool:
async def send_text(
self, content: str, reply_to: str = "", reply_to_platform_id: str = "", typing: bool = False
) -> bool:
"""发送文本消息
Args:
@@ -299,7 +301,7 @@ class BaseAction(ABC):
)
async def send_command(
self, command_name: str, args: dict = None, display_message: str = None, storage_message: bool = True
self, command_name: str, args: Optional[dict] = None, display_message: str = "", storage_message: bool = True
) -> bool:
"""发送命令消息

View File

@@ -135,7 +135,7 @@ class BaseCommand(ABC):
)
async def send_command(
self, command_name: str, args: dict = None, display_message: str = "", storage_message: bool = True
self, command_name: str, args: Optional[dict] = None, display_message: str = "", storage_message: bool = True
) -> bool:
"""发送命令消息

View File

@@ -346,67 +346,67 @@ class ComponentRegistry:
# === 状态管理方法 ===
def enable_component(self, component_name: str, component_type: ComponentType = None) -> bool:
# -------------------------------- NEED REFACTORING --------------------------------
# -------------------------------- LOGIC ERROR -------------------------------------
"""启用组件,支持命名空间解析"""
# 首先尝试找到正确的命名空间化名称
component_info = self.get_component_info(component_name, component_type)
if not component_info:
return False
# def enable_component(self, component_name: str, component_type: ComponentType = None) -> bool:
# # -------------------------------- NEED REFACTORING --------------------------------
# # -------------------------------- LOGIC ERROR -------------------------------------
# """启用组件,支持命名空间解析"""
# # 首先尝试找到正确的命名空间化名称
# component_info = self.get_component_info(component_name, component_type)
# if not component_info:
# return False
# 根据组件类型构造正确的命名空间化名称
if component_info.component_type == ComponentType.ACTION:
namespaced_name = f"action.{component_name}" if "." not in component_name else component_name
elif component_info.component_type == ComponentType.COMMAND:
namespaced_name = f"command.{component_name}" if "." not in component_name else component_name
else:
namespaced_name = (
f"{component_info.component_type.value}.{component_name}"
if "." not in component_name
else component_name
)
# # 根据组件类型构造正确的命名空间化名称
# if component_info.component_type == ComponentType.ACTION:
# namespaced_name = f"action.{component_name}" if "." not in component_name else component_name
# elif component_info.component_type == ComponentType.COMMAND:
# namespaced_name = f"command.{component_name}" if "." not in component_name else component_name
# else:
# namespaced_name = (
# f"{component_info.component_type.value}.{component_name}"
# if "." not in component_name
# else component_name
# )
if namespaced_name in self._components:
self._components[namespaced_name].enabled = True
# 如果是Action更新默认动作集
# ---- HERE ----
# if isinstance(component_info, ActionInfo):
# self._action_descriptions[component_name] = component_info.description
logger.debug(f"已启用组件: {component_name} -> {namespaced_name}")
return True
return False
# if namespaced_name in self._components:
# self._components[namespaced_name].enabled = True
# # 如果是Action更新默认动作集
# # ---- HERE ----
# # if isinstance(component_info, ActionInfo):
# # self._action_descriptions[component_name] = component_info.description
# logger.debug(f"已启用组件: {component_name} -> {namespaced_name}")
# return True
# return False
def disable_component(self, component_name: str, component_type: ComponentType = None) -> bool:
# -------------------------------- NEED REFACTORING --------------------------------
# -------------------------------- LOGIC ERROR -------------------------------------
"""禁用组件,支持命名空间解析"""
# 首先尝试找到正确的命名空间化名称
component_info = self.get_component_info(component_name, component_type)
if not component_info:
return False
# def disable_component(self, component_name: str, component_type: ComponentType = None) -> bool:
# # -------------------------------- NEED REFACTORING --------------------------------
# # -------------------------------- LOGIC ERROR -------------------------------------
# """禁用组件,支持命名空间解析"""
# # 首先尝试找到正确的命名空间化名称
# component_info = self.get_component_info(component_name, component_type)
# if not component_info:
# return False
# 根据组件类型构造正确的命名空间化名称
if component_info.component_type == ComponentType.ACTION:
namespaced_name = f"action.{component_name}" if "." not in component_name else component_name
elif component_info.component_type == ComponentType.COMMAND:
namespaced_name = f"command.{component_name}" if "." not in component_name else component_name
else:
namespaced_name = (
f"{component_info.component_type.value}.{component_name}"
if "." not in component_name
else component_name
)
# # 根据组件类型构造正确的命名空间化名称
# if component_info.component_type == ComponentType.ACTION:
# namespaced_name = f"action.{component_name}" if "." not in component_name else component_name
# elif component_info.component_type == ComponentType.COMMAND:
# namespaced_name = f"command.{component_name}" if "." not in component_name else component_name
# else:
# namespaced_name = (
# f"{component_info.component_type.value}.{component_name}"
# if "." not in component_name
# else component_name
# )
if namespaced_name in self._components:
self._components[namespaced_name].enabled = False
# 如果是Action从默认动作集中移除
# ---- HERE ----
# if component_name in self._action_descriptions:
# del self._action_descriptions[component_name]
logger.debug(f"已禁用组件: {component_name} -> {namespaced_name}")
return True
return False
# if namespaced_name in self._components:
# self._components[namespaced_name].enabled = False
# # 如果是Action从默认动作集中移除
# # ---- HERE ----
# # if component_name in self._action_descriptions:
# # del self._action_descriptions[component_name]
# logger.debug(f"已禁用组件: {component_name} -> {namespaced_name}")
# return True
# return False
def get_registry_stats(self) -> Dict[str, Any]:
"""获取注册中心统计信息"""

View File

@@ -7,7 +7,7 @@
import subprocess
import sys
import importlib
from typing import List, Dict, Tuple
from typing import List, Dict, Tuple, Any
from src.common.logger import get_logger
from src.plugin_system.base.component_types import PythonDependency
@@ -176,7 +176,7 @@ class DependencyManager:
logger.error(f"生成requirements文件失败: {str(e)}")
return False
def get_install_summary(self) -> Dict[str, any]:
def get_install_summary(self) -> Dict[str, Any]:
"""获取安装摘要"""
return {
"install_log": self.install_log.copy(),

View File

@@ -197,29 +197,29 @@ class PluginManager:
"""获取所有启用的插件信息"""
return list(component_registry.get_enabled_plugins().values())
def enable_plugin(self, plugin_name: str) -> bool:
# -------------------------------- NEED REFACTORING --------------------------------
"""启用插件"""
if plugin_info := component_registry.get_plugin_info(plugin_name):
plugin_info.enabled = True
# 启用插件的所有组件
for component in plugin_info.components:
component_registry.enable_component(component.name)
logger.debug(f"已启用插件: {plugin_name}")
return True
return False
# def enable_plugin(self, plugin_name: str) -> bool:
# # -------------------------------- NEED REFACTORING --------------------------------
# """启用插件"""
# if plugin_info := component_registry.get_plugin_info(plugin_name):
# plugin_info.enabled = True
# # 启用插件的所有组件
# for component in plugin_info.components:
# component_registry.enable_component(component.name)
# logger.debug(f"已启用插件: {plugin_name}")
# return True
# return False
def disable_plugin(self, plugin_name: str) -> bool:
# -------------------------------- NEED REFACTORING --------------------------------
"""禁用插件"""
if plugin_info := component_registry.get_plugin_info(plugin_name):
plugin_info.enabled = False
# 禁用插件的所有组件
for component in plugin_info.components:
component_registry.disable_component(component.name)
logger.debug(f"已禁用插件: {plugin_name}")
return True
return False
# def disable_plugin(self, plugin_name: str) -> bool:
# # -------------------------------- NEED REFACTORING --------------------------------
# """禁用插件"""
# if plugin_info := component_registry.get_plugin_info(plugin_name):
# plugin_info.enabled = False
# # 禁用插件的所有组件
# for component in plugin_info.components:
# component_registry.disable_component(component.name)
# logger.debug(f"已禁用插件: {plugin_name}")
# return True
# return False
def get_plugin_instance(self, plugin_name: str) -> Optional["PluginBase"]:
"""获取插件实例

View File

@@ -28,10 +28,10 @@ class CompareNumbersTool(BaseTool):
Returns:
dict: 工具执行结果
"""
try:
num1 = function_args.get("num1")
num2 = function_args.get("num2")
num1: int | float = function_args.get("num1") # type: ignore
num2: int | float = function_args.get("num2") # type: ignore
try:
if num1 > num2:
result = f"{num1} 大于 {num2}"
elif num1 < num2:

View File

@@ -68,10 +68,10 @@ class RenamePersonTool(BaseTool):
)
result = await person_info_manager.qv_person_name(
person_id=person_id,
user_nickname=user_nickname,
user_cardname=user_cardname,
user_avatar=user_avatar,
request=request_context,
user_nickname=user_nickname, # type: ignore
user_cardname=user_cardname, # type: ignore
user_avatar=user_avatar, # type: ignore
request=request_context, # type: ignore
)
# 3. 处理结果