ruff reformatted

This commit is contained in:
春河晴
2025-04-08 15:31:13 +09:00
parent 0d7068acab
commit 7840a6080d
40 changed files with 1227 additions and 1336 deletions

View File

@@ -42,7 +42,6 @@ class Heartflow:
self._subheartflows = {} self._subheartflows = {}
self.active_subheartflows_nums = 0 self.active_subheartflows_nums = 0
async def _cleanup_inactive_subheartflows(self): async def _cleanup_inactive_subheartflows(self):
"""定期清理不活跃的子心流""" """定期清理不活跃的子心流"""
while True: while True:
@@ -98,11 +97,8 @@ class Heartflow:
random.shuffle(identity_detail) random.shuffle(identity_detail)
prompt_personality += f",{identity_detail[0]}" prompt_personality += f",{identity_detail[0]}"
personality_info = prompt_personality personality_info = prompt_personality
current_thinking_info = self.current_mind current_thinking_info = self.current_mind
mood_info = self.current_state.mood mood_info = self.current_state.mood
related_memory_info = "memory" related_memory_info = "memory"
@@ -160,8 +156,6 @@ class Heartflow:
random.shuffle(identity_detail) random.shuffle(identity_detail)
prompt_personality += f",{identity_detail[0]}" prompt_personality += f",{identity_detail[0]}"
personality_info = prompt_personality personality_info = prompt_personality
mood_info = self.current_state.mood mood_info = self.current_state.mood

View File

@@ -7,6 +7,7 @@ from src.common.database import db
from src.individuality.individuality import Individuality from src.individuality.individuality import Individuality
import random import random
# 所有观察的基类 # 所有观察的基类
class Observation: class Observation:
def __init__(self, observe_type, observe_id): def __init__(self, observe_type, observe_id):
@@ -131,12 +132,8 @@ class ChattingObservation(Observation):
random.shuffle(identity_detail) random.shuffle(identity_detail)
prompt_personality += f",{identity_detail[0]}" prompt_personality += f",{identity_detail[0]}"
personality_info = prompt_personality personality_info = prompt_personality
prompt = "" prompt = ""
prompt += f"{personality_info},请注意识别你自己的聊天发言" prompt += f"{personality_info},请注意识别你自己的聊天发言"
prompt += f"你的名字叫:{self.name},你的昵称是:{self.nick_name}\n" prompt += f"你的名字叫:{self.name},你的昵称是:{self.nick_name}\n"
@@ -149,7 +146,6 @@ class ChattingObservation(Observation):
print(f"prompt{prompt}") print(f"prompt{prompt}")
print(f"self.observe_info{self.observe_info}") print(f"self.observe_info{self.observe_info}")
def translate_message_list_to_str(self): def translate_message_list_to_str(self):
self.talking_message_str = "" self.talking_message_str = ""
for message in self.talking_message: for message in self.talking_message:

View File

@@ -53,7 +53,6 @@ class SubHeartflow:
if not self.current_mind: if not self.current_mind:
self.current_mind = "你什么也没想" self.current_mind = "你什么也没想"
self.is_active = False self.is_active = False
self.observations: list[Observation] = [] self.observations: list[Observation] = []
@@ -86,7 +85,9 @@ class SubHeartflow:
async def subheartflow_start_working(self): async def subheartflow_start_working(self):
while True: while True:
current_time = time.time() current_time = time.time()
if current_time - self.last_reply_time > global_config.sub_heart_flow_freeze_time: # 120秒无回复/不在场,冻结 if (
current_time - self.last_reply_time > global_config.sub_heart_flow_freeze_time
): # 120秒无回复/不在场,冻结
self.is_active = False self.is_active = False
await asyncio.sleep(global_config.sub_heart_flow_update_interval) # 每60秒检查一次 await asyncio.sleep(global_config.sub_heart_flow_update_interval) # 每60秒检查一次
else: else:
@@ -100,7 +101,9 @@ class SubHeartflow:
await asyncio.sleep(global_config.sub_heart_flow_update_interval) await asyncio.sleep(global_config.sub_heart_flow_update_interval)
# 检查是否超过10分钟没有激活 # 检查是否超过10分钟没有激活
if current_time - self.last_active_time > global_config.sub_heart_flow_stop_time: # 5分钟无回复/不在场,销毁 if (
current_time - self.last_active_time > global_config.sub_heart_flow_stop_time
): # 5分钟无回复/不在场,销毁
logger.info(f"子心流 {self.subheartflow_id} 已经5分钟没有激活正在销毁...") logger.info(f"子心流 {self.subheartflow_id} 已经5分钟没有激活正在销毁...")
break # 退出循环,销毁自己 break # 退出循环,销毁自己
@@ -176,9 +179,6 @@ class SubHeartflow:
random.shuffle(identity_detail) random.shuffle(identity_detail)
prompt_personality += f",{identity_detail[0]}" prompt_personality += f",{identity_detail[0]}"
# 调取记忆 # 调取记忆
related_memory = await HippocampusManager.get_instance().get_memory_from_text( related_memory = await HippocampusManager.get_instance().get_memory_from_text(
text=chat_observe_info, max_memory_num=2, max_memory_length=2, max_depth=3, fast_retrieval=False text=chat_observe_info, max_memory_num=2, max_memory_length=2, max_depth=3, fast_retrieval=False
@@ -244,8 +244,6 @@ class SubHeartflow:
random.shuffle(identity_detail) random.shuffle(identity_detail)
prompt_personality += f",{identity_detail[0]}" prompt_personality += f",{identity_detail[0]}"
current_thinking_info = self.current_mind current_thinking_info = self.current_mind
mood_info = self.current_state.mood mood_info = self.current_state.mood
@@ -293,8 +291,6 @@ class SubHeartflow:
random.shuffle(identity_detail) random.shuffle(identity_detail)
prompt_personality += f",{identity_detail[0]}" prompt_personality += f",{identity_detail[0]}"
# print("麦麦闹情绪了1") # print("麦麦闹情绪了1")
current_thinking_info = self.current_mind current_thinking_info = self.current_mind
mood_info = self.current_state.mood mood_info = self.current_state.mood
@@ -321,7 +317,6 @@ class SubHeartflow:
self.past_mind.append(self.current_mind) self.past_mind.append(self.current_mind)
self.current_mind = reponse self.current_mind = reponse
async def get_prompt_info(self, message: str, threshold: float): async def get_prompt_info(self, message: str, threshold: float):
start_time = time.time() start_time = time.time()
related_info = "" related_info = ""
@@ -442,7 +437,9 @@ class SubHeartflow:
# 6. 限制总数量最多10条 # 6. 限制总数量最多10条
filtered_results = filtered_results[:10] filtered_results = filtered_results[:10]
logger.info(f"结果处理完成,耗时: {time.time() - process_start_time:.3f}秒,过滤后剩余{len(filtered_results)}条结果") logger.info(
f"结果处理完成,耗时: {time.time() - process_start_time:.3f}秒,过滤后剩余{len(filtered_results)}条结果"
)
# 7. 格式化输出 # 7. 格式化输出
if filtered_results: if filtered_results:
@@ -470,7 +467,9 @@ class SubHeartflow:
logger.info(f"知识库检索总耗时: {time.time() - start_time:.3f}") logger.info(f"知识库检索总耗时: {time.time() - start_time:.3f}")
return related_info, grouped_results return related_info, grouped_results
def get_info_from_db(self, query_embedding: list, limit: int = 1, threshold: float = 0.5, return_raw: bool = False) -> Union[str, list]: def get_info_from_db(
self, query_embedding: list, limit: int = 1, threshold: float = 0.5, return_raw: bool = False
) -> Union[str, list]:
if not query_embedding: if not query_embedding:
return "" if not return_raw else [] return "" if not return_raw else []
# 使用余弦相似度计算 # 使用余弦相似度计算

View File

@@ -2,9 +2,11 @@ from dataclasses import dataclass
from typing import List from typing import List
import random import random
@dataclass @dataclass
class Identity: class Identity:
"""身份特征类""" """身份特征类"""
identity_detail: List[str] # 身份细节描述 identity_detail: List[str] # 身份细节描述
height: int # 身高(厘米) height: int # 身高(厘米)
weight: int # 体重(千克) weight: int # 体重(千克)
@@ -19,8 +21,15 @@ class Identity:
cls._instance = super().__new__(cls) cls._instance = super().__new__(cls)
return cls._instance return cls._instance
def __init__(self, identity_detail: List[str] = None, height: int = 0, weight: int = 0, def __init__(
age: int = 0, gender: str = "", appearance: str = ""): self,
identity_detail: List[str] = None,
height: int = 0,
weight: int = 0,
age: int = 0,
gender: str = "",
appearance: str = "",
):
"""初始化身份特征 """初始化身份特征
Args: Args:
@@ -41,7 +50,7 @@ class Identity:
self.appearance = appearance self.appearance = appearance
@classmethod @classmethod
def get_instance(cls) -> 'Identity': def get_instance(cls) -> "Identity":
"""获取Identity单例实例 """获取Identity单例实例
Returns: Returns:
@@ -52,8 +61,9 @@ class Identity:
return cls._instance return cls._instance
@classmethod @classmethod
def initialize(cls, identity_detail: List[str], height: int, weight: int, def initialize(
age: int, gender: str, appearance: str) -> 'Identity': cls, identity_detail: List[str], height: int, weight: int, age: int, gender: str, appearance: str
) -> "Identity":
"""初始化身份特征 """初始化身份特征
Args: Args:
@@ -105,11 +115,11 @@ class Identity:
"weight": self.weight, "weight": self.weight,
"age": self.age, "age": self.age,
"gender": self.gender, "gender": self.gender,
"appearance": self.appearance "appearance": self.appearance,
} }
@classmethod @classmethod
def from_dict(cls, data: dict) -> 'Identity': def from_dict(cls, data: dict) -> "Identity":
"""从字典创建身份特征实例""" """从字典创建身份特征实例"""
instance = cls.get_instance() instance = cls.get_instance()
for key, value in data.items(): for key, value in data.items():

View File

@@ -2,8 +2,10 @@ from typing import Optional
from .personality import Personality from .personality import Personality
from .identity import Identity from .identity import Identity
class Individuality: class Individuality:
"""个体特征管理类""" """个体特征管理类"""
_instance = None _instance = None
def __new__(cls, *args, **kwargs): def __new__(cls, *args, **kwargs):
@@ -16,7 +18,7 @@ class Individuality:
self.identity: Optional[Identity] = None self.identity: Optional[Identity] = None
@classmethod @classmethod
def get_instance(cls) -> 'Individuality': def get_instance(cls) -> "Individuality":
"""获取Individuality单例实例 """获取Individuality单例实例
Returns: Returns:
@@ -26,9 +28,18 @@ class Individuality:
cls._instance = cls() cls._instance = cls()
return cls._instance return cls._instance
def initialize(self, bot_nickname: str, personality_core: str, personality_sides: list, def initialize(
identity_detail: list, height: int, weight: int, age: int, self,
gender: str, appearance: str) -> None: bot_nickname: str,
personality_core: str,
personality_sides: list,
identity_detail: list,
height: int,
weight: int,
age: int,
gender: str,
appearance: str,
) -> None:
"""初始化个体特征 """初始化个体特征
Args: Args:
@@ -44,30 +55,23 @@ class Individuality:
""" """
# 初始化人格 # 初始化人格
self.personality = Personality.initialize( self.personality = Personality.initialize(
bot_nickname=bot_nickname, bot_nickname=bot_nickname, personality_core=personality_core, personality_sides=personality_sides
personality_core=personality_core,
personality_sides=personality_sides
) )
# 初始化身份 # 初始化身份
self.identity = Identity.initialize( self.identity = Identity.initialize(
identity_detail=identity_detail, identity_detail=identity_detail, height=height, weight=weight, age=age, gender=gender, appearance=appearance
height=height,
weight=weight,
age=age,
gender=gender,
appearance=appearance
) )
def to_dict(self) -> dict: def to_dict(self) -> dict:
"""将个体特征转换为字典格式""" """将个体特征转换为字典格式"""
return { return {
"personality": self.personality.to_dict() if self.personality else None, "personality": self.personality.to_dict() if self.personality else None,
"identity": self.identity.to_dict() if self.identity else None "identity": self.identity.to_dict() if self.identity else None,
} }
@classmethod @classmethod
def from_dict(cls, data: dict) -> 'Individuality': def from_dict(cls, data: dict) -> "Individuality":
"""从字典创建个体特征实例""" """从字典创建个体特征实例"""
instance = cls.get_instance() instance = cls.get_instance()
if data.get("personality"): if data.get("personality"):
@@ -101,5 +105,3 @@ class Individuality:
return self.personality.agreeableness return self.personality.agreeableness
elif factor == "neuroticism": elif factor == "neuroticism":
return self.personality.neuroticism return self.personality.neuroticism

View File

@@ -32,11 +32,10 @@ else:
def adapt_scene(scene: str) -> str: def adapt_scene(scene: str) -> str:
personality_core = config["personality"]["personality_core"]
personality_core = config['personality']['personality_core'] personality_sides = config["personality"]["personality_sides"]
personality_sides = config['personality']['personality_sides']
personality_side = random.choice(personality_sides) personality_side = random.choice(personality_sides)
identity_details = config['identity']['identity_detail'] identity_details = config["identity"]["identity_detail"]
identity_detail = random.choice(identity_details) identity_detail = random.choice(identity_details)
""" """
@@ -51,10 +50,10 @@ def adapt_scene(scene: str) -> str:
try: try:
prompt = f""" prompt = f"""
这是一个参与人格测评的角色形象: 这是一个参与人格测评的角色形象:
- 昵称: {config['bot']['nickname']} - 昵称: {config["bot"]["nickname"]}
- 性别: {config['identity']['gender']} - 性别: {config["identity"]["gender"]}
- 年龄: {config['identity']['age']} - 年龄: {config["identity"]["age"]}
- 外貌: {config['identity']['appearance']} - 外貌: {config["identity"]["appearance"]}
- 性格核心: {personality_core} - 性格核心: {personality_core}
- 性格侧面: {personality_side} - 性格侧面: {personality_side}
- 身份细节: {identity_detail} - 身份细节: {identity_detail}
@@ -62,11 +61,11 @@ def adapt_scene(scene: str) -> str:
请根据上述形象,改编以下场景,在测评中,用户将根据该场景给出上述角色形象的反应: 请根据上述形象,改编以下场景,在测评中,用户将根据该场景给出上述角色形象的反应:
{scene} {scene}
保持场景的本质不变,但最好贴近生活且具体,并且让它更适合这个角色。 保持场景的本质不变,但最好贴近生活且具体,并且让它更适合这个角色。
改编后的场景应该自然、连贯,并考虑角色的年龄、身份和性格特点。只返回改编后的场景描述,不要包含其他说明。注意{config['bot']['nickname']}是面对这个场景的人,而不是场景的其他人。场景中不会有其描述, 改编后的场景应该自然、连贯,并考虑角色的年龄、身份和性格特点。只返回改编后的场景描述,不要包含其他说明。注意{config["bot"]["nickname"]}是面对这个场景的人,而不是场景的其他人。场景中不会有其描述,
现在,请你给出改编后的场景描述 现在,请你给出改编后的场景描述
""" """
llm = LLM_request_off(model_name=config['model']['llm_normal']['name']) llm = LLM_request_off(model_name=config["model"]["llm_normal"]["name"])
adapted_scene, _ = llm.generate_response(prompt) adapted_scene, _ = llm.generate_response(prompt)
# 检查返回的场景是否为空或错误信息 # 检查返回的场景是否为空或错误信息
@@ -187,7 +186,12 @@ class PersonalityEvaluator_direct:
input() input()
total_scenarios = len(self.scenarios) total_scenarios = len(self.scenarios)
progress_bar = tqdm(total=total_scenarios, desc="场景进度", ncols=100, bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}]') progress_bar = tqdm(
total=total_scenarios,
desc="场景进度",
ncols=100,
bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}]",
)
for _i, scenario_data in enumerate(self.scenarios, 1): for _i, scenario_data in enumerate(self.scenarios, 1):
# print(f"\n{'-' * 20} 场景 {i}/{total_scenarios} - {scenario_data['场景编号']} {'-' * 20}") # print(f"\n{'-' * 20} 场景 {i}/{total_scenarios} - {scenario_data['场景编号']} {'-' * 20}")
@@ -251,16 +255,16 @@ class PersonalityEvaluator_direct:
"dimension_counts": self.dimension_counts, "dimension_counts": self.dimension_counts,
"scenarios": self.scenarios, "scenarios": self.scenarios,
"bot_info": { "bot_info": {
"nickname": config['bot']['nickname'], "nickname": config["bot"]["nickname"],
"gender": config['identity']['gender'], "gender": config["identity"]["gender"],
"age": config['identity']['age'], "age": config["identity"]["age"],
"height": config['identity']['height'], "height": config["identity"]["height"],
"weight": config['identity']['weight'], "weight": config["identity"]["weight"],
"appearance": config['identity']['appearance'], "appearance": config["identity"]["appearance"],
"personality_core": config['personality']['personality_core'], "personality_core": config["personality"]["personality_core"],
"personality_sides": config['personality']['personality_sides'], "personality_sides": config["personality"]["personality_sides"],
"identity_detail": config['identity']['identity_detail'] "identity_detail": config["identity"]["identity_detail"],
} },
} }
@@ -275,7 +279,7 @@ def main():
"extraversion": round(result["final_scores"]["外向性"] / 6, 1), "extraversion": round(result["final_scores"]["外向性"] / 6, 1),
"agreeableness": round(result["final_scores"]["宜人性"] / 6, 1), "agreeableness": round(result["final_scores"]["宜人性"] / 6, 1),
"neuroticism": round(result["final_scores"]["神经质"] / 6, 1), "neuroticism": round(result["final_scores"]["神经质"] / 6, 1),
"bot_nickname": config['bot']['nickname'] "bot_nickname": config["bot"]["nickname"],
} }
# 确保目录存在 # 确保目录存在
@@ -283,10 +287,10 @@ def main():
os.makedirs(save_dir, exist_ok=True) os.makedirs(save_dir, exist_ok=True)
# 创建文件名,替换可能的非法字符 # 创建文件名,替换可能的非法字符
bot_name = config['bot']['nickname'] bot_name = config["bot"]["nickname"]
# 替换Windows文件名中不允许的字符 # 替换Windows文件名中不允许的字符
for char in ['\\', '/', ':', '*', '?', '"', '<', '>', '|']: for char in ["\\", "/", ":", "*", "?", '"', "<", ">", "|"]:
bot_name = bot_name.replace(char, '_') bot_name = bot_name.replace(char, "_")
file_name = f"{bot_name}_personality.per" file_name = f"{bot_name}_personality.per"
save_path = os.path.join(save_dir, file_name) save_path = os.path.join(save_dir, file_name)

