修复代码格式和文件名大小写问题
This commit is contained in:
@@ -23,6 +23,7 @@ DECAY_MIN = 0.01 # 最小衰减值
|
||||
|
||||
logger = get_logger("expressor")
|
||||
|
||||
|
||||
def format_create_date(timestamp: float) -> str:
|
||||
"""
|
||||
将时间戳格式化为可读的日期字符串
|
||||
@@ -87,24 +88,20 @@ class ExpressionLearner:
|
||||
self.chat_id = chat_id
|
||||
self.chat_name = get_chat_manager().get_stream_name(chat_id) or chat_id
|
||||
|
||||
|
||||
# 维护每个chat的上次学习时间
|
||||
self.last_learning_time: float = time.time()
|
||||
|
||||
|
||||
# 学习参数
|
||||
self.min_messages_for_learning = 25 # 触发学习所需的最少消息数
|
||||
self.min_learning_interval = 300 # 最短学习时间间隔(秒)
|
||||
|
||||
|
||||
|
||||
|
||||
def can_learn_for_chat(self) -> bool:
|
||||
"""
|
||||
检查指定聊天流是否允许学习表达
|
||||
|
||||
|
||||
Args:
|
||||
chat_id: 聊天流ID
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 是否允许学习
|
||||
"""
|
||||
@@ -118,10 +115,10 @@ class ExpressionLearner:
|
||||
def should_trigger_learning(self) -> bool:
|
||||
"""
|
||||
检查是否应该触发学习
|
||||
|
||||
|
||||
Args:
|
||||
chat_id: 聊天流ID
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 是否应该触发学习
|
||||
"""
|
||||
@@ -129,67 +126,69 @@ class ExpressionLearner:
|
||||
|
||||
# 获取该聊天流的学习强度
|
||||
try:
|
||||
use_expression, enable_learning, learning_intensity = global_config.expression.get_expression_config_for_chat(self.chat_id)
|
||||
use_expression, enable_learning, learning_intensity = (
|
||||
global_config.expression.get_expression_config_for_chat(self.chat_id)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"获取聊天流 {self.chat_id} 的学习配置失败: {e}")
|
||||
return False
|
||||
|
||||
|
||||
# 检查是否允许学习
|
||||
if not enable_learning:
|
||||
return False
|
||||
|
||||
|
||||
# 根据学习强度计算最短学习时间间隔
|
||||
min_interval = self.min_learning_interval / learning_intensity
|
||||
|
||||
|
||||
# 检查时间间隔
|
||||
time_diff = current_time - self.last_learning_time
|
||||
if time_diff < min_interval:
|
||||
return False
|
||||
|
||||
|
||||
# 检查消息数量(只检查指定聊天流的消息)
|
||||
recent_messages = get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||
chat_id=self.chat_id,
|
||||
timestamp_start=self.last_learning_time,
|
||||
timestamp_end=time.time(),
|
||||
)
|
||||
|
||||
|
||||
if not recent_messages or len(recent_messages) < self.min_messages_for_learning:
|
||||
return False
|
||||
|
||||
|
||||
return True
|
||||
|
||||
async def trigger_learning_for_chat(self) -> bool:
|
||||
"""
|
||||
为指定聊天流触发学习
|
||||
|
||||
|
||||
Args:
|
||||
chat_id: 聊天流ID
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 是否成功触发学习
|
||||
"""
|
||||
if not self.should_trigger_learning():
|
||||
return False
|
||||
|
||||
|
||||
try:
|
||||
logger.info(f"为聊天流 {self.chat_name} 触发表达学习")
|
||||
|
||||
|
||||
# 学习语言风格
|
||||
learnt_style = await self.learn_and_store(type="style", num=25)
|
||||
|
||||
|
||||
# 学习句法特点
|
||||
learnt_grammar = await self.learn_and_store(type="grammar", num=10)
|
||||
|
||||
|
||||
# 更新学习时间
|
||||
self.last_learning_time = time.time()
|
||||
|
||||
|
||||
if learnt_style or learnt_grammar:
|
||||
logger.info(f"聊天流 {self.chat_name} 表达学习完成")
|
||||
return True
|
||||
else:
|
||||
logger.warning(f"聊天流 {self.chat_name} 表达学习未获得有效结果")
|
||||
return False
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"为聊天流 {self.chat_name} 触发学习失败: {e}")
|
||||
return False
|
||||
@@ -204,7 +203,9 @@ class ExpressionLearner:
|
||||
|
||||
# 直接从数据库查询
|
||||
with get_db_session() as session:
|
||||
style_query = session.execute(select(Expression).where((Expression.chat_id == self.chat_id) & (Expression.type == "style")))
|
||||
style_query = session.execute(
|
||||
select(Expression).where((Expression.chat_id == self.chat_id) & (Expression.type == "style"))
|
||||
)
|
||||
for expr in style_query.scalars():
|
||||
# 确保create_date存在,如果不存在则使用last_active_time
|
||||
create_date = expr.create_date if expr.create_date is not None else expr.last_active_time
|
||||
@@ -219,7 +220,9 @@ class ExpressionLearner:
|
||||
"create_date": create_date,
|
||||
}
|
||||
)
|
||||
grammar_query = session.execute(select(Expression).where((Expression.chat_id == self.chat_id) & (Expression.type == "grammar")))
|
||||
grammar_query = session.execute(
|
||||
select(Expression).where((Expression.chat_id == self.chat_id) & (Expression.type == "grammar"))
|
||||
)
|
||||
for expr in grammar_query.scalars():
|
||||
# 确保create_date存在,如果不存在则使用last_active_time
|
||||
create_date = expr.create_date if expr.create_date is not None else expr.last_active_time
|
||||
@@ -236,12 +239,6 @@ class ExpressionLearner:
|
||||
)
|
||||
return learnt_style_expressions, learnt_grammar_expressions
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def _apply_global_decay_to_database(self, current_time: float) -> None:
|
||||
"""
|
||||
对数据库中的所有表达方式应用全局衰减
|
||||
@@ -273,8 +270,6 @@ class ExpressionLearner:
|
||||
expr.count = new_count
|
||||
updated_count += 1
|
||||
|
||||
|
||||
|
||||
if updated_count > 0 or deleted_count > 0:
|
||||
logger.info(f"全局衰减完成:更新了 {updated_count} 个表达方式,删除了 {deleted_count} 个表达方式")
|
||||
|
||||
@@ -357,15 +352,17 @@ class ExpressionLearner:
|
||||
for new_expr in expr_list:
|
||||
# 查找是否已存在相似表达方式
|
||||
with get_db_session() as session:
|
||||
query = session.execute(select(Expression).where(
|
||||
(Expression.chat_id == chat_id)
|
||||
& (Expression.type == type)
|
||||
& (Expression.situation == new_expr["situation"])
|
||||
& (Expression.style == new_expr["style"])
|
||||
)).scalar()
|
||||
query = session.execute(
|
||||
select(Expression).where(
|
||||
(Expression.chat_id == chat_id)
|
||||
& (Expression.type == type)
|
||||
& (Expression.situation == new_expr["situation"])
|
||||
& (Expression.style == new_expr["style"])
|
||||
)
|
||||
).scalar()
|
||||
if query:
|
||||
expr_obj = query
|
||||
# 50%概率替换内容
|
||||
# 50%概率替换内容
|
||||
if random.random() < 0.5:
|
||||
expr_obj.situation = new_expr["situation"]
|
||||
expr_obj.style = new_expr["style"]
|
||||
@@ -385,16 +382,18 @@ class ExpressionLearner:
|
||||
session.commit()
|
||||
# 限制最大数量
|
||||
exprs = list(
|
||||
session.execute(select(Expression)
|
||||
.where((Expression.chat_id == chat_id) & (Expression.type == type))
|
||||
.order_by(Expression.count.asc())).scalars()
|
||||
session.execute(
|
||||
select(Expression)
|
||||
.where((Expression.chat_id == chat_id) & (Expression.type == type))
|
||||
.order_by(Expression.count.asc())
|
||||
).scalars()
|
||||
)
|
||||
if len(exprs) > MAX_EXPRESSION_COUNT:
|
||||
# 删除count最小的多余表达方式
|
||||
for expr in exprs[: len(exprs) - MAX_EXPRESSION_COUNT]:
|
||||
session.delete(expr)
|
||||
session.commit()
|
||||
|
||||
|
||||
return learnt_expressions
|
||||
|
||||
async def learn_expression(self, type: str, num: int = 10) -> Optional[Tuple[List[Tuple[str, str, str]], str]]:
|
||||
@@ -413,7 +412,7 @@ class ExpressionLearner:
|
||||
raise ValueError(f"Invalid type: {type}")
|
||||
|
||||
current_time = time.time()
|
||||
|
||||
|
||||
# 获取上次学习时间
|
||||
random_msg: Optional[List[Dict[str, Any]]] = get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||
chat_id=self.chat_id,
|
||||
@@ -421,7 +420,7 @@ class ExpressionLearner:
|
||||
timestamp_end=current_time,
|
||||
limit=num,
|
||||
)
|
||||
|
||||
|
||||
# print(random_msg)
|
||||
if not random_msg or random_msg == []:
|
||||
return None
|
||||
@@ -483,19 +482,20 @@ class ExpressionLearner:
|
||||
|
||||
init_prompt()
|
||||
|
||||
|
||||
class ExpressionLearnerManager:
|
||||
def __init__(self):
|
||||
self.expression_learners = {}
|
||||
|
||||
|
||||
self._ensure_expression_directories()
|
||||
self._auto_migrate_json_to_db()
|
||||
self._migrate_old_data_create_date()
|
||||
|
||||
|
||||
def get_expression_learner(self, chat_id: str) -> ExpressionLearner:
|
||||
if chat_id not in self.expression_learners:
|
||||
self.expression_learners[chat_id] = ExpressionLearner(chat_id)
|
||||
return self.expression_learners[chat_id]
|
||||
|
||||
|
||||
def _ensure_expression_directories(self):
|
||||
"""
|
||||
确保表达方式相关的目录结构存在
|
||||
@@ -514,7 +514,6 @@ class ExpressionLearnerManager:
|
||||
except Exception as e:
|
||||
logger.error(f"创建目录失败 {directory}: {e}")
|
||||
|
||||
|
||||
def _auto_migrate_json_to_db(self):
|
||||
"""
|
||||
自动将/data/expression/learnt_style 和 learnt_grammar 下所有expressions.json迁移到数据库。
|
||||
@@ -579,12 +578,14 @@ class ExpressionLearnerManager:
|
||||
|
||||
# 查重:同chat_id+type+situation+style
|
||||
with get_db_session() as session:
|
||||
query = session.execute(select(Expression).where(
|
||||
(Expression.chat_id == chat_id)
|
||||
& (Expression.type == type_str)
|
||||
& (Expression.situation == situation)
|
||||
& (Expression.style == style_val)
|
||||
)).scalar()
|
||||
query = session.execute(
|
||||
select(Expression).where(
|
||||
(Expression.chat_id == chat_id)
|
||||
& (Expression.type == type_str)
|
||||
& (Expression.situation == situation)
|
||||
& (Expression.style == style_val)
|
||||
)
|
||||
).scalar()
|
||||
if query:
|
||||
expr_obj = query
|
||||
expr_obj.count = max(expr_obj.count, count)
|
||||
@@ -601,7 +602,7 @@ class ExpressionLearnerManager:
|
||||
)
|
||||
session.add(new_expression)
|
||||
session.commit()
|
||||
|
||||
|
||||
migrated_count += 1
|
||||
logger.info(f"已迁移 {expr_file} 到数据库,包含 {len(expressions)} 个表达方式")
|
||||
except orjson.JSONDecodeError as e:
|
||||
@@ -643,8 +644,6 @@ class ExpressionLearnerManager:
|
||||
expr.create_date = expr.last_active_time
|
||||
updated_count += 1
|
||||
|
||||
|
||||
|
||||
if updated_count > 0:
|
||||
logger.info(f"已为 {updated_count} 个老的表达方式设置创建日期")
|
||||
session.commit()
|
||||
|
||||
Reference in New Issue
Block a user