feat:更好的配置文件更新,表达方式迁移到数据库
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -40,6 +40,7 @@ config/bot_config.toml
|
||||
config/bot_config.toml.bak
|
||||
config/lpmm_config.toml
|
||||
config/lpmm_config.toml.bak
|
||||
template/compare/bot_config_template.toml
|
||||
(测试版)麦麦生成人格.bat
|
||||
(临时版)麦麦开始学习.bat
|
||||
src/plugins/utils/statistic.py
|
||||
|
||||
@@ -2,6 +2,7 @@ import time
|
||||
import random
|
||||
import json
|
||||
import os
|
||||
import glob
|
||||
|
||||
from typing import List, Dict, Optional, Any, Tuple
|
||||
|
||||
@@ -11,6 +12,7 @@ from src.config.config import global_config
|
||||
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_random, build_anonymous_messages
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.common.database.database_model import Expression
|
||||
|
||||
|
||||
MAX_EXPRESSION_COUNT = 300
|
||||
@@ -75,9 +77,69 @@ class ExpressionLearner:
|
||||
request_type="expressor.learner",
|
||||
)
|
||||
self.llm_model = None
|
||||
self._auto_migrate_json_to_db()
|
||||
|
||||
def _auto_migrate_json_to_db(self):
|
||||
"""
|
||||
自动将/data/expression/learnt_style 和 learnt_grammar 下所有expressions.json迁移到数据库。
|
||||
迁移完成后在/data/expression/done.done写入标记文件,存在则跳过。
|
||||
"""
|
||||
done_flag = os.path.join("data", "expression", "done.done")
|
||||
if os.path.exists(done_flag):
|
||||
logger.info("表达方式JSON已迁移,无需重复迁移。")
|
||||
return
|
||||
base_dir = os.path.join("data", "expression")
|
||||
for type in ["learnt_style", "learnt_grammar"]:
|
||||
type_str = "style" if type == "learnt_style" else "grammar"
|
||||
type_dir = os.path.join(base_dir, type)
|
||||
if not os.path.exists(type_dir):
|
||||
continue
|
||||
for chat_id in os.listdir(type_dir):
|
||||
expr_file = os.path.join(type_dir, chat_id, "expressions.json")
|
||||
if not os.path.exists(expr_file):
|
||||
continue
|
||||
try:
|
||||
with open(expr_file, "r", encoding="utf-8") as f:
|
||||
expressions = json.load(f)
|
||||
for expr in expressions:
|
||||
situation = expr.get("situation")
|
||||
style_val = expr.get("style")
|
||||
count = expr.get("count", 1)
|
||||
last_active_time = expr.get("last_active_time", time.time())
|
||||
# 查重:同chat_id+type+situation+style
|
||||
from src.common.database.database_model import Expression
|
||||
query = Expression.select().where(
|
||||
(Expression.chat_id == chat_id) &
|
||||
(Expression.type == type_str) &
|
||||
(Expression.situation == situation) &
|
||||
(Expression.style == style_val)
|
||||
)
|
||||
if query.exists():
|
||||
expr_obj = query.get()
|
||||
expr_obj.count = max(expr_obj.count, count)
|
||||
expr_obj.last_active_time = max(expr_obj.last_active_time, last_active_time)
|
||||
expr_obj.save()
|
||||
else:
|
||||
Expression.create(
|
||||
situation=situation,
|
||||
style=style_val,
|
||||
count=count,
|
||||
last_active_time=last_active_time,
|
||||
chat_id=chat_id,
|
||||
type=type_str
|
||||
)
|
||||
logger.info(f"已迁移 {expr_file} 到数据库")
|
||||
except Exception as e:
|
||||
logger.error(f"迁移表达方式 {expr_file} 失败: {e}")
|
||||
# 标记迁移完成
|
||||
try:
|
||||
with open(done_flag, "w", encoding="utf-8") as f:
|
||||
f.write("done\n")
|
||||
logger.info("表达方式JSON迁移已完成,已写入done.done标记文件")
|
||||
except Exception as e:
|
||||
logger.error(f"写入done.done标记文件失败: {e}")
|
||||
|
||||
def get_expression_by_chat_id(self, chat_id: str) -> Tuple[List[Dict[str, float]], List[Dict[str, float]]]:
|
||||
# sourcery skip: extract-duplicate-method, remove-unnecessary-cast
|
||||
"""
|
||||
获取指定chat_id的style和grammar表达方式
|
||||
返回的每个表达方式字典中都包含了source_id, 用于后续的更新操作
|
||||
@@ -85,32 +147,27 @@ class ExpressionLearner:
|
||||
learnt_style_expressions = []
|
||||
learnt_grammar_expressions = []
|
||||
|
||||
# 获取style表达方式
|
||||
style_dir = os.path.join("data", "expression", "learnt_style", str(chat_id))
|
||||
style_file = os.path.join(style_dir, "expressions.json")
|
||||
if os.path.exists(style_file):
|
||||
try:
|
||||
with open(style_file, "r", encoding="utf-8") as f:
|
||||
expressions = json.load(f)
|
||||
for expr in expressions:
|
||||
expr["source_id"] = chat_id # 添加来源ID
|
||||
learnt_style_expressions.append(expr)
|
||||
except Exception as e:
|
||||
logger.error(f"读取style表达方式失败: {e}")
|
||||
|
||||
# 获取grammar表达方式
|
||||
grammar_dir = os.path.join("data", "expression", "learnt_grammar", str(chat_id))
|
||||
grammar_file = os.path.join(grammar_dir, "expressions.json")
|
||||
if os.path.exists(grammar_file):
|
||||
try:
|
||||
with open(grammar_file, "r", encoding="utf-8") as f:
|
||||
expressions = json.load(f)
|
||||
for expr in expressions:
|
||||
expr["source_id"] = chat_id # 添加来源ID
|
||||
learnt_grammar_expressions.append(expr)
|
||||
except Exception as e:
|
||||
logger.error(f"读取grammar表达方式失败: {e}")
|
||||
|
||||
# 直接从数据库查询
|
||||
style_query = Expression.select().where((Expression.chat_id == chat_id) & (Expression.type == "style"))
|
||||
for expr in style_query:
|
||||
learnt_style_expressions.append({
|
||||
"situation": expr.situation,
|
||||
"style": expr.style,
|
||||
"count": expr.count,
|
||||
"last_active_time": expr.last_active_time,
|
||||
"source_id": chat_id,
|
||||
"type": "style"
|
||||
})
|
||||
grammar_query = Expression.select().where((Expression.chat_id == chat_id) & (Expression.type == "grammar"))
|
||||
for expr in grammar_query:
|
||||
learnt_grammar_expressions.append({
|
||||
"situation": expr.situation,
|
||||
"style": expr.style,
|
||||
"count": expr.count,
|
||||
"last_active_time": expr.last_active_time,
|
||||
"source_id": chat_id,
|
||||
"type": "grammar"
|
||||
})
|
||||
return learnt_style_expressions, learnt_grammar_expressions
|
||||
|
||||
def is_similar(self, s1: str, s2: str) -> bool:
|
||||
@@ -237,7 +294,6 @@ class ExpressionLearner:
|
||||
|
||||
chat_stream = get_chat_manager().get_stream(chat_id)
|
||||
if chat_stream is None:
|
||||
# 如果聊天流不在内存中,使用chat_id作为默认名称
|
||||
group_name = f"聊天流 {chat_id}"
|
||||
elif chat_stream.group_info:
|
||||
group_name = chat_stream.group_info.group_name
|
||||
@@ -261,80 +317,40 @@ class ExpressionLearner:
|
||||
|
||||
current_time = time.time()
|
||||
|
||||
# 存储到/data/expression/对应chat_id/expressions.json
|
||||
# 存储到数据库 Expression 表
|
||||
for chat_id, expr_list in chat_dict.items():
|
||||
dir_path = os.path.join("data", "expression", f"learnt_{type}", str(chat_id))
|
||||
os.makedirs(dir_path, exist_ok=True)
|
||||
file_path = os.path.join(dir_path, "expressions.json")
|
||||
|
||||
# 若已存在,先读出合并
|
||||
old_data: List[Dict[str, Any]] = []
|
||||
if os.path.exists(file_path):
|
||||
try:
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
old_data = json.load(f)
|
||||
except Exception:
|
||||
old_data = []
|
||||
|
||||
# 应用衰减
|
||||
# old_data = self.apply_decay_to_expressions(old_data, current_time)
|
||||
|
||||
# 合并逻辑
|
||||
for new_expr in expr_list:
|
||||
found = False
|
||||
for old_expr in old_data:
|
||||
if self.is_similar(new_expr["situation"], old_expr.get("situation", "")) and self.is_similar(
|
||||
new_expr["style"], old_expr.get("style", "")
|
||||
):
|
||||
found = True
|
||||
# 50%概率替换
|
||||
if random.random() < 0.5:
|
||||
old_expr["situation"] = new_expr["situation"]
|
||||
old_expr["style"] = new_expr["style"]
|
||||
old_expr["count"] = old_expr.get("count", 1) + 1
|
||||
old_expr["last_active_time"] = current_time
|
||||
break
|
||||
if not found:
|
||||
new_expr["count"] = 1
|
||||
new_expr["last_active_time"] = current_time
|
||||
old_data.append(new_expr)
|
||||
|
||||
# 处理超限问题
|
||||
if len(old_data) > MAX_EXPRESSION_COUNT:
|
||||
# 计算每个表达方式的权重(count的倒数,这样count越小的越容易被选中)
|
||||
weights = [1 / (expr.get("count", 1) + 0.1) for expr in old_data]
|
||||
|
||||
# 随机选择要移除的表达方式,避免重复索引
|
||||
remove_count = len(old_data) - MAX_EXPRESSION_COUNT
|
||||
|
||||
# 使用一种不会选到重复索引的方法
|
||||
indices = list(range(len(old_data)))
|
||||
|
||||
# 方法1:使用numpy.random.choice
|
||||
# 把列表转成一个映射字典,保证不会有重复
|
||||
remove_set = set()
|
||||
total_attempts = 0
|
||||
|
||||
# 尝试按权重随机选择,直到选够数量
|
||||
while len(remove_set) < remove_count and total_attempts < len(old_data) * 2:
|
||||
idx = random.choices(indices, weights=weights, k=1)[0]
|
||||
remove_set.add(idx)
|
||||
total_attempts += 1
|
||||
|
||||
# 如果没选够,随机补充
|
||||
if len(remove_set) < remove_count:
|
||||
remaining = set(indices) - remove_set
|
||||
remove_set.update(random.sample(list(remaining), remove_count - len(remove_set)))
|
||||
|
||||
remove_indices = list(remove_set)
|
||||
|
||||
# 从后往前删除,避免索引变化
|
||||
for idx in sorted(remove_indices, reverse=True):
|
||||
old_data.pop(idx)
|
||||
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
json.dump(old_data, f, ensure_ascii=False, indent=2)
|
||||
|
||||
# 查找是否已存在相似表达方式
|
||||
query = Expression.select().where(
|
||||
(Expression.chat_id == chat_id) &
|
||||
(Expression.type == type) &
|
||||
(Expression.situation == new_expr["situation"]) &
|
||||
(Expression.style == new_expr["style"])
|
||||
)
|
||||
if query.exists():
|
||||
expr_obj = query.get()
|
||||
# 50%概率替换内容
|
||||
if random.random() < 0.5:
|
||||
expr_obj.situation = new_expr["situation"]
|
||||
expr_obj.style = new_expr["style"]
|
||||
expr_obj.count = expr_obj.count + 1
|
||||
expr_obj.last_active_time = current_time
|
||||
expr_obj.save()
|
||||
else:
|
||||
Expression.create(
|
||||
situation=new_expr["situation"],
|
||||
style=new_expr["style"],
|
||||
count=1,
|
||||
last_active_time=current_time,
|
||||
chat_id=chat_id,
|
||||
type=type
|
||||
)
|
||||
# 限制最大数量
|
||||
exprs = list(Expression.select().where((Expression.chat_id == chat_id) & (Expression.type == type)).order_by(Expression.count.asc()))
|
||||
if len(exprs) > MAX_EXPRESSION_COUNT:
|
||||
# 删除count最小的多余表达方式
|
||||
for expr in exprs[:len(exprs) - MAX_EXPRESSION_COUNT]:
|
||||
expr.delete_instance()
|
||||
return learnt_expressions
|
||||
|
||||
async def learn_expression(self, type: str, num: int = 10) -> Optional[Tuple[List[Tuple[str, str, str]], str]]:
|
||||
|
||||
@@ -11,6 +11,7 @@ from src.config.config import global_config
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from .expression_learner import get_expression_learner
|
||||
from src.common.database.database_model import Expression
|
||||
|
||||
logger = get_logger("expression_selector")
|
||||
|
||||
@@ -84,88 +85,77 @@ class ExpressionSelector:
|
||||
def get_random_expressions(
|
||||
self, chat_id: str, total_num: int, style_percentage: float, grammar_percentage: float
|
||||
) -> Tuple[List[Dict[str, str]], List[Dict[str, str]]]:
|
||||
# sourcery skip: extract-duplicate-method, move-assign
|
||||
(
|
||||
learnt_style_expressions,
|
||||
learnt_grammar_expressions,
|
||||
) = self.expression_learner.get_expression_by_chat_id(chat_id)
|
||||
|
||||
# 直接数据库查询
|
||||
style_query = Expression.select().where((Expression.chat_id == chat_id) & (Expression.type == "style"))
|
||||
grammar_query = Expression.select().where((Expression.chat_id == chat_id) & (Expression.type == "grammar"))
|
||||
style_exprs = [
|
||||
{
|
||||
"situation": expr.situation,
|
||||
"style": expr.style,
|
||||
"count": expr.count,
|
||||
"last_active_time": expr.last_active_time,
|
||||
"source_id": chat_id,
|
||||
"type": "style"
|
||||
} for expr in style_query
|
||||
]
|
||||
grammar_exprs = [
|
||||
{
|
||||
"situation": expr.situation,
|
||||
"style": expr.style,
|
||||
"count": expr.count,
|
||||
"last_active_time": expr.last_active_time,
|
||||
"source_id": chat_id,
|
||||
"type": "grammar"
|
||||
} for expr in grammar_query
|
||||
]
|
||||
style_num = int(total_num * style_percentage)
|
||||
grammar_num = int(total_num * grammar_percentage)
|
||||
|
||||
# 按权重抽样(使用count作为权重)
|
||||
if learnt_style_expressions:
|
||||
style_weights = [expr.get("count", 1) for expr in learnt_style_expressions]
|
||||
selected_style = weighted_sample(learnt_style_expressions, style_weights, style_num)
|
||||
if style_exprs:
|
||||
style_weights = [expr.get("count", 1) for expr in style_exprs]
|
||||
selected_style = weighted_sample(style_exprs, style_weights, style_num)
|
||||
else:
|
||||
selected_style = []
|
||||
|
||||
if learnt_grammar_expressions:
|
||||
grammar_weights = [expr.get("count", 1) for expr in learnt_grammar_expressions]
|
||||
selected_grammar = weighted_sample(learnt_grammar_expressions, grammar_weights, grammar_num)
|
||||
if grammar_exprs:
|
||||
grammar_weights = [expr.get("count", 1) for expr in grammar_exprs]
|
||||
selected_grammar = weighted_sample(grammar_exprs, grammar_weights, grammar_num)
|
||||
else:
|
||||
selected_grammar = []
|
||||
|
||||
return selected_style, selected_grammar
|
||||
|
||||
def update_expressions_count_batch(self, expressions_to_update: List[Dict[str, str]], increment: float = 0.1):
|
||||
"""对一批表达方式更新count值,按文件分组后一次性写入"""
|
||||
"""对一批表达方式更新count值,按chat_id+type分组后一次性写入数据库"""
|
||||
if not expressions_to_update:
|
||||
return
|
||||
|
||||
updates_by_file = {}
|
||||
updates_by_key = {}
|
||||
for expr in expressions_to_update:
|
||||
source_id = expr.get("source_id")
|
||||
if not source_id:
|
||||
logger.warning(f"表达方式缺少source_id,无法更新: {expr}")
|
||||
expr_type = expr.get("type", "style")
|
||||
situation = expr.get("situation")
|
||||
style = expr.get("style")
|
||||
if not source_id or not situation or not style:
|
||||
logger.warning(f"表达方式缺少必要字段,无法更新: {expr}")
|
||||
continue
|
||||
|
||||
file_path = ""
|
||||
if source_id == "personality":
|
||||
file_path = os.path.join("data", "expression", "personality", "expressions.json")
|
||||
else:
|
||||
chat_id = source_id
|
||||
expr_type = expr.get("type", "style")
|
||||
if expr_type == "style":
|
||||
file_path = os.path.join("data", "expression", "learnt_style", str(chat_id), "expressions.json")
|
||||
elif expr_type == "grammar":
|
||||
file_path = os.path.join("data", "expression", "learnt_grammar", str(chat_id), "expressions.json")
|
||||
|
||||
if file_path:
|
||||
if file_path not in updates_by_file:
|
||||
updates_by_file[file_path] = []
|
||||
updates_by_file[file_path].append(expr)
|
||||
|
||||
for file_path, updates in updates_by_file.items():
|
||||
if not os.path.exists(file_path):
|
||||
continue
|
||||
|
||||
try:
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
all_expressions = json.load(f)
|
||||
|
||||
# Create a dictionary for quick lookup
|
||||
expr_map = {(e.get("situation"), e.get("style")): e for e in all_expressions}
|
||||
|
||||
# Update counts in memory
|
||||
for expr_to_update in updates:
|
||||
key = (expr_to_update.get("situation"), expr_to_update.get("style"))
|
||||
if key in expr_map:
|
||||
expr_in_map = expr_map[key]
|
||||
current_count = expr_in_map.get("count", 1)
|
||||
new_count = min(current_count + increment, 5.0)
|
||||
expr_in_map["count"] = new_count
|
||||
expr_in_map["last_active_time"] = time.time()
|
||||
logger.debug(
|
||||
f"表达方式激活: 原count={current_count:.3f}, 增量={increment}, 新count={new_count:.3f} in {file_path}"
|
||||
)
|
||||
|
||||
# Save the updated list once for this file
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
json.dump(all_expressions, f, ensure_ascii=False, indent=2)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"批量更新表达方式count失败 for {file_path}: {e}")
|
||||
key = (source_id, expr_type, situation, style)
|
||||
if key not in updates_by_key:
|
||||
updates_by_key[key] = expr
|
||||
for (chat_id, expr_type, situation, style), expr in updates_by_key.items():
|
||||
query = Expression.select().where(
|
||||
(Expression.chat_id == chat_id) &
|
||||
(Expression.type == expr_type) &
|
||||
(Expression.situation == situation) &
|
||||
(Expression.style == style)
|
||||
)
|
||||
if query.exists():
|
||||
expr_obj = query.get()
|
||||
current_count = expr_obj.count
|
||||
new_count = min(current_count + increment, 5.0)
|
||||
expr_obj.count = new_count
|
||||
expr_obj.last_active_time = time.time()
|
||||
expr_obj.save()
|
||||
logger.debug(
|
||||
f"表达方式激活: 原count={current_count:.3f}, 增量={increment}, 新count={new_count:.3f} in db"
|
||||
)
|
||||
|
||||
async def select_suitable_expressions_llm(
|
||||
self,
|
||||
|
||||
@@ -2,7 +2,7 @@ import re
|
||||
import traceback
|
||||
from typing import Union
|
||||
|
||||
from src.common.database.database_model import Messages, RecalledMessages, Images
|
||||
from src.common.database.database_model import Messages, Images
|
||||
from src.common.logger import get_logger
|
||||
from .chat_stream import ChatStream
|
||||
from .message import MessageSending, MessageRecv
|
||||
@@ -104,29 +104,6 @@ class MessageStorage:
|
||||
logger.exception("存储消息失败")
|
||||
traceback.print_exc()
|
||||
|
||||
@staticmethod
|
||||
async def store_recalled_message(message_id: str, time: str, chat_stream: ChatStream) -> None:
|
||||
"""存储撤回消息到数据库"""
|
||||
# Table creation is handled by initialize_database in database_model.py
|
||||
try:
|
||||
RecalledMessages.create(
|
||||
message_id=message_id,
|
||||
time=float(time), # Assuming time is a string representing a float timestamp
|
||||
stream_id=chat_stream.stream_id,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("存储撤回消息失败")
|
||||
|
||||
@staticmethod
|
||||
async def remove_recalled_message(time: str) -> None:
|
||||
"""删除撤回消息"""
|
||||
try:
|
||||
# Assuming input 'time' is a string timestamp that can be converted to float
|
||||
current_time_float = float(time)
|
||||
RecalledMessages.delete().where(RecalledMessages.time < (current_time_float - 300)).execute() # type: ignore
|
||||
except Exception:
|
||||
logger.exception("删除撤回消息失败")
|
||||
|
||||
# 如果需要其他存储相关的函数,可以在这里添加
|
||||
@staticmethod
|
||||
async def update_message(
|
||||
|
||||
@@ -367,6 +367,8 @@ class DefaultReplyer:
|
||||
if not global_config.memory.enable_memory:
|
||||
return ""
|
||||
|
||||
instant_memory = None
|
||||
|
||||
running_memories = await self.memory_activator.activate_memory_with_chat_history(
|
||||
target_message=target, chat_history_prompt=chat_history
|
||||
)
|
||||
@@ -384,7 +386,9 @@ class DefaultReplyer:
|
||||
for running_memory in running_memories:
|
||||
memory_str += f"- {running_memory['content']}\n"
|
||||
|
||||
memory_str += f"- {instant_memory}\n"
|
||||
if instant_memory:
|
||||
memory_str += f"- {instant_memory}\n"
|
||||
|
||||
return memory_str
|
||||
|
||||
async def build_tool_info(self, chat_history, reply_data: Optional[Dict], enable_tool: bool = True):
|
||||
|
||||
@@ -291,6 +291,20 @@ class Knowledges(BaseModel):
|
||||
# database = db # 继承自 BaseModel
|
||||
table_name = "knowledges"
|
||||
|
||||
class Expression(BaseModel):
|
||||
"""
|
||||
用于存储表达风格的模型。
|
||||
"""
|
||||
|
||||
situation = TextField()
|
||||
style = TextField()
|
||||
count = FloatField()
|
||||
last_active_time = FloatField()
|
||||
chat_id = TextField(index=True)
|
||||
type = TextField()
|
||||
|
||||
class Meta:
|
||||
table_name = "expression"
|
||||
|
||||
class ThinkingLog(BaseModel):
|
||||
chat_id = TextField(index=True)
|
||||
@@ -316,19 +330,6 @@ class ThinkingLog(BaseModel):
|
||||
table_name = "thinking_logs"
|
||||
|
||||
|
||||
class RecalledMessages(BaseModel):
|
||||
"""
|
||||
用于存储撤回消息记录的模型。
|
||||
"""
|
||||
|
||||
message_id = TextField(index=True) # 被撤回的消息 ID
|
||||
time = DoubleField() # 撤回操作发生的时间戳
|
||||
stream_id = TextField() # 对应的 ChatStreams stream_id
|
||||
|
||||
class Meta:
|
||||
table_name = "recalled_messages"
|
||||
|
||||
|
||||
class GraphNodes(BaseModel):
|
||||
"""
|
||||
用于存储记忆图节点的模型
|
||||
@@ -376,8 +377,8 @@ def create_tables():
|
||||
OnlineTime,
|
||||
PersonInfo,
|
||||
Knowledges,
|
||||
Expression,
|
||||
ThinkingLog,
|
||||
RecalledMessages, # 添加新模型
|
||||
GraphNodes, # 添加图节点表
|
||||
GraphEdges, # 添加图边表
|
||||
Memory,
|
||||
@@ -402,9 +403,9 @@ def initialize_database():
|
||||
OnlineTime,
|
||||
PersonInfo,
|
||||
Knowledges,
|
||||
Expression,
|
||||
Memory,
|
||||
ThinkingLog,
|
||||
RecalledMessages,
|
||||
GraphNodes,
|
||||
GraphEdges,
|
||||
ActionRecords, # 添加 ActionRecords 到初始化列表
|
||||
|
||||
@@ -1,10 +1,54 @@
|
||||
import shutil
|
||||
import tomlkit
|
||||
from tomlkit.items import Table
|
||||
from tomlkit.items import Table, KeyType
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
def get_key_comment(toml_table, key):
|
||||
# 获取key的注释(如果有)
|
||||
if hasattr(toml_table, 'trivia') and hasattr(toml_table.trivia, 'comment'):
|
||||
return toml_table.trivia.comment
|
||||
if hasattr(toml_table, 'value') and isinstance(toml_table.value, dict):
|
||||
item = toml_table.value.get(key)
|
||||
if item is not None and hasattr(item, 'trivia'):
|
||||
return item.trivia.comment
|
||||
if hasattr(toml_table, 'keys'):
|
||||
for k in toml_table.keys():
|
||||
if isinstance(k, KeyType) and k.key == key:
|
||||
return k.trivia.comment
|
||||
return None
|
||||
|
||||
|
||||
def compare_dicts(new, old, path=None, new_comments=None, old_comments=None, logs=None):
|
||||
# 递归比较两个dict,找出新增和删减项,收集注释
|
||||
if path is None:
|
||||
path = []
|
||||
if logs is None:
|
||||
logs = []
|
||||
if new_comments is None:
|
||||
new_comments = {}
|
||||
if old_comments is None:
|
||||
old_comments = {}
|
||||
# 新增项
|
||||
for key in new:
|
||||
if key == "version":
|
||||
continue
|
||||
if key not in old:
|
||||
comment = get_key_comment(new, key)
|
||||
logs.append(f"新增: {'.'.join(path+[str(key)])} 注释: {comment if comment else '无'}")
|
||||
elif isinstance(new[key], (dict, Table)) and isinstance(old.get(key), (dict, Table)):
|
||||
compare_dicts(new[key], old[key], path+[str(key)], new_comments, old_comments, logs)
|
||||
# 删减项
|
||||
for key in old:
|
||||
if key == "version":
|
||||
continue
|
||||
if key not in new:
|
||||
comment = get_key_comment(old, key)
|
||||
logs.append(f"删减: {'.'.join(path+[str(key)])} 注释: {comment if comment else '无'}")
|
||||
return logs
|
||||
|
||||
|
||||
def update_config():
|
||||
print("开始更新配置文件...")
|
||||
# 获取根目录路径
|
||||
@@ -56,6 +100,16 @@ def update_config():
|
||||
else:
|
||||
print(f"检测到版本号不同: 旧版本 v{old_version} -> 新版本 v{new_version}")
|
||||
|
||||
# 输出新增和删减项及注释
|
||||
if old_config:
|
||||
print("配置项变动如下:")
|
||||
logs = compare_dicts(new_config, old_config)
|
||||
if logs:
|
||||
for log in logs:
|
||||
print(log)
|
||||
else:
|
||||
print("无新增或删减项")
|
||||
|
||||
# 递归更新配置
|
||||
def update_dict(target, source):
|
||||
for key, value in source.items():
|
||||
|
||||
@@ -4,7 +4,7 @@ import shutil
|
||||
|
||||
from datetime import datetime
|
||||
from tomlkit import TOMLDocument
|
||||
from tomlkit.items import Table
|
||||
from tomlkit.items import Table, KeyType
|
||||
from dataclasses import field, dataclass
|
||||
from rich.traceback import install
|
||||
|
||||
@@ -51,14 +51,158 @@ TEMPLATE_DIR = os.path.join(PROJECT_ROOT, "template")
|
||||
MMC_VERSION = "0.9.0-snapshot.2"
|
||||
|
||||
|
||||
def get_key_comment(toml_table, key):
|
||||
# 获取key的注释(如果有)
|
||||
if hasattr(toml_table, 'trivia') and hasattr(toml_table.trivia, 'comment'):
|
||||
return toml_table.trivia.comment
|
||||
if hasattr(toml_table, 'value') and isinstance(toml_table.value, dict):
|
||||
item = toml_table.value.get(key)
|
||||
if item is not None and hasattr(item, 'trivia'):
|
||||
return item.trivia.comment
|
||||
if hasattr(toml_table, 'keys'):
|
||||
for k in toml_table.keys():
|
||||
if isinstance(k, KeyType) and k.key == key:
|
||||
return k.trivia.comment
|
||||
return None
|
||||
|
||||
|
||||
def compare_dicts(new, old, path=None, logs=None):
|
||||
# 递归比较两个dict,找出新增和删减项,收集注释
|
||||
if path is None:
|
||||
path = []
|
||||
if logs is None:
|
||||
logs = []
|
||||
# 新增项
|
||||
for key in new:
|
||||
if key == "version":
|
||||
continue
|
||||
if key not in old:
|
||||
comment = get_key_comment(new, key)
|
||||
logs.append(f"新增: {'.'.join(path+[str(key)])} 注释: {comment if comment else '无'}")
|
||||
elif isinstance(new[key], (dict, Table)) and isinstance(old.get(key), (dict, Table)):
|
||||
compare_dicts(new[key], old[key], path+[str(key)], logs)
|
||||
# 删减项
|
||||
for key in old:
|
||||
if key == "version":
|
||||
continue
|
||||
if key not in new:
|
||||
comment = get_key_comment(old, key)
|
||||
logs.append(f"删减: {'.'.join(path+[str(key)])} 注释: {comment if comment else '无'}")
|
||||
return logs
|
||||
|
||||
|
||||
def get_value_by_path(d, path):
|
||||
for k in path:
|
||||
if isinstance(d, dict) and k in d:
|
||||
d = d[k]
|
||||
else:
|
||||
return None
|
||||
return d
|
||||
|
||||
def set_value_by_path(d, path, value):
|
||||
for k in path[:-1]:
|
||||
if k not in d or not isinstance(d[k], dict):
|
||||
d[k] = {}
|
||||
d = d[k]
|
||||
d[path[-1]] = value
|
||||
|
||||
def compare_default_values(new, old, path=None, logs=None, changes=None):
|
||||
# 递归比较两个dict,找出默认值变化项
|
||||
if path is None:
|
||||
path = []
|
||||
if logs is None:
|
||||
logs = []
|
||||
if changes is None:
|
||||
changes = []
|
||||
for key in new:
|
||||
if key == "version":
|
||||
continue
|
||||
if key in old:
|
||||
if isinstance(new[key], (dict, Table)) and isinstance(old[key], (dict, Table)):
|
||||
compare_default_values(new[key], old[key], path+[str(key)], logs, changes)
|
||||
else:
|
||||
# 只要值发生变化就记录
|
||||
if new[key] != old[key]:
|
||||
logs.append(f"默认值变化: {'.'.join(path+[str(key)])} 旧默认值: {old[key]} 新默认值: {new[key]}")
|
||||
changes.append((path+[str(key)], old[key], new[key]))
|
||||
return logs, changes
|
||||
|
||||
|
||||
def update_config():
|
||||
# 获取根目录路径
|
||||
old_config_dir = os.path.join(CONFIG_DIR, "old")
|
||||
compare_dir = os.path.join(TEMPLATE_DIR, "compare")
|
||||
|
||||
# 定义文件路径
|
||||
template_path = os.path.join(TEMPLATE_DIR, "bot_config_template.toml")
|
||||
old_config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
|
||||
new_config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
|
||||
compare_path = os.path.join(compare_dir, "bot_config_template.toml")
|
||||
|
||||
# 创建compare目录(如果不存在)
|
||||
os.makedirs(compare_dir, exist_ok=True)
|
||||
|
||||
# 处理compare下的模板文件
|
||||
def get_version_from_toml(toml_path):
|
||||
if not os.path.exists(toml_path):
|
||||
return None
|
||||
with open(toml_path, "r", encoding="utf-8") as f:
|
||||
doc = tomlkit.load(f)
|
||||
if "inner" in doc and "version" in doc["inner"]:
|
||||
return doc["inner"]["version"]
|
||||
return None
|
||||
|
||||
template_version = get_version_from_toml(template_path)
|
||||
compare_version = get_version_from_toml(compare_path)
|
||||
|
||||
def version_tuple(v):
|
||||
if v is None:
|
||||
return (0,)
|
||||
return tuple(int(x) if x.isdigit() else 0 for x in str(v).replace("v", "").split("-")[0].split("."))
|
||||
|
||||
# 先读取 compare 下的模板(如果有),用于默认值变动检测
|
||||
if os.path.exists(compare_path):
|
||||
with open(compare_path, "r", encoding="utf-8") as f:
|
||||
compare_config = tomlkit.load(f)
|
||||
else:
|
||||
compare_config = None
|
||||
|
||||
# 读取当前模板
|
||||
with open(template_path, "r", encoding="utf-8") as f:
|
||||
new_config = tomlkit.load(f)
|
||||
|
||||
# 检查默认值变化并处理(只有 compare_config 存在时才做)
|
||||
if compare_config is not None:
|
||||
# 读取旧配置
|
||||
with open(old_config_path, "r", encoding="utf-8") as f:
|
||||
old_config = tomlkit.load(f)
|
||||
logs, changes = compare_default_values(new_config, compare_config)
|
||||
if logs:
|
||||
logger.info("检测到模板默认值变动如下:")
|
||||
for log in logs:
|
||||
logger.info(log)
|
||||
# 检查旧配置是否等于旧默认值,如果是则更新为新默认值
|
||||
for path, old_default, new_default in changes:
|
||||
old_value = get_value_by_path(old_config, path)
|
||||
if old_value == old_default:
|
||||
set_value_by_path(old_config, path, new_default)
|
||||
logger.info(f"已自动将配置 {'.'.join(path)} 的值从旧默认值 {old_default} 更新为新默认值 {new_default}")
|
||||
else:
|
||||
logger.info("未检测到模板默认值变动")
|
||||
# 保存旧配置的变更(后续合并逻辑会用到 old_config)
|
||||
else:
|
||||
old_config = None
|
||||
|
||||
# 检查 compare 下没有模板,或新模板版本更高,则复制
|
||||
if not os.path.exists(compare_path):
|
||||
shutil.copy2(template_path, compare_path)
|
||||
logger.info(f"已将模板文件复制到: {compare_path}")
|
||||
else:
|
||||
if version_tuple(template_version) > version_tuple(compare_version):
|
||||
shutil.copy2(template_path, compare_path)
|
||||
logger.info(f"模板版本较新,已替换compare下的模板: {compare_path}")
|
||||
else:
|
||||
logger.debug(f"compare下的模板版本不低于当前模板,无需替换: {compare_path}")
|
||||
|
||||
# 检查配置文件是否存在
|
||||
if not os.path.exists(old_config_path):
|
||||
@@ -69,11 +213,13 @@ def update_config():
|
||||
# 如果是新创建的配置文件,直接返回
|
||||
quit()
|
||||
|
||||
# 读取旧配置文件和模板文件
|
||||
with open(old_config_path, "r", encoding="utf-8") as f:
|
||||
old_config = tomlkit.load(f)
|
||||
with open(template_path, "r", encoding="utf-8") as f:
|
||||
new_config = tomlkit.load(f)
|
||||
# 读取旧配置文件和模板文件(如果前面没读过 old_config,这里再读一次)
|
||||
if old_config is None:
|
||||
with open(old_config_path, "r", encoding="utf-8") as f:
|
||||
old_config = tomlkit.load(f)
|
||||
# new_config 已经读取
|
||||
|
||||
# 读取 compare_config 只用于默认值变动检测,后续合并逻辑不再用
|
||||
|
||||
# 检查version是否相同
|
||||
if old_config and "inner" in old_config and "inner" in new_config:
|
||||
@@ -83,7 +229,7 @@ def update_config():
|
||||
logger.info(f"检测到配置文件版本号相同 (v{old_version}),跳过更新")
|
||||
return
|
||||
else:
|
||||
logger.info(f"检测到版本号不同: 旧版本 v{old_version} -> 新版本 v{new_version}")
|
||||
logger.info(f"\n----------------------------------------\n检测到版本号不同: 旧版本 v{old_version} -> 新版本 v{new_version}\n----------------------------------------")
|
||||
else:
|
||||
logger.info("已有配置文件未检测到版本号,可能是旧版本。将进行更新")
|
||||
|
||||
@@ -100,6 +246,16 @@ def update_config():
|
||||
shutil.copy2(template_path, new_config_path)
|
||||
logger.info(f"已创建新配置文件: {new_config_path}")
|
||||
|
||||
# 输出新增和删减项及注释
|
||||
if old_config:
|
||||
logger.info("配置项变动如下:\n----------------------------------------")
|
||||
logs = compare_dicts(new_config, old_config)
|
||||
if logs:
|
||||
for log in logs:
|
||||
logger.info(log)
|
||||
else:
|
||||
logger.info("无新增或删减项")
|
||||
|
||||
def update_dict(target: TOMLDocument | dict | Table, source: TOMLDocument | dict):
|
||||
"""
|
||||
将source字典的值更新到target字典中(如果target中存在相同的键)
|
||||
|
||||
18
src/main.py
18
src/main.py
@@ -131,7 +131,6 @@ class MainSystem:
|
||||
while True:
|
||||
tasks = [
|
||||
get_emoji_manager().start_periodic_check_register(),
|
||||
self.remove_recalled_message_task(),
|
||||
self.app.run(),
|
||||
self.server.run(),
|
||||
]
|
||||
@@ -184,23 +183,6 @@ class MainSystem:
|
||||
await expression_learner.learn_and_store_expression()
|
||||
logger.info("[表达方式学习] 表达方式学习完成")
|
||||
|
||||
# async def print_mood_task(self):
|
||||
# """打印情绪状态"""
|
||||
# while True:
|
||||
# self.mood_manager.print_mood_status()
|
||||
# await asyncio.sleep(60)
|
||||
|
||||
@staticmethod
|
||||
async def remove_recalled_message_task():
|
||||
"""删除撤回消息任务"""
|
||||
while True:
|
||||
try:
|
||||
storage = MessageStorage()
|
||||
await storage.remove_recalled_message(time.time())
|
||||
except Exception:
|
||||
logger.exception("删除撤回消息失败")
|
||||
await asyncio.sleep(3600)
|
||||
|
||||
|
||||
async def main():
|
||||
"""主函数"""
|
||||
|
||||
@@ -109,9 +109,12 @@ class NoReplyAction(BaseAction):
|
||||
interest_value = msg_dict.get("interest_value", 0.0)
|
||||
if text:
|
||||
accumulated_interest += interest_value
|
||||
|
||||
|
||||
talk_frequency = global_config.chat.get_current_talk_frequency(self.chat_id)
|
||||
logger.info(f"{self.log_prefix} 当前累计兴趣值: {accumulated_interest:.2f}, 当前聊天频率: {talk_frequency:.2f}")
|
||||
# 只在兴趣值变化时输出log
|
||||
if not hasattr(self, "_last_accumulated_interest") or accumulated_interest != self._last_accumulated_interest:
|
||||
logger.info(f"{self.log_prefix} 当前累计兴趣值: {accumulated_interest:.2f}, 当前聊天频率: {talk_frequency:.2f}")
|
||||
self._last_accumulated_interest = accumulated_interest
|
||||
|
||||
if accumulated_interest >= self._interest_exit_threshold / talk_frequency:
|
||||
logger.info(
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
[inner]
|
||||
version = "4.3.0"
|
||||
version = "4.4.3"
|
||||
|
||||
#----以下是给开发人员阅读的,如果你只是部署了麦麦,不需要阅读----
|
||||
#如果你想要修改配置文件,请在修改后将version的值进行变更
|
||||
@@ -33,9 +33,9 @@ compress_identity = true # 是否压缩身份,压缩后会精简身份信息
|
||||
# 表达方式
|
||||
enable_expression = true # 是否启用表达方式
|
||||
# 描述麦麦说话的表达风格,表达习惯,例如:(请回复的平淡一些,简短一些,说中文,不要刻意突出自身学科背景。)
|
||||
expression_style = "请回复的平淡一些,简短一些,说中文,可以参考贴吧,知乎和微博的回复风格,回复不要浮夸,不要用夸张修辞,不要刻意突出自身学科背景。"
|
||||
expression_style = "请回复的平淡些,简短一些,说中文,可以参考贴吧,知乎和微博的回复风格,回复不要浮夸,不要用夸张修辞,不要刻意突出自身学科背景。"
|
||||
enable_expression_learning = false # 是否启用表达学习,麦麦会学习不同群里人类说话风格(群之间不互通)
|
||||
learning_interval = 600 # 学习间隔 单位秒
|
||||
learning_interval = 350 # 学习间隔 单位秒
|
||||
|
||||
expression_groups = [
|
||||
["qq:1919810:private","qq:114514:private","qq:1111111:group"], # 在这里设置互通组,相同组的chat_id会共享学习到的表达方式
|
||||
@@ -124,21 +124,21 @@ filtration_prompt = "符合公序良俗" # 表情包过滤要求,只有符合
|
||||
|
||||
[memory]
|
||||
enable_memory = true # 是否启用记忆系统
|
||||
memory_build_interval = 1000 # 记忆构建间隔 单位秒 间隔越低,麦麦学习越多,但是冗余信息也会增多
|
||||
memory_build_interval = 600 # 记忆构建间隔 单位秒 间隔越低,麦麦学习越多,但是冗余信息也会增多
|
||||
memory_build_distribution = [6.0, 3.0, 0.6, 32.0, 12.0, 0.4] # 记忆构建分布,参数:分布1均值,标准差,权重,分布2均值,标准差,权重
|
||||
memory_build_sample_num = 4 # 采样数量,数值越高记忆采样次数越多
|
||||
memory_build_sample_num = 8 # 采样数量,数值越高记忆采样次数越多
|
||||
memory_build_sample_length = 30 # 采样长度,数值越高一段记忆内容越丰富
|
||||
memory_compress_rate = 0.1 # 记忆压缩率 控制记忆精简程度 建议保持默认,调高可以获得更多信息,但是冗余信息也会增多
|
||||
|
||||
forget_memory_interval = 1500 # 记忆遗忘间隔 单位秒 间隔越低,麦麦遗忘越频繁,记忆更精简,但更难学习
|
||||
memory_forget_time = 24 #多长时间后的记忆会被遗忘 单位小时
|
||||
memory_forget_percentage = 0.01 # 记忆遗忘比例 控制记忆遗忘程度 越大遗忘越多 建议保持默认
|
||||
forget_memory_interval = 3000 # 记忆遗忘间隔 单位秒 间隔越低,麦麦遗忘越频繁,记忆更精简,但更难学习
|
||||
memory_forget_time = 48 #多长时间后的记忆会被遗忘 单位小时
|
||||
memory_forget_percentage = 0.008 # 记忆遗忘比例 控制记忆遗忘程度 越大遗忘越多 建议保持默认
|
||||
|
||||
consolidate_memory_interval = 1000 # 记忆整合间隔 单位秒 间隔越低,麦麦整合越频繁,记忆更精简
|
||||
consolidation_similarity_threshold = 0.7 # 相似度阈值
|
||||
consolidation_check_percentage = 0.05 # 检查节点比例
|
||||
|
||||
enable_instant_memory = true # 是否启用即时记忆
|
||||
enable_instant_memory = false # 是否启用即时记忆,测试功能,可能存在未知问题
|
||||
|
||||
#不希望记忆的词,已经记忆的不会受到影响,需要手动清理
|
||||
memory_ban_words = [ "表情包", "图片", "回复", "聊天记录" ]
|
||||
|
||||
Reference in New Issue
Block a user