typing fix

This commit is contained in:
UnCLAS-Prommer
2025-07-17 00:10:41 +08:00
parent 6e838ccc74
commit 1aa2734d62
26 changed files with 329 additions and 293 deletions

View File

@@ -107,11 +107,12 @@ class ExpressionLearner:
last_active_time = expr.get("last_active_time", time.time())
# 查重同chat_id+type+situation+style
from src.common.database.database_model import Expression
query = Expression.select().where(
(Expression.chat_id == chat_id) &
(Expression.type == type_str) &
(Expression.situation == situation) &
(Expression.style == style_val)
(Expression.chat_id == chat_id)
& (Expression.type == type_str)
& (Expression.situation == situation)
& (Expression.style == style_val)
)
if query.exists():
expr_obj = query.get()
@@ -125,7 +126,7 @@ class ExpressionLearner:
count=count,
last_active_time=last_active_time,
chat_id=chat_id,
type=type_str
type=type_str,
)
logger.info(f"已迁移 {expr_file} 到数据库")
except Exception as e:
@@ -149,24 +150,28 @@ class ExpressionLearner:
# 直接从数据库查询
style_query = Expression.select().where((Expression.chat_id == chat_id) & (Expression.type == "style"))
for expr in style_query:
learnt_style_expressions.append({
"situation": expr.situation,
"style": expr.style,
"count": expr.count,
"last_active_time": expr.last_active_time,
"source_id": chat_id,
"type": "style"
})
learnt_style_expressions.append(
{
"situation": expr.situation,
"style": expr.style,
"count": expr.count,
"last_active_time": expr.last_active_time,
"source_id": chat_id,
"type": "style",
}
)
grammar_query = Expression.select().where((Expression.chat_id == chat_id) & (Expression.type == "grammar"))
for expr in grammar_query:
learnt_grammar_expressions.append({
"situation": expr.situation,
"style": expr.style,
"count": expr.count,
"last_active_time": expr.last_active_time,
"source_id": chat_id,
"type": "grammar"
})
learnt_grammar_expressions.append(
{
"situation": expr.situation,
"style": expr.style,
"count": expr.count,
"last_active_time": expr.last_active_time,
"source_id": chat_id,
"type": "grammar",
}
)
return learnt_style_expressions, learnt_grammar_expressions
def is_similar(self, s1: str, s2: str) -> bool:
@@ -213,14 +218,16 @@ class ExpressionLearner:
logger.error(f"全局衰减{type}表达方式失败: {e}")
continue
learnt_style: Optional[List[Tuple[str, str, str]]] = []
learnt_grammar: Optional[List[Tuple[str, str, str]]] = []
# 学习新的表达方式(这里会进行局部衰减)
for _ in range(3):
learnt_style: Optional[List[Tuple[str, str, str]]] = await self.learn_and_store(type="style", num=25)
learnt_style = await self.learn_and_store(type="style", num=25)
if not learnt_style:
return [], []
for _ in range(1):
learnt_grammar: Optional[List[Tuple[str, str, str]]] = await self.learn_and_store(type="grammar", num=10)
learnt_grammar = await self.learn_and_store(type="grammar", num=10)
if not learnt_grammar:
return [], []
@@ -321,10 +328,10 @@ class ExpressionLearner:
for new_expr in expr_list:
# 查找是否已存在相似表达方式
query = Expression.select().where(
(Expression.chat_id == chat_id) &
(Expression.type == type) &
(Expression.situation == new_expr["situation"]) &
(Expression.style == new_expr["style"])
(Expression.chat_id == chat_id)
& (Expression.type == type)
& (Expression.situation == new_expr["situation"])
& (Expression.style == new_expr["style"])
)
if query.exists():
expr_obj = query.get()
@@ -342,13 +349,17 @@ class ExpressionLearner:
count=1,
last_active_time=current_time,
chat_id=chat_id,
type=type
type=type,
)
# 限制最大数量
exprs = list(Expression.select().where((Expression.chat_id == chat_id) & (Expression.type == type)).order_by(Expression.count.asc()))
exprs = list(
Expression.select()
.where((Expression.chat_id == chat_id) & (Expression.type == type))
.order_by(Expression.count.asc())
)
if len(exprs) > MAX_EXPRESSION_COUNT:
# 删除count最小的多余表达方式
for expr in exprs[:len(exprs) - MAX_EXPRESSION_COUNT]:
for expr in exprs[: len(exprs) - MAX_EXPRESSION_COUNT]:
expr.delete_instance()
return learnt_expressions