View File

@@ -4,9 +4,11 @@ import json
from pathlib import Path from pathlib import Path
import random import random
@dataclass @dataclass
class Personality: class Personality:
"""人格特质类""" """人格特质类"""
openness: float # 开放性 openness: float # 开放性
conscientiousness: float # 尽责性 conscientiousness: float # 尽责性
extraversion: float # 外向性 extraversion: float # 外向性
@@ -30,7 +32,7 @@ class Personality:
self.personality_sides = personality_sides self.personality_sides = personality_sides
@classmethod @classmethod
def get_instance(cls) -> 'Personality': def get_instance(cls) -> "Personality":
"""获取Personality单例实例 """获取Personality单例实例
Returns: Returns:
@@ -47,13 +49,13 @@ class Personality:
# 如果文件存在,读取文件 # 如果文件存在,读取文件
if personality_file.exists(): if personality_file.exists():
with open(personality_file, 'r', encoding='utf-8') as f: with open(personality_file, "r", encoding="utf-8") as f:
personality_data = json.load(f) personality_data = json.load(f)
self.openness = personality_data.get('openness', 0.5) self.openness = personality_data.get("openness", 0.5)
self.conscientiousness = personality_data.get('conscientiousness', 0.5) self.conscientiousness = personality_data.get("conscientiousness", 0.5)
self.extraversion = personality_data.get('extraversion', 0.5) self.extraversion = personality_data.get("extraversion", 0.5)
self.agreeableness = personality_data.get('agreeableness', 0.5) self.agreeableness = personality_data.get("agreeableness", 0.5)
self.neuroticism = personality_data.get('neuroticism', 0.5) self.neuroticism = personality_data.get("neuroticism", 0.5)
else: else:
# 如果文件不存在根据personality_core和personality_core来设置大五人格特质 # 如果文件不存在根据personality_core和personality_core来设置大五人格特质
if "活泼" in self.personality_core or "开朗" in self.personality_sides: if "活泼" in self.personality_core or "开朗" in self.personality_sides:
@@ -79,7 +81,7 @@ class Personality:
self.openness = 0.5 self.openness = 0.5
@classmethod @classmethod
def initialize(cls, bot_nickname: str, personality_core: str, personality_sides: List[str]) -> 'Personality': def initialize(cls, bot_nickname: str, personality_core: str, personality_sides: List[str]) -> "Personality":
"""初始化人格特质 """初始化人格特质
Args: Args:
@@ -107,11 +109,11 @@ class Personality:
"neuroticism": self.neuroticism, "neuroticism": self.neuroticism,
"bot_nickname": self.bot_nickname, "bot_nickname": self.bot_nickname,
"personality_core": self.personality_core, "personality_core": self.personality_core,
"personality_sides": self.personality_sides "personality_sides": self.personality_sides,
} }
@classmethod @classmethod
def from_dict(cls, data: Dict) -> 'Personality': def from_dict(cls, data: Dict) -> "Personality":
"""从字典创建人格特质实例""" """从字典创建人格特质实例"""
instance = cls.get_instance() instance = cls.get_instance()
for key, value in data.items(): for key, value in data.items():

View File

@@ -2,6 +2,7 @@ import json
from typing import Dict from typing import Dict
import os import os
def load_scenes() -> Dict: def load_scenes() -> Dict:
""" """
从JSON文件加载场景数据 从JSON文件加载场景数据
@@ -10,13 +11,15 @@ def load_scenes() -> Dict:
Dict: 包含所有场景的字典 Dict: 包含所有场景的字典
""" """
current_dir = os.path.dirname(os.path.abspath(__file__)) current_dir = os.path.dirname(os.path.abspath(__file__))
json_path = os.path.join(current_dir, 'template_scene.json') json_path = os.path.join(current_dir, "template_scene.json")
with open(json_path, 'r', encoding='utf-8') as f: with open(json_path, "r", encoding="utf-8") as f:
return json.load(f) return json.load(f)
PERSONALITY_SCENES = load_scenes() PERSONALITY_SCENES = load_scenes()
def get_scene_by_factor(factor: str) -> Dict: def get_scene_by_factor(factor: str) -> Dict:
""" """
根据人格因子获取对应的情景测试 根据人格因子获取对应的情景测试

View File

@@ -100,7 +100,7 @@ class MainSystem:
weight=global_config.weight, weight=global_config.weight,
age=global_config.age, age=global_config.age,
gender=global_config.gender, gender=global_config.gender,
appearance=global_config.appearance appearance=global_config.appearance,
) )
logger.success("个体特征初始化成功") logger.success("个体特征初始化成功")
@@ -136,7 +136,6 @@ class MainSystem:
logger.info("正在进行记忆构建") logger.info("正在进行记忆构建")
await HippocampusManager.get_instance().build_memory() await HippocampusManager.get_instance().build_memory()
async def forget_memory_task(self): async def forget_memory_task(self):
"""记忆遗忘任务""" """记忆遗忘任务"""
while True: while True:
@@ -145,7 +144,6 @@ class MainSystem:
await HippocampusManager.get_instance().forget_memory(percentage=global_config.memory_forget_percentage) await HippocampusManager.get_instance().forget_memory(percentage=global_config.memory_forget_percentage)
print("\033[1;32m[记忆遗忘]\033[0m 记忆遗忘完成") print("\033[1;32m[记忆遗忘]\033[0m 记忆遗忘完成")
async def print_mood_task(self): async def print_mood_task(self):
"""打印情绪状态""" """打印情绪状态"""
while True: while True:

View File

