perf(methods): 通过移除不必要的 self 参数优化方法签名
在包括 chat、plugin_system、schedule 和 mais4u 在内的多个模块中,消除冗余的实例引用。此次改动将无需访问实例状态的实用函数转换为静态方法,从而提升了内存效率,并使方法依赖关系更加清晰。
This commit is contained in:
@@ -4,7 +4,7 @@ import orjson
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
from typing import List, Dict, Optional, Any, Tuple
|
||||
from typing import List, Dict, Optional, Any, Tuple, Coroutine
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||
@@ -112,7 +112,7 @@ class ExpressionLearner:
|
||||
logger.error(f"检查学习权限失败: {e}")
|
||||
return False
|
||||
|
||||
def should_trigger_learning(self) -> bool:
|
||||
async def should_trigger_learning(self) -> bool:
|
||||
"""
|
||||
检查是否应该触发学习
|
||||
|
||||
@@ -146,7 +146,7 @@ class ExpressionLearner:
|
||||
return False
|
||||
|
||||
# 检查消息数量(只检查指定聊天流的消息)
|
||||
recent_messages = get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||
recent_messages = await get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||
chat_id=self.chat_id,
|
||||
timestamp_start=self.last_learning_time,
|
||||
timestamp_end=time.time(),
|
||||
@@ -193,7 +193,7 @@ class ExpressionLearner:
|
||||
logger.error(f"为聊天流 {self.chat_name} 触发学习失败: {e}")
|
||||
return False
|
||||
|
||||
def get_expression_by_chat_id(self) -> Tuple[List[Dict[str, float]], List[Dict[str, float]]]:
|
||||
async def get_expression_by_chat_id(self) -> Tuple[List[Dict[str, float]], List[Dict[str, float]]]:
|
||||
"""
|
||||
获取指定chat_id的style和grammar表达方式
|
||||
返回的每个表达方式字典中都包含了source_id, 用于后续的更新操作
|
||||
@@ -202,8 +202,8 @@ class ExpressionLearner:
|
||||
learnt_grammar_expressions = []
|
||||
|
||||
# 直接从数据库查询
|
||||
with get_db_session() as session:
|
||||
style_query = session.execute(
|
||||
async with get_db_session() as session:
|
||||
style_query = await session.execute(
|
||||
select(Expression).where((Expression.chat_id == self.chat_id) & (Expression.type == "style"))
|
||||
)
|
||||
for expr in style_query.scalars():
|
||||
@@ -220,7 +220,7 @@ class ExpressionLearner:
|
||||
"create_date": create_date,
|
||||
}
|
||||
)
|
||||
grammar_query = session.execute(
|
||||
grammar_query = await session.execute(
|
||||
select(Expression).where((Expression.chat_id == self.chat_id) & (Expression.type == "grammar"))
|
||||
)
|
||||
for expr in grammar_query.scalars():
|
||||
@@ -239,14 +239,15 @@ class ExpressionLearner:
|
||||
)
|
||||
return learnt_style_expressions, learnt_grammar_expressions
|
||||
|
||||
def _apply_global_decay_to_database(self, current_time: float) -> None:
|
||||
async def _apply_global_decay_to_database(self, current_time: float) -> None:
|
||||
"""
|
||||
对数据库中的所有表达方式应用全局衰减
|
||||
"""
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
async with get_db_session() as session:
|
||||
# 获取所有表达方式
|
||||
all_expressions = session.execute(select(Expression)).scalars()
|
||||
all_expressions = await session.execute(select(Expression))
|
||||
all_expressions = all_expressions.scalars().all()
|
||||
|
||||
updated_count = 0
|
||||
deleted_count = 0
|
||||
@@ -263,7 +264,7 @@ class ExpressionLearner:
|
||||
if new_count <= 0.01:
|
||||
# 如果count太小,删除这个表达方式
|
||||
session.delete(expr)
|
||||
session.commit()
|
||||
await session.commit()
|
||||
deleted_count += 1
|
||||
else:
|
||||
# 更新count
|
||||
@@ -276,7 +277,8 @@ class ExpressionLearner:
|
||||
except Exception as e:
|
||||
logger.error(f"数据库全局衰减失败: {e}")
|
||||
|
||||
def calculate_decay_factor(self, time_diff_days: float) -> float:
|
||||
@staticmethod
|
||||
def calculate_decay_factor(time_diff_days: float) -> float:
|
||||
"""
|
||||
计算衰减值
|
||||
当时间差为0天时,衰减值为0(最近活跃的不衰减)
|
||||
@@ -298,7 +300,7 @@ class ExpressionLearner:
|
||||
|
||||
return min(0.01, decay)
|
||||
|
||||
async def learn_and_store(self, type: str, num: int = 10) -> List[Tuple[str, str, str]]:
|
||||
async def learn_and_store(self, type: str, num: int = 10) -> None | list[Any] | list[tuple[str, str, str]]:
|
||||
# sourcery skip: use-join
|
||||
"""
|
||||
学习并存储表达方式
|
||||
@@ -349,19 +351,20 @@ class ExpressionLearner:
|
||||
|
||||
# 存储到数据库 Expression 表
|
||||
for chat_id, expr_list in chat_dict.items():
|
||||
for new_expr in expr_list:
|
||||
# 查找是否已存在相似表达方式
|
||||
with get_db_session() as session:
|
||||
query = session.execute(
|
||||
async with get_db_session() as session:
|
||||
for new_expr in expr_list:
|
||||
# 查找是否已存在相似表达方式
|
||||
query = await session.execute(
|
||||
select(Expression).where(
|
||||
(Expression.chat_id == chat_id)
|
||||
& (Expression.type == type)
|
||||
& (Expression.situation == new_expr["situation"])
|
||||
& (Expression.style == new_expr["style"])
|
||||
)
|
||||
).scalar()
|
||||
if query:
|
||||
expr_obj = query
|
||||
)
|
||||
existing_expr = query.scalar()
|
||||
if existing_expr:
|
||||
expr_obj = existing_expr
|
||||
# 50%概率替换内容
|
||||
if random.random() < 0.5:
|
||||
expr_obj.situation = new_expr["situation"]
|
||||
@@ -378,23 +381,22 @@ class ExpressionLearner:
|
||||
type=type,
|
||||
create_date=current_time, # 手动设置创建日期
|
||||
)
|
||||
session.add(new_expression)
|
||||
session.commit()
|
||||
await session.add(new_expression)
|
||||
|
||||
# 限制最大数量
|
||||
exprs = list(
|
||||
session.execute(
|
||||
select(Expression)
|
||||
.where((Expression.chat_id == chat_id) & (Expression.type == type))
|
||||
.order_by(Expression.count.asc())
|
||||
).scalars()
|
||||
exprs_result = await session.execute(
|
||||
select(Expression)
|
||||
.where((Expression.chat_id == chat_id) & (Expression.type == type))
|
||||
.order_by(Expression.count.asc())
|
||||
)
|
||||
exprs = list(exprs_result.scalars())
|
||||
if len(exprs) > MAX_EXPRESSION_COUNT:
|
||||
# 删除count最小的多余表达方式
|
||||
for expr in exprs[: len(exprs) - MAX_EXPRESSION_COUNT]:
|
||||
session.delete(expr)
|
||||
session.commit()
|
||||
await session.delete(expr)
|
||||
|
||||
return learnt_expressions
|
||||
return None
|
||||
|
||||
async def learn_expression(self, type: str, num: int = 10) -> Optional[Tuple[List[Tuple[str, str, str]], str]]:
|
||||
"""从指定聊天流学习表达方式
|
||||
@@ -414,7 +416,7 @@ class ExpressionLearner:
|
||||
current_time = time.time()
|
||||
|
||||
# 获取上次学习时间
|
||||
random_msg: Optional[List[Dict[str, Any]]] = get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||
random_msg: Optional[List[Dict[str, Any]]] = await get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||
chat_id=self.chat_id,
|
||||
timestamp_start=self.last_learning_time,
|
||||
timestamp_end=current_time,
|
||||
@@ -449,7 +451,8 @@ class ExpressionLearner:
|
||||
|
||||
return expressions, chat_id
|
||||
|
||||
def parse_expression_response(self, response: str, chat_id: str) -> List[Tuple[str, str, str]]:
|
||||
@staticmethod
|
||||
def parse_expression_response(response: str, chat_id: str) -> List[Tuple[str, str, str]]:
|
||||
"""
|
||||
解析LLM返回的表达风格总结,每一行提取"当"和"使用"之间的内容,存储为(situation, style)元组
|
||||
"""
|
||||
@@ -488,15 +491,18 @@ class ExpressionLearnerManager:
|
||||
self.expression_learners = {}
|
||||
|
||||
self._ensure_expression_directories()
|
||||
self._auto_migrate_json_to_db()
|
||||
self._migrate_old_data_create_date()
|
||||
|
||||
def get_expression_learner(self, chat_id: str) -> ExpressionLearner:
|
||||
|
||||
async def get_expression_learner(self, chat_id: str) -> ExpressionLearner:
|
||||
await self._auto_migrate_json_to_db()
|
||||
await self._migrate_old_data_create_date()
|
||||
|
||||
if chat_id not in self.expression_learners:
|
||||
self.expression_learners[chat_id] = ExpressionLearner(chat_id)
|
||||
return self.expression_learners[chat_id]
|
||||
|
||||
def _ensure_expression_directories(self):
|
||||
@staticmethod
|
||||
def _ensure_expression_directories():
|
||||
"""
|
||||
确保表达方式相关的目录结构存在
|
||||
"""
|
||||
@@ -514,7 +520,8 @@ class ExpressionLearnerManager:
|
||||
except Exception as e:
|
||||
logger.error(f"创建目录失败 {directory}: {e}")
|
||||
|
||||
def _auto_migrate_json_to_db(self):
|
||||
@staticmethod
|
||||
async def _auto_migrate_json_to_db():
|
||||
"""
|
||||
自动将/data/expression/learnt_style 和 learnt_grammar 下所有expressions.json迁移到数据库。
|
||||
迁移完成后在/data/expression/done.done写入标记文件,存在则跳过。
|
||||
@@ -577,33 +584,33 @@ class ExpressionLearnerManager:
|
||||
continue
|
||||
|
||||
# 查重:同chat_id+type+situation+style
|
||||
with get_db_session() as session:
|
||||
query = session.execute(
|
||||
async with get_db_session() as session:
|
||||
query = await session.execute(
|
||||
select(Expression).where(
|
||||
(Expression.chat_id == chat_id)
|
||||
& (Expression.type == type_str)
|
||||
& (Expression.situation == situation)
|
||||
& (Expression.style == style_val)
|
||||
)
|
||||
).scalar()
|
||||
if query:
|
||||
expr_obj = query
|
||||
expr_obj.count = max(expr_obj.count, count)
|
||||
expr_obj.last_active_time = max(expr_obj.last_active_time, last_active_time)
|
||||
else:
|
||||
new_expression = Expression(
|
||||
situation=situation,
|
||||
style=style_val,
|
||||
count=count,
|
||||
last_active_time=last_active_time,
|
||||
chat_id=chat_id,
|
||||
type=type_str,
|
||||
create_date=last_active_time, # 迁移时使用last_active_time作为创建时间
|
||||
)
|
||||
session.add(new_expression)
|
||||
session.commit()
|
||||
existing_expr = query.scalar()
|
||||
if existing_expr:
|
||||
expr_obj = existing_expr
|
||||
expr_obj.count = max(expr_obj.count, count)
|
||||
expr_obj.last_active_time = max(expr_obj.last_active_time, last_active_time)
|
||||
else:
|
||||
new_expression = Expression(
|
||||
situation=situation,
|
||||
style=style_val,
|
||||
count=count,
|
||||
last_active_time=last_active_time,
|
||||
chat_id=chat_id,
|
||||
type=type_str,
|
||||
create_date=last_active_time, # 迁移时使用last_active_time作为创建时间
|
||||
)
|
||||
await session.add(new_expression)
|
||||
|
||||
migrated_count += 1
|
||||
migrated_count += 1
|
||||
logger.info(f"已迁移 {expr_file} 到数据库,包含 {len(expressions)} 个表达方式")
|
||||
except orjson.JSONDecodeError as e:
|
||||
logger.error(f"JSON解析失败 {expr_file}: {e}")
|
||||
@@ -628,15 +635,17 @@ class ExpressionLearnerManager:
|
||||
except Exception as e:
|
||||
logger.error(f"写入done.done标记文件失败: {e}")
|
||||
|
||||
def _migrate_old_data_create_date(self):
|
||||
@staticmethod
|
||||
async def _migrate_old_data_create_date():
|
||||
"""
|
||||
为没有create_date的老数据设置创建日期
|
||||
使用last_active_time作为create_date的默认值
|
||||
"""
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
async with get_db_session() as session:
|
||||
# 查找所有create_date为空的表达方式
|
||||
old_expressions = session.execute(select(Expression).where(Expression.create_date.is_(None))).scalars()
|
||||
old_expressions_result = await session.execute(select(Expression).where(Expression.create_date.is_(None)))
|
||||
old_expressions = old_expressions_result.scalars().all()
|
||||
updated_count = 0
|
||||
|
||||
for expr in old_expressions:
|
||||
@@ -646,7 +655,6 @@ class ExpressionLearnerManager:
|
||||
|
||||
if updated_count > 0:
|
||||
logger.info(f"已为 {updated_count} 个老的表达方式设置创建日期")
|
||||
session.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"迁移老数据创建日期失败: {e}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user