🤖 自动格式化代码 [skip ci]
This commit is contained in:
4
bot.py
4
bot.py
@@ -2,6 +2,7 @@ import asyncio
|
|||||||
import hashlib
|
import hashlib
|
||||||
import os
|
import os
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
if os.path.exists(".env"):
|
if os.path.exists(".env"):
|
||||||
load_dotenv(".env", override=True)
|
load_dotenv(".env", override=True)
|
||||||
print("成功加载环境变量配置")
|
print("成功加载环境变量配置")
|
||||||
@@ -13,9 +14,11 @@ import platform
|
|||||||
import traceback
|
import traceback
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
|
|
||||||
# maim_message imports for console input
|
# maim_message imports for console input
|
||||||
from maim_message import Seg, UserInfo, BaseMessageInfo, MessageBase
|
from maim_message import Seg, UserInfo, BaseMessageInfo, MessageBase
|
||||||
from src.chat.message_receive.bot import chat_bot
|
from src.chat.message_receive.bot import chat_bot
|
||||||
|
|
||||||
# 最早期初始化日志系统,确保所有后续模块都使用正确的日志格式
|
# 最早期初始化日志系统,确保所有后续模块都使用正确的日志格式
|
||||||
from src.common.logger import initialize_logging, get_logger, shutdown_logging
|
from src.common.logger import initialize_logging, get_logger, shutdown_logging
|
||||||
from src.main import MainSystem
|
from src.main import MainSystem
|
||||||
@@ -26,7 +29,6 @@ initialize_logging()
|
|||||||
logger = get_logger("main")
|
logger = get_logger("main")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
install(extra_lines=3)
|
install(extra_lines=3)
|
||||||
|
|
||||||
# 设置工作目录为脚本所在目录
|
# 设置工作目录为脚本所在目录
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
|||||||
|
|
||||||
logger = get_logger("expression_selector")
|
logger = get_logger("expression_selector")
|
||||||
|
|
||||||
|
|
||||||
def init_prompt():
|
def init_prompt():
|
||||||
expression_evaluation_prompt = """
|
expression_evaluation_prompt = """
|
||||||
你的名字是{bot_name}
|
你的名字是{bot_name}
|
||||||
@@ -42,30 +43,32 @@ def init_prompt():
|
|||||||
"""
|
"""
|
||||||
Prompt(expression_evaluation_prompt, "expression_evaluation_prompt")
|
Prompt(expression_evaluation_prompt, "expression_evaluation_prompt")
|
||||||
|
|
||||||
|
|
||||||
def weighted_sample(population: List[Dict], weights: List[float], k: int) -> List[Dict]:
|
def weighted_sample(population: List[Dict], weights: List[float], k: int) -> List[Dict]:
|
||||||
"""按权重随机抽样"""
|
"""按权重随机抽样"""
|
||||||
if not population or not weights or k <= 0:
|
if not population or not weights or k <= 0:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
if len(population) <= k:
|
if len(population) <= k:
|
||||||
return population.copy()
|
return population.copy()
|
||||||
|
|
||||||
# 使用累积权重的方法进行加权抽样
|
# 使用累积权重的方法进行加权抽样
|
||||||
selected = []
|
selected = []
|
||||||
population_copy = population.copy()
|
population_copy = population.copy()
|
||||||
weights_copy = weights.copy()
|
weights_copy = weights.copy()
|
||||||
|
|
||||||
for _ in range(k):
|
for _ in range(k):
|
||||||
if not population_copy:
|
if not population_copy:
|
||||||
break
|
break
|
||||||
|
|
||||||
# 选择一个元素
|
# 选择一个元素
|
||||||
chosen_idx = random.choices(range(len(population_copy)), weights=weights_copy)[0]
|
chosen_idx = random.choices(range(len(population_copy)), weights=weights_copy)[0]
|
||||||
selected.append(population_copy.pop(chosen_idx))
|
selected.append(population_copy.pop(chosen_idx))
|
||||||
weights_copy.pop(chosen_idx)
|
weights_copy.pop(chosen_idx)
|
||||||
|
|
||||||
return selected
|
return selected
|
||||||
|
|
||||||
|
|
||||||
class ExpressionSelector:
|
class ExpressionSelector:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.expression_learner = get_expression_learner()
|
self.expression_learner = get_expression_learner()
|
||||||
@@ -75,7 +78,9 @@ class ExpressionSelector:
|
|||||||
request_type="expression.selector",
|
request_type="expression.selector",
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_random_expressions(self, chat_id: str, style_num: int, grammar_num: int, personality_num: int) -> Tuple[List[Dict[str, str]], List[Dict[str, str]]]:
|
def get_random_expressions(
|
||||||
|
self, chat_id: str, style_num: int, grammar_num: int, personality_num: int
|
||||||
|
) -> Tuple[List[Dict[str, str]], List[Dict[str, str]]]:
|
||||||
(
|
(
|
||||||
learnt_style_expressions,
|
learnt_style_expressions,
|
||||||
learnt_grammar_expressions,
|
learnt_grammar_expressions,
|
||||||
@@ -88,13 +93,13 @@ class ExpressionSelector:
|
|||||||
selected_style = weighted_sample(learnt_style_expressions, style_weights, style_num)
|
selected_style = weighted_sample(learnt_style_expressions, style_weights, style_num)
|
||||||
else:
|
else:
|
||||||
selected_style = []
|
selected_style = []
|
||||||
|
|
||||||
if learnt_grammar_expressions:
|
if learnt_grammar_expressions:
|
||||||
grammar_weights = [expr.get("count", 1) for expr in learnt_grammar_expressions]
|
grammar_weights = [expr.get("count", 1) for expr in learnt_grammar_expressions]
|
||||||
selected_grammar = weighted_sample(learnt_grammar_expressions, grammar_weights, grammar_num)
|
selected_grammar = weighted_sample(learnt_grammar_expressions, grammar_weights, grammar_num)
|
||||||
else:
|
else:
|
||||||
selected_grammar = []
|
selected_grammar = []
|
||||||
|
|
||||||
if personality_expressions:
|
if personality_expressions:
|
||||||
personality_weights = [expr.get("count", 1) for expr in personality_expressions]
|
personality_weights = [expr.get("count", 1) for expr in personality_expressions]
|
||||||
selected_personality = weighted_sample(personality_expressions, personality_weights, personality_num)
|
selected_personality = weighted_sample(personality_expressions, personality_weights, personality_num)
|
||||||
@@ -102,7 +107,7 @@ class ExpressionSelector:
|
|||||||
selected_personality = []
|
selected_personality = []
|
||||||
|
|
||||||
return selected_style, selected_grammar, selected_personality
|
return selected_style, selected_grammar, selected_personality
|
||||||
|
|
||||||
def update_expression_count(self, chat_id: str, expression: Dict[str, str], multiplier: float = 1.5):
|
def update_expression_count(self, chat_id: str, expression: Dict[str, str], multiplier: float = 1.5):
|
||||||
"""更新表达方式的count值"""
|
"""更新表达方式的count值"""
|
||||||
if expression.get("type") == "style_personality":
|
if expression.get("type") == "style_personality":
|
||||||
@@ -117,29 +122,30 @@ class ExpressionSelector:
|
|||||||
file_path = os.path.join("data", "expression", "learnt_grammar", str(chat_id), "expressions.json")
|
file_path = os.path.join("data", "expression", "learnt_grammar", str(chat_id), "expressions.json")
|
||||||
else:
|
else:
|
||||||
return
|
return
|
||||||
|
|
||||||
if not os.path.exists(file_path):
|
if not os.path.exists(file_path):
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with open(file_path, "r", encoding="utf-8") as f:
|
with open(file_path, "r", encoding="utf-8") as f:
|
||||||
expressions = json.load(f)
|
expressions = json.load(f)
|
||||||
|
|
||||||
# 找到匹配的表达方式并更新count
|
# 找到匹配的表达方式并更新count
|
||||||
for expr in expressions:
|
for expr in expressions:
|
||||||
if (expr.get("situation") == expression.get("situation") and
|
if expr.get("situation") == expression.get("situation") and expr.get("style") == expression.get(
|
||||||
expr.get("style") == expression.get("style")):
|
"style"
|
||||||
|
):
|
||||||
expr["count"] = expr.get("count", 1) * multiplier
|
expr["count"] = expr.get("count", 1) * multiplier
|
||||||
expr["last_active_time"] = time.time()
|
expr["last_active_time"] = time.time()
|
||||||
break
|
break
|
||||||
|
|
||||||
# 保存更新后的文件
|
# 保存更新后的文件
|
||||||
with open(file_path, "w", encoding="utf-8") as f:
|
with open(file_path, "w", encoding="utf-8") as f:
|
||||||
json.dump(expressions, f, ensure_ascii=False, indent=2)
|
json.dump(expressions, f, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"更新表达方式count失败: {e}")
|
logger.error(f"更新表达方式count失败: {e}")
|
||||||
|
|
||||||
async def select_suitable_expressions_llm(self, chat_id: str, chat_info: str) -> List[Dict[str, str]]:
|
async def select_suitable_expressions_llm(self, chat_id: str, chat_info: str) -> List[Dict[str, str]]:
|
||||||
"""使用LLM选择适合的表达方式"""
|
"""使用LLM选择适合的表达方式"""
|
||||||
|
|
||||||
@@ -188,7 +194,7 @@ class ExpressionSelector:
|
|||||||
)
|
)
|
||||||
|
|
||||||
print(prompt)
|
print(prompt)
|
||||||
|
|
||||||
# 4. 调用LLM
|
# 4. 调用LLM
|
||||||
try:
|
try:
|
||||||
content, (_, _) = await self.llm_model.generate_response_async(prompt=prompt)
|
content, (_, _) = await self.llm_model.generate_response_async(prompt=prompt)
|
||||||
@@ -216,7 +222,7 @@ class ExpressionSelector:
|
|||||||
if isinstance(idx, int) and 1 <= idx <= len(all_expressions):
|
if isinstance(idx, int) and 1 <= idx <= len(all_expressions):
|
||||||
expression = all_expressions[idx - 1] # 索引从1开始
|
expression = all_expressions[idx - 1] # 索引从1开始
|
||||||
valid_expressions.append(expression)
|
valid_expressions.append(expression)
|
||||||
|
|
||||||
# 对选中的表达方式count数*1.5
|
# 对选中的表达方式count数*1.5
|
||||||
self.update_expression_count(chat_id, expression, 1.5)
|
self.update_expression_count(chat_id, expression, 1.5)
|
||||||
|
|
||||||
@@ -226,7 +232,7 @@ class ExpressionSelector:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"LLM处理表达方式选择时出错: {e}")
|
logger.error(f"LLM处理表达方式选择时出错: {e}")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
init_prompt()
|
init_prompt()
|
||||||
|
|
||||||
@@ -234,10 +240,3 @@ try:
|
|||||||
expression_selector = ExpressionSelector()
|
expression_selector = ExpressionSelector()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"ExpressionSelector初始化失败: {e}")
|
print(f"ExpressionSelector初始化失败: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -24,7 +24,6 @@ class ExpressionSelectorProcessor(BaseProcessor):
|
|||||||
self.selection_interval = 10 # 40秒间隔
|
self.selection_interval = 10 # 40秒间隔
|
||||||
self.cached_expressions = [] # 缓存上一次选择的表达方式
|
self.cached_expressions = [] # 缓存上一次选择的表达方式
|
||||||
|
|
||||||
|
|
||||||
name = get_chat_manager().get_stream_name(self.subheartflow_id)
|
name = get_chat_manager().get_stream_name(self.subheartflow_id)
|
||||||
self.log_prefix = f"[{name}] 表达选择器"
|
self.log_prefix = f"[{name}] 表达选择器"
|
||||||
|
|
||||||
@@ -72,7 +71,9 @@ class ExpressionSelectorProcessor(BaseProcessor):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# LLM模式:调用LLM选择5-10个,然后随机选5个
|
# LLM模式:调用LLM选择5-10个,然后随机选5个
|
||||||
selected_expressions = await expression_selector.select_suitable_expressions_llm(self.subheartflow_id, chat_info)
|
selected_expressions = await expression_selector.select_suitable_expressions_llm(
|
||||||
|
self.subheartflow_id, chat_info
|
||||||
|
)
|
||||||
cache_size = len(selected_expressions) if selected_expressions else 0
|
cache_size = len(selected_expressions) if selected_expressions else 0
|
||||||
mode_desc = f"LLM模式(已缓存{cache_size}个)"
|
mode_desc = f"LLM模式(已缓存{cache_size}个)"
|
||||||
|
|
||||||
@@ -93,4 +94,3 @@ class ExpressionSelectorProcessor(BaseProcessor):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"{self.log_prefix} 处理表达方式选择时出错: {e}")
|
logger.error(f"{self.log_prefix} 处理表达方式选择时出错: {e}")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|||||||
@@ -148,7 +148,7 @@ class PromptBuilder:
|
|||||||
read_mark=0.0,
|
read_mark=0.0,
|
||||||
show_actions=True,
|
show_actions=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
message_list_before_now_half = get_raw_msg_before_timestamp_with_chat(
|
message_list_before_now_half = get_raw_msg_before_timestamp_with_chat(
|
||||||
chat_id=chat_stream.stream_id,
|
chat_id=chat_stream.stream_id,
|
||||||
timestamp=time.time(),
|
timestamp=time.time(),
|
||||||
@@ -162,8 +162,10 @@ class PromptBuilder:
|
|||||||
read_mark=0.0,
|
read_mark=0.0,
|
||||||
show_actions=True,
|
show_actions=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
expressions = expression_selector.select_suitable_expressions_llm(chat_stream.stream_id, chat_talking_prompt_half)
|
expressions = expression_selector.select_suitable_expressions_llm(
|
||||||
|
chat_stream.stream_id, chat_talking_prompt_half
|
||||||
|
)
|
||||||
style_habbits = []
|
style_habbits = []
|
||||||
grammar_habbits = []
|
grammar_habbits = []
|
||||||
if expressions:
|
if expressions:
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ class Individuality:
|
|||||||
self.name = ""
|
self.name = ""
|
||||||
self.bot_person_id = ""
|
self.bot_person_id = ""
|
||||||
self.meta_info_file_path = "data/personality/meta.json"
|
self.meta_info_file_path = "data/personality/meta.json"
|
||||||
|
|
||||||
self.model = LLMRequest(
|
self.model = LLMRequest(
|
||||||
model=global_config.model.utils,
|
model=global_config.model.utils,
|
||||||
request_type="individuality.compress",
|
request_type="individuality.compress",
|
||||||
@@ -99,9 +99,7 @@ class Individuality:
|
|||||||
logger.info("已将完整人设更新到bot的impression中")
|
logger.info("已将完整人设更新到bot的impression中")
|
||||||
|
|
||||||
# 创建压缩版本的short_impression
|
# 创建压缩版本的short_impression
|
||||||
asyncio.create_task(self._create_compressed_impression(
|
asyncio.create_task(self._create_compressed_impression(personality_core, personality_sides, identity_detail))
|
||||||
personality_core, personality_sides, identity_detail
|
|
||||||
))
|
|
||||||
|
|
||||||
asyncio.create_task(self.express_style.extract_and_store_personality_expressions())
|
asyncio.create_task(self.express_style.extract_and_store_personality_expressions())
|
||||||
|
|
||||||
@@ -374,12 +372,12 @@ class Individuality:
|
|||||||
self, personality_core: str, personality_sides: list, identity_detail: list
|
self, personality_core: str, personality_sides: list, identity_detail: list
|
||||||
) -> str:
|
) -> str:
|
||||||
"""使用LLM创建压缩版本的impression
|
"""使用LLM创建压缩版本的impression
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
personality_core: 核心人格
|
personality_core: 核心人格
|
||||||
personality_sides: 人格侧面列表
|
personality_sides: 人格侧面列表
|
||||||
identity_detail: 身份细节列表
|
identity_detail: 身份细节列表
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: 压缩后的impression文本
|
str: 压缩后的impression文本
|
||||||
"""
|
"""
|
||||||
@@ -387,23 +385,23 @@ class Individuality:
|
|||||||
compressed_parts = []
|
compressed_parts = []
|
||||||
if personality_core:
|
if personality_core:
|
||||||
compressed_parts.append(f"{personality_core}")
|
compressed_parts.append(f"{personality_core}")
|
||||||
|
|
||||||
# 准备需要压缩的内容
|
# 准备需要压缩的内容
|
||||||
content_to_compress = []
|
content_to_compress = []
|
||||||
if personality_sides:
|
if personality_sides:
|
||||||
content_to_compress.append(f"人格特质: {'、'.join(personality_sides)}")
|
content_to_compress.append(f"人格特质: {'、'.join(personality_sides)}")
|
||||||
if identity_detail:
|
if identity_detail:
|
||||||
content_to_compress.append(f"身份背景: {'、'.join(identity_detail)}")
|
content_to_compress.append(f"身份背景: {'、'.join(identity_detail)}")
|
||||||
|
|
||||||
if not content_to_compress:
|
if not content_to_compress:
|
||||||
# 如果没有需要压缩的内容,直接返回核心人格
|
# 如果没有需要压缩的内容,直接返回核心人格
|
||||||
result = "。".join(compressed_parts)
|
result = "。".join(compressed_parts)
|
||||||
return result + "。" if result else ""
|
return result + "。" if result else ""
|
||||||
|
|
||||||
# 使用LLM压缩其他内容
|
# 使用LLM压缩其他内容
|
||||||
try:
|
try:
|
||||||
compress_content = "、".join(content_to_compress)
|
compress_content = "、".join(content_to_compress)
|
||||||
|
|
||||||
prompt = f"""请将以下人设信息进行简洁压缩,保留主要内容,用简练的中文表达:
|
prompt = f"""请将以下人设信息进行简洁压缩,保留主要内容,用简练的中文表达:
|
||||||
|
|
||||||
{compress_content}
|
{compress_content}
|
||||||
@@ -413,10 +411,10 @@ class Individuality:
|
|||||||
2. 尽量简洁,不超过30字
|
2. 尽量简洁,不超过30字
|
||||||
3. 直接输出压缩后的内容,不要解释"""
|
3. 直接输出压缩后的内容,不要解释"""
|
||||||
|
|
||||||
response,(_,_) = await self.model.generate_response_async(
|
response, (_, _) = await self.model.generate_response_async(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
)
|
)
|
||||||
|
|
||||||
if response.strip():
|
if response.strip():
|
||||||
compressed_parts.append(response.strip())
|
compressed_parts.append(response.strip())
|
||||||
logger.info(f"精简人格侧面: {response.strip()}")
|
logger.info(f"精简人格侧面: {response.strip()}")
|
||||||
@@ -424,15 +422,13 @@ class Individuality:
|
|||||||
logger.error(f"使用LLM压缩人设时出错: {response}")
|
logger.error(f"使用LLM压缩人设时出错: {response}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"使用LLM压缩人设时出错: {e}")
|
logger.error(f"使用LLM压缩人设时出错: {e}")
|
||||||
|
|
||||||
result = "。".join(compressed_parts)
|
result = "。".join(compressed_parts)
|
||||||
|
|
||||||
# 更新short_impression字段
|
# 更新short_impression字段
|
||||||
if result:
|
if result:
|
||||||
person_info_manager = get_person_info_manager()
|
person_info_manager = get_person_info_manager()
|
||||||
await person_info_manager.update_one_field(
|
await person_info_manager.update_one_field(self.bot_person_id, "short_impression", result)
|
||||||
self.bot_person_id, "short_impression", result
|
|
||||||
)
|
|
||||||
logger.info("已将压缩人设更新到bot的short_impression中")
|
logger.info("已将压缩人设更新到bot的short_impression中")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user