@@ -8,14 +8,15 @@ from ..config.config import global_config
logger = get_module_logger("chat_observer") logger = get_module_logger("chat_observer")
class ChatObserver: class ChatObserver:
"""聊天状态观察器""" """聊天状态观察器"""
# 类级别的实例管理 # 类级别的实例管理
_instances: Dict[str, 'ChatObserver'] = {} _instances: Dict[str, "ChatObserver"] = {}
@classmethod @classmethod
def get_instance(cls, stream_id: str) -> 'ChatObserver': def get_instance(cls, stream_id: str) -> "ChatObserver":
"""获取或创建观察器实例 """获取或创建观察器实例
Args: Args:
@@ -65,10 +66,7 @@ class ChatObserver:
""" """
logger.debug(f"检查距离上一次观察之后是否有了新消息: {self.last_check_time}") logger.debug(f"检查距离上一次观察之后是否有了新消息: {self.last_check_time}")
query = { query = {"chat_id": self.stream_id, "time": {"$gt": self.last_check_time}}
"chat_id": self.stream_id,
"time": {"$gt": self.last_check_time}
}
# 只需要查询是否存在,不需要获取具体消息 # 只需要查询是否存在,不需要获取具体消息
new_message_exists = db.messages.find_one(query) is not None new_message_exists = db.messages.find_one(query) is not None
@@ -121,7 +119,7 @@ class ChatObserver:
start_time: Optional[float] = None, start_time: Optional[float] = None,
end_time: Optional[float] = None, end_time: Optional[float] = None,
limit: Optional[int] = None, limit: Optional[int] = None,
user_id: Optional[str] = None user_id: Optional[str] = None,
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
"""获取消息历史 """获取消息历史
@@ -144,8 +142,7 @@ class ChatObserver:
if user_id is not None: if user_id is not None:
filtered_messages = [ filtered_messages = [
m for m in filtered_messages m for m in filtered_messages if UserInfo.from_dict(m.get("user_info", {})).user_id == user_id
if UserInfo.from_dict(m.get("user_info", {})).user_id == user_id
] ]
if limit is not None: if limit is not None:
@@ -166,9 +163,7 @@ class ChatObserver:
if last_message: if last_message:
query["time"] = {"$gt": last_message["time"]} query["time"] = {"$gt": last_message["time"]}
new_messages = list( new_messages = list(db.messages.find(query).sort("time", 1))
db.messages.find(query).sort("time", 1)
)
if new_messages: if new_messages:
self.last_message_read = new_messages[-1]["message_id"] self.last_message_read = new_messages[-1]["message_id"]
@@ -184,10 +179,7 @@ class ChatObserver:
Returns: Returns:
List[Dict[str, Any]]: 最多5条消息 List[Dict[str, Any]]: 最多5条消息
""" """
query = { query = {"chat_id": self.stream_id, "time": {"$lt": time_point}}
"chat_id": self.stream_id,
"time": {"$lt": time_point}
}
new_messages = list( new_messages = list(
db.messages.find(query).sort("time", -1).limit(5) # 倒序获取5条 db.messages.find(query).sort("time", -1).limit(5) # 倒序获取5条

View File

@@ -26,6 +26,7 @@ logger = get_module_logger("pfc")
class ConversationState(Enum): class ConversationState(Enum):
"""对话状态""" """对话状态"""
INIT = "初始化" INIT = "初始化"
RETHINKING = "重新思考" RETHINKING = "重新思考"
ANALYZING = "分析历史" ANALYZING = "分析历史"
@@ -47,10 +48,7 @@ class ActionPlanner:
def __init__(self, stream_id: str): def __init__(self, stream_id: str):
self.llm = LLM_request( self.llm = LLM_request(
model=global_config.llm_normal, model=global_config.llm_normal, temperature=0.7, max_tokens=1000, request_type="action_planning"
temperature=0.7,
max_tokens=1000,
request_type="action_planning"
) )
self.personality_info = Individuality.get_instance().get_prompt(type="personality", x_person=2, level=2) self.personality_info = Individuality.get_instance().get_prompt(type="personality", x_person=2, level=2)
self.name = global_config.BOT_NICKNAME self.name = global_config.BOT_NICKNAME
@@ -93,7 +91,7 @@ class ActionPlanner:
# 构建action历史文本 # 构建action历史文本
action_history_text = "" action_history_text = ""
if action_history: if action_history:
if action_history[-1]['action'] == "direct_reply": if action_history[-1]["action"] == "direct_reply":
action_history_text = "你刚刚发言回复了对方" action_history_text = "你刚刚发言回复了对方"
# 获取时间信息 # 获取时间信息
@@ -130,9 +128,7 @@ judge_conversation: 判断对话是否结束,当发现对话目标已经达到
# 使用简化函数提取JSON内容 # 使用简化函数提取JSON内容
success, result = get_items_from_json( success, result = get_items_from_json(
content, content, "action", "reason", default_values={"action": "direct_reply", "reason": "默认原因"}
"action", "reason",
default_values={"action": "direct_reply", "reason": "默认原因"}
) )
if not success: if not success:
@@ -142,7 +138,14 @@ judge_conversation: 判断对话是否结束,当发现对话目标已经达到
reason = result["reason"] reason = result["reason"]
# 验证action类型 # 验证action类型
if action not in ["direct_reply", "fetch_knowledge", "wait", "listening", "rethink_goal", "judge_conversation"]: if action not in [
"direct_reply",
"fetch_knowledge",
"wait",
"listening",
"rethink_goal",
"judge_conversation",
]:
logger.warning(f"未知的行动类型: {action}默认使用listening") logger.warning(f"未知的行动类型: {action}默认使用listening")
action = "listening" action = "listening"
@@ -160,10 +163,7 @@ class GoalAnalyzer:
def __init__(self, stream_id: str): def __init__(self, stream_id: str):
self.llm = LLM_request( self.llm = LLM_request(
model=global_config.llm_normal, model=global_config.llm_normal, temperature=0.7, max_tokens=1000, request_type="conversation_goal"
temperature=0.7,
max_tokens=1000,
request_type="conversation_goal"
) )
self.personality_info = Individuality.get_instance().get_prompt(type="personality", x_person=2, level=2) self.personality_info = Individuality.get_instance().get_prompt(type="personality", x_person=2, level=2)
@@ -238,9 +238,7 @@ class GoalAnalyzer:
# 使用简化函数提取JSON内容 # 使用简化函数提取JSON内容
success, result = get_items_from_json( success, result = get_items_from_json(
content, content, "goal", "reasoning", required_types={"goal": str, "reasoning": str}
"goal", "reasoning",
required_types={"goal": str, "reasoning": str}
) )
if not success: if not success:
@@ -372,12 +370,10 @@ class GoalAnalyzer:
# 使用简化函数提取JSON内容 # 使用简化函数提取JSON内容
success, result = get_items_from_json( success, result = get_items_from_json(
content, content,
"goal_achieved", "stop_conversation", "reason", "goal_achieved",
required_types={ "stop_conversation",
"goal_achieved": bool, "reason",
"stop_conversation": bool, required_types={"goal_achieved": bool, "stop_conversation": bool, "reason": str},
"reason": str
}
) )
if not success: if not success:
@@ -402,6 +398,7 @@ class GoalAnalyzer:
class Waiter: class Waiter:
"""快 速 等 待""" """快 速 等 待"""
def __init__(self, stream_id: str): def __init__(self, stream_id: str):
self.chat_observer = ChatObserver.get_instance(stream_id) self.chat_observer = ChatObserver.get_instance(stream_id)
self.personality_info = Individuality.get_instance().get_prompt(type="personality", x_person=2, level=2) self.personality_info = Individuality.get_instance().get_prompt(type="personality", x_person=2, level=2)
@@ -430,10 +427,7 @@ class ReplyGenerator:
def __init__(self, stream_id: str): def __init__(self, stream_id: str):
self.llm = LLM_request( self.llm = LLM_request(
model=global_config.llm_normal, model=global_config.llm_normal, temperature=0.7, max_tokens=300, request_type="reply_generation"
temperature=0.7,
max_tokens=300,
request_type="reply_generation"
) )
self.personality_info = Individuality.get_instance().get_prompt(type="personality", x_person=2, level=2) self.personality_info = Individuality.get_instance().get_prompt(type="personality", x_person=2, level=2)
self.name = global_config.BOT_NICKNAME self.name = global_config.BOT_NICKNAME
@@ -446,7 +440,7 @@ class ReplyGenerator:
chat_history: List[Message], chat_history: List[Message],
knowledge_cache: Dict[str, str], knowledge_cache: Dict[str, str],
previous_reply: Optional[str] = None, previous_reply: Optional[str] = None,
retry_count: int = 0 retry_count: int = 0,
) -> str: ) -> str:
"""生成回复 """生成回复
@@ -507,7 +501,7 @@ class ReplyGenerator:
2. 体现你的性格特征 2. 体现你的性格特征
3. 自然流畅,像正常聊天一样,简短 3. 自然流畅,像正常聊天一样,简短
4. 适当利用相关知识,但不要生硬引用 4. 适当利用相关知识,但不要生硬引用
{'5. 改进上一次回复中的问题' if previous_reply else ''} {"5. 改进上一次回复中的问题" if previous_reply else ""}
请注意把握聊天内容,不要回复的太有条理,可以有个性。请分清""和对方说的话,不要把""说的话当做对方说的话,这是你自己说的话。 请注意把握聊天内容,不要回复的太有条理,可以有个性。请分清""和对方说的话,不要把""说的话当做对方说的话,这是你自己说的话。
请你回复的平淡一些,简短一些,说中文,不要刻意突出自身学科背景,尽量不要说你说过的话 请你回复的平淡一些,简短一些,说中文,不要刻意突出自身学科背景,尽量不要说你说过的话
@@ -525,10 +519,7 @@ class ReplyGenerator:
# 如果有新消息,重新生成回复 # 如果有新消息,重新生成回复
if is_new: if is_new:
logger.info("检测到新消息,重新生成回复") logger.info("检测到新消息,重新生成回复")
return await self.generate( return await self.generate(goal, chat_history, knowledge_cache, None, retry_count)
goal, chat_history, knowledge_cache,
None, retry_count
)
return content return content
@@ -536,12 +527,7 @@ class ReplyGenerator:
logger.error(f"生成回复时出错: {e}") logger.error(f"生成回复时出错: {e}")
return "抱歉,我现在有点混乱,让我重新思考一下..." return "抱歉,我现在有点混乱,让我重新思考一下..."
async def check_reply( async def check_reply(self, reply: str, goal: str, retry_count: int = 0) -> Tuple[bool, str, bool]:
self,
reply: str,
goal: str,
retry_count: int = 0
) -> Tuple[bool, str, bool]:
"""检查回复是否合适 """检查回复是否合适
Args: Args:
@@ -557,13 +543,13 @@ class ReplyGenerator:
class Conversation: class Conversation:
# 类级别的实例管理 # 类级别的实例管理
_instances: Dict[str, 'Conversation'] = {} _instances: Dict[str, "Conversation"] = {}
_instance_lock = asyncio.Lock() # 类级别的全局锁 _instance_lock = asyncio.Lock() # 类级别的全局锁
_init_events: Dict[str, asyncio.Event] = {} # 初始化完成事件 _init_events: Dict[str, asyncio.Event] = {} # 初始化完成事件
_initializing: Dict[str, bool] = {} # 标记是否正在初始化 _initializing: Dict[str, bool] = {} # 标记是否正在初始化
@classmethod @classmethod
async def get_instance(cls, stream_id: str) -> Optional['Conversation']: async def get_instance(cls, stream_id: str) -> Optional["Conversation"]:
"""获取或创建对话实例 """获取或创建对话实例
Args: Args:
@@ -686,7 +672,7 @@ class Conversation:
self.current_method, self.current_method,
self.goal_reasoning, self.goal_reasoning,
self.action_history, # 传入action历史 self.action_history, # 传入action历史
self.chat_observer # 传入chat_observer self.chat_observer, # 传入chat_observer
) )
# 执行行动 # 执行行动
@@ -705,7 +691,7 @@ class Conversation:
time=msg_dict["time"], time=msg_dict["time"],
user_info=user_info, user_info=user_info,
processed_plain_text=msg_dict.get("processed_plain_text", ""), processed_plain_text=msg_dict.get("processed_plain_text", ""),
detailed_plain_text=msg_dict.get("detailed_plain_text", "") detailed_plain_text=msg_dict.get("detailed_plain_text", ""),
) )
except Exception as e: except Exception as e:
logger.warning(f"转换消息时出错: {e}") logger.warning(f"转换消息时出错: {e}")
@@ -716,11 +702,9 @@ class Conversation:
logger.info(f"执行行动: {action}, 原因: {reason}") logger.info(f"执行行动: {action}, 原因: {reason}")
# 记录action历史 # 记录action历史
self.action_history.append({ self.action_history.append(
"action": action, {"action": action, "reason": reason, "time": datetime.datetime.now().strftime("%H:%M:%S")}
"reason": reason, )
"time": datetime.datetime.now().strftime("%H:%M:%S")
})
# 只保留最近的10条记录 # 只保留最近的10条记录
if len(self.action_history) > 10: if len(self.action_history) > 10:
@@ -733,13 +717,12 @@ class Conversation:
self.current_goal, self.current_goal,
self.current_method, self.current_method,
[self._convert_to_message(msg) for msg in messages], [self._convert_to_message(msg) for msg in messages],
self.knowledge_cache self.knowledge_cache,
) )
# 检查回复是否合适 # 检查回复是否合适
is_suitable, reason, need_replan = await self.reply_generator.check_reply( is_suitable, reason, need_replan = await self.reply_generator.check_reply(
self.generated_reply, self.generated_reply, self.current_goal
self.current_goal
) )
if not is_suitable: if not is_suitable:
@@ -756,29 +739,34 @@ class Conversation:
self.current_goal, self.current_goal,
self.current_method, self.current_method,
[self._convert_to_message(msg) for msg in messages], [self._convert_to_message(msg) for msg in messages],
self.knowledge_cache self.knowledge_cache,
) )
# 检查使用新目标生成的回复是否合适 # 检查使用新目标生成的回复是否合适
is_suitable, reason, _ = await self.reply_generator.check_reply( is_suitable, reason, _ = await self.reply_generator.check_reply(
self.generated_reply, self.generated_reply, self.current_goal
self.current_goal
) )
if is_suitable: if is_suitable:
# 如果新目标的回复合适,调整目标优先级 # 如果新目标的回复合适,调整目标优先级
await self.goal_analyzer._update_goals( await self.goal_analyzer._update_goals(
self.current_goal, self.current_goal, self.current_method, self.goal_reasoning
self.current_method,
self.goal_reasoning
) )
else: else:
# 如果新目标还是不合适,重新思考目标 # 如果新目标还是不合适,重新思考目标
self.state = ConversationState.RETHINKING self.state = ConversationState.RETHINKING
self.current_goal, self.current_method, self.goal_reasoning = await self.goal_analyzer.analyze_goal() (
self.current_goal,
self.current_method,
self.goal_reasoning,
) = await self.goal_analyzer.analyze_goal()
return return
else: else:
# 没有备选目标,重新分析 # 没有备选目标,重新分析
self.state = ConversationState.RETHINKING self.state = ConversationState.RETHINKING
self.current_goal, self.current_method, self.goal_reasoning = await self.goal_analyzer.analyze_goal() (
self.current_goal,
self.current_method,
self.goal_reasoning,
) = await self.goal_analyzer.analyze_goal()
return return
else: else:
# 重新生成回复 # 重新生成回复
@@ -787,7 +775,7 @@ class Conversation:
self.current_method, self.current_method,
[self._convert_to_message(msg) for msg in messages], [self._convert_to_message(msg) for msg in messages],
self.knowledge_cache, self.knowledge_cache,
self.generated_reply # 将不合适的回复作为previous_reply传入 self.generated_reply, # 将不合适的回复作为previous_reply传入
) )
while self.chat_observer.check(): while self.chat_observer.check():
@@ -805,13 +793,17 @@ class Conversation:
self.current_goal, self.current_goal,
self.current_method, self.current_method,
[self._convert_to_message(msg) for msg in messages], [self._convert_to_message(msg) for msg in messages],
self.knowledge_cache self.knowledge_cache,
) )
is_suitable = True # 假设使用新目标后回复是合适的 is_suitable = True # 假设使用新目标后回复是合适的
else: else:
# 没有备选目标,重新分析 # 没有备选目标,重新分析
self.state = ConversationState.RETHINKING self.state = ConversationState.RETHINKING
self.current_goal, self.current_method, self.goal_reasoning = await self.goal_analyzer.analyze_goal() (
self.current_goal,
self.current_method,
self.goal_reasoning,
) = await self.goal_analyzer.analyze_goal()
return return
else: else:
# 重新生成回复 # 重新生成回复
@@ -820,7 +812,7 @@ class Conversation:
self.current_method, self.current_method,
[self._convert_to_message(msg) for msg in messages], [self._convert_to_message(msg) for msg in messages],
self.knowledge_cache, self.knowledge_cache,
self.generated_reply # 将不合适的回复作为previous_reply传入 self.generated_reply, # 将不合适的回复作为previous_reply传入
) )
await self._send_reply() await self._send_reply()
@@ -829,8 +821,7 @@ class Conversation:
self.state = ConversationState.GENERATING self.state = ConversationState.GENERATING
messages = self.chat_observer.get_message_history(limit=30) messages = self.chat_observer.get_message_history(limit=30)
knowledge, sources = await self.knowledge_fetcher.fetch( knowledge, sources = await self.knowledge_fetcher.fetch(
self.current_goal, self.current_goal, [self._convert_to_message(msg) for msg in messages]
[self._convert_to_message(msg) for msg in messages]
) )
logger.info(f"获取到知识,来源: {sources}") logger.info(f"获取到知识,来源: {sources}")
@@ -841,13 +832,12 @@ class Conversation:
self.current_goal, self.current_goal,
self.current_method, self.current_method,
[self._convert_to_message(msg) for msg in messages], [self._convert_to_message(msg) for msg in messages],
self.knowledge_cache self.knowledge_cache,
) )
# 检查回复是否合适 # 检查回复是否合适
is_suitable, reason, need_replan = await self.reply_generator.check_reply( is_suitable, reason, need_replan = await self.reply_generator.check_reply(
self.generated_reply, self.generated_reply, self.current_goal
self.current_goal
) )
if not is_suitable: if not is_suitable:
@@ -861,8 +851,7 @@ class Conversation:
logger.info(f"切换到备选目标: {self.current_goal}") logger.info(f"切换到备选目标: {self.current_goal}")
# 使用新目标获取知识并生成回复 # 使用新目标获取知识并生成回复
knowledge, sources = await self.knowledge_fetcher.fetch( knowledge, sources = await self.knowledge_fetcher.fetch(
self.current_goal, self.current_goal, [self._convert_to_message(msg) for msg in messages]
[self._convert_to_message(msg) for msg in messages]
) )
if knowledge != "未找到相关知识": if knowledge != "未找到相关知识":
self.knowledge_cache[sources] = knowledge self.knowledge_cache[sources] = knowledge
@@ -871,12 +860,16 @@ class Conversation:
self.current_goal, self.current_goal,
self.current_method, self.current_method,
[self._convert_to_message(msg) for msg in messages], [self._convert_to_message(msg) for msg in messages],
self.knowledge_cache self.knowledge_cache,
) )
else: else:
# 没有备选目标,重新分析 # 没有备选目标,重新分析
self.state = ConversationState.RETHINKING self.state = ConversationState.RETHINKING
self.current_goal, self.current_method, self.goal_reasoning = await self.goal_analyzer.analyze_goal() (
self.current_goal,
self.current_method,
self.goal_reasoning,
) = await self.goal_analyzer.analyze_goal()
return return
else: else:
# 重新生成回复 # 重新生成回复
@@ -885,7 +878,7 @@ class Conversation:
self.current_method, self.current_method,
[self._convert_to_message(msg) for msg in messages], [self._convert_to_message(msg) for msg in messages],
self.knowledge_cache, self.knowledge_cache,
self.generated_reply # 将不合适的回复作为previous_reply传入 self.generated_reply, # 将不合适的回复作为previous_reply传入
) )
await self._send_reply() await self._send_reply()
@@ -896,7 +889,9 @@ class Conversation:
elif action == "judge_conversation": elif action == "judge_conversation":
self.state = ConversationState.JUDGING self.state = ConversationState.JUDGING
self.goal_achieved, self.stop_conversation, self.reason = await self.goal_analyzer.analyze_conversation(self.current_goal, self.goal_reasoning) self.goal_achieved, self.stop_conversation, self.reason = await self.goal_analyzer.analyze_conversation(
self.current_goal, self.goal_reasoning
)
# 如果当前目标达成但还有其他目标 # 如果当前目标达成但还有其他目标
if self.goal_achieved and not self.stop_conversation: if self.goal_achieved and not self.stop_conversation:
@@ -943,7 +938,7 @@ class Conversation:
await self.direct_sender.send_message( await self.direct_sender.send_message(
chat_stream=self.chat_stream, chat_stream=self.chat_stream,
content="抱歉,由于等待时间过长,我需要先去忙别的了。下次再聊吧~", content="抱歉,由于等待时间过长,我需要先去忙别的了。下次再聊吧~",
reply_to_message=latest_message reply_to_message=latest_message,
) )
except Exception as e: except Exception as e:
logger.error(f"发送超时消息失败: {str(e)}") logger.error(f"发送超时消息失败: {str(e)}")
@@ -962,9 +957,7 @@ class Conversation:
latest_message = self._convert_to_message(messages[0]) latest_message = self._convert_to_message(messages[0])
try: try:
await self.direct_sender.send_message( await self.direct_sender.send_message(
chat_stream=self.chat_stream, chat_stream=self.chat_stream, content=self.generated_reply, reply_to_message=latest_message
content=self.generated_reply,
reply_to_message=latest_message
) )
self.chat_observer.trigger_update() # 触发立即更新 self.chat_observer.trigger_update() # 触发立即更新
if not await self.chat_observer.wait_for_update(): if not await self.chat_observer.wait_for_update():
@@ -1037,4 +1030,3 @@ class DirectMessageSender:
except Exception as e: except Exception as e:
self.logger.error(f"直接发送消息失败: {str(e)}") self.logger.error(f"直接发送消息失败: {str(e)}")
raise raise

View File

@@ -7,15 +7,13 @@ from ..chat.message import Message
logger = get_module_logger("knowledge_fetcher") logger = get_module_logger("knowledge_fetcher")
class KnowledgeFetcher: class KnowledgeFetcher:
"""知识调取器""" """知识调取器"""
def __init__(self): def __init__(self):
self.llm = LLM_request( self.llm = LLM_request(
model=global_config.llm_normal, model=global_config.llm_normal, temperature=0.7, max_tokens=1000, request_type="knowledge_fetch"
temperature=0.7,
max_tokens=1000,
request_type="knowledge_fetch"
) )
async def fetch(self, query: str, chat_history: List[Message]) -> Tuple[str, str]: async def fetch(self, query: str, chat_history: List[Message]) -> Tuple[str, str]:
@@ -40,7 +38,7 @@ class KnowledgeFetcher:
max_memory_num=3, max_memory_num=3,
max_memory_length=2, max_memory_length=2,
max_depth=3, max_depth=3,
fast_retrieval=False fast_retrieval=False,
) )
if related_memory: if related_memory:

View File

@@ -5,11 +5,12 @@ from src.common.logger import get_module_logger
logger = get_module_logger("pfc_utils") logger = get_module_logger("pfc_utils")
def get_items_from_json( def get_items_from_json(
content: str, content: str,
*items: str, *items: str,
default_values: Optional[Dict[str, Any]] = None, default_values: Optional[Dict[str, Any]] = None,
required_types: Optional[Dict[str, type]] = None required_types: Optional[Dict[str, type]] = None,
) -> Tuple[bool, Dict[str, Any]]: ) -> Tuple[bool, Dict[str, Any]]:
"""从文本中提取JSON内容并获取指定字段 """从文本中提取JSON内容并获取指定字段
@@ -34,7 +35,7 @@ def get_items_from_json(
json_data = json.loads(content) json_data = json.loads(content)
except json.JSONDecodeError: except json.JSONDecodeError:
# 如果直接解析失败尝试查找和提取JSON部分 # 如果直接解析失败尝试查找和提取JSON部分
json_pattern = r'\{[^{}]*\}' json_pattern = r"\{[^{}]*\}"
json_match = re.search(json_pattern, content) json_match = re.search(json_pattern, content)
if json_match: if json_match:
try: try:

View File

@@ -9,26 +9,19 @@ from ..message.message_base import UserInfo
logger = get_module_logger("reply_checker") logger = get_module_logger("reply_checker")
class ReplyChecker: class ReplyChecker:
"""回复检查器""" """回复检查器"""
def __init__(self, stream_id: str): def __init__(self, stream_id: str):
self.llm = LLM_request( self.llm = LLM_request(
model=global_config.llm_normal, model=global_config.llm_normal, temperature=0.7, max_tokens=1000, request_type="reply_check"
temperature=0.7,
max_tokens=1000,
request_type="reply_check"
) )
self.name = global_config.BOT_NICKNAME self.name = global_config.BOT_NICKNAME
self.chat_observer = ChatObserver.get_instance(stream_id) self.chat_observer = ChatObserver.get_instance(stream_id)
self.max_retries = 2 # 最大重试次数 self.max_retries = 2 # 最大重试次数
async def check( async def check(self, reply: str, goal: str, retry_count: int = 0) -> Tuple[bool, str, bool]:
self,
reply: str,
goal: str,
retry_count: int = 0
) -> Tuple[bool, str, bool]:
"""检查生成的回复是否合适 """检查生成的回复是否合适
Args: Args:
@@ -92,7 +85,8 @@ class ReplyChecker:
except json.JSONDecodeError: except json.JSONDecodeError:
# 如果直接解析失败尝试查找和提取JSON部分 # 如果直接解析失败尝试查找和提取JSON部分
import re import re
json_pattern = r'\{[^{}]*\}'
json_pattern = r"\{[^{}]*\}"
json_match = re.search(json_pattern, content) json_match = re.search(json_pattern, content)
if json_match: if json_match:
try: try:

View File

@@ -12,5 +12,5 @@ __all__ = [
"chat_manager", "chat_manager",
"message_manager", "message_manager",
"MessageStorage", "MessageStorage",
"auto_speak_manager" "auto_speak_manager",
] ]

View File