View File

@@ -9,51 +9,49 @@ from src.common.logger import get_logger
import traceback
from src.config.config import global_config
from src.common.database.database_model import Memory # Peewee Models导入
from src.common.database.database_model import Memory # Peewee Models导入
logger = get_logger(__name__)
class MemoryItem:
def __init__(self,memory_id:str,chat_id:str,memory_text:str,keywords:list[str]):
def __init__(self, memory_id: str, chat_id: str, memory_text: str, keywords: list[str]):
self.memory_id = memory_id
self.chat_id = chat_id
self.memory_text:str = memory_text
self.keywords:list[str] = keywords
self.create_time:float = time.time()
self.last_view_time:float = time.time()
self.memory_text: str = memory_text
self.keywords: list[str] = keywords
self.create_time: float = time.time()
self.last_view_time: float = time.time()
class MemoryManager:
def __init__(self):
# self.memory_items:list[MemoryItem] = []
pass
class InstantMemory:
def __init__(self,chat_id):
self.chat_id = chat_id
def __init__(self, chat_id):
self.chat_id = chat_id
self.last_view_time = time.time()
self.summary_model = LLMRequest(
model=global_config.model.memory,
temperature=0.5,
request_type="memory.summary",
)
async def if_need_build(self,text):
async def if_need_build(self, text):
prompt = f"""
请判断以下内容中是否有值得记忆的信息如果有请输出1否则输出0
{text}
请只输出1或0就好
"""
try:
response,_ = await self.summary_model.generate_response_async(prompt)
response, _ = await self.summary_model.generate_response_async(prompt)
print(prompt)
print(response)
if "1" in response:
return True
else:
@@ -61,8 +59,8 @@ class InstantMemory:
except Exception as e:
logger.error(f"判断是否需要记忆出现错误:{str(e)} {traceback.format_exc()}")
return False
async def build_memory(self,text):
async def build_memory(self, text):
prompt = f"""
以下内容中存在值得记忆的信息,请你从中总结出一段值得记忆的信息,并输出
{text}
@@ -73,7 +71,7 @@ class InstantMemory:
}}
"""
try:
response,_ = await self.summary_model.generate_response_async(prompt)
response, _ = await self.summary_model.generate_response_async(prompt)
print(prompt)
print(response)
if not response:
@@ -81,53 +79,53 @@ class InstantMemory:
try:
repaired = repair_json(response)
result = json.loads(repaired)
memory_text = result.get('memory_text', '')
keywords = result.get('keywords', '')
memory_text = result.get("memory_text", "")
keywords = result.get("keywords", "")
if isinstance(keywords, str):
keywords_list = [k.strip() for k in keywords.split('/') if k.strip()]
keywords_list = [k.strip() for k in keywords.split("/") if k.strip()]
elif isinstance(keywords, list):
keywords_list = keywords
else:
keywords_list = []
return {'memory_text': memory_text, 'keywords': keywords_list}
return {"memory_text": memory_text, "keywords": keywords_list}
except Exception as parse_e:
logger.error(f"解析记忆json失败{str(parse_e)} {traceback.format_exc()}")
return None
except Exception as e:
logger.error(f"构建记忆出现错误:{str(e)} {traceback.format_exc()}")
return None
async def create_and_store_memory(self,text):
async def create_and_store_memory(self, text):
if_need = await self.if_need_build(text)
if if_need:
logger.info(f"需要记忆:{text}")
memory = await self.build_memory(text)
if memory and memory.get('memory_text'):
memory = await self.build_memory(text)
if memory and memory.get("memory_text"):
memory_id = f"{self.chat_id}_{time.time()}"
memory_item = MemoryItem(
memory_id=memory_id,
chat_id=self.chat_id,
memory_text=memory['memory_text'],
keywords=memory.get('keywords', [])
memory_text=memory["memory_text"],
keywords=memory.get("keywords", []),
)
await self.store_memory(memory_item)
else:
logger.info(f"不需要记忆:{text}")
async def store_memory(self,memory_item:MemoryItem):
async def store_memory(self, memory_item: MemoryItem):
memory = Memory(
memory_id=memory_item.memory_id,
chat_id=memory_item.chat_id,
memory_text=memory_item.memory_text,
keywords=memory_item.keywords,
create_time=memory_item.create_time,
last_view_time=memory_item.last_view_time
last_view_time=memory_item.last_view_time,
)
memory.save()
async def get_memory(self,target:str):
async def get_memory(self, target: str):
from json_repair import repair_json
prompt = f"""
请根据以下发言内容,判断是否需要提取记忆
{target}
@@ -144,7 +142,7 @@ class InstantMemory:
请只输出json格式不要输出其他多余内容
"""
try:
response,_ = await self.summary_model.generate_response_async(prompt)
response, _ = await self.summary_model.generate_response_async(prompt)
print(prompt)
print(response)
if not response:
@@ -153,15 +151,15 @@ class InstantMemory:
repaired = repair_json(response)
result = json.loads(repaired)
# 解析keywords
keywords = result.get('keywords', '')
keywords = result.get("keywords", "")
if isinstance(keywords, str):
keywords_list = [k.strip() for k in keywords.split('/') if k.strip()]
keywords_list = [k.strip() for k in keywords.split("/") if k.strip()]
elif isinstance(keywords, list):
keywords_list = keywords
else:
keywords_list = []
# 解析time为时间段
time_str = result.get('time', '').strip()
time_str = result.get("time", "").strip()
start_time, end_time = self._parse_time_range(time_str)
logger.info(f"start_time: {start_time}, end_time: {end_time}")
# 检索包含关键词的记忆
@@ -170,16 +168,15 @@ class InstantMemory:
start_ts = start_time.timestamp()
end_ts = end_time.timestamp()
query = Memory.select().where(
(Memory.chat_id == self.chat_id) &
(Memory.create_time >= start_ts) &
(Memory.create_time < end_ts)
(Memory.chat_id == self.chat_id)
& (Memory.create_time >= start_ts) # type: ignore
& (Memory.create_time < end_ts) # type: ignore
)
else:
query = Memory.select().where(Memory.chat_id == self.chat_id)
for mem in query:
#对每条记忆
# 对每条记忆
mem_keywords = mem.keywords or []
parsed = ast.literal_eval(mem_keywords)
if isinstance(parsed, list):
@@ -212,6 +209,7 @@ class InstantMemory:
- 空字符串:返回(None, None)
"""
from datetime import datetime, timedelta
now = datetime.now()
if not time_str:
return 0, now
@@ -251,8 +249,8 @@ class InstantMemory:
if m:
months = int(m.group(1))
# 近似每月30天
start = (now - timedelta(days=months*30)).replace(hour=0, minute=0, second=0, microsecond=0)
start = (now - timedelta(days=months * 30)).replace(hour=0, minute=0, second=0, microsecond=0)
end = start + timedelta(days=1)
return start, end
# 其他无法解析
return 0, now
return 0, now

View File

@@ -30,7 +30,7 @@ class ChatMessageContext:
def get_template_name(self) -> Optional[str]:
"""获取模板名称"""
if self.message.message_info.template_info and not self.message.message_info.template_info.template_default:
return self.message.message_info.template_info.template_name
return self.message.message_info.template_info.template_name # type: ignore
return None
def get_last_message(self) -> "MessageRecv":

View File

@@ -107,9 +107,9 @@ class MessageRecv(Message):
self.is_picid = False
self.has_picid = False
self.is_mentioned = None
self.is_command = False
self.priority_mode = "interest"
self.priority_info = None
self.interest_value: float = None # type: ignore
@@ -181,6 +181,7 @@ class MessageRecv(Message):
logger.error(f"处理消息段失败: {str(e)}, 类型: {segment.type}, 数据: {segment.data}")
return f"[处理失败的{segment.type}消息]"
@dataclass
class MessageRecvS4U(MessageRecv):
def __init__(self, message_dict: dict[str, Any]):
@@ -194,10 +195,10 @@ class MessageRecvS4U(MessageRecv):
self.superchat_price = None
self.superchat_message_text = None
self.is_screen = False
async def process(self) -> None:
self.processed_plain_text = await self._process_message_segments(self.message_segment)
async def _process_single_segment(self, segment: Seg) -> str:
"""处理单个消息段
@@ -252,7 +253,7 @@ class MessageRecvS4U(MessageRecv):
elif segment.type == "gift":
self.is_gift = True
# 解析gift_info格式为"名称:数量"
name, count = segment.data.split(":", 1)
name, count = segment.data.split(":", 1) # type: ignore
self.gift_info = segment.data
self.gift_name = name.strip()
self.gift_count = int(count.strip())
@@ -260,13 +261,15 @@ class MessageRecvS4U(MessageRecv):
elif segment.type == "superchat":
self.is_superchat = True
self.superchat_info = segment.data
price,message_text = segment.data.split(":", 1)
price, message_text = segment.data.split(":", 1) # type: ignore
self.superchat_price = price.strip()
self.superchat_message_text = message_text.strip()
self.processed_plain_text = str(self.superchat_message_text)
self.processed_plain_text += f"(注意:这是一条超级弹幕信息,价值{self.superchat_price}元,请你认真回复)"
self.processed_plain_text += (
f"(注意:这是一条超级弹幕信息,价值{self.superchat_price}元,请你认真回复)"
)
return self.processed_plain_text
elif segment.type == "screen":
self.is_screen = True

