feat:更好的配置文件更新,表达方式迁移到数据库
This commit is contained in:
@@ -2,6 +2,7 @@ import time
|
||||
import random
|
||||
import json
|
||||
import os
|
||||
import glob
|
||||
|
||||
from typing import List, Dict, Optional, Any, Tuple
|
||||
|
||||
@@ -11,6 +12,7 @@ from src.config.config import global_config
|
||||
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_random, build_anonymous_messages
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.common.database.database_model import Expression
|
||||
|
||||
|
||||
MAX_EXPRESSION_COUNT = 300
|
||||
@@ -75,9 +77,69 @@ class ExpressionLearner:
|
||||
request_type="expressor.learner",
|
||||
)
|
||||
self.llm_model = None
|
||||
self._auto_migrate_json_to_db()
|
||||
|
||||
def _auto_migrate_json_to_db(self):
|
||||
"""
|
||||
自动将/data/expression/learnt_style 和 learnt_grammar 下所有expressions.json迁移到数据库。
|
||||
迁移完成后在/data/expression/done.done写入标记文件,存在则跳过。
|
||||
"""
|
||||
done_flag = os.path.join("data", "expression", "done.done")
|
||||
if os.path.exists(done_flag):
|
||||
logger.info("表达方式JSON已迁移,无需重复迁移。")
|
||||
return
|
||||
base_dir = os.path.join("data", "expression")
|
||||
for type in ["learnt_style", "learnt_grammar"]:
|
||||
type_str = "style" if type == "learnt_style" else "grammar"
|
||||
type_dir = os.path.join(base_dir, type)
|
||||
if not os.path.exists(type_dir):
|
||||
continue
|
||||
for chat_id in os.listdir(type_dir):
|
||||
expr_file = os.path.join(type_dir, chat_id, "expressions.json")
|
||||
if not os.path.exists(expr_file):
|
||||
continue
|
||||
try:
|
||||
with open(expr_file, "r", encoding="utf-8") as f:
|
||||
expressions = json.load(f)
|
||||
for expr in expressions:
|
||||
situation = expr.get("situation")
|
||||
style_val = expr.get("style")
|
||||
count = expr.get("count", 1)
|
||||
last_active_time = expr.get("last_active_time", time.time())
|
||||
# 查重:同chat_id+type+situation+style
|
||||
from src.common.database.database_model import Expression
|
||||
query = Expression.select().where(
|
||||
(Expression.chat_id == chat_id) &
|
||||
(Expression.type == type_str) &
|
||||
(Expression.situation == situation) &
|
||||
(Expression.style == style_val)
|
||||
)
|
||||
if query.exists():
|
||||
expr_obj = query.get()
|
||||
expr_obj.count = max(expr_obj.count, count)
|
||||
expr_obj.last_active_time = max(expr_obj.last_active_time, last_active_time)
|
||||
expr_obj.save()
|
||||
else:
|
||||
Expression.create(
|
||||
situation=situation,
|
||||
style=style_val,
|
||||
count=count,
|
||||
last_active_time=last_active_time,
|
||||
chat_id=chat_id,
|
||||
type=type_str
|
||||
)
|
||||
logger.info(f"已迁移 {expr_file} 到数据库")
|
||||
except Exception as e:
|
||||
logger.error(f"迁移表达方式 {expr_file} 失败: {e}")
|
||||
# 标记迁移完成
|
||||
try:
|
||||
with open(done_flag, "w", encoding="utf-8") as f:
|
||||
f.write("done\n")
|
||||
logger.info("表达方式JSON迁移已完成,已写入done.done标记文件")
|
||||
except Exception as e:
|
||||
logger.error(f"写入done.done标记文件失败: {e}")
|
||||
|
||||
def get_expression_by_chat_id(self, chat_id: str) -> Tuple[List[Dict[str, float]], List[Dict[str, float]]]:
|
||||
# sourcery skip: extract-duplicate-method, remove-unnecessary-cast
|
||||
"""
|
||||
获取指定chat_id的style和grammar表达方式
|
||||
返回的每个表达方式字典中都包含了source_id, 用于后续的更新操作
|
||||
@@ -85,32 +147,27 @@ class ExpressionLearner:
|
||||
learnt_style_expressions = []
|
||||
learnt_grammar_expressions = []
|
||||
|
||||
# 获取style表达方式
|
||||
style_dir = os.path.join("data", "expression", "learnt_style", str(chat_id))
|
||||
style_file = os.path.join(style_dir, "expressions.json")
|
||||
if os.path.exists(style_file):
|
||||
try:
|
||||
with open(style_file, "r", encoding="utf-8") as f:
|
||||
expressions = json.load(f)
|
||||
for expr in expressions:
|
||||
expr["source_id"] = chat_id # 添加来源ID
|
||||
learnt_style_expressions.append(expr)
|
||||
except Exception as e:
|
||||
logger.error(f"读取style表达方式失败: {e}")
|
||||
|
||||
# 获取grammar表达方式
|
||||
grammar_dir = os.path.join("data", "expression", "learnt_grammar", str(chat_id))
|
||||
grammar_file = os.path.join(grammar_dir, "expressions.json")
|
||||
if os.path.exists(grammar_file):
|
||||
try:
|
||||
with open(grammar_file, "r", encoding="utf-8") as f:
|
||||
expressions = json.load(f)
|
||||
for expr in expressions:
|
||||
expr["source_id"] = chat_id # 添加来源ID
|
||||
learnt_grammar_expressions.append(expr)
|
||||
except Exception as e:
|
||||
logger.error(f"读取grammar表达方式失败: {e}")
|
||||
|
||||
# 直接从数据库查询
|
||||
style_query = Expression.select().where((Expression.chat_id == chat_id) & (Expression.type == "style"))
|
||||
for expr in style_query:
|
||||
learnt_style_expressions.append({
|
||||
"situation": expr.situation,
|
||||
"style": expr.style,
|
||||
"count": expr.count,
|
||||
"last_active_time": expr.last_active_time,
|
||||
"source_id": chat_id,
|
||||
"type": "style"
|
||||
})
|
||||
grammar_query = Expression.select().where((Expression.chat_id == chat_id) & (Expression.type == "grammar"))
|
||||
for expr in grammar_query:
|
||||
learnt_grammar_expressions.append({
|
||||
"situation": expr.situation,
|
||||
"style": expr.style,
|
||||
"count": expr.count,
|
||||
"last_active_time": expr.last_active_time,
|
||||
"source_id": chat_id,
|
||||
"type": "grammar"
|
||||
})
|
||||
return learnt_style_expressions, learnt_grammar_expressions
|
||||
|
||||
def is_similar(self, s1: str, s2: str) -> bool:
|
||||
@@ -237,7 +294,6 @@ class ExpressionLearner:
|
||||
|
||||
chat_stream = get_chat_manager().get_stream(chat_id)
|
||||
if chat_stream is None:
|
||||
# 如果聊天流不在内存中,使用chat_id作为默认名称
|
||||
group_name = f"聊天流 {chat_id}"
|
||||
elif chat_stream.group_info:
|
||||
group_name = chat_stream.group_info.group_name
|
||||
@@ -261,80 +317,40 @@ class ExpressionLearner:
|
||||
|
||||
current_time = time.time()
|
||||
|
||||
# 存储到/data/expression/对应chat_id/expressions.json
|
||||
# 存储到数据库 Expression 表
|
||||
for chat_id, expr_list in chat_dict.items():
|
||||
dir_path = os.path.join("data", "expression", f"learnt_{type}", str(chat_id))
|
||||
os.makedirs(dir_path, exist_ok=True)
|
||||
file_path = os.path.join(dir_path, "expressions.json")
|
||||
|
||||
# 若已存在,先读出合并
|
||||
old_data: List[Dict[str, Any]] = []
|
||||
if os.path.exists(file_path):
|
||||
try:
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
old_data = json.load(f)
|
||||
except Exception:
|
||||
old_data = []
|
||||
|
||||
# 应用衰减
|
||||
# old_data = self.apply_decay_to_expressions(old_data, current_time)
|
||||
|
||||
# 合并逻辑
|
||||
for new_expr in expr_list:
|
||||
found = False
|
||||
for old_expr in old_data:
|
||||
if self.is_similar(new_expr["situation"], old_expr.get("situation", "")) and self.is_similar(
|
||||
new_expr["style"], old_expr.get("style", "")
|
||||
):
|
||||
found = True
|
||||
# 50%概率替换
|
||||
if random.random() < 0.5:
|
||||
old_expr["situation"] = new_expr["situation"]
|
||||
old_expr["style"] = new_expr["style"]
|
||||
old_expr["count"] = old_expr.get("count", 1) + 1
|
||||
old_expr["last_active_time"] = current_time
|
||||
break
|
||||
if not found:
|
||||
new_expr["count"] = 1
|
||||
new_expr["last_active_time"] = current_time
|
||||
old_data.append(new_expr)
|
||||
|
||||
# 处理超限问题
|
||||
if len(old_data) > MAX_EXPRESSION_COUNT:
|
||||
# 计算每个表达方式的权重(count的倒数,这样count越小的越容易被选中)
|
||||
weights = [1 / (expr.get("count", 1) + 0.1) for expr in old_data]
|
||||
|
||||
# 随机选择要移除的表达方式,避免重复索引
|
||||
remove_count = len(old_data) - MAX_EXPRESSION_COUNT
|
||||
|
||||
# 使用一种不会选到重复索引的方法
|
||||
indices = list(range(len(old_data)))
|
||||
|
||||
# 方法1:使用numpy.random.choice
|
||||
# 把列表转成一个映射字典,保证不会有重复
|
||||
remove_set = set()
|
||||
total_attempts = 0
|
||||
|
||||
# 尝试按权重随机选择,直到选够数量
|
||||
while len(remove_set) < remove_count and total_attempts < len(old_data) * 2:
|
||||
idx = random.choices(indices, weights=weights, k=1)[0]
|
||||
remove_set.add(idx)
|
||||
total_attempts += 1
|
||||
|
||||
# 如果没选够,随机补充
|
||||
if len(remove_set) < remove_count:
|
||||
remaining = set(indices) - remove_set
|
||||
remove_set.update(random.sample(list(remaining), remove_count - len(remove_set)))
|
||||
|
||||
remove_indices = list(remove_set)
|
||||
|
||||
# 从后往前删除,避免索引变化
|
||||
for idx in sorted(remove_indices, reverse=True):
|
||||
old_data.pop(idx)
|
||||
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
json.dump(old_data, f, ensure_ascii=False, indent=2)
|
||||
|
||||
# 查找是否已存在相似表达方式
|
||||
query = Expression.select().where(
|
||||
(Expression.chat_id == chat_id) &
|
||||
(Expression.type == type) &
|
||||
(Expression.situation == new_expr["situation"]) &
|
||||
(Expression.style == new_expr["style"])
|
||||
)
|
||||
if query.exists():
|
||||
expr_obj = query.get()
|
||||
# 50%概率替换内容
|
||||
if random.random() < 0.5:
|
||||
expr_obj.situation = new_expr["situation"]
|
||||
expr_obj.style = new_expr["style"]
|
||||
expr_obj.count = expr_obj.count + 1
|
||||
expr_obj.last_active_time = current_time
|
||||
expr_obj.save()
|
||||
else:
|
||||
Expression.create(
|
||||
situation=new_expr["situation"],
|
||||
style=new_expr["style"],
|
||||
count=1,
|
||||
last_active_time=current_time,
|
||||
chat_id=chat_id,
|
||||
type=type
|
||||
)
|
||||
# 限制最大数量
|
||||
exprs = list(Expression.select().where((Expression.chat_id == chat_id) & (Expression.type == type)).order_by(Expression.count.asc()))
|
||||
if len(exprs) > MAX_EXPRESSION_COUNT:
|
||||
# 删除count最小的多余表达方式
|
||||
for expr in exprs[:len(exprs) - MAX_EXPRESSION_COUNT]:
|
||||
expr.delete_instance()
|
||||
return learnt_expressions
|
||||
|
||||
async def learn_expression(self, type: str, num: int = 10) -> Optional[Tuple[List[Tuple[str, str, str]], str]]:
|
||||
|
||||
Reference in New Issue
Block a user