@@ -377,7 +377,6 @@ class EmojiManager:
except Exception: except Exception:
logger.exception("[错误] 扫描表情包失败") logger.exception("[错误] 扫描表情包失败")
def check_emoji_file_integrity(self): def check_emoji_file_integrity(self):
"""检查表情包文件完整性 """检查表情包文件完整性
如果文件已被删除,则从数据库中移除对应记录 如果文件已被删除,则从数据库中移除对应记录
@@ -542,7 +541,7 @@ class EmojiManager:
logger.info("[扫描] 开始扫描新表情包...") logger.info("[扫描] 开始扫描新表情包...")
if self.emoji_num < self.emoji_num_max: if self.emoji_num < self.emoji_num_max:
await self.scan_new_emojis() await self.scan_new_emojis()
if (self.emoji_num > self.emoji_num_max): if self.emoji_num > self.emoji_num_max:
logger.warning(f"[警告] 表情包数量超过最大限制: {self.emoji_num} > {self.emoji_num_max},跳过注册") logger.warning(f"[警告] 表情包数量超过最大限制: {self.emoji_num} > {self.emoji_num_max},跳过注册")
if not global_config.max_reach_deletion: if not global_config.max_reach_deletion:
logger.warning("表情包数量超过最大限制,终止注册") logger.warning("表情包数量超过最大限制,终止注册")
@@ -580,5 +579,6 @@ class EmojiManager:
except Exception as e: except Exception as e:
logger.error(f"[错误] 删除图片目录失败: {str(e)}") logger.error(f"[错误] 删除图片目录失败: {str(e)}")
# 创建全局单例 # 创建全局单例
emoji_manager = EmojiManager() emoji_manager = EmojiManager()

View File

@@ -13,6 +13,7 @@ from ..config.config import global_config
logger = get_module_logger("message_buffer") logger = get_module_logger("message_buffer")
@dataclass @dataclass
class CacheMessages: class CacheMessages:
message: MessageRecv message: MessageRecv
@@ -37,13 +38,14 @@ class MessageBuffer:
async def start_caching_messages(self, message: MessageRecv): async def start_caching_messages(self, message: MessageRecv):
"""添加消息,启动缓冲""" """添加消息,启动缓冲"""
if not global_config.message_buffer: if not global_config.message_buffer:
person_id = person_info_manager.get_person_id(message.message_info.user_info.platform, person_id = person_info_manager.get_person_id(
message.message_info.user_info.user_id) message.message_info.user_info.platform, message.message_info.user_info.user_id
)
asyncio.create_task(self.save_message_interval(person_id, message.message_info)) asyncio.create_task(self.save_message_interval(person_id, message.message_info))
return return
person_id_ = self.get_person_id_(message.message_info.platform, person_id_ = self.get_person_id_(
message.message_info.user_info.user_id, message.message_info.platform, message.message_info.user_info.user_id, message.message_info.group_info
message.message_info.group_info) )
async with self.lock: async with self.lock:
if person_id_ not in self.buffer_pool: if person_id_ not in self.buffer_pool:
@@ -66,7 +68,7 @@ class MessageBuffer:
recent_F_count += 1 recent_F_count += 1
# 判断条件最近T之后有超过3-5条F # 判断条件最近T之后有超过3-5条F
if (recent_F_count >= random.randint(3, 5)): if recent_F_count >= random.randint(3, 5):
new_msg = CacheMessages(message=message, result="T") new_msg = CacheMessages(message=message, result="T")
new_msg.cache_determination.set() new_msg.cache_determination.set()
self.buffer_pool[person_id_][message.message_info.message_id] = new_msg self.buffer_pool[person_id_][message.message_info.message_id] = new_msg
@@ -77,12 +79,11 @@ class MessageBuffer:
self.buffer_pool[person_id_][message.message_info.message_id] = CacheMessages(message=message) self.buffer_pool[person_id_][message.message_info.message_id] = CacheMessages(message=message)
# 启动3秒缓冲计时器 # 启动3秒缓冲计时器
person_id = person_info_manager.get_person_id(message.message_info.user_info.platform, person_id = person_info_manager.get_person_id(
message.message_info.user_info.user_id) message.message_info.user_info.platform, message.message_info.user_info.user_id
)
asyncio.create_task(self.save_message_interval(person_id, message.message_info)) asyncio.create_task(self.save_message_interval(person_id, message.message_info))
asyncio.create_task(self._debounce_processor(person_id_, asyncio.create_task(self._debounce_processor(person_id_, message.message_info.message_id, person_id))
message.message_info.message_id,
person_id))
async def _debounce_processor(self, person_id_: str, message_id: str, person_id: str): async def _debounce_processor(self, person_id_: str, message_id: str, person_id: str):
"""等待3秒无新消息""" """等待3秒无新消息"""
@@ -94,8 +95,7 @@ class MessageBuffer:
await asyncio.sleep(interval_time) await asyncio.sleep(interval_time)
async with self.lock: async with self.lock:
if (person_id_ not in self.buffer_pool or if person_id_ not in self.buffer_pool or message_id not in self.buffer_pool[person_id_]:
message_id not in self.buffer_pool[person_id_]):
logger.debug(f"消息已被清理msgid: {message_id}") logger.debug(f"消息已被清理msgid: {message_id}")
return return
@@ -104,15 +104,13 @@ class MessageBuffer:
cache_msg.result = "T" cache_msg.result = "T"
cache_msg.cache_determination.set() cache_msg.cache_determination.set()
async def query_buffer_result(self, message: MessageRecv) -> bool: async def query_buffer_result(self, message: MessageRecv) -> bool:
"""查询缓冲结果,并清理""" """查询缓冲结果,并清理"""
if not global_config.message_buffer: if not global_config.message_buffer:
return True return True
person_id_ = self.get_person_id_(message.message_info.platform, person_id_ = self.get_person_id_(
message.message_info.user_info.user_id, message.message_info.platform, message.message_info.user_info.user_id, message.message_info.group_info
message.message_info.group_info) )
async with self.lock: async with self.lock:
user_msgs = self.buffer_pool.get(person_id_, {}) user_msgs = self.buffer_pool.get(person_id_, {})
@@ -144,8 +142,7 @@ class MessageBuffer:
keep_msgs[msg_id] = msg keep_msgs[msg_id] = msg
elif msg.result == "F": elif msg.result == "F":
# 收集F消息的文本内容 # 收集F消息的文本内容
if (hasattr(msg.message, 'processed_plain_text') if hasattr(msg.message, "processed_plain_text") and msg.message.processed_plain_text:
and msg.message.processed_plain_text):
if msg.message.message_segment.type == "text": if msg.message.message_segment.type == "text":
combined_text.append(msg.message.processed_plain_text) combined_text.append(msg.message.processed_plain_text)
elif msg.message.message_segment.type != "text": elif msg.message.message_segment.type != "text":
@@ -182,7 +179,7 @@ class MessageBuffer:
"platform": message.platform, "platform": message.platform,
"user_id": message.user_info.user_id, "user_id": message.user_info.user_id,
"nickname": message.user_info.user_nickname, "nickname": message.user_info.user_nickname,
"konw_time" : int(time.time()) "konw_time": int(time.time()),
} }
await person_info_manager.update_one_field(person_id, "msg_interval_list", message_interval_list, data) await person_info_manager.update_one_field(person_id, "msg_interval_list", message_interval_list, data)

View File

@@ -68,7 +68,8 @@ class Message_Sender:
typing_time = calculate_typing_time( typing_time = calculate_typing_time(
input_string=message.processed_plain_text, input_string=message.processed_plain_text,
thinking_start_time=message.thinking_start_time, thinking_start_time=message.thinking_start_time,
is_emoji=message.is_emoji) is_emoji=message.is_emoji,
)
logger.debug(f"{message.processed_plain_text},{typing_time},计算输入时间结束") logger.debug(f"{message.processed_plain_text},{typing_time},计算输入时间结束")
await asyncio.sleep(typing_time) await asyncio.sleep(typing_time)
logger.debug(f"{message.processed_plain_text},{typing_time},等待输入时间结束") logger.debug(f"{message.processed_plain_text},{typing_time},等待输入时间结束")

View File

