typing fix
This commit is contained in:
27
bot.py
27
bot.py
@@ -146,7 +146,7 @@ def _calculate_file_hash(file_path: Path, file_type: str) -> str:
|
||||
if not file_path.exists():
|
||||
logger.error(f"{file_type} 文件不存在")
|
||||
raise FileNotFoundError(f"{file_type} 文件不存在")
|
||||
|
||||
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
return hashlib.md5(content.encode("utf-8")).hexdigest()
|
||||
@@ -154,21 +154,21 @@ def _calculate_file_hash(file_path: Path, file_type: str) -> str:
|
||||
|
||||
def _check_agreement_status(file_hash: str, confirm_file: Path, env_var: str) -> tuple[bool, bool]:
|
||||
"""检查协议确认状态
|
||||
|
||||
|
||||
Returns:
|
||||
tuple[bool, bool]: (已确认, 未更新)
|
||||
"""
|
||||
# 检查环境变量确认
|
||||
if file_hash == os.getenv(env_var):
|
||||
return True, False
|
||||
|
||||
|
||||
# 检查确认文件
|
||||
if confirm_file.exists():
|
||||
with open(confirm_file, "r", encoding="utf-8") as f:
|
||||
confirmed_content = f.read()
|
||||
if file_hash == confirmed_content:
|
||||
return True, False
|
||||
|
||||
|
||||
return False, True
|
||||
|
||||
|
||||
@@ -178,7 +178,7 @@ def _prompt_user_confirmation(eula_hash: str, privacy_hash: str) -> None:
|
||||
confirm_logger.critical(
|
||||
f'输入"同意"或"confirmed"或设置环境变量"EULA_AGREE={eula_hash}"和"PRIVACY_AGREE={privacy_hash}"继续运行'
|
||||
)
|
||||
|
||||
|
||||
while True:
|
||||
user_input = input().strip().lower()
|
||||
if user_input in ["同意", "confirmed"]:
|
||||
@@ -186,13 +186,12 @@ def _prompt_user_confirmation(eula_hash: str, privacy_hash: str) -> None:
|
||||
confirm_logger.critical('请输入"同意"或"confirmed"以继续运行')
|
||||
|
||||
|
||||
def _save_confirmations(eula_updated: bool, privacy_updated: bool,
|
||||
eula_hash: str, privacy_hash: str) -> None:
|
||||
def _save_confirmations(eula_updated: bool, privacy_updated: bool, eula_hash: str, privacy_hash: str) -> None:
|
||||
"""保存用户确认结果"""
|
||||
if eula_updated:
|
||||
logger.info(f"更新EULA确认文件{eula_hash}")
|
||||
Path("eula.confirmed").write_text(eula_hash, encoding="utf-8")
|
||||
|
||||
|
||||
if privacy_updated:
|
||||
logger.info(f"更新隐私条款确认文件{privacy_hash}")
|
||||
Path("privacy.confirmed").write_text(privacy_hash, encoding="utf-8")
|
||||
@@ -203,19 +202,17 @@ def check_eula():
|
||||
# 计算文件哈希值
|
||||
eula_hash = _calculate_file_hash(Path("EULA.md"), "EULA.md")
|
||||
privacy_hash = _calculate_file_hash(Path("PRIVACY.md"), "PRIVACY.md")
|
||||
|
||||
|
||||
# 检查确认状态
|
||||
eula_confirmed, eula_updated = _check_agreement_status(
|
||||
eula_hash, Path("eula.confirmed"), "EULA_AGREE"
|
||||
)
|
||||
eula_confirmed, eula_updated = _check_agreement_status(eula_hash, Path("eula.confirmed"), "EULA_AGREE")
|
||||
privacy_confirmed, privacy_updated = _check_agreement_status(
|
||||
privacy_hash, Path("privacy.confirmed"), "PRIVACY_AGREE"
|
||||
)
|
||||
|
||||
|
||||
# 早期返回:如果都已确认且未更新
|
||||
if eula_confirmed and privacy_confirmed:
|
||||
return
|
||||
|
||||
|
||||
# 如果有更新,需要重新确认
|
||||
if eula_updated or privacy_updated:
|
||||
_prompt_user_confirmation(eula_hash, privacy_hash)
|
||||
@@ -225,7 +222,7 @@ def check_eula():
|
||||
def raw_main():
|
||||
# 利用 TZ 环境变量设定程序工作的时区
|
||||
if platform.system().lower() != "windows":
|
||||
time.tzset()
|
||||
time.tzset() # type: ignore
|
||||
|
||||
check_eula()
|
||||
logger.info("检查EULA和隐私条款完成")
|
||||
|
||||
@@ -107,11 +107,12 @@ class ExpressionLearner:
|
||||
last_active_time = expr.get("last_active_time", time.time())
|
||||
# 查重:同chat_id+type+situation+style
|
||||
from src.common.database.database_model import Expression
|
||||
|
||||
query = Expression.select().where(
|
||||
(Expression.chat_id == chat_id) &
|
||||
(Expression.type == type_str) &
|
||||
(Expression.situation == situation) &
|
||||
(Expression.style == style_val)
|
||||
(Expression.chat_id == chat_id)
|
||||
& (Expression.type == type_str)
|
||||
& (Expression.situation == situation)
|
||||
& (Expression.style == style_val)
|
||||
)
|
||||
if query.exists():
|
||||
expr_obj = query.get()
|
||||
@@ -125,7 +126,7 @@ class ExpressionLearner:
|
||||
count=count,
|
||||
last_active_time=last_active_time,
|
||||
chat_id=chat_id,
|
||||
type=type_str
|
||||
type=type_str,
|
||||
)
|
||||
logger.info(f"已迁移 {expr_file} 到数据库")
|
||||
except Exception as e:
|
||||
@@ -149,24 +150,28 @@ class ExpressionLearner:
|
||||
# 直接从数据库查询
|
||||
style_query = Expression.select().where((Expression.chat_id == chat_id) & (Expression.type == "style"))
|
||||
for expr in style_query:
|
||||
learnt_style_expressions.append({
|
||||
"situation": expr.situation,
|
||||
"style": expr.style,
|
||||
"count": expr.count,
|
||||
"last_active_time": expr.last_active_time,
|
||||
"source_id": chat_id,
|
||||
"type": "style"
|
||||
})
|
||||
learnt_style_expressions.append(
|
||||
{
|
||||
"situation": expr.situation,
|
||||
"style": expr.style,
|
||||
"count": expr.count,
|
||||
"last_active_time": expr.last_active_time,
|
||||
"source_id": chat_id,
|
||||
"type": "style",
|
||||
}
|
||||
)
|
||||
grammar_query = Expression.select().where((Expression.chat_id == chat_id) & (Expression.type == "grammar"))
|
||||
for expr in grammar_query:
|
||||
learnt_grammar_expressions.append({
|
||||
"situation": expr.situation,
|
||||
"style": expr.style,
|
||||
"count": expr.count,
|
||||
"last_active_time": expr.last_active_time,
|
||||
"source_id": chat_id,
|
||||
"type": "grammar"
|
||||
})
|
||||
learnt_grammar_expressions.append(
|
||||
{
|
||||
"situation": expr.situation,
|
||||
"style": expr.style,
|
||||
"count": expr.count,
|
||||
"last_active_time": expr.last_active_time,
|
||||
"source_id": chat_id,
|
||||
"type": "grammar",
|
||||
}
|
||||
)
|
||||
return learnt_style_expressions, learnt_grammar_expressions
|
||||
|
||||
def is_similar(self, s1: str, s2: str) -> bool:
|
||||
@@ -213,14 +218,16 @@ class ExpressionLearner:
|
||||
logger.error(f"全局衰减{type}表达方式失败: {e}")
|
||||
continue
|
||||
|
||||
learnt_style: Optional[List[Tuple[str, str, str]]] = []
|
||||
learnt_grammar: Optional[List[Tuple[str, str, str]]] = []
|
||||
# 学习新的表达方式(这里会进行局部衰减)
|
||||
for _ in range(3):
|
||||
learnt_style: Optional[List[Tuple[str, str, str]]] = await self.learn_and_store(type="style", num=25)
|
||||
learnt_style = await self.learn_and_store(type="style", num=25)
|
||||
if not learnt_style:
|
||||
return [], []
|
||||
|
||||
for _ in range(1):
|
||||
learnt_grammar: Optional[List[Tuple[str, str, str]]] = await self.learn_and_store(type="grammar", num=10)
|
||||
learnt_grammar = await self.learn_and_store(type="grammar", num=10)
|
||||
if not learnt_grammar:
|
||||
return [], []
|
||||
|
||||
@@ -321,10 +328,10 @@ class ExpressionLearner:
|
||||
for new_expr in expr_list:
|
||||
# 查找是否已存在相似表达方式
|
||||
query = Expression.select().where(
|
||||
(Expression.chat_id == chat_id) &
|
||||
(Expression.type == type) &
|
||||
(Expression.situation == new_expr["situation"]) &
|
||||
(Expression.style == new_expr["style"])
|
||||
(Expression.chat_id == chat_id)
|
||||
& (Expression.type == type)
|
||||
& (Expression.situation == new_expr["situation"])
|
||||
& (Expression.style == new_expr["style"])
|
||||
)
|
||||
if query.exists():
|
||||
expr_obj = query.get()
|
||||
@@ -342,13 +349,17 @@ class ExpressionLearner:
|
||||
count=1,
|
||||
last_active_time=current_time,
|
||||
chat_id=chat_id,
|
||||
type=type
|
||||
type=type,
|
||||
)
|
||||
# 限制最大数量
|
||||
exprs = list(Expression.select().where((Expression.chat_id == chat_id) & (Expression.type == type)).order_by(Expression.count.asc()))
|
||||
exprs = list(
|
||||
Expression.select()
|
||||
.where((Expression.chat_id == chat_id) & (Expression.type == type))
|
||||
.order_by(Expression.count.asc())
|
||||
)
|
||||
if len(exprs) > MAX_EXPRESSION_COUNT:
|
||||
# 删除count最小的多余表达方式
|
||||
for expr in exprs[:len(exprs) - MAX_EXPRESSION_COUNT]:
|
||||
for expr in exprs[: len(exprs) - MAX_EXPRESSION_COUNT]:
|
||||
expr.delete_instance()
|
||||
return learnt_expressions
|
||||
|
||||
|
||||
@@ -9,51 +9,49 @@ from src.common.logger import get_logger
|
||||
import traceback
|
||||
|
||||
from src.config.config import global_config
|
||||
from src.common.database.database_model import Memory # Peewee Models导入
|
||||
from src.common.database.database_model import Memory # Peewee Models导入
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class MemoryItem:
|
||||
def __init__(self,memory_id:str,chat_id:str,memory_text:str,keywords:list[str]):
|
||||
def __init__(self, memory_id: str, chat_id: str, memory_text: str, keywords: list[str]):
|
||||
self.memory_id = memory_id
|
||||
self.chat_id = chat_id
|
||||
self.memory_text:str = memory_text
|
||||
self.keywords:list[str] = keywords
|
||||
self.create_time:float = time.time()
|
||||
self.last_view_time:float = time.time()
|
||||
|
||||
self.memory_text: str = memory_text
|
||||
self.keywords: list[str] = keywords
|
||||
self.create_time: float = time.time()
|
||||
self.last_view_time: float = time.time()
|
||||
|
||||
|
||||
class MemoryManager:
|
||||
def __init__(self):
|
||||
# self.memory_items:list[MemoryItem] = []
|
||||
pass
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class InstantMemory:
|
||||
def __init__(self,chat_id):
|
||||
self.chat_id = chat_id
|
||||
def __init__(self, chat_id):
|
||||
self.chat_id = chat_id
|
||||
self.last_view_time = time.time()
|
||||
self.summary_model = LLMRequest(
|
||||
model=global_config.model.memory,
|
||||
temperature=0.5,
|
||||
request_type="memory.summary",
|
||||
)
|
||||
|
||||
async def if_need_build(self,text):
|
||||
|
||||
async def if_need_build(self, text):
|
||||
prompt = f"""
|
||||
请判断以下内容中是否有值得记忆的信息,如果有,请输出1,否则输出0
|
||||
{text}
|
||||
请只输出1或0就好
|
||||
"""
|
||||
|
||||
|
||||
try:
|
||||
response,_ = await self.summary_model.generate_response_async(prompt)
|
||||
response, _ = await self.summary_model.generate_response_async(prompt)
|
||||
print(prompt)
|
||||
print(response)
|
||||
|
||||
|
||||
|
||||
if "1" in response:
|
||||
return True
|
||||
else:
|
||||
@@ -61,8 +59,8 @@ class InstantMemory:
|
||||
except Exception as e:
|
||||
logger.error(f"判断是否需要记忆出现错误:{str(e)} {traceback.format_exc()}")
|
||||
return False
|
||||
|
||||
async def build_memory(self,text):
|
||||
|
||||
async def build_memory(self, text):
|
||||
prompt = f"""
|
||||
以下内容中存在值得记忆的信息,请你从中总结出一段值得记忆的信息,并输出
|
||||
{text}
|
||||
@@ -73,7 +71,7 @@ class InstantMemory:
|
||||
}}
|
||||
"""
|
||||
try:
|
||||
response,_ = await self.summary_model.generate_response_async(prompt)
|
||||
response, _ = await self.summary_model.generate_response_async(prompt)
|
||||
print(prompt)
|
||||
print(response)
|
||||
if not response:
|
||||
@@ -81,53 +79,53 @@ class InstantMemory:
|
||||
try:
|
||||
repaired = repair_json(response)
|
||||
result = json.loads(repaired)
|
||||
memory_text = result.get('memory_text', '')
|
||||
keywords = result.get('keywords', '')
|
||||
memory_text = result.get("memory_text", "")
|
||||
keywords = result.get("keywords", "")
|
||||
if isinstance(keywords, str):
|
||||
keywords_list = [k.strip() for k in keywords.split('/') if k.strip()]
|
||||
keywords_list = [k.strip() for k in keywords.split("/") if k.strip()]
|
||||
elif isinstance(keywords, list):
|
||||
keywords_list = keywords
|
||||
else:
|
||||
keywords_list = []
|
||||
return {'memory_text': memory_text, 'keywords': keywords_list}
|
||||
return {"memory_text": memory_text, "keywords": keywords_list}
|
||||
except Exception as parse_e:
|
||||
logger.error(f"解析记忆json失败:{str(parse_e)} {traceback.format_exc()}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"构建记忆出现错误:{str(e)} {traceback.format_exc()}")
|
||||
return None
|
||||
|
||||
|
||||
async def create_and_store_memory(self,text):
|
||||
async def create_and_store_memory(self, text):
|
||||
if_need = await self.if_need_build(text)
|
||||
if if_need:
|
||||
logger.info(f"需要记忆:{text}")
|
||||
memory = await self.build_memory(text)
|
||||
if memory and memory.get('memory_text'):
|
||||
memory = await self.build_memory(text)
|
||||
if memory and memory.get("memory_text"):
|
||||
memory_id = f"{self.chat_id}_{time.time()}"
|
||||
memory_item = MemoryItem(
|
||||
memory_id=memory_id,
|
||||
chat_id=self.chat_id,
|
||||
memory_text=memory['memory_text'],
|
||||
keywords=memory.get('keywords', [])
|
||||
memory_text=memory["memory_text"],
|
||||
keywords=memory.get("keywords", []),
|
||||
)
|
||||
await self.store_memory(memory_item)
|
||||
else:
|
||||
logger.info(f"不需要记忆:{text}")
|
||||
|
||||
async def store_memory(self,memory_item:MemoryItem):
|
||||
|
||||
async def store_memory(self, memory_item: MemoryItem):
|
||||
memory = Memory(
|
||||
memory_id=memory_item.memory_id,
|
||||
chat_id=memory_item.chat_id,
|
||||
memory_text=memory_item.memory_text,
|
||||
keywords=memory_item.keywords,
|
||||
create_time=memory_item.create_time,
|
||||
last_view_time=memory_item.last_view_time
|
||||
last_view_time=memory_item.last_view_time,
|
||||
)
|
||||
memory.save()
|
||||
|
||||
async def get_memory(self,target:str):
|
||||
|
||||
async def get_memory(self, target: str):
|
||||
from json_repair import repair_json
|
||||
|
||||
prompt = f"""
|
||||
请根据以下发言内容,判断是否需要提取记忆
|
||||
{target}
|
||||
@@ -144,7 +142,7 @@ class InstantMemory:
|
||||
请只输出json格式,不要输出其他多余内容
|
||||
"""
|
||||
try:
|
||||
response,_ = await self.summary_model.generate_response_async(prompt)
|
||||
response, _ = await self.summary_model.generate_response_async(prompt)
|
||||
print(prompt)
|
||||
print(response)
|
||||
if not response:
|
||||
@@ -153,15 +151,15 @@ class InstantMemory:
|
||||
repaired = repair_json(response)
|
||||
result = json.loads(repaired)
|
||||
# 解析keywords
|
||||
keywords = result.get('keywords', '')
|
||||
keywords = result.get("keywords", "")
|
||||
if isinstance(keywords, str):
|
||||
keywords_list = [k.strip() for k in keywords.split('/') if k.strip()]
|
||||
keywords_list = [k.strip() for k in keywords.split("/") if k.strip()]
|
||||
elif isinstance(keywords, list):
|
||||
keywords_list = keywords
|
||||
else:
|
||||
keywords_list = []
|
||||
# 解析time为时间段
|
||||
time_str = result.get('time', '').strip()
|
||||
time_str = result.get("time", "").strip()
|
||||
start_time, end_time = self._parse_time_range(time_str)
|
||||
logger.info(f"start_time: {start_time}, end_time: {end_time}")
|
||||
# 检索包含关键词的记忆
|
||||
@@ -170,16 +168,15 @@ class InstantMemory:
|
||||
start_ts = start_time.timestamp()
|
||||
end_ts = end_time.timestamp()
|
||||
query = Memory.select().where(
|
||||
(Memory.chat_id == self.chat_id) &
|
||||
(Memory.create_time >= start_ts) &
|
||||
(Memory.create_time < end_ts)
|
||||
(Memory.chat_id == self.chat_id)
|
||||
& (Memory.create_time >= start_ts) # type: ignore
|
||||
& (Memory.create_time < end_ts) # type: ignore
|
||||
)
|
||||
else:
|
||||
query = Memory.select().where(Memory.chat_id == self.chat_id)
|
||||
|
||||
|
||||
for mem in query:
|
||||
#对每条记忆
|
||||
# 对每条记忆
|
||||
mem_keywords = mem.keywords or []
|
||||
parsed = ast.literal_eval(mem_keywords)
|
||||
if isinstance(parsed, list):
|
||||
@@ -212,6 +209,7 @@ class InstantMemory:
|
||||
- 空字符串:返回(None, None)
|
||||
"""
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
now = datetime.now()
|
||||
if not time_str:
|
||||
return 0, now
|
||||
@@ -251,8 +249,8 @@ class InstantMemory:
|
||||
if m:
|
||||
months = int(m.group(1))
|
||||
# 近似每月30天
|
||||
start = (now - timedelta(days=months*30)).replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
start = (now - timedelta(days=months * 30)).replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
end = start + timedelta(days=1)
|
||||
return start, end
|
||||
# 其他无法解析
|
||||
return 0, now
|
||||
return 0, now
|
||||
|
||||
@@ -30,7 +30,7 @@ class ChatMessageContext:
|
||||
def get_template_name(self) -> Optional[str]:
|
||||
"""获取模板名称"""
|
||||
if self.message.message_info.template_info and not self.message.message_info.template_info.template_default:
|
||||
return self.message.message_info.template_info.template_name
|
||||
return self.message.message_info.template_info.template_name # type: ignore
|
||||
return None
|
||||
|
||||
def get_last_message(self) -> "MessageRecv":
|
||||
|
||||
@@ -107,9 +107,9 @@ class MessageRecv(Message):
|
||||
self.is_picid = False
|
||||
self.has_picid = False
|
||||
self.is_mentioned = None
|
||||
|
||||
|
||||
self.is_command = False
|
||||
|
||||
|
||||
self.priority_mode = "interest"
|
||||
self.priority_info = None
|
||||
self.interest_value: float = None # type: ignore
|
||||
@@ -181,6 +181,7 @@ class MessageRecv(Message):
|
||||
logger.error(f"处理消息段失败: {str(e)}, 类型: {segment.type}, 数据: {segment.data}")
|
||||
return f"[处理失败的{segment.type}消息]"
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageRecvS4U(MessageRecv):
|
||||
def __init__(self, message_dict: dict[str, Any]):
|
||||
@@ -194,10 +195,10 @@ class MessageRecvS4U(MessageRecv):
|
||||
self.superchat_price = None
|
||||
self.superchat_message_text = None
|
||||
self.is_screen = False
|
||||
|
||||
|
||||
async def process(self) -> None:
|
||||
self.processed_plain_text = await self._process_message_segments(self.message_segment)
|
||||
|
||||
|
||||
async def _process_single_segment(self, segment: Seg) -> str:
|
||||
"""处理单个消息段
|
||||
|
||||
@@ -252,7 +253,7 @@ class MessageRecvS4U(MessageRecv):
|
||||
elif segment.type == "gift":
|
||||
self.is_gift = True
|
||||
# 解析gift_info,格式为"名称:数量"
|
||||
name, count = segment.data.split(":", 1)
|
||||
name, count = segment.data.split(":", 1) # type: ignore
|
||||
self.gift_info = segment.data
|
||||
self.gift_name = name.strip()
|
||||
self.gift_count = int(count.strip())
|
||||
@@ -260,13 +261,15 @@ class MessageRecvS4U(MessageRecv):
|
||||
elif segment.type == "superchat":
|
||||
self.is_superchat = True
|
||||
self.superchat_info = segment.data
|
||||
price,message_text = segment.data.split(":", 1)
|
||||
price, message_text = segment.data.split(":", 1) # type: ignore
|
||||
self.superchat_price = price.strip()
|
||||
self.superchat_message_text = message_text.strip()
|
||||
|
||||
|
||||
self.processed_plain_text = str(self.superchat_message_text)
|
||||
self.processed_plain_text += f"(注意:这是一条超级弹幕信息,价值{self.superchat_price}元,请你认真回复)"
|
||||
|
||||
self.processed_plain_text += (
|
||||
f"(注意:这是一条超级弹幕信息,价值{self.superchat_price}元,请你认真回复)"
|
||||
)
|
||||
|
||||
return self.processed_plain_text
|
||||
elif segment.type == "screen":
|
||||
self.is_screen = True
|
||||
|
||||
@@ -80,7 +80,7 @@ class ActionManager:
|
||||
chat_stream: ChatStream,
|
||||
log_prefix: str,
|
||||
shutting_down: bool = False,
|
||||
action_message: dict = None,
|
||||
action_message: Optional[dict] = None,
|
||||
) -> Optional[BaseAction]:
|
||||
"""
|
||||
创建动作处理器实例
|
||||
|
||||
@@ -252,7 +252,7 @@ def _build_readable_messages_internal(
|
||||
pic_id_mapping: Optional[Dict[str, str]] = None,
|
||||
pic_counter: int = 1,
|
||||
show_pic: bool = True,
|
||||
message_id_list: List[Dict[str, Any]] = None,
|
||||
message_id_list: Optional[List[Dict[str, Any]]] = None,
|
||||
) -> Tuple[str, List[Tuple[float, str, str]], Dict[str, str], int]:
|
||||
"""
|
||||
内部辅助函数,构建可读消息字符串和原始消息详情列表。
|
||||
@@ -615,7 +615,7 @@ def build_readable_actions(actions: List[Dict[str, Any]]) -> str:
|
||||
for action in actions:
|
||||
action_time = action.get("time", current_time)
|
||||
action_name = action.get("action_name", "未知动作")
|
||||
if action_name == "no_action" or action_name == "no_reply":
|
||||
if action_name in ["no_action", "no_reply"]:
|
||||
continue
|
||||
|
||||
action_prompt_display = action.get("action_prompt_display", "无具体内容")
|
||||
@@ -697,7 +697,7 @@ def build_readable_messages(
|
||||
truncate: bool = False,
|
||||
show_actions: bool = False,
|
||||
show_pic: bool = True,
|
||||
message_id_list: List[Dict[str, Any]] = None,
|
||||
message_id_list: Optional[List[Dict[str, Any]]] = None,
|
||||
) -> str: # sourcery skip: extract-method
|
||||
"""
|
||||
将消息列表转换为可读的文本格式。
|
||||
|
||||
@@ -1211,7 +1211,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
f.write(html_template)
|
||||
|
||||
def _generate_focus_tab(self, stat: dict[str, Any]) -> str:
|
||||
# sourcery skip: for-append-to-extend, list-comprehension, use-any
|
||||
# sourcery skip: for-append-to-extend, list-comprehension, use-any, use-named-expression, use-next
|
||||
"""生成Focus统计独立分页的HTML内容"""
|
||||
|
||||
# 为每个时间段准备Focus数据
|
||||
@@ -1559,6 +1559,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
"""
|
||||
|
||||
def _generate_versions_tab(self, stat: dict[str, Any]) -> str:
|
||||
# sourcery skip: use-named-expression, use-next
|
||||
"""生成版本对比独立分页的HTML内容"""
|
||||
|
||||
# 为每个时间段准备版本对比数据
|
||||
@@ -2306,13 +2307,13 @@ class AsyncStatisticOutputTask(AsyncTask):
|
||||
|
||||
# 复用 StatisticOutputTask 的所有方法
|
||||
def _collect_all_statistics(self, now: datetime):
|
||||
return StatisticOutputTask._collect_all_statistics(self, now)
|
||||
return StatisticOutputTask._collect_all_statistics(self, now) # type: ignore
|
||||
|
||||
def _statistic_console_output(self, stats: Dict[str, Any], now: datetime):
|
||||
return StatisticOutputTask._statistic_console_output(self, stats, now)
|
||||
return StatisticOutputTask._statistic_console_output(self, stats, now) # type: ignore
|
||||
|
||||
def _generate_html_report(self, stats: dict[str, Any], now: datetime):
|
||||
return StatisticOutputTask._generate_html_report(self, stats, now)
|
||||
return StatisticOutputTask._generate_html_report(self, stats, now) # type: ignore
|
||||
|
||||
# 其他需要的方法也可以类似复用...
|
||||
@staticmethod
|
||||
@@ -2324,10 +2325,10 @@ class AsyncStatisticOutputTask(AsyncTask):
|
||||
return StatisticOutputTask._collect_online_time_for_period(collect_period, now)
|
||||
|
||||
def _collect_message_count_for_period(self, collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]:
|
||||
return StatisticOutputTask._collect_message_count_for_period(self, collect_period)
|
||||
return StatisticOutputTask._collect_message_count_for_period(self, collect_period) # type: ignore
|
||||
|
||||
def _collect_focus_statistics_for_period(self, collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]:
|
||||
return StatisticOutputTask._collect_focus_statistics_for_period(self, collect_period)
|
||||
return StatisticOutputTask._collect_focus_statistics_for_period(self, collect_period) # type: ignore
|
||||
|
||||
def _process_focus_file_data(
|
||||
self,
|
||||
@@ -2336,10 +2337,10 @@ class AsyncStatisticOutputTask(AsyncTask):
|
||||
collect_period: List[Tuple[str, datetime]],
|
||||
file_time: datetime,
|
||||
):
|
||||
return StatisticOutputTask._process_focus_file_data(self, cycles_data, stats, collect_period, file_time)
|
||||
return StatisticOutputTask._process_focus_file_data(self, cycles_data, stats, collect_period, file_time) # type: ignore
|
||||
|
||||
def _calculate_focus_averages(self, stats: Dict[str, Any]):
|
||||
return StatisticOutputTask._calculate_focus_averages(self, stats)
|
||||
return StatisticOutputTask._calculate_focus_averages(self, stats) # type: ignore
|
||||
|
||||
@staticmethod
|
||||
def _format_total_stat(stats: Dict[str, Any]) -> str:
|
||||
@@ -2347,31 +2348,31 @@ class AsyncStatisticOutputTask(AsyncTask):
|
||||
|
||||
@staticmethod
|
||||
def _format_model_classified_stat(stats: Dict[str, Any]) -> str:
|
||||
return StatisticOutputTask._format_model_classified_stat(stats)
|
||||
return StatisticOutputTask._format_model_classified_stat(stats) # type: ignore
|
||||
|
||||
def _format_chat_stat(self, stats: Dict[str, Any]) -> str:
|
||||
return StatisticOutputTask._format_chat_stat(self, stats)
|
||||
return StatisticOutputTask._format_chat_stat(self, stats) # type: ignore
|
||||
|
||||
def _format_focus_stat(self, stats: Dict[str, Any]) -> str:
|
||||
return StatisticOutputTask._format_focus_stat(self, stats)
|
||||
return StatisticOutputTask._format_focus_stat(self, stats) # type: ignore
|
||||
|
||||
def _generate_chart_data(self, stat: dict[str, Any]) -> dict:
|
||||
return StatisticOutputTask._generate_chart_data(self, stat)
|
||||
return StatisticOutputTask._generate_chart_data(self, stat) # type: ignore
|
||||
|
||||
def _collect_interval_data(self, now: datetime, hours: int, interval_minutes: int) -> dict:
|
||||
return StatisticOutputTask._collect_interval_data(self, now, hours, interval_minutes)
|
||||
return StatisticOutputTask._collect_interval_data(self, now, hours, interval_minutes) # type: ignore
|
||||
|
||||
def _generate_chart_tab(self, chart_data: dict) -> str:
|
||||
return StatisticOutputTask._generate_chart_tab(self, chart_data)
|
||||
return StatisticOutputTask._generate_chart_tab(self, chart_data) # type: ignore
|
||||
|
||||
def _get_chat_display_name_from_id(self, chat_id: str) -> str:
|
||||
return StatisticOutputTask._get_chat_display_name_from_id(self, chat_id)
|
||||
return StatisticOutputTask._get_chat_display_name_from_id(self, chat_id) # type: ignore
|
||||
|
||||
def _generate_focus_tab(self, stat: dict[str, Any]) -> str:
|
||||
return StatisticOutputTask._generate_focus_tab(self, stat)
|
||||
return StatisticOutputTask._generate_focus_tab(self, stat) # type: ignore
|
||||
|
||||
def _generate_versions_tab(self, stat: dict[str, Any]) -> str:
|
||||
return StatisticOutputTask._generate_versions_tab(self, stat)
|
||||
return StatisticOutputTask._generate_versions_tab(self, stat) # type: ignore
|
||||
|
||||
def _convert_defaultdict_to_dict(self, data):
|
||||
return StatisticOutputTask._convert_defaultdict_to_dict(self, data)
|
||||
return StatisticOutputTask._convert_defaultdict_to_dict(self, data) # type: ignore
|
||||
|
||||
@@ -2,14 +2,13 @@ import importlib
|
||||
import asyncio
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Optional
|
||||
from typing import Dict, Optional, Any
|
||||
from rich.traceback import install
|
||||
from dataclasses import dataclass
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.chat.message_receive.chat_stream import ChatStream, GroupInfo
|
||||
from src.chat.message_receive.message import MessageRecv
|
||||
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
|
||||
|
||||
install(extra_lines=3)
|
||||
@@ -54,7 +53,7 @@ class WillingInfo:
|
||||
interested_rate (float): 兴趣度
|
||||
"""
|
||||
|
||||
message: MessageRecv
|
||||
message: Dict[str, Any] # 原始消息数据
|
||||
chat: ChatStream
|
||||
person_info_manager: PersonInfoManager
|
||||
chat_id: str
|
||||
|
||||
@@ -65,7 +65,7 @@ class ChatStreams(BaseModel):
|
||||
# user_cardname 可能为空字符串或不存在,设置 null=True 更具灵活性。
|
||||
user_cardname = TextField(null=True)
|
||||
|
||||
class Meta:
|
||||
class Meta: # type: ignore
|
||||
# 如果 BaseModel.Meta.database 已设置,则此模型将继承该数据库配置。
|
||||
# 如果不使用带有数据库实例的 BaseModel,或者想覆盖它,
|
||||
# 请取消注释并在下面设置数据库实例:
|
||||
@@ -89,7 +89,7 @@ class LLMUsage(BaseModel):
|
||||
status = TextField()
|
||||
timestamp = DateTimeField(index=True) # 更改为 DateTimeField 并添加索引
|
||||
|
||||
class Meta:
|
||||
class Meta: # type: ignore
|
||||
# 如果 BaseModel.Meta.database 已设置,则此模型将继承该数据库配置。
|
||||
# database = db
|
||||
table_name = "llm_usage"
|
||||
@@ -112,7 +112,7 @@ class Emoji(BaseModel):
|
||||
usage_count = IntegerField(default=0) # 使用次数(被使用的次数)
|
||||
last_used_time = FloatField(null=True) # 上次使用时间
|
||||
|
||||
class Meta:
|
||||
class Meta: # type: ignore
|
||||
# database = db # 继承自 BaseModel
|
||||
table_name = "emoji"
|
||||
|
||||
@@ -162,7 +162,8 @@ class Messages(BaseModel):
|
||||
is_emoji = BooleanField(default=False)
|
||||
is_picid = BooleanField(default=False)
|
||||
is_command = BooleanField(default=False)
|
||||
class Meta:
|
||||
|
||||
class Meta: # type: ignore
|
||||
# database = db # 继承自 BaseModel
|
||||
table_name = "messages"
|
||||
|
||||
@@ -186,7 +187,7 @@ class ActionRecords(BaseModel):
|
||||
chat_info_stream_id = TextField()
|
||||
chat_info_platform = TextField()
|
||||
|
||||
class Meta:
|
||||
class Meta: # type: ignore
|
||||
# database = db # 继承自 BaseModel
|
||||
table_name = "action_records"
|
||||
|
||||
@@ -206,7 +207,7 @@ class Images(BaseModel):
|
||||
type = TextField() # 图像类型,例如 "emoji"
|
||||
vlm_processed = BooleanField(default=False) # 是否已经过VLM处理
|
||||
|
||||
class Meta:
|
||||
class Meta: # type: ignore
|
||||
table_name = "images"
|
||||
|
||||
|
||||
@@ -220,7 +221,7 @@ class ImageDescriptions(BaseModel):
|
||||
description = TextField() # 图像的描述
|
||||
timestamp = FloatField() # 时间戳
|
||||
|
||||
class Meta:
|
||||
class Meta: # type: ignore
|
||||
# database = db # 继承自 BaseModel
|
||||
table_name = "image_descriptions"
|
||||
|
||||
@@ -236,7 +237,7 @@ class OnlineTime(BaseModel):
|
||||
start_timestamp = DateTimeField(default=datetime.datetime.now)
|
||||
end_timestamp = DateTimeField(index=True)
|
||||
|
||||
class Meta:
|
||||
class Meta: # type: ignore
|
||||
# database = db # 继承自 BaseModel
|
||||
table_name = "online_time"
|
||||
|
||||
@@ -263,10 +264,11 @@ class PersonInfo(BaseModel):
|
||||
last_know = FloatField(null=True) # 最后一次印象总结时间
|
||||
attitude = IntegerField(null=True, default=50) # 态度,0-100,从非常厌恶到十分喜欢
|
||||
|
||||
class Meta:
|
||||
class Meta: # type: ignore
|
||||
# database = db # 继承自 BaseModel
|
||||
table_name = "person_info"
|
||||
|
||||
|
||||
class Memory(BaseModel):
|
||||
memory_id = TextField(index=True)
|
||||
chat_id = TextField(null=True)
|
||||
@@ -274,10 +276,11 @@ class Memory(BaseModel):
|
||||
keywords = TextField(null=True)
|
||||
create_time = FloatField(null=True)
|
||||
last_view_time = FloatField(null=True)
|
||||
|
||||
class Meta:
|
||||
|
||||
class Meta: # type: ignore
|
||||
table_name = "memory"
|
||||
|
||||
|
||||
class Knowledges(BaseModel):
|
||||
"""
|
||||
用于存储知识库条目的模型。
|
||||
@@ -287,10 +290,11 @@ class Knowledges(BaseModel):
|
||||
embedding = TextField() # 知识内容的嵌入向量,存储为 JSON 字符串的浮点数列表
|
||||
# 可以添加其他元数据字段,如 source, create_time 等
|
||||
|
||||
class Meta:
|
||||
class Meta: # type: ignore
|
||||
# database = db # 继承自 BaseModel
|
||||
table_name = "knowledges"
|
||||
|
||||
|
||||
class Expression(BaseModel):
|
||||
"""
|
||||
用于存储表达风格的模型。
|
||||
@@ -302,10 +306,11 @@ class Expression(BaseModel):
|
||||
last_active_time = FloatField()
|
||||
chat_id = TextField(index=True)
|
||||
type = TextField()
|
||||
|
||||
class Meta:
|
||||
|
||||
class Meta: # type: ignore
|
||||
table_name = "expression"
|
||||
|
||||
|
||||
class ThinkingLog(BaseModel):
|
||||
chat_id = TextField(index=True)
|
||||
trigger_text = TextField(null=True)
|
||||
@@ -326,7 +331,7 @@ class ThinkingLog(BaseModel):
|
||||
# And: import datetime
|
||||
created_at = DateTimeField(default=datetime.datetime.now)
|
||||
|
||||
class Meta:
|
||||
class Meta: # type: ignore
|
||||
table_name = "thinking_logs"
|
||||
|
||||
|
||||
@@ -341,7 +346,7 @@ class GraphNodes(BaseModel):
|
||||
created_time = FloatField() # 创建时间戳
|
||||
last_modified = FloatField() # 最后修改时间戳
|
||||
|
||||
class Meta:
|
||||
class Meta: # type: ignore
|
||||
table_name = "graph_nodes"
|
||||
|
||||
|
||||
@@ -357,7 +362,7 @@ class GraphEdges(BaseModel):
|
||||
created_time = FloatField() # 创建时间戳
|
||||
last_modified = FloatField() # 最后修改时间戳
|
||||
|
||||
class Meta:
|
||||
class Meta: # type: ignore
|
||||
table_name = "graph_edges"
|
||||
|
||||
|
||||
|
||||
@@ -7,13 +7,13 @@ from datetime import datetime
|
||||
|
||||
def get_key_comment(toml_table, key):
|
||||
# 获取key的注释(如果有)
|
||||
if hasattr(toml_table, 'trivia') and hasattr(toml_table.trivia, 'comment'):
|
||||
if hasattr(toml_table, "trivia") and hasattr(toml_table.trivia, "comment"):
|
||||
return toml_table.trivia.comment
|
||||
if hasattr(toml_table, 'value') and isinstance(toml_table.value, dict):
|
||||
if hasattr(toml_table, "value") and isinstance(toml_table.value, dict):
|
||||
item = toml_table.value.get(key)
|
||||
if item is not None and hasattr(item, 'trivia'):
|
||||
if item is not None and hasattr(item, "trivia"):
|
||||
return item.trivia.comment
|
||||
if hasattr(toml_table, 'keys'):
|
||||
if hasattr(toml_table, "keys"):
|
||||
for k in toml_table.keys():
|
||||
if isinstance(k, KeyType) and k.key == key:
|
||||
return k.trivia.comment
|
||||
@@ -36,16 +36,16 @@ def compare_dicts(new, old, path=None, new_comments=None, old_comments=None, log
|
||||
continue
|
||||
if key not in old:
|
||||
comment = get_key_comment(new, key)
|
||||
logs.append(f"新增: {'.'.join(path+[str(key)])} 注释: {comment if comment else '无'}")
|
||||
logs.append(f"新增: {'.'.join(path + [str(key)])} 注释: {comment if comment else '无'}")
|
||||
elif isinstance(new[key], (dict, Table)) and isinstance(old.get(key), (dict, Table)):
|
||||
compare_dicts(new[key], old[key], path+[str(key)], new_comments, old_comments, logs)
|
||||
compare_dicts(new[key], old[key], path + [str(key)], new_comments, old_comments, logs)
|
||||
# 删减项
|
||||
for key in old:
|
||||
if key == "version":
|
||||
continue
|
||||
if key not in new:
|
||||
comment = get_key_comment(old, key)
|
||||
logs.append(f"删减: {'.'.join(path+[str(key)])} 注释: {comment if comment else '无'}")
|
||||
logs.append(f"删减: {'.'.join(path + [str(key)])} 注释: {comment if comment else '无'}")
|
||||
return logs
|
||||
|
||||
|
||||
@@ -95,7 +95,7 @@ def update_config():
|
||||
if old_version and new_version and old_version == new_version:
|
||||
print(f"检测到版本号相同 (v{old_version}),跳过更新")
|
||||
# 如果version相同,恢复旧配置文件并返回
|
||||
shutil.move(old_backup_path, old_config_path)
|
||||
shutil.move(old_backup_path, old_config_path) # type: ignore
|
||||
return
|
||||
else:
|
||||
print(f"检测到版本号不同: 旧版本 v{old_version} -> 新版本 v{new_version}")
|
||||
|
||||
@@ -53,13 +53,13 @@ MMC_VERSION = "0.9.0-snapshot.2"
|
||||
|
||||
def get_key_comment(toml_table, key):
|
||||
# 获取key的注释(如果有)
|
||||
if hasattr(toml_table, 'trivia') and hasattr(toml_table.trivia, 'comment'):
|
||||
if hasattr(toml_table, "trivia") and hasattr(toml_table.trivia, "comment"):
|
||||
return toml_table.trivia.comment
|
||||
if hasattr(toml_table, 'value') and isinstance(toml_table.value, dict):
|
||||
if hasattr(toml_table, "value") and isinstance(toml_table.value, dict):
|
||||
item = toml_table.value.get(key)
|
||||
if item is not None and hasattr(item, 'trivia'):
|
||||
if item is not None and hasattr(item, "trivia"):
|
||||
return item.trivia.comment
|
||||
if hasattr(toml_table, 'keys'):
|
||||
if hasattr(toml_table, "keys"):
|
||||
for k in toml_table.keys():
|
||||
if isinstance(k, KeyType) and k.key == key:
|
||||
return k.trivia.comment
|
||||
@@ -78,16 +78,16 @@ def compare_dicts(new, old, path=None, logs=None):
|
||||
continue
|
||||
if key not in old:
|
||||
comment = get_key_comment(new, key)
|
||||
logs.append(f"新增: {'.'.join(path+[str(key)])} 注释: {comment if comment else '无'}")
|
||||
logs.append(f"新增: {'.'.join(path + [str(key)])} 注释: {comment if comment else '无'}")
|
||||
elif isinstance(new[key], (dict, Table)) and isinstance(old.get(key), (dict, Table)):
|
||||
compare_dicts(new[key], old[key], path+[str(key)], logs)
|
||||
compare_dicts(new[key], old[key], path + [str(key)], logs)
|
||||
# 删减项
|
||||
for key in old:
|
||||
if key == "version":
|
||||
continue
|
||||
if key not in new:
|
||||
comment = get_key_comment(old, key)
|
||||
logs.append(f"删减: {'.'.join(path+[str(key)])} 注释: {comment if comment else '无'}")
|
||||
logs.append(f"删减: {'.'.join(path + [str(key)])} 注释: {comment if comment else '无'}")
|
||||
return logs
|
||||
|
||||
|
||||
@@ -99,6 +99,7 @@ def get_value_by_path(d, path):
|
||||
return None
|
||||
return d
|
||||
|
||||
|
||||
def set_value_by_path(d, path, value):
|
||||
for k in path[:-1]:
|
||||
if k not in d or not isinstance(d[k], dict):
|
||||
@@ -106,6 +107,7 @@ def set_value_by_path(d, path, value):
|
||||
d = d[k]
|
||||
d[path[-1]] = value
|
||||
|
||||
|
||||
def compare_default_values(new, old, path=None, logs=None, changes=None):
|
||||
# 递归比较两个dict,找出默认值变化项
|
||||
if path is None:
|
||||
@@ -119,12 +121,14 @@ def compare_default_values(new, old, path=None, logs=None, changes=None):
|
||||
continue
|
||||
if key in old:
|
||||
if isinstance(new[key], (dict, Table)) and isinstance(old[key], (dict, Table)):
|
||||
compare_default_values(new[key], old[key], path+[str(key)], logs, changes)
|
||||
compare_default_values(new[key], old[key], path + [str(key)], logs, changes)
|
||||
else:
|
||||
# 只要值发生变化就记录
|
||||
if new[key] != old[key]:
|
||||
logs.append(f"默认值变化: {'.'.join(path+[str(key)])} 旧默认值: {old[key]} 新默认值: {new[key]}")
|
||||
changes.append((path+[str(key)], old[key], new[key]))
|
||||
logs.append(
|
||||
f"默认值变化: {'.'.join(path + [str(key)])} 旧默认值: {old[key]} 新默认值: {new[key]}"
|
||||
)
|
||||
changes.append((path + [str(key)], old[key], new[key]))
|
||||
return logs, changes
|
||||
|
||||
|
||||
@@ -148,8 +152,8 @@ def update_config():
|
||||
return None
|
||||
with open(toml_path, "r", encoding="utf-8") as f:
|
||||
doc = tomlkit.load(f)
|
||||
if "inner" in doc and "version" in doc["inner"]:
|
||||
return doc["inner"]["version"]
|
||||
if "inner" in doc and "version" in doc["inner"]: # type: ignore
|
||||
return doc["inner"]["version"] # type: ignore
|
||||
return None
|
||||
|
||||
template_version = get_version_from_toml(template_path)
|
||||
@@ -186,7 +190,9 @@ def update_config():
|
||||
old_value = get_value_by_path(old_config, path)
|
||||
if old_value == old_default:
|
||||
set_value_by_path(old_config, path, new_default)
|
||||
logger.info(f"已自动将配置 {'.'.join(path)} 的值从旧默认值 {old_default} 更新为新默认值 {new_default}")
|
||||
logger.info(
|
||||
f"已自动将配置 {'.'.join(path)} 的值从旧默认值 {old_default} 更新为新默认值 {new_default}"
|
||||
)
|
||||
else:
|
||||
logger.info("未检测到模板默认值变动")
|
||||
# 保存旧配置的变更(后续合并逻辑会用到 old_config)
|
||||
@@ -229,7 +235,9 @@ def update_config():
|
||||
logger.info(f"检测到配置文件版本号相同 (v{old_version}),跳过更新")
|
||||
return
|
||||
else:
|
||||
logger.info(f"\n----------------------------------------\n检测到版本号不同: 旧版本 v{old_version} -> 新版本 v{new_version}\n----------------------------------------")
|
||||
logger.info(
|
||||
f"\n----------------------------------------\n检测到版本号不同: 旧版本 v{old_version} -> 新版本 v{new_version}\n----------------------------------------"
|
||||
)
|
||||
else:
|
||||
logger.info("已有配置文件未检测到版本号,可能是旧版本。将进行更新")
|
||||
|
||||
@@ -321,6 +329,7 @@ class Config(ConfigBase):
|
||||
debug: DebugConfig
|
||||
custom_prompt: CustomPromptConfig
|
||||
|
||||
|
||||
def load_config(config_path: str) -> Config:
|
||||
"""
|
||||
加载配置文件
|
||||
|
||||
@@ -39,7 +39,7 @@ class LLMRequestOff:
|
||||
}
|
||||
|
||||
# 发送请求到完整的 chat/completions 端点
|
||||
api_url = f"{self.base_url.rstrip('/')}/chat/completions"
|
||||
api_url = f"{self.base_url.rstrip('/')}/chat/completions" # type: ignore
|
||||
logger.info(f"Request URL: {api_url}") # 记录请求的 URL
|
||||
|
||||
max_retries = 3
|
||||
@@ -89,7 +89,7 @@ class LLMRequestOff:
|
||||
}
|
||||
|
||||
# 发送请求到完整的 chat/completions 端点
|
||||
api_url = f"{self.base_url.rstrip('/')}/chat/completions"
|
||||
api_url = f"{self.base_url.rstrip('/')}/chat/completions" # type: ignore
|
||||
logger.info(f"Request URL: {api_url}") # 记录请求的 URL
|
||||
|
||||
max_retries = 3
|
||||
|
||||
@@ -83,8 +83,8 @@ class PersonalityEvaluatorDirect:
|
||||
def __init__(self):
|
||||
self.personality_traits = {"开放性": 0, "严谨性": 0, "外向性": 0, "宜人性": 0, "神经质": 0}
|
||||
self.scenarios = []
|
||||
self.final_scores = {"开放性": 0, "严谨性": 0, "外向性": 0, "宜人性": 0, "神经质": 0}
|
||||
self.dimension_counts = {trait: 0 for trait in self.final_scores.keys()}
|
||||
self.final_scores: Dict[str, float] = {"开放性": 0, "严谨性": 0, "外向性": 0, "宜人性": 0, "神经质": 0}
|
||||
self.dimension_counts = {trait: 0 for trait in self.final_scores}
|
||||
|
||||
# 为每个人格特质获取对应的场景
|
||||
for trait in PERSONALITY_SCENES:
|
||||
@@ -119,8 +119,7 @@ class PersonalityEvaluatorDirect:
|
||||
# 构建维度描述
|
||||
dimension_descriptions = []
|
||||
for dim in dimensions:
|
||||
desc = FACTOR_DESCRIPTIONS.get(dim, "")
|
||||
if desc:
|
||||
if desc := FACTOR_DESCRIPTIONS.get(dim, ""):
|
||||
dimension_descriptions.append(f"- {dim}:{desc}")
|
||||
|
||||
dimensions_text = "\n".join(dimension_descriptions)
|
||||
|
||||
@@ -153,14 +153,14 @@ class MainSystem:
|
||||
while True:
|
||||
await asyncio.sleep(global_config.memory.memory_build_interval)
|
||||
logger.info("正在进行记忆构建")
|
||||
await self.hippocampus_manager.build_memory()
|
||||
await self.hippocampus_manager.build_memory() # type: ignore
|
||||
|
||||
async def forget_memory_task(self):
|
||||
"""记忆遗忘任务"""
|
||||
while True:
|
||||
await asyncio.sleep(global_config.memory.forget_memory_interval)
|
||||
logger.info("[记忆遗忘] 开始遗忘记忆...")
|
||||
await self.hippocampus_manager.forget_memory(percentage=global_config.memory.memory_forget_percentage)
|
||||
await self.hippocampus_manager.forget_memory(percentage=global_config.memory.memory_forget_percentage) # type: ignore
|
||||
logger.info("[记忆遗忘] 记忆遗忘完成")
|
||||
|
||||
async def consolidate_memory_task(self):
|
||||
@@ -168,7 +168,7 @@ class MainSystem:
|
||||
while True:
|
||||
await asyncio.sleep(global_config.memory.consolidate_memory_interval)
|
||||
logger.info("[记忆整合] 开始整合记忆...")
|
||||
await self.hippocampus_manager.consolidate_memory()
|
||||
await self.hippocampus_manager.consolidate_memory() # type: ignore
|
||||
logger.info("[记忆整合] 记忆整合完成")
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -49,6 +49,9 @@ class ChatMood:
|
||||
|
||||
chat_manager = get_chat_manager()
|
||||
self.chat_stream = chat_manager.get_stream(self.chat_id)
|
||||
|
||||
if not self.chat_stream:
|
||||
raise ValueError(f"Chat stream for chat_id {chat_id} not found")
|
||||
|
||||
self.log_prefix = f"[{self.chat_stream.group_info.group_name if self.chat_stream.group_info else self.chat_stream.user_info.user_nickname}]"
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ SEGMENT_CLEANUP_CONFIG = {
|
||||
"cleanup_interval_hours": 0.5, # 清理间隔(小时)
|
||||
}
|
||||
|
||||
MAX_MESSAGE_COUNT = 80 / global_config.relationship.relation_frequency
|
||||
MAX_MESSAGE_COUNT = int(80 / global_config.relationship.relation_frequency)
|
||||
|
||||
|
||||
class RelationshipBuilder:
|
||||
|
||||
@@ -61,7 +61,7 @@ __all__ = [
|
||||
"ConfigField",
|
||||
# 工具函数
|
||||
"ManifestValidator",
|
||||
"ManifestGenerator",
|
||||
"validate_plugin_manifest",
|
||||
"generate_plugin_manifest",
|
||||
# "ManifestGenerator",
|
||||
# "validate_plugin_manifest",
|
||||
# "generate_plugin_manifest",
|
||||
]
|
||||
|
||||
@@ -111,7 +111,7 @@ async def _send_to_target(
|
||||
is_head=True,
|
||||
is_emoji=(message_type == "emoji"),
|
||||
thinking_start_time=current_time,
|
||||
reply_to = reply_to_platform_id
|
||||
reply_to=reply_to_platform_id,
|
||||
)
|
||||
|
||||
# 发送消息
|
||||
@@ -137,6 +137,7 @@ async def _send_to_target(
|
||||
|
||||
|
||||
async def _find_reply_message(target_stream, reply_to: str) -> Optional[MessageRecv]:
|
||||
# sourcery skip: inline-variable, use-named-expression
|
||||
"""查找要回复的消息
|
||||
|
||||
Args:
|
||||
@@ -184,14 +185,11 @@ async def _find_reply_message(target_stream, reply_to: str) -> Optional[MessageR
|
||||
|
||||
# 检查是否有 回复<aaa:bbb> 字段
|
||||
reply_pattern = r"回复<([^:<>]+):([^:<>]+)>"
|
||||
match = re.search(reply_pattern, translate_text)
|
||||
if match:
|
||||
if match := re.search(reply_pattern, translate_text):
|
||||
aaa = match.group(1)
|
||||
bbb = match.group(2)
|
||||
reply_person_id = get_person_info_manager().get_person_id(platform, bbb)
|
||||
reply_person_name = await get_person_info_manager().get_value(reply_person_id, "person_name")
|
||||
if not reply_person_name:
|
||||
reply_person_name = aaa
|
||||
reply_person_name = await get_person_info_manager().get_value(reply_person_id, "person_name") or aaa
|
||||
# 在内容前加上回复信息
|
||||
translate_text = re.sub(reply_pattern, f"回复 {reply_person_name}", translate_text, count=1)
|
||||
|
||||
@@ -206,9 +204,7 @@ async def _find_reply_message(target_stream, reply_to: str) -> Optional[MessageR
|
||||
aaa = m.group(1)
|
||||
bbb = m.group(2)
|
||||
at_person_id = get_person_info_manager().get_person_id(platform, bbb)
|
||||
at_person_name = await get_person_info_manager().get_value(at_person_id, "person_name")
|
||||
if not at_person_name:
|
||||
at_person_name = aaa
|
||||
at_person_name = await get_person_info_manager().get_value(at_person_id, "person_name") or aaa
|
||||
new_content += f"@{at_person_name}"
|
||||
last_end = m.end()
|
||||
new_content += translate_text[last_end:]
|
||||
@@ -370,7 +366,14 @@ async def custom_to_stream(
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
return await _send_to_target(
|
||||
message_type, content, stream_id, display_message, typing, reply_to, storage_message, show_log
|
||||
message_type,
|
||||
content,
|
||||
stream_id,
|
||||
display_message,
|
||||
typing,
|
||||
reply_to,
|
||||
storage_message=storage_message,
|
||||
show_log=show_log,
|
||||
)
|
||||
|
||||
|
||||
@@ -396,7 +399,7 @@ async def text_to_group(
|
||||
"""
|
||||
stream_id = get_chat_manager().get_stream_id(platform, group_id, True)
|
||||
|
||||
return await _send_to_target("text", text, stream_id, "", typing, reply_to, storage_message)
|
||||
return await _send_to_target("text", text, stream_id, "", typing, reply_to, storage_message=storage_message)
|
||||
|
||||
|
||||
async def text_to_user(
|
||||
@@ -420,7 +423,7 @@ async def text_to_user(
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
stream_id = get_chat_manager().get_stream_id(platform, user_id, False)
|
||||
return await _send_to_target("text", text, stream_id, "", typing, reply_to, storage_message)
|
||||
return await _send_to_target("text", text, stream_id, "", typing, reply_to, storage_message=storage_message)
|
||||
|
||||
|
||||
async def emoji_to_group(emoji_base64: str, group_id: str, platform: str = "qq", storage_message: bool = True) -> bool:
|
||||
@@ -543,7 +546,9 @@ async def custom_to_group(
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
stream_id = get_chat_manager().get_stream_id(platform, group_id, True)
|
||||
return await _send_to_target(message_type, content, stream_id, display_message, typing, reply_to, storage_message)
|
||||
return await _send_to_target(
|
||||
message_type, content, stream_id, display_message, typing, reply_to, storage_message=storage_message
|
||||
)
|
||||
|
||||
|
||||
async def custom_to_user(
|
||||
@@ -571,7 +576,9 @@ async def custom_to_user(
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
stream_id = get_chat_manager().get_stream_id(platform, user_id, False)
|
||||
return await _send_to_target(message_type, content, stream_id, display_message, typing, reply_to, storage_message)
|
||||
return await _send_to_target(
|
||||
message_type, content, stream_id, display_message, typing, reply_to, storage_message=storage_message
|
||||
)
|
||||
|
||||
|
||||
async def custom_message(
|
||||
@@ -611,4 +618,6 @@ async def custom_message(
|
||||
await send_api.custom_message("audio", audio_base64, "123456", True, reply_to="张三:你好")
|
||||
"""
|
||||
stream_id = get_chat_manager().get_stream_id(platform, target_id, is_group)
|
||||
return await _send_to_target(message_type, content, stream_id, display_message, typing, reply_to, storage_message)
|
||||
return await _send_to_target(
|
||||
message_type, content, stream_id, display_message, typing, reply_to, storage_message=storage_message
|
||||
)
|
||||
|
||||
@@ -38,7 +38,7 @@ class BaseAction(ABC):
|
||||
chat_stream: ChatStream,
|
||||
log_prefix: str = "",
|
||||
plugin_config: Optional[dict] = None,
|
||||
action_message: dict = None,
|
||||
action_message: Optional[dict] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""初始化Action组件
|
||||
@@ -63,7 +63,7 @@ class BaseAction(ABC):
|
||||
self.cycle_timers = cycle_timers
|
||||
self.thinking_id = thinking_id
|
||||
self.log_prefix = log_prefix
|
||||
|
||||
|
||||
# 保存插件配置
|
||||
self.plugin_config = plugin_config or {}
|
||||
|
||||
@@ -92,10 +92,10 @@ class BaseAction(ABC):
|
||||
self.chat_stream = chat_stream or kwargs.get("chat_stream")
|
||||
self.chat_id = self.chat_stream.stream_id
|
||||
self.platform = getattr(self.chat_stream, "platform", None)
|
||||
|
||||
|
||||
# 初始化基础信息(带类型注解)
|
||||
self.action_message = action_message
|
||||
|
||||
|
||||
self.group_id = None
|
||||
self.group_name = None
|
||||
self.user_id = None
|
||||
@@ -103,15 +103,17 @@ class BaseAction(ABC):
|
||||
self.is_group = False
|
||||
self.target_id = None
|
||||
self.has_action_message = False
|
||||
|
||||
|
||||
if self.action_message:
|
||||
self.has_action_message = True
|
||||
|
||||
else:
|
||||
self.action_message = {}
|
||||
|
||||
if self.has_action_message:
|
||||
if self.action_name != "no_reply":
|
||||
self.group_id = str(self.action_message.get("chat_info_group_id", None))
|
||||
self.group_name = self.action_message.get("chat_info_group_name", None)
|
||||
|
||||
|
||||
self.user_id = str(self.action_message.get("user_id", None))
|
||||
self.user_nickname = self.action_message.get("user_nickname", None)
|
||||
if self.group_id:
|
||||
@@ -132,8 +134,6 @@ class BaseAction(ABC):
|
||||
self.is_group = False
|
||||
self.target_id = self.user_id
|
||||
|
||||
|
||||
|
||||
logger.debug(f"{self.log_prefix} Action组件初始化完成")
|
||||
logger.info(
|
||||
f"{self.log_prefix} 聊天信息: 类型={'群聊' if self.is_group else '私聊'}, 平台={self.platform}, 目标={self.target_id}"
|
||||
@@ -199,7 +199,9 @@ class BaseAction(ABC):
|
||||
logger.error(f"{self.log_prefix} 等待新消息时发生错误: {e}")
|
||||
return False, f"等待新消息失败: {str(e)}"
|
||||
|
||||
async def send_text(self, content: str, reply_to: str = "", reply_to_platform_id: str = "", typing: bool = False) -> bool:
|
||||
async def send_text(
|
||||
self, content: str, reply_to: str = "", reply_to_platform_id: str = "", typing: bool = False
|
||||
) -> bool:
|
||||
"""发送文本消息
|
||||
|
||||
Args:
|
||||
@@ -299,7 +301,7 @@ class BaseAction(ABC):
|
||||
)
|
||||
|
||||
async def send_command(
|
||||
self, command_name: str, args: dict = None, display_message: str = None, storage_message: bool = True
|
||||
self, command_name: str, args: Optional[dict] = None, display_message: str = "", storage_message: bool = True
|
||||
) -> bool:
|
||||
"""发送命令消息
|
||||
|
||||
|
||||
@@ -135,7 +135,7 @@ class BaseCommand(ABC):
|
||||
)
|
||||
|
||||
async def send_command(
|
||||
self, command_name: str, args: dict = None, display_message: str = "", storage_message: bool = True
|
||||
self, command_name: str, args: Optional[dict] = None, display_message: str = "", storage_message: bool = True
|
||||
) -> bool:
|
||||
"""发送命令消息
|
||||
|
||||
|
||||
@@ -346,67 +346,67 @@ class ComponentRegistry:
|
||||
|
||||
# === 状态管理方法 ===
|
||||
|
||||
def enable_component(self, component_name: str, component_type: ComponentType = None) -> bool:
|
||||
# -------------------------------- NEED REFACTORING --------------------------------
|
||||
# -------------------------------- LOGIC ERROR -------------------------------------
|
||||
"""启用组件,支持命名空间解析"""
|
||||
# 首先尝试找到正确的命名空间化名称
|
||||
component_info = self.get_component_info(component_name, component_type)
|
||||
if not component_info:
|
||||
return False
|
||||
# def enable_component(self, component_name: str, component_type: ComponentType = None) -> bool:
|
||||
# # -------------------------------- NEED REFACTORING --------------------------------
|
||||
# # -------------------------------- LOGIC ERROR -------------------------------------
|
||||
# """启用组件,支持命名空间解析"""
|
||||
# # 首先尝试找到正确的命名空间化名称
|
||||
# component_info = self.get_component_info(component_name, component_type)
|
||||
# if not component_info:
|
||||
# return False
|
||||
|
||||
# 根据组件类型构造正确的命名空间化名称
|
||||
if component_info.component_type == ComponentType.ACTION:
|
||||
namespaced_name = f"action.{component_name}" if "." not in component_name else component_name
|
||||
elif component_info.component_type == ComponentType.COMMAND:
|
||||
namespaced_name = f"command.{component_name}" if "." not in component_name else component_name
|
||||
else:
|
||||
namespaced_name = (
|
||||
f"{component_info.component_type.value}.{component_name}"
|
||||
if "." not in component_name
|
||||
else component_name
|
||||
)
|
||||
# # 根据组件类型构造正确的命名空间化名称
|
||||
# if component_info.component_type == ComponentType.ACTION:
|
||||
# namespaced_name = f"action.{component_name}" if "." not in component_name else component_name
|
||||
# elif component_info.component_type == ComponentType.COMMAND:
|
||||
# namespaced_name = f"command.{component_name}" if "." not in component_name else component_name
|
||||
# else:
|
||||
# namespaced_name = (
|
||||
# f"{component_info.component_type.value}.{component_name}"
|
||||
# if "." not in component_name
|
||||
# else component_name
|
||||
# )
|
||||
|
||||
if namespaced_name in self._components:
|
||||
self._components[namespaced_name].enabled = True
|
||||
# 如果是Action,更新默认动作集
|
||||
# ---- HERE ----
|
||||
# if isinstance(component_info, ActionInfo):
|
||||
# self._action_descriptions[component_name] = component_info.description
|
||||
logger.debug(f"已启用组件: {component_name} -> {namespaced_name}")
|
||||
return True
|
||||
return False
|
||||
# if namespaced_name in self._components:
|
||||
# self._components[namespaced_name].enabled = True
|
||||
# # 如果是Action,更新默认动作集
|
||||
# # ---- HERE ----
|
||||
# # if isinstance(component_info, ActionInfo):
|
||||
# # self._action_descriptions[component_name] = component_info.description
|
||||
# logger.debug(f"已启用组件: {component_name} -> {namespaced_name}")
|
||||
# return True
|
||||
# return False
|
||||
|
||||
def disable_component(self, component_name: str, component_type: ComponentType = None) -> bool:
|
||||
# -------------------------------- NEED REFACTORING --------------------------------
|
||||
# -------------------------------- LOGIC ERROR -------------------------------------
|
||||
"""禁用组件,支持命名空间解析"""
|
||||
# 首先尝试找到正确的命名空间化名称
|
||||
component_info = self.get_component_info(component_name, component_type)
|
||||
if not component_info:
|
||||
return False
|
||||
# def disable_component(self, component_name: str, component_type: ComponentType = None) -> bool:
|
||||
# # -------------------------------- NEED REFACTORING --------------------------------
|
||||
# # -------------------------------- LOGIC ERROR -------------------------------------
|
||||
# """禁用组件,支持命名空间解析"""
|
||||
# # 首先尝试找到正确的命名空间化名称
|
||||
# component_info = self.get_component_info(component_name, component_type)
|
||||
# if not component_info:
|
||||
# return False
|
||||
|
||||
# 根据组件类型构造正确的命名空间化名称
|
||||
if component_info.component_type == ComponentType.ACTION:
|
||||
namespaced_name = f"action.{component_name}" if "." not in component_name else component_name
|
||||
elif component_info.component_type == ComponentType.COMMAND:
|
||||
namespaced_name = f"command.{component_name}" if "." not in component_name else component_name
|
||||
else:
|
||||
namespaced_name = (
|
||||
f"{component_info.component_type.value}.{component_name}"
|
||||
if "." not in component_name
|
||||
else component_name
|
||||
)
|
||||
# # 根据组件类型构造正确的命名空间化名称
|
||||
# if component_info.component_type == ComponentType.ACTION:
|
||||
# namespaced_name = f"action.{component_name}" if "." not in component_name else component_name
|
||||
# elif component_info.component_type == ComponentType.COMMAND:
|
||||
# namespaced_name = f"command.{component_name}" if "." not in component_name else component_name
|
||||
# else:
|
||||
# namespaced_name = (
|
||||
# f"{component_info.component_type.value}.{component_name}"
|
||||
# if "." not in component_name
|
||||
# else component_name
|
||||
# )
|
||||
|
||||
if namespaced_name in self._components:
|
||||
self._components[namespaced_name].enabled = False
|
||||
# 如果是Action,从默认动作集中移除
|
||||
# ---- HERE ----
|
||||
# if component_name in self._action_descriptions:
|
||||
# del self._action_descriptions[component_name]
|
||||
logger.debug(f"已禁用组件: {component_name} -> {namespaced_name}")
|
||||
return True
|
||||
return False
|
||||
# if namespaced_name in self._components:
|
||||
# self._components[namespaced_name].enabled = False
|
||||
# # 如果是Action,从默认动作集中移除
|
||||
# # ---- HERE ----
|
||||
# # if component_name in self._action_descriptions:
|
||||
# # del self._action_descriptions[component_name]
|
||||
# logger.debug(f"已禁用组件: {component_name} -> {namespaced_name}")
|
||||
# return True
|
||||
# return False
|
||||
|
||||
def get_registry_stats(self) -> Dict[str, Any]:
|
||||
"""获取注册中心统计信息"""
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
import subprocess
|
||||
import sys
|
||||
import importlib
|
||||
from typing import List, Dict, Tuple
|
||||
from typing import List, Dict, Tuple, Any
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.base.component_types import PythonDependency
|
||||
@@ -176,7 +176,7 @@ class DependencyManager:
|
||||
logger.error(f"生成requirements文件失败: {str(e)}")
|
||||
return False
|
||||
|
||||
def get_install_summary(self) -> Dict[str, any]:
|
||||
def get_install_summary(self) -> Dict[str, Any]:
|
||||
"""获取安装摘要"""
|
||||
return {
|
||||
"install_log": self.install_log.copy(),
|
||||
|
||||
@@ -197,29 +197,29 @@ class PluginManager:
|
||||
"""获取所有启用的插件信息"""
|
||||
return list(component_registry.get_enabled_plugins().values())
|
||||
|
||||
def enable_plugin(self, plugin_name: str) -> bool:
|
||||
# -------------------------------- NEED REFACTORING --------------------------------
|
||||
"""启用插件"""
|
||||
if plugin_info := component_registry.get_plugin_info(plugin_name):
|
||||
plugin_info.enabled = True
|
||||
# 启用插件的所有组件
|
||||
for component in plugin_info.components:
|
||||
component_registry.enable_component(component.name)
|
||||
logger.debug(f"已启用插件: {plugin_name}")
|
||||
return True
|
||||
return False
|
||||
# def enable_plugin(self, plugin_name: str) -> bool:
|
||||
# # -------------------------------- NEED REFACTORING --------------------------------
|
||||
# """启用插件"""
|
||||
# if plugin_info := component_registry.get_plugin_info(plugin_name):
|
||||
# plugin_info.enabled = True
|
||||
# # 启用插件的所有组件
|
||||
# for component in plugin_info.components:
|
||||
# component_registry.enable_component(component.name)
|
||||
# logger.debug(f"已启用插件: {plugin_name}")
|
||||
# return True
|
||||
# return False
|
||||
|
||||
def disable_plugin(self, plugin_name: str) -> bool:
|
||||
# -------------------------------- NEED REFACTORING --------------------------------
|
||||
"""禁用插件"""
|
||||
if plugin_info := component_registry.get_plugin_info(plugin_name):
|
||||
plugin_info.enabled = False
|
||||
# 禁用插件的所有组件
|
||||
for component in plugin_info.components:
|
||||
component_registry.disable_component(component.name)
|
||||
logger.debug(f"已禁用插件: {plugin_name}")
|
||||
return True
|
||||
return False
|
||||
# def disable_plugin(self, plugin_name: str) -> bool:
|
||||
# # -------------------------------- NEED REFACTORING --------------------------------
|
||||
# """禁用插件"""
|
||||
# if plugin_info := component_registry.get_plugin_info(plugin_name):
|
||||
# plugin_info.enabled = False
|
||||
# # 禁用插件的所有组件
|
||||
# for component in plugin_info.components:
|
||||
# component_registry.disable_component(component.name)
|
||||
# logger.debug(f"已禁用插件: {plugin_name}")
|
||||
# return True
|
||||
# return False
|
||||
|
||||
def get_plugin_instance(self, plugin_name: str) -> Optional["PluginBase"]:
|
||||
"""获取插件实例
|
||||
|
||||
@@ -28,10 +28,10 @@ class CompareNumbersTool(BaseTool):
|
||||
Returns:
|
||||
dict: 工具执行结果
|
||||
"""
|
||||
try:
|
||||
num1 = function_args.get("num1")
|
||||
num2 = function_args.get("num2")
|
||||
num1: int | float = function_args.get("num1") # type: ignore
|
||||
num2: int | float = function_args.get("num2") # type: ignore
|
||||
|
||||
try:
|
||||
if num1 > num2:
|
||||
result = f"{num1} 大于 {num2}"
|
||||
elif num1 < num2:
|
||||
|
||||
@@ -68,10 +68,10 @@ class RenamePersonTool(BaseTool):
|
||||
)
|
||||
result = await person_info_manager.qv_person_name(
|
||||
person_id=person_id,
|
||||
user_nickname=user_nickname,
|
||||
user_cardname=user_cardname,
|
||||
user_avatar=user_avatar,
|
||||
request=request_context,
|
||||
user_nickname=user_nickname, # type: ignore
|
||||
user_cardname=user_cardname, # type: ignore
|
||||
user_avatar=user_avatar, # type: ignore
|
||||
request=request_context, # type: ignore
|
||||
)
|
||||
|
||||
# 3. 处理结果
|
||||
|
||||
Reference in New Issue
Block a user