View File

@@ -80,7 +80,7 @@ class ActionManager:
chat_stream: ChatStream,
log_prefix: str,
shutting_down: bool = False,
action_message: dict = None,
action_message: Optional[dict] = None,
) -> Optional[BaseAction]:
"""
创建动作处理器实例

View File

@@ -252,7 +252,7 @@ def _build_readable_messages_internal(
pic_id_mapping: Optional[Dict[str, str]] = None,
pic_counter: int = 1,
show_pic: bool = True,
message_id_list: List[Dict[str, Any]] = None,
message_id_list: Optional[List[Dict[str, Any]]] = None,
) -> Tuple[str, List[Tuple[float, str, str]], Dict[str, str], int]:
"""
内部辅助函数,构建可读消息字符串和原始消息详情列表。
@@ -615,7 +615,7 @@ def build_readable_actions(actions: List[Dict[str, Any]]) -> str:
for action in actions:
action_time = action.get("time", current_time)
action_name = action.get("action_name", "未知动作")
if action_name == "no_action" or action_name == "no_reply":
if action_name in ["no_action", "no_reply"]:
continue
action_prompt_display = action.get("action_prompt_display", "无具体内容")
@@ -697,7 +697,7 @@ def build_readable_messages(
truncate: bool = False,
show_actions: bool = False,
show_pic: bool = True,
message_id_list: List[Dict[str, Any]] = None,
message_id_list: Optional[List[Dict[str, Any]]] = None,
) -> str: # sourcery skip: extract-method
"""
将消息列表转换为可读的文本格式。

View File

@@ -1211,7 +1211,7 @@ class StatisticOutputTask(AsyncTask):
f.write(html_template)
def _generate_focus_tab(self, stat: dict[str, Any]) -> str:
# sourcery skip: for-append-to-extend, list-comprehension, use-any
# sourcery skip: for-append-to-extend, list-comprehension, use-any, use-named-expression, use-next
"""生成Focus统计独立分页的HTML内容"""
# 为每个时间段准备Focus数据
@@ -1559,6 +1559,7 @@ class StatisticOutputTask(AsyncTask):
"""
def _generate_versions_tab(self, stat: dict[str, Any]) -> str:
# sourcery skip: use-named-expression, use-next
"""生成版本对比独立分页的HTML内容"""
# 为每个时间段准备版本对比数据
@@ -2306,13 +2307,13 @@ class AsyncStatisticOutputTask(AsyncTask):
# 复用 StatisticOutputTask 的所有方法
def _collect_all_statistics(self, now: datetime):
return StatisticOutputTask._collect_all_statistics(self, now)
return StatisticOutputTask._collect_all_statistics(self, now) # type: ignore
def _statistic_console_output(self, stats: Dict[str, Any], now: datetime):
return StatisticOutputTask._statistic_console_output(self, stats, now)
return StatisticOutputTask._statistic_console_output(self, stats, now) # type: ignore
def _generate_html_report(self, stats: dict[str, Any], now: datetime):
return StatisticOutputTask._generate_html_report(self, stats, now)
return StatisticOutputTask._generate_html_report(self, stats, now) # type: ignore
# 其他需要的方法也可以类似复用...
@staticmethod
@@ -2324,10 +2325,10 @@ class AsyncStatisticOutputTask(AsyncTask):
return StatisticOutputTask._collect_online_time_for_period(collect_period, now)
def _collect_message_count_for_period(self, collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]:
return StatisticOutputTask._collect_message_count_for_period(self, collect_period)
return StatisticOutputTask._collect_message_count_for_period(self, collect_period) # type: ignore
def _collect_focus_statistics_for_period(self, collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]:
return StatisticOutputTask._collect_focus_statistics_for_period(self, collect_period)
return StatisticOutputTask._collect_focus_statistics_for_period(self, collect_period) # type: ignore
def _process_focus_file_data(
self,
@@ -2336,10 +2337,10 @@ class AsyncStatisticOutputTask(AsyncTask):
collect_period: List[Tuple[str, datetime]],
file_time: datetime,
):
return StatisticOutputTask._process_focus_file_data(self, cycles_data, stats, collect_period, file_time)
return StatisticOutputTask._process_focus_file_data(self, cycles_data, stats, collect_period, file_time) # type: ignore
def _calculate_focus_averages(self, stats: Dict[str, Any]):
return StatisticOutputTask._calculate_focus_averages(self, stats)
return StatisticOutputTask._calculate_focus_averages(self, stats) # type: ignore
@staticmethod
def _format_total_stat(stats: Dict[str, Any]) -> str:
@@ -2347,31 +2348,31 @@ class AsyncStatisticOutputTask(AsyncTask):
@staticmethod
def _format_model_classified_stat(stats: Dict[str, Any]) -> str:
return StatisticOutputTask._format_model_classified_stat(stats)
return StatisticOutputTask._format_model_classified_stat(stats) # type: ignore
def _format_chat_stat(self, stats: Dict[str, Any]) -> str:
return StatisticOutputTask._format_chat_stat(self, stats)
return StatisticOutputTask._format_chat_stat(self, stats) # type: ignore
def _format_focus_stat(self, stats: Dict[str, Any]) -> str:
return StatisticOutputTask._format_focus_stat(self, stats)
return StatisticOutputTask._format_focus_stat(self, stats) # type: ignore
def _generate_chart_data(self, stat: dict[str, Any]) -> dict:
return StatisticOutputTask._generate_chart_data(self, stat)
return StatisticOutputTask._generate_chart_data(self, stat) # type: ignore
def _collect_interval_data(self, now: datetime, hours: int, interval_minutes: int) -> dict:
return StatisticOutputTask._collect_interval_data(self, now, hours, interval_minutes)
return StatisticOutputTask._collect_interval_data(self, now, hours, interval_minutes) # type: ignore
def _generate_chart_tab(self, chart_data: dict) -> str:
return StatisticOutputTask._generate_chart_tab(self, chart_data)
return StatisticOutputTask._generate_chart_tab(self, chart_data) # type: ignore
def _get_chat_display_name_from_id(self, chat_id: str) -> str:
return StatisticOutputTask._get_chat_display_name_from_id(self, chat_id)
return StatisticOutputTask._get_chat_display_name_from_id(self, chat_id) # type: ignore
def _generate_focus_tab(self, stat: dict[str, Any]) -> str:
return StatisticOutputTask._generate_focus_tab(self, stat)
return StatisticOutputTask._generate_focus_tab(self, stat) # type: ignore
def _generate_versions_tab(self, stat: dict[str, Any]) -> str:
return StatisticOutputTask._generate_versions_tab(self, stat)
return StatisticOutputTask._generate_versions_tab(self, stat) # type: ignore
def _convert_defaultdict_to_dict(self, data):
return StatisticOutputTask._convert_defaultdict_to_dict(self, data)
return StatisticOutputTask._convert_defaultdict_to_dict(self, data) # type: ignore

View File

@@ -2,14 +2,13 @@ import importlib
import asyncio
from abc import ABC, abstractmethod
from typing import Dict, Optional
from typing import Dict, Optional, Any
from rich.traceback import install
from dataclasses import dataclass
from src.common.logger import get_logger
from src.config.config import global_config
from src.chat.message_receive.chat_stream import ChatStream, GroupInfo
from src.chat.message_receive.message import MessageRecv
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
install(extra_lines=3)
@@ -54,7 +53,7 @@ class WillingInfo:
interested_rate (float): 兴趣度
"""
message: MessageRecv
message: Dict[str, Any] # 原始消息数据
chat: ChatStream
person_info_manager: PersonInfoManager
chat_id: str