@@ -56,14 +56,13 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> bool:
logger.info("被@回复概率设置为100%") logger.info("被@回复概率设置为100%")
else: else:
if not is_mentioned: if not is_mentioned:
# 判断是否被回复 # 判断是否被回复
if re.match(f"回复[\s\S]*?\({global_config.BOT_QQ}\)的消息,说:", message.processed_plain_text): if re.match(f"回复[\s\S]*?\({global_config.BOT_QQ}\)的消息,说:", message.processed_plain_text):
is_mentioned = True is_mentioned = True
# 判断内容中是否被提及 # 判断内容中是否被提及
message_content = re.sub(r'\@[\s\S]*?(\d+)','', message.processed_plain_text) message_content = re.sub(r"\@[\s\S]*?(\d+)", "", message.processed_plain_text)
message_content = re.sub(r'回复[\s\S]*?\((\d+)\)的消息,说: ','', message_content) message_content = re.sub(r"回复[\s\S]*?\((\d+)\)的消息,说: ", "", message_content)
for keyword in keywords: for keyword in keywords:
if keyword in message_content: if keyword in message_content:
is_mentioned = True is_mentioned = True
@@ -359,7 +358,13 @@ def process_llm_response(text: str) -> List[str]:
return sentences return sentences
def calculate_typing_time(input_string: str, thinking_start_time: float, chinese_time: float = 0.2, english_time: float = 0.1, is_emoji: bool = False) -> float: def calculate_typing_time(
input_string: str,
thinking_start_time: float,
chinese_time: float = 0.2,
english_time: float = 0.1,
is_emoji: bool = False,
) -> float:
""" """
计算输入字符串所需的时间,中文和英文字符有不同的输入时间 计算输入字符串所需的时间,中文和英文字符有不同的输入时间
input_string (str): 输入的字符串 input_string (str): 输入的字符串
@@ -394,7 +399,6 @@ def calculate_typing_time(input_string: str, thinking_start_time: float, chinese
else: # 其他字符(如英文) else: # 其他字符(如英文)
total_time += english_time total_time += english_time
if is_emoji: if is_emoji:
total_time = 1 total_time = 1
@@ -535,22 +539,18 @@ def count_messages_between(start_time: float, end_time: float, stream_id: str) -
try: try:
# 获取开始时间之前最新的一条消息 # 获取开始时间之前最新的一条消息
start_message = db.messages.find_one( start_message = db.messages.find_one(
{ {"chat_id": stream_id, "time": {"$lte": start_time}},
"chat_id": stream_id, sort=[("time", -1), ("_id", -1)], # 按时间倒序_id倒序最后插入的在前
"time": {"$lte": start_time}
},
sort=[("time", -1), ("_id", -1)] # 按时间倒序_id倒序最后插入的在前
) )
# 获取结束时间最近的一条消息 # 获取结束时间最近的一条消息
# 先找到结束时间点的所有消息 # 先找到结束时间点的所有消息
end_time_messages = list(db.messages.find( end_time_messages = list(
{ db.messages.find(
"chat_id": stream_id, {"chat_id": stream_id, "time": {"$lte": end_time}},
"time": {"$lte": end_time} sort=[("time", -1)], # 先按时间倒序
}, ).limit(10)
sort=[("time", -1)] # 先按时间倒序 ) # 限制查询数量,避免性能问题
).limit(10)) # 限制查询数量,避免性能问题
if not end_time_messages: if not end_time_messages:
logger.warning(f"未找到结束时间 {end_time} 之前的消息") logger.warning(f"未找到结束时间 {end_time} 之前的消息")
@@ -559,10 +559,7 @@ def count_messages_between(start_time: float, end_time: float, stream_id: str) -
# 找到最大时间 # 找到最大时间
max_time = end_time_messages[0]["time"] max_time = end_time_messages[0]["time"]
# 在最大时间的消息中找最后插入的_id最大的 # 在最大时间的消息中找最后插入的_id最大的
end_message = max( end_message = max([msg for msg in end_time_messages if msg["time"] == max_time], key=lambda x: x["_id"])
[msg for msg in end_time_messages if msg["time"] == max_time],
key=lambda x: x["_id"]
)
if not start_message: if not start_message:
logger.warning(f"未找到开始时间 {start_time} 之前的消息") logger.warning(f"未找到开始时间 {start_time} 之前的消息")
@@ -590,16 +587,12 @@ def count_messages_between(start_time: float, end_time: float, stream_id: str) -
# 获取并打印这个时间范围内的所有消息 # 获取并打印这个时间范围内的所有消息
# print("\n=== 时间范围内的所有消息 ===") # print("\n=== 时间范围内的所有消息 ===")
all_messages = list(db.messages.find( all_messages = list(
{ db.messages.find(
"chat_id": stream_id, {"chat_id": stream_id, "time": {"$gte": start_message["time"], "$lte": end_message["time"]}},
"time": { sort=[("time", 1), ("_id", 1)], # 按时间正序_id正序
"$gte": start_message["time"], )
"$lte": end_message["time"] )
}
},
sort=[("time", 1), ("_id", 1)] # 按时间正序_id正序
))
count = 0 count = 0
total_length = 0 total_length = 0

View File

@@ -245,7 +245,7 @@ class ImageManager:
try: try:
while True: while True:
gif.seek(len(frames)) gif.seek(len(frames))
frame = gif.convert('RGB') frame = gif.convert("RGB")
frames.append(frame.copy()) frames.append(frame.copy())
except EOFError: except EOFError:
pass pass
@@ -270,12 +270,13 @@ class ImageManager:
target_width = int((target_height / frame_height) * frame_width) target_width = int((target_height / frame_height) * frame_width)
# 调整所有帧的大小 # 调整所有帧的大小
resized_frames = [frame.resize((target_width, target_height), Image.Resampling.LANCZOS) resized_frames = [
for frame in selected_frames] frame.resize((target_width, target_height), Image.Resampling.LANCZOS) for frame in selected_frames
]
# 创建拼接图像 # 创建拼接图像
total_width = target_width * len(resized_frames) total_width = target_width * len(resized_frames)
combined_image = Image.new('RGB', (total_width, target_height)) combined_image = Image.new("RGB", (total_width, target_height))
# 水平拼接图像 # 水平拼接图像
for idx, frame in enumerate(resized_frames): for idx, frame in enumerate(resized_frames):
@@ -283,8 +284,8 @@ class ImageManager:
# 转换为base64 # 转换为base64
buffer = io.BytesIO() buffer = io.BytesIO()
combined_image.save(buffer, format='JPEG', quality=85) combined_image.save(buffer, format="JPEG", quality=85)
result_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8') result_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
return result_base64 return result_base64

View File

@@ -7,6 +7,7 @@ from datetime import datetime
logger = get_module_logger("pfc_message_processor") logger = get_module_logger("pfc_message_processor")
class MessageProcessor: class MessageProcessor:
"""消息处理器,负责处理接收到的消息并存储""" """消息处理器,负责处理接收到的消息并存储"""
@@ -60,7 +61,4 @@ class MessageProcessor:
mes_name = chat.group_info.group_name if chat.group_info else "私聊" mes_name = chat.group_info.group_name if chat.group_info else "私聊"
# 将时间戳转换为datetime对象 # 将时间戳转换为datetime对象
current_time = datetime.fromtimestamp(message.message_info.time).strftime("%H:%M:%S") current_time = datetime.fromtimestamp(message.message_info.time).strftime("%H:%M:%S")
logger.info( logger.info(f"[{current_time}][{mes_name}]{chat.user_info.user_nickname}: {message.processed_plain_text}")
f"[{current_time}][{mes_name}]"
f"{chat.user_info.user_nickname}: {message.processed_plain_text}"
)

View File

@@ -27,6 +27,7 @@ chat_config = LogConfig(
logger = get_module_logger("reasoning_chat", config=chat_config) logger = get_module_logger("reasoning_chat", config=chat_config)
class ReasoningChat: class ReasoningChat:
def __init__(self): def __init__(self):
self.storage = MessageStorage() self.storage = MessageStorage()

View File

@@ -52,7 +52,6 @@ class ResponseGenerator:
f"{self.current_model_type}思考:{message.processed_plain_text[:30] + '...' if len(message.processed_plain_text) > 30 else message.processed_plain_text}" f"{self.current_model_type}思考:{message.processed_plain_text[:30] + '...' if len(message.processed_plain_text) > 30 else message.processed_plain_text}"
) # noqa: E501 ) # noqa: E501
model_response = await self._generate_response_with_model(message, current_model) model_response = await self._generate_response_with_model(message, current_model)
# print(f"raw_content: {model_response}") # print(f"raw_content: {model_response}")

View File

@@ -24,7 +24,6 @@ class PromptBuilder:
async def _build_prompt( async def _build_prompt(
self, chat_stream, message_txt: str, sender_name: str = "某人", stream_id: Optional[int] = None self, chat_stream, message_txt: str, sender_name: str = "某人", stream_id: Optional[int] = None
) -> tuple[str, str]: ) -> tuple[str, str]:
# 开始构建prompt # 开始构建prompt
prompt_personality = "" prompt_personality = ""
# person # person
@@ -41,12 +40,10 @@ class PromptBuilder:
random.shuffle(identity_detail) random.shuffle(identity_detail)
prompt_personality += f",{identity_detail[0]}" prompt_personality += f",{identity_detail[0]}"
# 关系 # 关系
who_chat_in_group = [(chat_stream.user_info.platform, who_chat_in_group = [
chat_stream.user_info.user_id, (chat_stream.user_info.platform, chat_stream.user_info.user_id, chat_stream.user_info.user_nickname)
chat_stream.user_info.user_nickname)] ]
who_chat_in_group += get_recent_group_speaker( who_chat_in_group += get_recent_group_speaker(
stream_id, stream_id,
(chat_stream.user_info.platform, chat_stream.user_info.user_id), (chat_stream.user_info.platform, chat_stream.user_info.user_id),
@@ -84,7 +81,7 @@ class PromptBuilder:
# print(f"相关记忆:{related_memory_info}") # print(f"相关记忆:{related_memory_info}")
# 日程构建 # 日程构建
schedule_prompt = f'''你现在正在做的事情是:{bot_schedule.get_current_num_task(num = 1,time_info = False)}''' schedule_prompt = f"""你现在正在做的事情是:{bot_schedule.get_current_num_task(num=1, time_info=False)}"""
# 获取聊天上下文 # 获取聊天上下文
chat_in_group = True chat_in_group = True
@@ -281,7 +278,9 @@ class PromptBuilder:
# 6. 限制总数量最多10条 # 6. 限制总数量最多10条
filtered_results = filtered_results[:10] filtered_results = filtered_results[:10]
logger.info(f"结果处理完成,耗时: {time.time() - process_start_time:.3f}秒,过滤后剩余{len(filtered_results)}条结果") logger.info(
f"结果处理完成,耗时: {time.time() - process_start_time:.3f}秒,过滤后剩余{len(filtered_results)}条结果"
)
# 7. 格式化输出 # 7. 格式化输出
if filtered_results: if filtered_results:
@@ -309,7 +308,9 @@ class PromptBuilder:
logger.info(f"知识库检索总耗时: {time.time() - start_time:.3f}") logger.info(f"知识库检索总耗时: {time.time() - start_time:.3f}")
return related_info return related_info
def get_info_from_db(self, query_embedding: list, limit: int = 1, threshold: float = 0.5, return_raw: bool = False) -> Union[str, list]: def get_info_from_db(
self, query_embedding: list, limit: int = 1, threshold: float = 0.5, return_raw: bool = False
) -> Union[str, list]:
if not query_embedding: if not query_embedding:
return "" if not return_raw else [] return "" if not return_raw else []
# 使用余弦相似度计算 # 使用余弦相似度计算

View File

@@ -28,6 +28,7 @@ chat_config = LogConfig(
logger = get_module_logger("think_flow_chat", config=chat_config) logger = get_module_logger("think_flow_chat", config=chat_config)
class ThinkFlowChat: class ThinkFlowChat:
def __init__(self): def __init__(self):
self.storage = MessageStorage() self.storage = MessageStorage()
@@ -214,7 +215,6 @@ class ThinkFlowChat:
# 处理提及 # 处理提及
is_mentioned, reply_probability = is_mentioned_bot_in_message(message) is_mentioned, reply_probability = is_mentioned_bot_in_message(message)
# 计算回复意愿 # 计算回复意愿
current_willing_old = willing_manager.get_willing(chat_stream=chat) current_willing_old = willing_manager.get_willing(chat_stream=chat)
# current_willing_new = (heartflow.get_subheartflow(chat.stream_id).current_state.willing - 5) / 4 # current_willing_new = (heartflow.get_subheartflow(chat.stream_id).current_state.willing - 5) / 4
@@ -222,7 +222,6 @@ class ThinkFlowChat:
# 有点bug # 有点bug
current_willing = current_willing_old current_willing = current_willing_old
willing_manager.set_willing(chat.stream_id, current_willing) willing_manager.set_willing(chat.stream_id, current_willing)
# 意愿激活 # 意愿激活
@@ -280,7 +279,9 @@ class ThinkFlowChat:
# 思考前脑内状态 # 思考前脑内状态
try: try:
timer1 = time.time() timer1 = time.time()
await heartflow.get_subheartflow(chat.stream_id).do_thinking_before_reply(message.processed_plain_text) await heartflow.get_subheartflow(chat.stream_id).do_thinking_before_reply(
message.processed_plain_text
)
timer2 = time.time() timer2 = time.time()
timing_results["思考前脑内状态"] = timer2 - timer1 timing_results["思考前脑内状态"] = timer2 - timer1
except Exception as e: except Exception as e:

View File

@@ -35,7 +35,6 @@ class ResponseGenerator:
async def generate_response(self, message: MessageThinking) -> Optional[Union[str, List[str]]]: async def generate_response(self, message: MessageThinking) -> Optional[Union[str, List[str]]]:
"""根据当前模型类型选择对应的生成函数""" """根据当前模型类型选择对应的生成函数"""
logger.info( logger.info(
f"思考:{message.processed_plain_text[:30] + '...' if len(message.processed_plain_text) > 30 else message.processed_plain_text}" f"思考:{message.processed_plain_text[:30] + '...' if len(message.processed_plain_text) > 30 else message.processed_plain_text}"
) )
@@ -178,4 +177,3 @@ class ResponseGenerator:
# print(f"得到了处理后的llm返回{processed_response}") # print(f"得到了处理后的llm返回{processed_response}")
return processed_response return processed_response

View File

@@ -21,16 +21,15 @@ class PromptBuilder:
async def _build_prompt( async def _build_prompt(
self, chat_stream, message_txt: str, sender_name: str = "某人", stream_id: Optional[int] = None self, chat_stream, message_txt: str, sender_name: str = "某人", stream_id: Optional[int] = None
) -> tuple[str, str]: ) -> tuple[str, str]:
current_mind_info = heartflow.get_subheartflow(stream_id).current_mind current_mind_info = heartflow.get_subheartflow(stream_id).current_mind
individuality = Individuality.get_instance() individuality = Individuality.get_instance()
prompt_personality = individuality.get_prompt(type="personality", x_person=2, level=1) prompt_personality = individuality.get_prompt(type="personality", x_person=2, level=1)
prompt_identity = individuality.get_prompt(type="identity", x_person=2, level=1) prompt_identity = individuality.get_prompt(type="identity", x_person=2, level=1)
# 关系 # 关系
who_chat_in_group = [(chat_stream.user_info.platform, who_chat_in_group = [
chat_stream.user_info.user_id, (chat_stream.user_info.platform, chat_stream.user_info.user_id, chat_stream.user_info.user_nickname)
chat_stream.user_info.user_nickname)] ]
who_chat_in_group += get_recent_group_speaker( who_chat_in_group += get_recent_group_speaker(
stream_id, stream_id,
(chat_stream.user_info.platform, chat_stream.user_info.user_id), (chat_stream.user_info.platform, chat_stream.user_info.user_id),

View File

@@ -3,6 +3,7 @@ import tomlkit
from pathlib import Path from pathlib import Path
from datetime import datetime from datetime import datetime
def update_config(): def update_config():
print("开始更新配置文件...") print("开始更新配置文件...")
# 获取根目录路径 # 获取根目录路径

View File

@@ -39,6 +39,7 @@ else:
else: else:
mai_version = mai_version_main mai_version = mai_version_main
def update_config(): def update_config():
# 获取根目录路径 # 获取根目录路径
root_dir = Path(__file__).parent.parent.parent.parent root_dir = Path(__file__).parent.parent.parent.parent
@@ -127,6 +128,7 @@ def update_config():
f.write(tomlkit.dumps(new_config)) f.write(tomlkit.dumps(new_config))
logger.info("配置文件更新完成") logger.info("配置文件更新完成")
logger = get_module_logger("config") logger = get_module_logger("config")
@@ -149,16 +151,20 @@ class BotConfig:
# personality # personality
personality_core = "用一句话或几句话描述人格的核心特点" # 建议20字以内谁再写3000字小作文敲谁脑袋 personality_core = "用一句话或几句话描述人格的核心特点" # 建议20字以内谁再写3000字小作文敲谁脑袋
personality_sides: List[str] = field(default_factory=lambda: [ personality_sides: List[str] = field(
default_factory=lambda: [
"用一句话或几句话描述人格的一些侧面", "用一句话或几句话描述人格的一些侧面",
"用一句话或几句话描述人格的一些侧面", "用一句话或几句话描述人格的一些侧面",
"用一句话或几句话描述人格的一些侧面" "用一句话或几句话描述人格的一些侧面",
]) ]
)
# identity # identity
identity_detail: List[str] = field(default_factory=lambda: [ identity_detail: List[str] = field(
default_factory=lambda: [
"身份特点", "身份特点",
"身份特点", "身份特点",
]) ]
)
height: int = 170 # 身高 单位厘米 height: int = 170 # 身高 单位厘米
weight: int = 50 # 体重 单位千克 weight: int = 50 # 体重 单位千克
age: int = 20 # 年龄 单位岁 age: int = 20 # 年龄 单位岁
@@ -354,7 +360,6 @@ class BotConfig:
"""从TOML配置文件加载配置""" """从TOML配置文件加载配置"""
config = cls() config = cls()
def personality(parent: dict): def personality(parent: dict):
personality_config = parent["personality"] personality_config = parent["personality"]
if config.INNER_VERSION in SpecifierSet(">=1.2.4"): if config.INNER_VERSION in SpecifierSet(">=1.2.4"):
@@ -421,10 +426,18 @@ class BotConfig:
def heartflow(parent: dict): def heartflow(parent: dict):
heartflow_config = parent["heartflow"] heartflow_config = parent["heartflow"]
config.sub_heart_flow_update_interval = heartflow_config.get("sub_heart_flow_update_interval", config.sub_heart_flow_update_interval) config.sub_heart_flow_update_interval = heartflow_config.get(
config.sub_heart_flow_freeze_time = heartflow_config.get("sub_heart_flow_freeze_time", config.sub_heart_flow_freeze_time) "sub_heart_flow_update_interval", config.sub_heart_flow_update_interval
config.sub_heart_flow_stop_time = heartflow_config.get("sub_heart_flow_stop_time", config.sub_heart_flow_stop_time) )
config.heart_flow_update_interval = heartflow_config.get("heart_flow_update_interval", config.heart_flow_update_interval) config.sub_heart_flow_freeze_time = heartflow_config.get(
"sub_heart_flow_freeze_time", config.sub_heart_flow_freeze_time
)
config.sub_heart_flow_stop_time = heartflow_config.get(
"sub_heart_flow_stop_time", config.sub_heart_flow_stop_time
)
config.heart_flow_update_interval = heartflow_config.get(
"heart_flow_update_interval", config.heart_flow_update_interval
)
def willing(parent: dict): def willing(parent: dict):
willing_config = parent["willing"] willing_config = parent["willing"]

View File

@@ -14,6 +14,7 @@ from src.common.logger import get_module_logger, LogConfig, MEMORY_STYLE_CONFIG
from src.plugins.memory_system.sample_distribution import MemoryBuildScheduler # 分布生成器 from src.plugins.memory_system.sample_distribution import MemoryBuildScheduler # 分布生成器
from .memory_config import MemoryConfig from .memory_config import MemoryConfig
def get_closest_chat_from_db(length: int, timestamp: str): def get_closest_chat_from_db(length: int, timestamp: str):
# print(f"获取最接近指定时间戳的聊天记录,长度: {length}, 时间戳: {timestamp}") # print(f"获取最接近指定时间戳的聊天记录,长度: {length}, 时间戳: {timestamp}")
# print(f"当前时间: {timestamp},转换后时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(timestamp))}") # print(f"当前时间: {timestamp},转换后时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(timestamp))}")

View File

@@ -179,7 +179,6 @@ class LLM_request:
# logger.debug(f"{logger_msg}发送请求到URL: {api_url}") # logger.debug(f"{logger_msg}发送请求到URL: {api_url}")
# logger.info(f"使用模型: {self.model_name}") # logger.info(f"使用模型: {self.model_name}")
# 构建请求体 # 构建请求体
if image_base64: if image_base64:
payload = await self._build_payload(prompt, image_base64, image_format) payload = await self._build_payload(prompt, image_base64, image_format)
@@ -205,13 +204,17 @@ class LLM_request:
# 处理需要重试的状态码 # 处理需要重试的状态码
if response.status in policy["retry_codes"]: if response.status in policy["retry_codes"]:
wait_time = policy["base_wait"] * (2**retry) wait_time = policy["base_wait"] * (2**retry)
logger.warning(f"模型 {self.model_name} 错误码: {response.status}, 等待 {wait_time}秒后重试") logger.warning(
f"模型 {self.model_name} 错误码: {response.status}, 等待 {wait_time}秒后重试"
)
if response.status == 413: if response.status == 413:
logger.warning("请求体过大,尝试压缩...") logger.warning("请求体过大,尝试压缩...")
image_base64 = compress_base64_image_by_scale(image_base64) image_base64 = compress_base64_image_by_scale(image_base64)
payload = await self._build_payload(prompt, image_base64, image_format) payload = await self._build_payload(prompt, image_base64, image_format)
elif response.status in [500, 503]: elif response.status in [500, 503]:
logger.error(f"模型 {self.model_name} 错误码: {response.status} - {error_code_mapping.get(response.status)}") logger.error(
f"模型 {self.model_name} 错误码: {response.status} - {error_code_mapping.get(response.status)}"
)
raise RuntimeError("服务器负载过高模型恢复失败QAQ") raise RuntimeError("服务器负载过高模型恢复失败QAQ")
else: else:
logger.warning(f"模型 {self.model_name} 请求限制(429),等待{wait_time}秒后重试...") logger.warning(f"模型 {self.model_name} 请求限制(429),等待{wait_time}秒后重试...")
@@ -219,7 +222,9 @@ class LLM_request:
await asyncio.sleep(wait_time) await asyncio.sleep(wait_time)
continue continue
elif response.status in policy["abort_codes"]: elif response.status in policy["abort_codes"]:
logger.error(f"模型 {self.model_name} 错误码: {response.status} - {error_code_mapping.get(response.status)}") logger.error(
f"模型 {self.model_name} 错误码: {response.status} - {error_code_mapping.get(response.status)}"
)
# 尝试获取并记录服务器返回的详细错误信息 # 尝试获取并记录服务器返回的详细错误信息
try: try:
error_json = await response.json() error_json = await response.json()
@@ -257,7 +262,9 @@ class LLM_request:
): ):
old_model_name = self.model_name old_model_name = self.model_name
self.model_name = self.model_name[4:] # 移除"Pro/"前缀 self.model_name = self.model_name[4:] # 移除"Pro/"前缀
logger.warning(f"检测到403错误模型从 {old_model_name} 降级为 {self.model_name}") logger.warning(
f"检测到403错误模型从 {old_model_name} 降级为 {self.model_name}"
)
# 对全局配置进行更新 # 对全局配置进行更新
if global_config.llm_normal.get("name") == old_model_name: if global_config.llm_normal.get("name") == old_model_name:
@@ -266,7 +273,9 @@ class LLM_request:
if global_config.llm_reasoning.get("name") == old_model_name: if global_config.llm_reasoning.get("name") == old_model_name:
global_config.llm_reasoning["name"] = self.model_name global_config.llm_reasoning["name"] = self.model_name
logger.warning(f"将全局配置中的 llm_reasoning 模型临时降级至{self.model_name}") logger.warning(
f"将全局配置中的 llm_reasoning 模型临时降级至{self.model_name}"
)
# 更新payload中的模型名 # 更新payload中的模型名
if payload and "model" in payload: if payload and "model" in payload:
@@ -328,7 +337,14 @@ class LLM_request:
await response.release() await response.release()
# 返回已经累积的内容 # 返回已经累积的内容
result = { result = {
"choices": [{"message": {"content": accumulated_content, "reasoning_content": reasoning_content}}], "choices": [
{
"message": {
"content": accumulated_content,
"reasoning_content": reasoning_content,
}
}
],
"usage": usage, "usage": usage,
} }
return ( return (
@@ -345,7 +361,14 @@ class LLM_request:
logger.error(f"清理资源时发生错误: {cleanup_error}") logger.error(f"清理资源时发生错误: {cleanup_error}")
# 返回已经累积的内容 # 返回已经累积的内容
result = { result = {
"choices": [{"message": {"content": accumulated_content, "reasoning_content": reasoning_content}}], "choices": [
{
"message": {
"content": accumulated_content,
"reasoning_content": reasoning_content,
}
}
],
"usage": usage, "usage": usage,
} }
return ( return (
@@ -360,7 +383,9 @@ class LLM_request:
content = re.sub(r"<think>.*?</think>", "", content, flags=re.DOTALL).strip() content = re.sub(r"<think>.*?</think>", "", content, flags=re.DOTALL).strip()
# 构造一个伪result以便调用自定义响应处理器或默认处理器 # 构造一个伪result以便调用自定义响应处理器或默认处理器
result = { result = {
"choices": [{"message": {"content": content, "reasoning_content": reasoning_content}}], "choices": [
{"message": {"content": content, "reasoning_content": reasoning_content}}
],
"usage": usage, "usage": usage,
} }
return ( return (
@@ -394,7 +419,9 @@ class LLM_request:
# 处理aiohttp抛出的响应错误 # 处理aiohttp抛出的响应错误
if retry < policy["max_retries"] - 1: if retry < policy["max_retries"] - 1:
wait_time = policy["base_wait"] * (2**retry) wait_time = policy["base_wait"] * (2**retry)
logger.error(f"模型 {self.model_name} HTTP响应错误等待{wait_time}秒后重试... 状态码: {e.status}, 错误: {e.message}") logger.error(
f"模型 {self.model_name} HTTP响应错误等待{wait_time}秒后重试... 状态码: {e.status}, 错误: {e.message}"
)
try: try:
if hasattr(e, "response") and e.response and hasattr(e.response, "text"): if hasattr(e, "response") and e.response and hasattr(e.response, "text"):
error_text = await e.response.text() error_text = await e.response.text()
@@ -419,13 +446,17 @@ class LLM_request:
else: else:
logger.error(f"模型 {self.model_name} 服务器错误响应: {error_json}") logger.error(f"模型 {self.model_name} 服务器错误响应: {error_json}")
except (json.JSONDecodeError, TypeError) as json_err: except (json.JSONDecodeError, TypeError) as json_err:
logger.warning(f"模型 {self.model_name} 响应不是有效的JSON: {str(json_err)}, 原始内容: {error_text[:200]}") logger.warning(
f"模型 {self.model_name} 响应不是有效的JSON: {str(json_err)}, 原始内容: {error_text[:200]}"
)
except (AttributeError, TypeError, ValueError) as parse_err: except (AttributeError, TypeError, ValueError) as parse_err:
logger.warning(f"模型 {self.model_name} 无法解析响应错误内容: {str(parse_err)}") logger.warning(f"模型 {self.model_name} 无法解析响应错误内容: {str(parse_err)}")
await asyncio.sleep(wait_time) await asyncio.sleep(wait_time)
else: else:
logger.critical(f"模型 {self.model_name} HTTP响应错误达到最大重试次数: 状态码: {e.status}, 错误: {e.message}") logger.critical(
f"模型 {self.model_name} HTTP响应错误达到最大重试次数: 状态码: {e.status}, 错误: {e.message}"
)
# 安全地检查和记录请求详情 # 安全地检查和记录请求详情
if ( if (
image_base64 image_base64

View File

@@ -282,5 +282,6 @@ class MoodManager:
self._update_mood_text() self._update_mood_text()
logger.info(f"[情绪变化] {emotion}(强度:{intensity:.2f}) | 愉悦度:{old_valence:.2f}->{self.current_mood.valence:.2f}, 唤醒度:{old_arousal:.2f}->{self.current_mood.arousal:.2f} | 心情:{old_mood}->{self.current_mood.text}") logger.info(
f"[情绪变化] {emotion}(强度:{intensity:.2f}) | 愉悦度:{old_valence:.2f}->{self.current_mood.valence:.2f}, 唤醒度:{old_arousal:.2f}->{self.current_mood.arousal:.2f} | 心情:{old_mood}->{self.current_mood.text}"
)

View File

@@ -8,7 +8,8 @@ import asyncio
import numpy as np import numpy as np
import matplotlib import matplotlib
matplotlib.use('Agg')
matplotlib.use("Agg")
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from pathlib import Path from pathlib import Path
import pandas as pd import pandas as pd
@@ -41,9 +42,10 @@ person_info_default = {
# "gender" : Unkown, # "gender" : Unkown,
"konw_time": 0, "konw_time": 0,
"msg_interval": 3000, "msg_interval": 3000,
"msg_interval_list": [] "msg_interval_list": [],
} # 个人信息的各项与默认值在此定义,以下处理会自动创建/补全每一项 } # 个人信息的各项与默认值在此定义,以下处理会自动创建/补全每一项
class PersonInfoManager: class PersonInfoManager:
def __init__(self): def __init__(self):
if "person_info" not in db.list_collection_names(): if "person_info" not in db.list_collection_names():
@@ -81,10 +83,7 @@ class PersonInfoManager:
document = db.person_info.find_one({"person_id": person_id}) document = db.person_info.find_one({"person_id": person_id})
if document: if document:
db.person_info.update_one( db.person_info.update_one({"person_id": person_id}, {"$set": {field_name: value}})
{"person_id": person_id},
{"$set": {field_name: value}}
)
else: else:
Data[field_name] = value Data[field_name] = value
logger.debug(f"更新时{person_id}不存在,已新建") logger.debug(f"更新时{person_id}不存在,已新建")
@@ -112,10 +111,7 @@ class PersonInfoManager:
logger.debug(f"get_value获取失败字段'{field_name}'未定义") logger.debug(f"get_value获取失败字段'{field_name}'未定义")
return None return None
document = db.person_info.find_one( document = db.person_info.find_one({"person_id": person_id}, {field_name: 1})
{"person_id": person_id},
{field_name: 1}
)
if document and field_name in document: if document and field_name in document:
return document[field_name] return document[field_name]
@@ -139,16 +135,12 @@ class PersonInfoManager:
# 构建查询投影(所有字段都有效才会执行到这里) # 构建查询投影(所有字段都有效才会执行到这里)
projection = {field: 1 for field in field_names} projection = {field: 1 for field in field_names}
document = db.person_info.find_one( document = db.person_info.find_one({"person_id": person_id}, projection)
{"person_id": person_id},
projection
)
result = {} result = {}
for field in field_names: for field in field_names:
result[field] = copy.deepcopy( result[field] = copy.deepcopy(
document.get(field, person_info_default[field]) document.get(field, person_info_default[field]) if document else person_info_default[field]
if document else person_info_default[field]
) )
return result return result
@@ -162,13 +154,12 @@ class PersonInfoManager:
# 遍历集合中的所有文档 # 遍历集合中的所有文档
for document in db.person_info.find({}): for document in db.person_info.find({}):
# 找出文档中未定义的字段 # 找出文档中未定义的字段
undefined_fields = set(document.keys()) - defined_fields - {'_id'} undefined_fields = set(document.keys()) - defined_fields - {"_id"}
if undefined_fields: if undefined_fields:
# 构建更新操作,使用$unset删除未定义字段 # 构建更新操作,使用$unset删除未定义字段
update_result = db.person_info.update_one( update_result = db.person_info.update_one(
{'_id': document['_id']}, {"_id": document["_id"]}, {"$unset": {field: 1 for field in undefined_fields}}
{'$unset': {field: 1 for field in undefined_fields}}
) )
if update_result.modified_count > 0: if update_result.modified_count > 0:
@@ -208,10 +199,7 @@ class PersonInfoManager:
try: try:
result = {} result = {}
for doc in db.person_info.find( for doc in db.person_info.find({field_name: {"$exists": True}}, {"person_id": 1, field_name: 1, "_id": 0}):
{field_name: {"$exists": True}},
{"person_id": 1, field_name: 1, "_id": 0}
):
try: try:
value = doc[field_name] value = doc[field_name]
if way(value): if way(value):
@@ -229,7 +217,7 @@ class PersonInfoManager:
async def personal_habit_deduction(self): async def personal_habit_deduction(self):
"""启动个人信息推断,每天根据一定条件推断一次""" """启动个人信息推断,每天根据一定条件推断一次"""
try: try:
while(1): while 1:
await asyncio.sleep(60) await asyncio.sleep(60)
current_time = datetime.datetime.now() current_time = datetime.datetime.now()
logger.info(f"个人信息推断启动: {current_time.strftime('%Y-%m-%d %H:%M:%S')}") logger.info(f"个人信息推断启动: {current_time.strftime('%Y-%m-%d %H:%M:%S')}")
@@ -237,8 +225,7 @@ class PersonInfoManager:
# "msg_interval"推断 # "msg_interval"推断
msg_interval_map = False msg_interval_map = False
msg_interval_lists = await self.get_specific_value_list( msg_interval_lists = await self.get_specific_value_list(
"msg_interval_list", "msg_interval_list", lambda x: isinstance(x, list) and len(x) >= 100
lambda x: isinstance(x, list) and len(x) >= 100
) )
for person_id, msg_interval_list_ in msg_interval_lists.items(): for person_id, msg_interval_list_ in msg_interval_lists.items():
try: try:
@@ -258,14 +245,14 @@ class PersonInfoManager:
log_dir.mkdir(parents=True, exist_ok=True) log_dir.mkdir(parents=True, exist_ok=True)
plt.figure(figsize=(10, 6)) plt.figure(figsize=(10, 6))
time_series = pd.Series(time_interval) time_series = pd.Series(time_interval)
plt.hist(time_series, bins=50, density=True, alpha=0.4, color='pink', label='Histogram') plt.hist(time_series, bins=50, density=True, alpha=0.4, color="pink", label="Histogram")
time_series.plot(kind='kde', color='mediumpurple', linewidth=1, label='Density') time_series.plot(kind="kde", color="mediumpurple", linewidth=1, label="Density")
plt.grid(True, alpha=0.2) plt.grid(True, alpha=0.2)
plt.xlim(0, 8000) plt.xlim(0, 8000)
plt.title(f"Message Interval Distribution (User: {person_id[:8]}...)") plt.title(f"Message Interval Distribution (User: {person_id[:8]}...)")
plt.xlabel("Interval (ms)") plt.xlabel("Interval (ms)")
plt.ylabel("Density") plt.ylabel("Density")
plt.legend(framealpha=0.9, facecolor='white') plt.legend(framealpha=0.9, facecolor="white")
img_path = log_dir / f"interval_distribution_{person_id[:8]}.png" img_path = log_dir / f"interval_distribution_{person_id[:8]}.png"
plt.savefig(img_path) plt.savefig(img_path)
plt.close() plt.close()

View File

@@ -12,6 +12,7 @@ relationship_config = LogConfig(
) )
logger = get_module_logger("rel_manager", config=relationship_config) logger = get_module_logger("rel_manager", config=relationship_config)
class RelationshipManager: class RelationshipManager:
def __init__(self): def __init__(self):
self.positive_feedback_value = 0 # 正反馈系统 self.positive_feedback_value = 0 # 正反馈系统
@@ -22,6 +23,7 @@ class RelationshipManager:
def mood_manager(self): def mood_manager(self):
if self._mood_manager is None: if self._mood_manager is None:
from ..moods.moods import MoodManager # 延迟导入 from ..moods.moods import MoodManager # 延迟导入
self._mood_manager = MoodManager.get_instance() self._mood_manager = MoodManager.get_instance()
return self._mood_manager return self._mood_manager
@@ -58,8 +60,9 @@ class RelationshipManager:
def mood_feedback(self, value): def mood_feedback(self, value):
"""情绪反馈""" """情绪反馈"""
mood_manager = self.mood_manager mood_manager = self.mood_manager
mood_gain = (mood_manager.get_current_mood().valence) ** 2 \ mood_gain = (mood_manager.get_current_mood().valence) ** 2 * math.copysign(
* math.copysign(1, value * mood_manager.get_current_mood().valence) 1, value * mood_manager.get_current_mood().valence
)
value += value * mood_gain value += value * mood_gain
logger.info(f"当前relationship增益系数{mood_gain:.3f}") logger.info(f"当前relationship增益系数{mood_gain:.3f}")
return value return value
@@ -67,8 +70,7 @@ class RelationshipManager:
def feedback_to_mood(self, mood_value): def feedback_to_mood(self, mood_value):
"""对情绪的反馈""" """对情绪的反馈"""
coefficient = self.gain_coefficient[abs(self.positive_feedback_value)] coefficient = self.gain_coefficient[abs(self.positive_feedback_value)]
if (mood_value > 0 and self.positive_feedback_value > 0 if mood_value > 0 and self.positive_feedback_value > 0 or mood_value < 0 and self.positive_feedback_value < 0:
or mood_value < 0 and self.positive_feedback_value < 0):
return mood_value * coefficient return mood_value * coefficient
else: else:
return mood_value / coefficient return mood_value / coefficient
@@ -106,7 +108,7 @@ class RelationshipManager:
"platform": chat_stream.user_info.platform, "platform": chat_stream.user_info.platform,
"user_id": chat_stream.user_info.user_id, "user_id": chat_stream.user_info.user_id,
"nickname": chat_stream.user_info.user_nickname, "nickname": chat_stream.user_info.user_nickname,
"konw_time" : int(time.time()) "konw_time": int(time.time()),
} }
old_value = await person_info_manager.get_value(person_id, "relationship_value") old_value = await person_info_manager.get_value(person_id, "relationship_value")
old_value = self.ensure_float(old_value, person_id) old_value = self.ensure_float(old_value, person_id)
@@ -200,4 +202,5 @@ class RelationshipManager:
logger.warning(f"[关系管理] {person_id}值转换失败(原始值:{value}已重置为0") logger.warning(f"[关系管理] {person_id}值转换失败(原始值:{value}已重置为0")
return 0.0 return 0.0
relationship_manager = RelationshipManager() relationship_manager = RelationshipManager()

View File

@@ -31,10 +31,16 @@ class ScheduleGenerator:
def __init__(self): def __init__(self):
# 使用离线LLM模型 # 使用离线LLM模型
self.llm_scheduler_all = LLM_request( self.llm_scheduler_all = LLM_request(
model=global_config.llm_reasoning, temperature=global_config.SCHEDULE_TEMPERATURE, max_tokens=7000, request_type="schedule" model=global_config.llm_reasoning,
temperature=global_config.SCHEDULE_TEMPERATURE,
max_tokens=7000,
request_type="schedule",
) )
self.llm_scheduler_doing = LLM_request( self.llm_scheduler_doing = LLM_request(
model=global_config.llm_normal, temperature=global_config.SCHEDULE_TEMPERATURE, max_tokens=2048, request_type="schedule" model=global_config.llm_normal,
temperature=global_config.SCHEDULE_TEMPERATURE,
max_tokens=2048,
request_type="schedule",
) )
self.today_schedule_text = "" self.today_schedule_text = ""

View File

@@ -32,7 +32,7 @@ SECTION_TRANSLATIONS = {
"response_spliter": "回复分割器", "response_spliter": "回复分割器",
"remote": "远程设置", "remote": "远程设置",
"experimental": "实验功能", "experimental": "实验功能",
"model": "模型设置" "model": "模型设置",
} }
# 配置项的中文描述 # 配置项的中文描述
@@ -41,16 +41,13 @@ CONFIG_DESCRIPTIONS = {
"bot.qq": "机器人的QQ号码", "bot.qq": "机器人的QQ号码",
"bot.nickname": "机器人的昵称", "bot.nickname": "机器人的昵称",
"bot.alias_names": "机器人的别名列表", "bot.alias_names": "机器人的别名列表",
# 群组设置 # 群组设置
"groups.talk_allowed": "允许机器人回复消息的群号列表", "groups.talk_allowed": "允许机器人回复消息的群号列表",
"groups.talk_frequency_down": "降低回复频率的群号列表", "groups.talk_frequency_down": "降低回复频率的群号列表",
"groups.ban_user_id": "禁止回复和读取消息的QQ号列表", "groups.ban_user_id": "禁止回复和读取消息的QQ号列表",
# 人格设置 # 人格设置
"personality.personality_core": "人格核心描述建议20字以内", "personality.personality_core": "人格核心描述建议20字以内",
"personality.personality_sides": "人格特点列表", "personality.personality_sides": "人格特点列表",
# 身份设置 # 身份设置
"identity.identity_detail": "身份细节描述列表", "identity.identity_detail": "身份细节描述列表",
"identity.height": "身高(厘米)", "identity.height": "身高(厘米)",
@@ -58,28 +55,23 @@ CONFIG_DESCRIPTIONS = {
"identity.age": "年龄", "identity.age": "年龄",
"identity.gender": "性别", "identity.gender": "性别",
"identity.appearance": "外貌特征", "identity.appearance": "外貌特征",
# 日程设置 # 日程设置
"schedule.enable_schedule_gen": "是否启用日程表生成", "schedule.enable_schedule_gen": "是否启用日程表生成",
"schedule.prompt_schedule_gen": "日程表生成提示词", "schedule.prompt_schedule_gen": "日程表生成提示词",
"schedule.schedule_doing_update_interval": "日程表更新间隔(秒)", "schedule.schedule_doing_update_interval": "日程表更新间隔(秒)",
"schedule.schedule_temperature": "日程表温度建议0.3-0.6", "schedule.schedule_temperature": "日程表温度建议0.3-0.6",
"schedule.time_zone": "时区设置", "schedule.time_zone": "时区设置",
# 平台设置 # 平台设置
"platforms.nonebot-qq": "QQ平台适配器链接", "platforms.nonebot-qq": "QQ平台适配器链接",
# 回复设置 # 回复设置
"response.response_mode": "回复策略heart_flow心流reasoning推理", "response.response_mode": "回复策略heart_flow心流reasoning推理",
"response.model_r1_probability": "主要回复模型使用概率", "response.model_r1_probability": "主要回复模型使用概率",
"response.model_v3_probability": "次要回复模型使用概率", "response.model_v3_probability": "次要回复模型使用概率",
# 心流设置 # 心流设置
"heartflow.sub_heart_flow_update_interval": "子心流更新频率(秒)", "heartflow.sub_heart_flow_update_interval": "子心流更新频率(秒)",
"heartflow.sub_heart_flow_freeze_time": "子心流冻结时间(秒)", "heartflow.sub_heart_flow_freeze_time": "子心流冻结时间(秒)",
"heartflow.sub_heart_flow_stop_time": "子心流停止时间(秒)", "heartflow.sub_heart_flow_stop_time": "子心流停止时间(秒)",
"heartflow.heart_flow_update_interval": "心流更新频率(秒)", "heartflow.heart_flow_update_interval": "心流更新频率(秒)",
# 消息设置 # 消息设置
"message.max_context_size": "获取的上下文数量", "message.max_context_size": "获取的上下文数量",
"message.emoji_chance": "使用表情包的概率", "message.emoji_chance": "使用表情包的概率",
@@ -88,14 +80,12 @@ CONFIG_DESCRIPTIONS = {
"message.message_buffer": "是否启用消息缓冲器", "message.message_buffer": "是否启用消息缓冲器",
"message.ban_words": "禁用词列表", "message.ban_words": "禁用词列表",
"message.ban_msgs_regex": "禁用消息正则表达式列表", "message.ban_msgs_regex": "禁用消息正则表达式列表",
# 意愿设置 # 意愿设置
"willing.willing_mode": "回复意愿模式", "willing.willing_mode": "回复意愿模式",
"willing.response_willing_amplifier": "回复意愿放大系数", "willing.response_willing_amplifier": "回复意愿放大系数",
"willing.response_interested_rate_amplifier": "回复兴趣度放大系数", "willing.response_interested_rate_amplifier": "回复兴趣度放大系数",
"willing.down_frequency_rate": "降低回复频率的群组回复意愿降低系数", "willing.down_frequency_rate": "降低回复频率的群组回复意愿降低系数",
"willing.emoji_response_penalty": "表情包回复惩罚系数", "willing.emoji_response_penalty": "表情包回复惩罚系数",
# 表情设置 # 表情设置
"emoji.max_emoji_num": "表情包最大数量", "emoji.max_emoji_num": "表情包最大数量",
"emoji.max_reach_deletion": "达到最大数量时是否删除表情包", "emoji.max_reach_deletion": "达到最大数量时是否删除表情包",
@@ -103,7 +93,6 @@ CONFIG_DESCRIPTIONS = {
"emoji.auto_save": "是否保存表情包和图片", "emoji.auto_save": "是否保存表情包和图片",
"emoji.enable_check": "是否启用表情包过滤", "emoji.enable_check": "是否启用表情包过滤",
"emoji.check_prompt": "表情包过滤要求", "emoji.check_prompt": "表情包过滤要求",
# 记忆设置 # 记忆设置
"memory.build_memory_interval": "记忆构建间隔(秒)", "memory.build_memory_interval": "记忆构建间隔(秒)",
"memory.build_memory_distribution": "记忆构建分布参数", "memory.build_memory_distribution": "记忆构建分布参数",
@@ -114,104 +103,90 @@ CONFIG_DESCRIPTIONS = {
"memory.memory_forget_time": "记忆遗忘时间(小时)", "memory.memory_forget_time": "记忆遗忘时间(小时)",
"memory.memory_forget_percentage": "记忆遗忘比例", "memory.memory_forget_percentage": "记忆遗忘比例",
"memory.memory_ban_words": "记忆禁用词列表", "memory.memory_ban_words": "记忆禁用词列表",
# 情绪设置 # 情绪设置
"mood.mood_update_interval": "情绪更新间隔(秒)", "mood.mood_update_interval": "情绪更新间隔(秒)",
"mood.mood_decay_rate": "情绪衰减率", "mood.mood_decay_rate": "情绪衰减率",
"mood.mood_intensity_factor": "情绪强度因子", "mood.mood_intensity_factor": "情绪强度因子",
# 关键词反应 # 关键词反应
"keywords_reaction.enable": "是否启用关键词反应功能", "keywords_reaction.enable": "是否启用关键词反应功能",
# 中文错别字 # 中文错别字
"chinese_typo.enable": "是否启用中文错别字生成器", "chinese_typo.enable": "是否启用中文错别字生成器",
"chinese_typo.error_rate": "单字替换概率", "chinese_typo.error_rate": "单字替换概率",
"chinese_typo.min_freq": "最小字频阈值", "chinese_typo.min_freq": "最小字频阈值",
"chinese_typo.tone_error_rate": "声调错误概率", "chinese_typo.tone_error_rate": "声调错误概率",
"chinese_typo.word_replace_rate": "整词替换概率", "chinese_typo.word_replace_rate": "整词替换概率",
# 回复分割器 # 回复分割器
"response_spliter.enable_response_spliter": "是否启用回复分割器", "response_spliter.enable_response_spliter": "是否启用回复分割器",
"response_spliter.response_max_length": "回复允许的最大长度", "response_spliter.response_max_length": "回复允许的最大长度",
"response_spliter.response_max_sentence_num": "回复允许的最大句子数", "response_spliter.response_max_sentence_num": "回复允许的最大句子数",
# 远程设置 # 远程设置
"remote.enable": "是否启用远程统计", "remote.enable": "是否启用远程统计",
# 实验功能 # 实验功能
"experimental.enable_friend_chat": "是否启用好友聊天", "experimental.enable_friend_chat": "是否启用好友聊天",
"experimental.pfc_chatting": "是否启用PFC聊天", "experimental.pfc_chatting": "是否启用PFC聊天",
# 模型设置 # 模型设置
"model.llm_reasoning.name": "推理模型名称", "model.llm_reasoning.name": "推理模型名称",
"model.llm_reasoning.provider": "推理模型提供商", "model.llm_reasoning.provider": "推理模型提供商",
"model.llm_reasoning.pri_in": "推理模型输入价格", "model.llm_reasoning.pri_in": "推理模型输入价格",
"model.llm_reasoning.pri_out": "推理模型输出价格", "model.llm_reasoning.pri_out": "推理模型输出价格",
"model.llm_normal.name": "回复模型名称", "model.llm_normal.name": "回复模型名称",
"model.llm_normal.provider": "回复模型提供商", "model.llm_normal.provider": "回复模型提供商",
"model.llm_normal.pri_in": "回复模型输入价格", "model.llm_normal.pri_in": "回复模型输入价格",
"model.llm_normal.pri_out": "回复模型输出价格", "model.llm_normal.pri_out": "回复模型输出价格",
"model.llm_emotion_judge.name": "表情判断模型名称", "model.llm_emotion_judge.name": "表情判断模型名称",
"model.llm_emotion_judge.provider": "表情判断模型提供商", "model.llm_emotion_judge.provider": "表情判断模型提供商",
"model.llm_emotion_judge.pri_in": "表情判断模型输入价格", "model.llm_emotion_judge.pri_in": "表情判断模型输入价格",
"model.llm_emotion_judge.pri_out": "表情判断模型输出价格", "model.llm_emotion_judge.pri_out": "表情判断模型输出价格",
"model.llm_topic_judge.name": "主题判断模型名称", "model.llm_topic_judge.name": "主题判断模型名称",
"model.llm_topic_judge.provider": "主题判断模型提供商", "model.llm_topic_judge.provider": "主题判断模型提供商",
"model.llm_topic_judge.pri_in": "主题判断模型输入价格", "model.llm_topic_judge.pri_in": "主题判断模型输入价格",
"model.llm_topic_judge.pri_out": "主题判断模型输出价格", "model.llm_topic_judge.pri_out": "主题判断模型输出价格",
"model.llm_summary_by_topic.name": "概括模型名称", "model.llm_summary_by_topic.name": "概括模型名称",
"model.llm_summary_by_topic.provider": "概括模型提供商", "model.llm_summary_by_topic.provider": "概括模型提供商",
"model.llm_summary_by_topic.pri_in": "概括模型输入价格", "model.llm_summary_by_topic.pri_in": "概括模型输入价格",
"model.llm_summary_by_topic.pri_out": "概括模型输出价格", "model.llm_summary_by_topic.pri_out": "概括模型输出价格",
"model.moderation.name": "内容审核模型名称", "model.moderation.name": "内容审核模型名称",
"model.moderation.provider": "内容审核模型提供商", "model.moderation.provider": "内容审核模型提供商",
"model.moderation.pri_in": "内容审核模型输入价格", "model.moderation.pri_in": "内容审核模型输入价格",
"model.moderation.pri_out": "内容审核模型输出价格", "model.moderation.pri_out": "内容审核模型输出价格",
"model.vlm.name": "图像识别模型名称", "model.vlm.name": "图像识别模型名称",
"model.vlm.provider": "图像识别模型提供商", "model.vlm.provider": "图像识别模型提供商",
"model.vlm.pri_in": "图像识别模型输入价格", "model.vlm.pri_in": "图像识别模型输入价格",
"model.vlm.pri_out": "图像识别模型输出价格", "model.vlm.pri_out": "图像识别模型输出价格",
"model.embedding.name": "嵌入模型名称", "model.embedding.name": "嵌入模型名称",
"model.embedding.provider": "嵌入模型提供商", "model.embedding.provider": "嵌入模型提供商",
"model.embedding.pri_in": "嵌入模型输入价格", "model.embedding.pri_in": "嵌入模型输入价格",
"model.embedding.pri_out": "嵌入模型输出价格", "model.embedding.pri_out": "嵌入模型输出价格",
"model.llm_observation.name": "观察模型名称", "model.llm_observation.name": "观察模型名称",
"model.llm_observation.provider": "观察模型提供商", "model.llm_observation.provider": "观察模型提供商",
"model.llm_observation.pri_in": "观察模型输入价格", "model.llm_observation.pri_in": "观察模型输入价格",
"model.llm_observation.pri_out": "观察模型输出价格", "model.llm_observation.pri_out": "观察模型输出价格",
"model.llm_sub_heartflow.name": "子心流模型名称", "model.llm_sub_heartflow.name": "子心流模型名称",
"model.llm_sub_heartflow.provider": "子心流模型提供商", "model.llm_sub_heartflow.provider": "子心流模型提供商",
"model.llm_sub_heartflow.pri_in": "子心流模型输入价格", "model.llm_sub_heartflow.pri_in": "子心流模型输入价格",
"model.llm_sub_heartflow.pri_out": "子心流模型输出价格", "model.llm_sub_heartflow.pri_out": "子心流模型输出价格",
"model.llm_heartflow.name": "心流模型名称", "model.llm_heartflow.name": "心流模型名称",
"model.llm_heartflow.provider": "心流模型提供商", "model.llm_heartflow.provider": "心流模型提供商",
"model.llm_heartflow.pri_in": "心流模型输入价格", "model.llm_heartflow.pri_in": "心流模型输入价格",
"model.llm_heartflow.pri_out": "心流模型输出价格", "model.llm_heartflow.pri_out": "心流模型输出价格",
} }
# 获取翻译 # 获取翻译
def get_translation(key): def get_translation(key):
return SECTION_TRANSLATIONS.get(key, key) return SECTION_TRANSLATIONS.get(key, key)
# 获取配置项描述 # 获取配置项描述
def get_description(key): def get_description(key):
return CONFIG_DESCRIPTIONS.get(key, "") return CONFIG_DESCRIPTIONS.get(key, "")
# 获取根目录路径 # 获取根目录路径
def get_root_dir(): def get_root_dir():
try: try:
# 获取当前脚本所在目录 # 获取当前脚本所在目录
if getattr(sys, 'frozen', False): if getattr(sys, "frozen", False):
# 如果是打包后的应用 # 如果是打包后的应用
current_dir = os.path.dirname(sys.executable) current_dir = os.path.dirname(sys.executable)
else: else:
@@ -235,9 +210,11 @@ def get_root_dir():
# 返回当前目录作为备选 # 返回当前目录作为备选
return os.getcwd() return os.getcwd()
# 配置文件路径 # 配置文件路径
CONFIG_PATH = os.path.join(get_root_dir(), "config", "bot_config.toml") CONFIG_PATH = os.path.join(get_root_dir(), "config", "bot_config.toml")
# 保存配置 # 保存配置
def save_config(config_data): def save_config(config_data):
try: try:
@@ -266,6 +243,7 @@ def save_config(config_data):
print(f"保存配置失败: {e}") print(f"保存配置失败: {e}")
return False return False
# 加载配置 # 加载配置
def load_config(): def load_config():
try: try:
@@ -279,6 +257,7 @@ def load_config():
print(f"加载配置失败: {e}") print(f"加载配置失败: {e}")
return {} return {}
# 多行文本输入框 # 多行文本输入框
class ScrollableTextFrame(ctk.CTkFrame): class ScrollableTextFrame(ctk.CTkFrame):
def __init__(self, master, initial_text="", height=100, width=400, **kwargs): def __init__(self, master, initial_text="", height=100, width=400, **kwargs):
@@ -305,6 +284,7 @@ class ScrollableTextFrame(ctk.CTkFrame):
self.text_box.insert("1.0", text) self.text_box.insert("1.0", text)
self.update_var() self.update_var()
# 配置UI # 配置UI
class ConfigUI(ctk.CTk): class ConfigUI(ctk.CTk):
def __init__(self): def __init__(self):
@@ -430,7 +410,7 @@ class ConfigUI(ctk.CTk):
width=30, width=30,
command=self.show_search_dialog, command=self.show_search_dialog,
fg_color="transparent", fg_color="transparent",
hover_color=("gray80", "gray30") hover_color=("gray80", "gray30"),
) )
search_btn.pack(side="right", padx=5, pady=5) search_btn.pack(side="right", padx=5, pady=5)
@@ -457,7 +437,7 @@ class ConfigUI(ctk.CTk):
text_color=("gray10", "gray90"), text_color=("gray10", "gray90"),
anchor="w", anchor="w",
height=35, height=35,
command=lambda s=section: self.show_category(s) command=lambda s=section: self.show_category(s),
) )
btn.pack(fill="x", padx=5, pady=2) btn.pack(fill="x", padx=5, pady=2)
self.category_buttons[section] = btn self.category_buttons[section] = btn
@@ -484,18 +464,12 @@ class ConfigUI(ctk.CTk):
category_name = f"{category} ({get_translation(category)})" category_name = f"{category} ({get_translation(category)})"
# 添加标题 # 添加标题
ctk.CTkLabel( ctk.CTkLabel(self.content_frame, text=f"{category_name} 配置", font=("Arial", 16, "bold")).pack(
self.content_frame, anchor="w", padx=10, pady=(5, 15)
text=f"{category_name} 配置", )
font=("Arial", 16, "bold")
).pack(anchor="w", padx=10, pady=(5, 15))
# 添加配置项 # 添加配置项
self.add_config_section( self.add_config_section(self.content_frame, category, self.config_data[category])
self.content_frame,
category,
self.config_data[category]
)
def add_config_section(self, parent, section_path, section_data, indent=0): def add_config_section(self, parent, section_path, section_data, indent=0):
# 递归添加配置项 # 递归添加配置项
@@ -514,12 +488,7 @@ class ConfigUI(ctk.CTk):
header_frame = ctk.CTkFrame(group_frame, fg_color=("gray85", "gray25")) header_frame = ctk.CTkFrame(group_frame, fg_color=("gray85", "gray25"))
header_frame.pack(fill="x", padx=0, pady=0) header_frame.pack(fill="x", padx=0, pady=0)
label = ctk.CTkLabel( label = ctk.CTkLabel(header_frame, text=f"{key}", font=("Arial", 13, "bold"), anchor="w")
header_frame,
text=f"{key}",
font=("Arial", 13, "bold"),
anchor="w"
)
label.pack(anchor="w", padx=10, pady=5) label.pack(anchor="w", padx=10, pady=5)
# 如果有描述,添加提示图标 # 如果有描述,添加提示图标
@@ -536,12 +505,7 @@ class ConfigUI(ctk.CTk):
tipwindow.wm_geometry(f"+{x}+{y}") tipwindow.wm_geometry(f"+{x}+{y}")
tipwindow.lift() tipwindow.lift()
label = ctk.CTkLabel( label = ctk.CTkLabel(tipwindow, text=text, justify="left", wraplength=300)
tipwindow,
text=text,
justify="left",
wraplength=300
)
label.pack(padx=5, pady=5) label.pack(padx=5, pady=5)
# 自动关闭 # 自动关闭
@@ -553,11 +517,7 @@ class ConfigUI(ctk.CTk):
# 在标题后添加提示图标 # 在标题后添加提示图标
tip_label = ctk.CTkLabel( tip_label = ctk.CTkLabel(
header_frame, header_frame, text="", font=("Arial", 12), text_color="light blue", width=20
text="",
font=("Arial", 12),
text_color="light blue",
width=20
) )
tip_label.pack(side="right", padx=5) tip_label.pack(side="right", padx=5)
@@ -584,21 +544,11 @@ class ConfigUI(ctk.CTk):
if description: if description:
label_text = f"{key}: ({description})" label_text = f"{key}: ({description})"
label = ctk.CTkLabel( label = ctk.CTkLabel(label_frame, text=label_text, font=("Arial", 12), anchor="w")
label_frame,
text=label_text,
font=("Arial", 12),
anchor="w"
)
label.pack(anchor="w", padx=5 + indent * 10, pady=0) label.pack(anchor="w", padx=5 + indent * 10, pady=0)
# 添加提示信息 # 添加提示信息
info_label = ctk.CTkLabel( info_label = ctk.CTkLabel(label_frame, text="(列表格式: JSON)", font=("Arial", 9), text_color="gray50")
label_frame,
text="(列表格式: JSON)",
font=("Arial", 9),
text_color="gray50"
)
info_label.pack(anchor="w", padx=5 + indent * 10, pady=(0, 5)) info_label.pack(anchor="w", padx=5 + indent * 10, pady=(0, 5))
# 确定文本框高度,根据列表项数量决定 # 确定文本框高度,根据列表项数量决定
@@ -608,12 +558,7 @@ class ConfigUI(ctk.CTk):
json_str = json.dumps(value, ensure_ascii=False, indent=2) json_str = json.dumps(value, ensure_ascii=False, indent=2)
# 使用多行文本框 # 使用多行文本框
text_frame = ScrollableTextFrame( text_frame = ScrollableTextFrame(frame, initial_text=json_str, height=list_height, width=550)
frame,
initial_text=json_str,
height=list_height,
width=550
)
text_frame.pack(fill="x", padx=10 + indent * 10, pady=5) text_frame.pack(fill="x", padx=10 + indent * 10, pady=5)
self.config_vars[full_path] = (text_frame.text_var, "list") self.config_vars[full_path] = (text_frame.text_var, "list")
@@ -635,10 +580,7 @@ class ConfigUI(ctk.CTk):
checkbox_text = f"{key} ({description})" checkbox_text = f"{key} ({description})"
checkbox = ctk.CTkCheckBox( checkbox = ctk.CTkCheckBox(
frame, frame, text=checkbox_text, variable=var, command=lambda path=full_path: self.on_field_change(path)
text=checkbox_text,
variable=var,
command=lambda path=full_path: self.on_field_change(path)
) )
checkbox.pack(anchor="w", padx=10 + indent * 10, pady=5) checkbox.pack(anchor="w", padx=10 + indent * 10, pady=5)
@@ -652,12 +594,7 @@ class ConfigUI(ctk.CTk):
if description: if description:
label_text = f"{key}: ({description})" label_text = f"{key}: ({description})"
label = ctk.CTkLabel( label = ctk.CTkLabel(frame, text=label_text, font=("Arial", 12), anchor="w")
frame,
text=label_text,
font=("Arial", 12),
anchor="w"
)
label.pack(anchor="w", padx=10 + indent * 10, pady=(5, 0)) label.pack(anchor="w", padx=10 + indent * 10, pady=(5, 0))
var = StringVar(value=str(value)) var = StringVar(value=str(value))
@@ -682,12 +619,7 @@ class ConfigUI(ctk.CTk):
if description: if description:
label_text = f"{key}: ({description})" label_text = f"{key}: ({description})"
label = ctk.CTkLabel( label = ctk.CTkLabel(frame, text=label_text, font=("Arial", 12), anchor="w")
frame,
text=label_text,
font=("Arial", 12),
anchor="w"
)
label.pack(anchor="w", padx=10 + indent * 10, pady=(5, 0)) label.pack(anchor="w", padx=10 + indent * 10, pady=(5, 0))
var = StringVar(value=str(value)) var = StringVar(value=str(value))
@@ -696,16 +628,11 @@ class ConfigUI(ctk.CTk):
# 判断文本长度,决定输入框的类型和大小 # 判断文本长度,决定输入框的类型和大小
text_len = len(str(value)) text_len = len(str(value))
if text_len > 80 or '\n' in str(value): if text_len > 80 or "\n" in str(value):
# 对于长文本或多行文本,使用多行文本框 # 对于长文本或多行文本,使用多行文本框
text_height = max(80, min(str(value).count('\n') * 20 + 40, 150)) text_height = max(80, min(str(value).count("\n") * 20 + 40, 150))
text_frame = ScrollableTextFrame( text_frame = ScrollableTextFrame(frame, initial_text=str(value), height=text_height, width=550)
frame,
initial_text=str(value),
height=text_height,
width=550
)
text_frame.pack(fill="x", padx=10 + indent * 10, pady=5) text_frame.pack(fill="x", padx=10 + indent * 10, pady=5)
self.config_vars[full_path] = (text_frame.text_var, "string") self.config_vars[full_path] = (text_frame.text_var, "string")
@@ -751,7 +678,6 @@ class ConfigUI(ctk.CTk):
target[parts[-1]] = var.get() target[parts[-1]] = var.get()
updated = True updated = True
elif var_type == "number": elif var_type == "number":
# 获取原始类型int或float # 获取原始类型int或float
num_type = args[0] if args else int num_type = args[0] if args else int
new_value = num_type(var.get()) new_value = num_type(var.get())
@@ -760,7 +686,6 @@ class ConfigUI(ctk.CTk):
updated = True updated = True
elif var_type == "list": elif var_type == "list":
# 解析JSON字符串为列表 # 解析JSON字符串为列表
new_value = json.loads(var.get()) new_value = json.loads(var.get())
if json.dumps(target[parts[-1]], sort_keys=True) != json.dumps(new_value, sort_keys=True): if json.dumps(target[parts[-1]], sort_keys=True) != json.dumps(new_value, sort_keys=True):
@@ -841,11 +766,7 @@ class ConfigUI(ctk.CTk):
current_config = json.dumps(temp_config, sort_keys=True) current_config = json.dumps(temp_config, sort_keys=True)
if current_config != self.original_config: if current_config != self.original_config:
result = messagebox.askyesnocancel( result = messagebox.askyesnocancel("未保存的更改", "有未保存的更改,是否保存?", icon="warning")
"未保存的更改",
"有未保存的更改,是否保存?",
icon="warning"
)
if result is None: # 取消 if result is None: # 取消
return False return False
@@ -868,29 +789,17 @@ class ConfigUI(ctk.CTk):
about_window.geometry(f"+{x}+{y}") about_window.geometry(f"+{x}+{y}")
# 内容 # 内容
ctk.CTkLabel( ctk.CTkLabel(about_window, text="麦麦配置修改器", font=("Arial", 16, "bold")).pack(pady=(20, 10))
about_window,
text="麦麦配置修改器",
font=("Arial", 16, "bold")
).pack(pady=(20, 10))
ctk.CTkLabel( ctk.CTkLabel(about_window, text="用于修改MaiBot-Core的配置文件\n配置文件路径: config/bot_config.toml").pack(
about_window, pady=5
text="用于修改MaiBot-Core的配置文件\n配置文件路径: config/bot_config.toml" )
).pack(pady=5)
ctk.CTkLabel( ctk.CTkLabel(about_window, text="注意: 修改配置前请备份原始配置文件", text_color=("red", "light coral")).pack(
about_window, pady=5
text="注意: 修改配置前请备份原始配置文件", )
text_color=("red", "light coral")
).pack(pady=5)
ctk.CTkButton( ctk.CTkButton(about_window, text="确定", command=about_window.destroy, width=100).pack(pady=15)
about_window,
text="确定",
command=about_window.destroy,
width=100
).pack(pady=15)
def on_closing(self): def on_closing(self):
"""关闭窗口前检查未保存更改""" """关闭窗口前检查未保存更改"""
@@ -961,11 +870,9 @@ class ConfigUI(ctk.CTk):
backup_window.geometry(f"+{x}+{y}") backup_window.geometry(f"+{x}+{y}")
# 创建说明标签 # 创建说明标签
ctk.CTkLabel( ctk.CTkLabel(backup_window, text="备份文件列表 (双击可恢复)", font=("Arial", 14, "bold")).pack(
backup_window, pady=(10, 5), padx=10, anchor="w"
text="备份文件列表 (双击可恢复)", )
font=("Arial", 14, "bold")
).pack(pady=(10, 5), padx=10, anchor="w")
# 创建列表框 # 创建列表框
backup_frame = ctk.CTkScrollableFrame(backup_window, width=580, height=300) backup_frame = ctk.CTkScrollableFrame(backup_window, width=580, height=300)
@@ -981,27 +888,17 @@ class ConfigUI(ctk.CTk):
item_frame.pack(fill="x", padx=5, pady=5) item_frame.pack(fill="x", padx=5, pady=5)
# 显示备份文件信息 # 显示备份文件信息
ctk.CTkLabel( ctk.CTkLabel(item_frame, text=f"{time_str}", font=("Arial", 12, "bold"), width=200).pack(
item_frame, side="left", padx=10, pady=10
text=f"{time_str}", )
font=("Arial", 12, "bold"),
width=200
).pack(side="left", padx=10, pady=10)
# 文件名 # 文件名
name_label = ctk.CTkLabel( name_label = ctk.CTkLabel(item_frame, text=filename, font=("Arial", 11))
item_frame,
text=filename,
font=("Arial", 11)
)
name_label.pack(side="left", fill="x", expand=True, padx=5, pady=10) name_label.pack(side="left", fill="x", expand=True, padx=5, pady=10)
# 恢复按钮 # 恢复按钮
restore_btn = ctk.CTkButton( restore_btn = ctk.CTkButton(
item_frame, item_frame, text="恢复", width=80, command=lambda path=filepath: self.restore_backup(path)
text="恢复",
width=80,
command=lambda path=filepath: self.restore_backup(path)
) )
restore_btn.pack(side="right", padx=10, pady=10) restore_btn.pack(side="right", padx=10, pady=10)
@@ -1010,12 +907,7 @@ class ConfigUI(ctk.CTk):
widget.bind("<Double-1>", lambda e, path=filepath: self.restore_backup(path)) widget.bind("<Double-1>", lambda e, path=filepath: self.restore_backup(path))
# 关闭按钮 # 关闭按钮
ctk.CTkButton( ctk.CTkButton(backup_window, text="关闭", command=backup_window.destroy, width=100).pack(pady=10)
backup_window,
text="关闭",
command=backup_window.destroy,
width=100
).pack(pady=10)
def restore_backup(self, backup_path): def restore_backup(self, backup_path):
"""从备份文件恢复配置""" """从备份文件恢复配置"""
@@ -1027,7 +919,7 @@ class ConfigUI(ctk.CTk):
confirm = messagebox.askyesno( confirm = messagebox.askyesno(
"确认", "确认",
f"确定要从以下备份文件恢复配置吗?\n{os.path.basename(backup_path)}\n\n这将覆盖当前的配置!", f"确定要从以下备份文件恢复配置吗?\n{os.path.basename(backup_path)}\n\n这将覆盖当前的配置!",
icon="warning" icon="warning",
) )
if not confirm: if not confirm:
@@ -1069,7 +961,9 @@ class ConfigUI(ctk.CTk):
search_frame.pack(fill="x", padx=10, pady=10) search_frame.pack(fill="x", padx=10, pady=10)
search_var = StringVar() search_var = StringVar()
search_entry = ctk.CTkEntry(search_frame, placeholder_text="输入关键词搜索...", width=380, textvariable=search_var) search_entry = ctk.CTkEntry(
search_frame, placeholder_text="输入关键词搜索...", width=380, textvariable=search_var
)
search_entry.pack(side="left", padx=5, pady=5, fill="x", expand=True) search_entry.pack(side="left", padx=5, pady=5, fill="x", expand=True)
# 结果列表框 # 结果列表框
@@ -1150,7 +1044,7 @@ class ConfigUI(ctk.CTk):
text=f"{full_path}{desc_text}", text=f"{full_path}{desc_text}",
font=("Arial", 11, "bold"), font=("Arial", 11, "bold"),
anchor="w", anchor="w",
wraplength=450 wraplength=450,
) )
path_label.pack(anchor="w", padx=10, pady=(5, 0), fill="x") path_label.pack(anchor="w", padx=10, pady=(5, 0), fill="x")
@@ -1160,11 +1054,7 @@ class ConfigUI(ctk.CTk):
value_str = value_str[:50] + "..." value_str = value_str[:50] + "..."
value_label = ctk.CTkLabel( value_label = ctk.CTkLabel(
item_frame, item_frame, text=f"值: {value_str}", font=("Arial", 10), anchor="w", wraplength=450
text=f"值: {value_str}",
font=("Arial", 10),
anchor="w",
wraplength=450
) )
value_label.pack(anchor="w", padx=10, pady=(0, 5), fill="x") value_label.pack(anchor="w", padx=10, pady=(0, 5), fill="x")
@@ -1174,7 +1064,7 @@ class ConfigUI(ctk.CTk):
text="转到", text="转到",
width=60, width=60,
height=25, height=25,
command=lambda s=section: self.goto_config_item(s, search_window) command=lambda s=section: self.goto_config_item(s, search_window),
) )
goto_btn.pack(side="right", padx=10, pady=5) goto_btn.pack(side="right", padx=10, pady=5)
@@ -1227,37 +1117,22 @@ class ConfigUI(ctk.CTk):
menu_window.geometry(f"+{x}+{y}") menu_window.geometry(f"+{x}+{y}")
# 创建按钮 # 创建按钮
ctk.CTkLabel( ctk.CTkLabel(menu_window, text="配置导入导出", font=("Arial", 16, "bold")).pack(pady=(20, 10))
menu_window,
text="配置导入导出",
font=("Arial", 16, "bold")
).pack(pady=(20, 10))
# 导出按钮 # 导出按钮
export_btn = ctk.CTkButton( export_btn = ctk.CTkButton(
menu_window, menu_window, text="导出配置到文件", command=lambda: self.export_config(menu_window), width=200
text="导出配置到文件",
command=lambda: self.export_config(menu_window),
width=200
) )
export_btn.pack(pady=10) export_btn.pack(pady=10)
# 导入按钮 # 导入按钮
import_btn = ctk.CTkButton( import_btn = ctk.CTkButton(
menu_window, menu_window, text="从文件导入配置", command=lambda: self.import_config(menu_window), width=200
text="从文件导入配置",
command=lambda: self.import_config(menu_window),
width=200
) )
import_btn.pack(pady=10) import_btn.pack(pady=10)
# 取消按钮 # 取消按钮
cancel_btn = ctk.CTkButton( cancel_btn = ctk.CTkButton(menu_window, text="取消", command=menu_window.destroy, width=100)
menu_window,
text="取消",
command=menu_window.destroy,
width=100
)
cancel_btn.pack(pady=10) cancel_btn.pack(pady=10)
def export_config(self, parent_window=None): def export_config(self, parent_window=None):
@@ -1277,7 +1152,7 @@ class ConfigUI(ctk.CTk):
title="导出配置", title="导出配置",
filetypes=[("TOML 文件", "*.toml"), ("所有文件", "*.*")], filetypes=[("TOML 文件", "*.toml"), ("所有文件", "*.*")],
defaultextension=".toml", defaultextension=".toml",
initialfile=default_filename initialfile=default_filename,
) )
if not file_path: if not file_path:
@@ -1306,8 +1181,7 @@ class ConfigUI(ctk.CTk):
# 选择要导入的文件 # 选择要导入的文件
file_path = filedialog.askopenfilename( file_path = filedialog.askopenfilename(
title="导入配置", title="导入配置", filetypes=[("TOML 文件", "*.toml"), ("所有文件", "*.*")]
filetypes=[("TOML 文件", "*.toml"), ("所有文件", "*.*")]
) )
if not file_path: if not file_path:
@@ -1327,9 +1201,7 @@ class ConfigUI(ctk.CTk):
# 确认导入 # 确认导入
confirm = messagebox.askyesno( confirm = messagebox.askyesno(
"确认导入", "确认导入", f"确定要导入此配置文件吗?\n{file_path}\n\n这将替换当前的配置!", icon="warning"
f"确定要导入此配置文件吗?\n{file_path}\n\n这将替换当前的配置!",
icon="warning"
) )
if not confirm: if not confirm:
@@ -1354,6 +1226,7 @@ class ConfigUI(ctk.CTk):
messagebox.showerror("导入失败", f"导入配置失败: {e}") messagebox.showerror("导入失败", f"导入配置失败: {e}")
return False return False
# 主函数 # 主函数
def main(): def main():
try: try:
@@ -1365,6 +1238,7 @@ def main():
import tkinter as tk import tkinter as tk
from tkinter import messagebox from tkinter import messagebox
root = tk.Tk() root = tk.Tk()
root.withdraw() root.withdraw()
messagebox.showerror("程序错误", f"程序运行时发生错误:\n{e}") messagebox.showerror("程序错误", f"程序运行时发生错误:\n{e}")