初始化

This commit is contained in:
雅诺狐
2025-08-11 19:34:18 +08:00
parent ff7d1177fa
commit 2d4745cd58
257 changed files with 69069 additions and 0 deletions

View File

@@ -0,0 +1,648 @@
import time
import random
import json
import os
from datetime import datetime
from typing import List, Dict, Optional, Any, Tuple
from src.common.logger import get_logger
from src.common.database.sqlalchemy_database_api import get_session
from sqlalchemy import select
from src.common.database.sqlalchemy_models import Expression
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config, global_config
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat_inclusive, build_anonymous_messages
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.message_receive.chat_stream import get_chat_manager
MAX_EXPRESSION_COUNT = 300
DECAY_DAYS = 30 # 30天衰减到0.01
DECAY_MIN = 0.01 # 最小衰减值
logger = get_logger("expressor")
session = get_session()
def format_create_date(timestamp: float) -> str:
"""
将时间戳格式化为可读的日期字符串
"""
try:
return datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d %H:%M:%S")
except (ValueError, OSError):
return "未知时间"
def init_prompt() -> None:
learn_style_prompt = """
{chat_str}
请从上面这段群聊中概括除了人名为"SELF"之外的人的语言风格
1. 只考虑文字,不要考虑表情包和图片
2. 不要涉及具体的人名,只考虑语言风格
3. 语言风格包含特殊内容和情感
4. 思考有没有特殊的梗,一并总结成语言风格
5. 例子仅供参考,请严格根据群聊内容总结!!!
注意:总结成如下格式的规律,总结的内容要详细,但具有概括性:
例如:当"AAAAA"时,可以"BBBBB", AAAAA代表某个具体的场景不超过20个字。BBBBB代表对应的语言风格特定句式或表达方式不超过20个字。
例如:
"对某件事表示十分惊叹,有些意外"时,使用"我嘞个xxxx"
"表示讽刺的赞同,不想讲道理"时,使用"对对对"
"想说明某个具体的事实观点,但懒得明说,或者不便明说,或表达一种默契",使用"懂的都懂"
"当涉及游戏相关时,表示意外的夸赞,略带戏谑意味"时,使用"这么强!"
请注意不要总结你自己SELF的发言
现在请你概括
"""
Prompt(learn_style_prompt, "learn_style_prompt")
learn_grammar_prompt = """
{chat_str}
请从上面这段群聊中概括除了人名为"SELF"之外的人的语法和句法特点,只考虑纯文字,不要考虑表情包和图片
1.不要总结【图片】,【动画表情】,[图片][动画表情],不总结 表情符号 at @ 回复 和[回复]
2.不要涉及具体的人名,只考虑语法和句法特点,
3.语法和句法特点要包括,句子长短(具体字数),有何种语病,如何拆分句子。
4. 例子仅供参考,请严格根据群聊内容总结!!!
总结成如下格式的规律,总结的内容要简洁,不浮夸:
"xxx"时,可以"xxx"
例如:
"表达观点较复杂"时,使用"省略主语(3-6个字)"的句法
"不用详细说明的一般表达"时,使用"非常简洁的句子"的句法
"需要单纯简单的确认"时,使用"单字或几个字的肯定(1-2个字)"的句法
注意不要总结你自己SELF的发言
现在请你概括
"""
Prompt(learn_grammar_prompt, "learn_grammar_prompt")
class ExpressionLearner:
def __init__(self, chat_id: str) -> None:
self.express_learn_model: LLMRequest = LLMRequest(
model_set=model_config.model_task_config.replyer_1, request_type="expressor.learner"
)
self.chat_id = chat_id
self.chat_name = get_chat_manager().get_stream_name(chat_id) or chat_id
# 维护每个chat的上次学习时间
self.last_learning_time: float = time.time()
# 学习参数
self.min_messages_for_learning = 25 # 触发学习所需的最少消息数
self.min_learning_interval = 300 # 最短学习时间间隔(秒)
def can_learn_for_chat(self) -> bool:
"""
检查指定聊天流是否允许学习表达
Args:
chat_id: 聊天流ID
Returns:
bool: 是否允许学习
"""
try:
use_expression, enable_learning, _ = global_config.expression.get_expression_config_for_chat(self.chat_id)
return enable_learning
except Exception as e:
logger.error(f"检查学习权限失败: {e}")
return False
def should_trigger_learning(self) -> bool:
"""
检查是否应该触发学习
Args:
chat_id: 聊天流ID
Returns:
bool: 是否应该触发学习
"""
current_time = time.time()
# 获取该聊天流的学习强度
try:
use_expression, enable_learning, learning_intensity = global_config.expression.get_expression_config_for_chat(self.chat_id)
except Exception as e:
logger.error(f"获取聊天流 {self.chat_id} 的学习配置失败: {e}")
return False
# 检查是否允许学习
if not enable_learning:
return False
# 根据学习强度计算最短学习时间间隔
min_interval = self.min_learning_interval / learning_intensity
# 检查时间间隔
time_diff = current_time - self.last_learning_time
if time_diff < min_interval:
return False
# 检查消息数量(只检查指定聊天流的消息)
recent_messages = get_raw_msg_by_timestamp_with_chat_inclusive(
chat_id=self.chat_id,
timestamp_start=self.last_learning_time,
timestamp_end=time.time(),
)
if not recent_messages or len(recent_messages) < self.min_messages_for_learning:
return False
return True
async def trigger_learning_for_chat(self) -> bool:
"""
为指定聊天流触发学习
Args:
chat_id: 聊天流ID
Returns:
bool: 是否成功触发学习
"""
if not self.should_trigger_learning():
return False
try:
logger.info(f"为聊天流 {self.chat_name} 触发表达学习")
# 学习语言风格
learnt_style = await self.learn_and_store(type="style", num=25)
# 学习句法特点
learnt_grammar = await self.learn_and_store(type="grammar", num=10)
# 更新学习时间
self.last_learning_time = time.time()
if learnt_style or learnt_grammar:
logger.info(f"聊天流 {self.chat_name} 表达学习完成")
return True
else:
logger.warning(f"聊天流 {self.chat_name} 表达学习未获得有效结果")
return False
except Exception as e:
logger.error(f"为聊天流 {self.chat_name} 触发学习失败: {e}")
return False
def get_expression_by_chat_id(self) -> Tuple[List[Dict[str, float]], List[Dict[str, float]]]:
"""
获取指定chat_id的style和grammar表达方式
返回的每个表达方式字典中都包含了source_id, 用于后续的更新操作
"""
learnt_style_expressions = []
learnt_grammar_expressions = []
# 直接从数据库查询
style_query = session.execute(select(Expression).where((Expression.chat_id == self.chat_id) & (Expression.type == "style")))
for expr in style_query.scalars():
# 确保create_date存在如果不存在则使用last_active_time
create_date = expr.create_date if expr.create_date is not None else expr.last_active_time
learnt_style_expressions.append(
{
"situation": expr.situation,
"style": expr.style,
"count": expr.count,
"last_active_time": expr.last_active_time,
"source_id": self.chat_id,
"type": "style",
"create_date": create_date,
}
)
grammar_query = session.execute(select(Expression).where((Expression.chat_id == self.chat_id) & (Expression.type == "grammar")))
for expr in grammar_query.scalars():
# 确保create_date存在如果不存在则使用last_active_time
create_date = expr.create_date if expr.create_date is not None else expr.last_active_time
learnt_grammar_expressions.append(
{
"situation": expr.situation,
"style": expr.style,
"count": expr.count,
"last_active_time": expr.last_active_time,
"source_id": self.chat_id,
"type": "grammar",
"create_date": create_date,
}
)
return learnt_style_expressions, learnt_grammar_expressions
def _apply_global_decay_to_database(self, current_time: float) -> None:
"""
对数据库中的所有表达方式应用全局衰减
"""
try:
# 获取所有表达方式
all_expressions = session.execute(select(Expression)).scalars()
updated_count = 0
deleted_count = 0
for expr in all_expressions:
# 计算时间差
last_active = expr.last_active_time
time_diff_days = (current_time - last_active) / (24 * 3600) # 转换为天
# 计算衰减值
decay_value = self.calculate_decay_factor(time_diff_days)
new_count = max(0.01, expr.count - decay_value)
if new_count <= 0.01:
# 如果count太小删除这个表达方式
session.delete(expr)
deleted_count += 1
else:
# 更新count
expr.count = new_count
updated_count += 1
session.commit()
if updated_count > 0 or deleted_count > 0:
logger.info(f"全局衰减完成:更新了 {updated_count} 个表达方式,删除了 {deleted_count} 个表达方式")
except Exception as e:
session.rollback()
logger.error(f"数据库全局衰减失败: {e}")
def calculate_decay_factor(self, time_diff_days: float) -> float:
"""
计算衰减值
当时间差为0天时衰减值为0最近活跃的不衰减
当时间差为7天时衰减值为0.002(中等衰减)
当时间差为30天或更长时衰减值为0.01(高衰减)
使用二次函数进行曲线插值
"""
if time_diff_days <= 0:
return 0.0 # 刚激活的表达式不衰减
if time_diff_days >= DECAY_DAYS:
return 0.01 # 长时间未活跃的表达式大幅衰减
# 使用二次函数插值在0-30天之间从0衰减到0.01
# 使用简单的二次函数y = a * x^2
# 当x=30时y=0.01,所以 a = 0.01 / (30^2) = 0.01 / 900
a = 0.01 / (DECAY_DAYS**2)
decay = a * (time_diff_days**2)
return min(0.01, decay)
async def learn_and_store(self, type: str, num: int = 10) -> List[Tuple[str, str, str]]:
# sourcery skip: use-join
"""
学习并存储表达方式
type: "style" or "grammar"
"""
if type == "style":
type_str = "语言风格"
elif type == "grammar":
type_str = "句法特点"
else:
raise ValueError(f"Invalid type: {type}")
# 检查是否允许在此聊天流中学习(在函数最前面检查)
if not self.can_learn_for_chat():
logger.debug(f"聊天流 {self.chat_name} 不允许学习表达,跳过学习")
return []
res = await self.learn_expression(type, num)
if res is None:
return []
learnt_expressions, chat_id = res
chat_stream = get_chat_manager().get_stream(chat_id)
if chat_stream is None:
group_name = f"聊天流 {chat_id}"
elif chat_stream.group_info:
group_name = chat_stream.group_info.group_name
else:
group_name = f"{chat_stream.user_info.user_nickname}的私聊"
learnt_expressions_str = ""
for _chat_id, situation, style in learnt_expressions:
learnt_expressions_str += f"{situation}->{style}\n"
logger.info(f"{group_name} 学习到{type_str}:\n{learnt_expressions_str}")
if not learnt_expressions:
logger.info(f"没有学习到{type_str}")
return []
# 按chat_id分组
chat_dict: Dict[str, List[Dict[str, Any]]] = {}
for chat_id, situation, style in learnt_expressions:
if chat_id not in chat_dict:
chat_dict[chat_id] = []
chat_dict[chat_id].append({"situation": situation, "style": style})
current_time = time.time()
# 存储到数据库 Expression 表
for chat_id, expr_list in chat_dict.items():
for new_expr in expr_list:
# 查找是否已存在相似表达方式
query = session.execute(select(Expression).where(
(Expression.chat_id == chat_id)
& (Expression.type == type)
& (Expression.situation == new_expr["situation"])
& (Expression.style == new_expr["style"])
)).scalar()
if query:
expr_obj = query
# 50%概率替换内容
if random.random() < 0.5:
expr_obj.situation = new_expr["situation"]
expr_obj.style = new_expr["style"]
expr_obj.count = expr_obj.count + 1
expr_obj.last_active_time = current_time
else:
new_expression = Expression(
situation=new_expr["situation"],
style=new_expr["style"],
count=1,
last_active_time=current_time,
chat_id=chat_id,
type=type,
create_date=current_time, # 手动设置创建日期
)
session.add(new_expression)
# 限制最大数量
exprs = list(
session.execute(select(Expression)
.where((Expression.chat_id == chat_id) & (Expression.type == type))
.order_by(Expression.count.asc())).scalars()
)
if len(exprs) > MAX_EXPRESSION_COUNT:
# 删除count最小的多余表达方式
for expr in exprs[: len(exprs) - MAX_EXPRESSION_COUNT]:
session.delete(expr)
session.commit()
return learnt_expressions
async def learn_expression(self, type: str, num: int = 10) -> Optional[Tuple[List[Tuple[str, str, str]], str]]:
"""从指定聊天流学习表达方式
Args:
type: "style" or "grammar"
"""
if type == "style":
type_str = "语言风格"
prompt = "learn_style_prompt"
elif type == "grammar":
type_str = "句法特点"
prompt = "learn_grammar_prompt"
else:
raise ValueError(f"Invalid type: {type}")
current_time = time.time()
# 获取上次学习时间
random_msg: Optional[List[Dict[str, Any]]] = get_raw_msg_by_timestamp_with_chat_inclusive(
chat_id=self.chat_id,
timestamp_start=self.last_learning_time,
timestamp_end=current_time,
limit=num,
)
# print(random_msg)
if not random_msg or random_msg == []:
return None
# 转化成str
chat_id: str = random_msg[0]["chat_id"]
# random_msg_str: str = build_readable_messages(random_msg, timestamp_mode="normal")
random_msg_str: str = await build_anonymous_messages(random_msg)
# print(f"random_msg_str:{random_msg_str}")
prompt: str = await global_prompt_manager.format_prompt(
prompt,
chat_str=random_msg_str,
)
logger.debug(f"学习{type_str}的prompt: {prompt}")
try:
response, _ = await self.express_learn_model.generate_response_async(prompt, temperature=0.3)
except Exception as e:
logger.error(f"学习{type_str}失败: {e}")
return None
logger.debug(f"学习{type_str}的response: {response}")
expressions: List[Tuple[str, str, str]] = self.parse_expression_response(response, chat_id)
return expressions, chat_id
def parse_expression_response(self, response: str, chat_id: str) -> List[Tuple[str, str, str]]:
"""
解析LLM返回的表达风格总结每一行提取"""使用"之间的内容,存储为(situation, style)元组
"""
expressions: List[Tuple[str, str, str]] = []
for line in response.splitlines():
line = line.strip()
if not line:
continue
# 查找"当"和下一个引号
idx_when = line.find('"')
if idx_when == -1:
continue
idx_quote1 = idx_when + 1
idx_quote2 = line.find('"', idx_quote1 + 1)
if idx_quote2 == -1:
continue
situation = line[idx_quote1 + 1 : idx_quote2]
# 查找"使用"
idx_use = line.find('使用"', idx_quote2)
if idx_use == -1:
continue
idx_quote3 = idx_use + 2
idx_quote4 = line.find('"', idx_quote3 + 1)
if idx_quote4 == -1:
continue
style = line[idx_quote3 + 1 : idx_quote4]
expressions.append((chat_id, situation, style))
return expressions
init_prompt()
class ExpressionLearnerManager:
def __init__(self):
self.expression_learners = {}
self._ensure_expression_directories()
self._auto_migrate_json_to_db()
self._migrate_old_data_create_date()
def get_expression_learner(self, chat_id: str) -> ExpressionLearner:
if chat_id not in self.expression_learners:
self.expression_learners[chat_id] = ExpressionLearner(chat_id)
return self.expression_learners[chat_id]
def _ensure_expression_directories(self):
"""
确保表达方式相关的目录结构存在
"""
base_dir = os.path.join("data", "expression")
directories_to_create = [
base_dir,
os.path.join(base_dir, "learnt_style"),
os.path.join(base_dir, "learnt_grammar"),
]
for directory in directories_to_create:
try:
os.makedirs(directory, exist_ok=True)
logger.debug(f"确保目录存在: {directory}")
except Exception as e:
logger.error(f"创建目录失败 {directory}: {e}")
def _auto_migrate_json_to_db(self):
"""
自动将/data/expression/learnt_style 和 learnt_grammar 下所有expressions.json迁移到数据库。
迁移完成后在/data/expression/done.done写入标记文件存在则跳过。
"""
base_dir = os.path.join("data", "expression")
done_flag = os.path.join(base_dir, "done.done")
# 确保基础目录存在
try:
os.makedirs(base_dir, exist_ok=True)
logger.debug(f"确保目录存在: {base_dir}")
except Exception as e:
logger.error(f"创建表达方式目录失败: {e}")
return
if os.path.exists(done_flag):
logger.info("表达方式JSON已迁移无需重复迁移。")
return
logger.info("开始迁移表达方式JSON到数据库...")
migrated_count = 0
for type in ["learnt_style", "learnt_grammar"]:
type_str = "style" if type == "learnt_style" else "grammar"
type_dir = os.path.join(base_dir, type)
if not os.path.exists(type_dir):
logger.debug(f"目录不存在,跳过: {type_dir}")
continue
try:
chat_ids = os.listdir(type_dir)
logger.debug(f"{type_dir} 中找到 {len(chat_ids)} 个聊天ID目录")
except Exception as e:
logger.error(f"读取目录失败 {type_dir}: {e}")
continue
for chat_id in chat_ids:
expr_file = os.path.join(type_dir, chat_id, "expressions.json")
if not os.path.exists(expr_file):
continue
try:
with open(expr_file, "r", encoding="utf-8") as f:
expressions = json.load(f)
if not isinstance(expressions, list):
logger.warning(f"表达方式文件格式错误,跳过: {expr_file}")
continue
for expr in expressions:
if not isinstance(expr, dict):
continue
situation = expr.get("situation")
style_val = expr.get("style")
count = expr.get("count", 1)
last_active_time = expr.get("last_active_time", time.time())
if not situation or not style_val:
logger.warning(f"表达方式缺少必要字段,跳过: {expr}")
continue
# 查重同chat_id+type+situation+style
query = session.execute(select(Expression).where(
(Expression.chat_id == chat_id)
& (Expression.type == type_str)
& (Expression.situation == situation)
& (Expression.style == style_val)
)).scalar()
if query:
expr_obj = query
expr_obj.count = max(expr_obj.count, count)
expr_obj.last_active_time = max(expr_obj.last_active_time, last_active_time)
else:
new_expression = Expression(
situation=situation,
style=style_val,
count=count,
last_active_time=last_active_time,
chat_id=chat_id,
type=type_str,
create_date=last_active_time, # 迁移时使用last_active_time作为创建时间
)
session.add(new_expression)
migrated_count += 1
logger.info(f"已迁移 {expr_file} 到数据库,包含 {len(expressions)} 个表达方式")
except json.JSONDecodeError as e:
logger.error(f"JSON解析失败 {expr_file}: {e}")
except Exception as e:
logger.error(f"迁移表达方式 {expr_file} 失败: {e}")
# 标记迁移完成
try:
# 确保done.done文件的父目录存在
done_parent_dir = os.path.dirname(done_flag)
if not os.path.exists(done_parent_dir):
os.makedirs(done_parent_dir, exist_ok=True)
logger.debug(f"为done.done创建父目录: {done_parent_dir}")
with open(done_flag, "w", encoding="utf-8") as f:
f.write("done\n")
logger.info(f"表达方式JSON迁移已完成共迁移 {migrated_count} 个表达方式已写入done.done标记文件")
except PermissionError as e:
logger.error(f"权限不足无法写入done.done标记文件: {e}")
except OSError as e:
logger.error(f"文件系统错误无法写入done.done标记文件: {e}")
except Exception as e:
logger.error(f"写入done.done标记文件失败: {e}")
def _migrate_old_data_create_date(self):
"""
为没有create_date的老数据设置创建日期
使用last_active_time作为create_date的默认值
"""
try:
# 查找所有create_date为空的表达方式
old_expressions = session.execute(select(Expression).where(Expression.create_date.is_(None))).scalars()
updated_count = 0
for expr in old_expressions:
# 使用last_active_time作为create_date
expr.create_date = expr.last_active_time
updated_count += 1
session.commit()
if updated_count > 0:
logger.info(f"已为 {updated_count} 个老的表达方式设置创建日期")
except Exception as e:
session.rollback()
logger.error(f"迁移老数据创建日期失败: {e}")
expression_learner_manager = ExpressionLearnerManager()

View File

@@ -0,0 +1,339 @@
import json
import time
import random
import hashlib
from typing import List, Dict, Tuple, Optional, Any
from json_repair import repair_json
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config, model_config
from src.common.logger import get_logger
from sqlalchemy import select
from src.common.database.sqlalchemy_models import Expression
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.common.database.sqlalchemy_database_api import get_session
session = get_session()
logger = get_logger("expression_selector")
def init_prompt():
expression_evaluation_prompt = """
以下是正在进行的聊天内容:
{chat_observe_info}
你的名字是{bot_name}{target_message}
以下是可选的表达情境:
{all_situations}
请你分析聊天内容的语境、情绪、话题类型,从上述情境中选择最适合当前聊天情境的{min_num}-{max_num}个情境。
考虑因素包括:
1. 聊天的情绪氛围(轻松、严肃、幽默等)
2. 话题类型(日常、技术、游戏、情感等)
3. 情境与当前语境的匹配度
{target_message_extra_block}
请以JSON格式输出只需要输出选中的情境编号
例如:
{{
"selected_situations": [2, 3, 5, 7, 19, 22, 25, 38, 39, 45, 48, 64]
}}
请严格按照JSON格式输出不要包含其他内容
"""
Prompt(expression_evaluation_prompt, "expression_evaluation_prompt")
def weighted_sample(population: List[Dict], weights: List[float], k: int) -> List[Dict]:
"""按权重随机抽样"""
if not population or not weights or k <= 0:
return []
if len(population) <= k:
return population.copy()
# 使用累积权重的方法进行加权抽样
selected = []
population_copy = population.copy()
weights_copy = weights.copy()
for _ in range(k):
if not population_copy:
break
# 选择一个元素
chosen_idx = random.choices(range(len(population_copy)), weights=weights_copy)[0]
selected.append(population_copy.pop(chosen_idx))
weights_copy.pop(chosen_idx)
return selected
class ExpressionSelector:
def __init__(self):
self.llm_model = LLMRequest(
model_set=model_config.model_task_config.utils_small, request_type="expression.selector"
)
def can_use_expression_for_chat(self, chat_id: str) -> bool:
"""
检查指定聊天流是否允许使用表达
Args:
chat_id: 聊天流ID
Returns:
bool: 是否允许使用表达
"""
try:
use_expression, _, _ = global_config.expression.get_expression_config_for_chat(chat_id)
return use_expression
except Exception as e:
logger.error(f"检查表达使用权限失败: {e}")
return False
@staticmethod
def _parse_stream_config_to_chat_id(stream_config_str: str) -> Optional[str]:
"""解析'platform:id:type'为chat_id与get_stream_id一致"""
try:
parts = stream_config_str.split(":")
if len(parts) != 3:
return None
platform = parts[0]
id_str = parts[1]
stream_type = parts[2]
is_group = stream_type == "group"
if is_group:
components = [platform, str(id_str)]
else:
components = [platform, str(id_str), "private"]
key = "_".join(components)
return hashlib.md5(key.encode()).hexdigest()
except Exception:
return None
def get_related_chat_ids(self, chat_id: str) -> List[str]:
"""根据expression_groups配置获取与当前chat_id相关的所有chat_id包括自身"""
groups = global_config.expression.expression_groups
for group in groups:
group_chat_ids = []
for stream_config_str in group:
if chat_id_candidate := self._parse_stream_config_to_chat_id(stream_config_str):
group_chat_ids.append(chat_id_candidate)
if chat_id in group_chat_ids:
return group_chat_ids
return [chat_id]
def get_random_expressions(
self, chat_id: str, total_num: int, style_percentage: float, grammar_percentage: float
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
# sourcery skip: extract-duplicate-method, move-assign
# 支持多chat_id合并抽选
related_chat_ids = self.get_related_chat_ids(chat_id)
# 优化一次性查询所有相关chat_id的表达方式
style_query = session.execute(select(Expression).where(
(Expression.chat_id.in_(related_chat_ids)) & (Expression.type == "style")
))
grammar_query = session.execute(select(Expression).where(
(Expression.chat_id.in_(related_chat_ids)) & (Expression.type == "grammar")
))
style_exprs = [
{
"situation": expr.situation,
"style": expr.style,
"count": expr.count,
"last_active_time": expr.last_active_time,
"source_id": expr.chat_id,
"type": "style",
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
}
for expr in style_query.scalars()
]
grammar_exprs = [
{
"situation": expr.situation,
"style": expr.style,
"count": expr.count,
"last_active_time": expr.last_active_time,
"source_id": expr.chat_id,
"type": "grammar",
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
}
for expr in grammar_query.scalars()
]
style_num = int(total_num * style_percentage)
grammar_num = int(total_num * grammar_percentage)
# 按权重抽样使用count作为权重
if style_exprs:
style_weights = [expr.get("count", 1) for expr in style_exprs]
selected_style = weighted_sample(style_exprs, style_weights, style_num)
else:
selected_style = []
if grammar_exprs:
grammar_weights = [expr.get("count", 1) for expr in grammar_exprs]
selected_grammar = weighted_sample(grammar_exprs, grammar_weights, grammar_num)
else:
selected_grammar = []
return selected_style, selected_grammar
def update_expressions_count_batch(self, expressions_to_update: List[Dict[str, Any]], increment: float = 0.1):
"""对一批表达方式更新count值按chat_id+type分组后一次性写入数据库"""
if not expressions_to_update:
return
updates_by_key = {}
for expr in expressions_to_update:
source_id: str = expr.get("source_id") # type: ignore
expr_type: str = expr.get("type", "style")
situation: str = expr.get("situation") # type: ignore
style: str = expr.get("style") # type: ignore
if not source_id or not situation or not style:
logger.warning(f"表达方式缺少必要字段,无法更新: {expr}")
continue
key = (source_id, expr_type, situation, style)
if key not in updates_by_key:
updates_by_key[key] = expr
for chat_id, expr_type, situation, style in updates_by_key:
query = session.execute(select(Expression).where(
(Expression.chat_id == chat_id)
& (Expression.type == expr_type)
& (Expression.situation == situation)
& (Expression.style == style)
)).scalar()
if query:
expr_obj = query
current_count = expr_obj.count
new_count = min(current_count + increment, 5.0)
expr_obj.count = new_count
expr_obj.last_active_time = time.time()
session.commit()
logger.debug(
f"表达方式激活: 原count={current_count:.3f}, 增量={increment}, 新count={new_count:.3f} in db"
)
async def select_suitable_expressions_llm(
self,
chat_id: str,
chat_info: str,
max_num: int = 10,
min_num: int = 5,
target_message: Optional[str] = None,
) -> List[Dict[str, Any]]:
# sourcery skip: inline-variable, list-comprehension
"""使用LLM选择适合的表达方式"""
# 检查是否允许在此聊天流中使用表达
if not self.can_use_expression_for_chat(chat_id):
logger.debug(f"聊天流 {chat_id} 不允许使用表达,返回空列表")
return []
# 1. 获取35个随机表达方式现在按权重抽取
style_exprs, grammar_exprs = self.get_random_expressions(chat_id, 30, 0.5, 0.5)
# 2. 构建所有表达方式的索引和情境列表
all_expressions = []
all_situations = []
# 添加style表达方式
for expr in style_exprs:
if isinstance(expr, dict) and "situation" in expr and "style" in expr:
expr_with_type = expr.copy()
expr_with_type["type"] = "style"
all_expressions.append(expr_with_type)
all_situations.append(f"{len(all_expressions)}.{expr['situation']}")
# 添加grammar表达方式
for expr in grammar_exprs:
if isinstance(expr, dict) and "situation" in expr and "style" in expr:
expr_with_type = expr.copy()
expr_with_type["type"] = "grammar"
all_expressions.append(expr_with_type)
all_situations.append(f"{len(all_expressions)}.{expr['situation']}")
if not all_expressions:
logger.warning("没有找到可用的表达方式")
return []
all_situations_str = "\n".join(all_situations)
if target_message:
target_message_str = f",现在你想要回复消息:{target_message}"
target_message_extra_block = "4.考虑你要回复的目标消息"
else:
target_message_str = ""
target_message_extra_block = ""
# 3. 构建prompt只包含情境不包含完整的表达方式
prompt = (await global_prompt_manager.get_prompt_async("expression_evaluation_prompt")).format(
bot_name=global_config.bot.nickname,
chat_observe_info=chat_info,
all_situations=all_situations_str,
min_num=min_num,
max_num=max_num,
target_message=target_message_str,
target_message_extra_block=target_message_extra_block,
)
# print(prompt)
# 4. 调用LLM
try:
# start_time = time.time()
content, (reasoning_content, model_name, _) = await self.llm_model.generate_response_async(prompt=prompt)
# logger.info(f"LLM请求时间: {model_name} {time.time() - start_time} \n{prompt}")
# logger.info(f"模型名称: {model_name}")
# logger.info(f"LLM返回结果: {content}")
# if reasoning_content:
# logger.info(f"LLM推理: {reasoning_content}")
# else:
# logger.info(f"LLM推理: 无")
if not content:
logger.warning("LLM返回空结果")
return []
# 5. 解析结果
result = repair_json(content)
if isinstance(result, str):
result = json.loads(result)
if not isinstance(result, dict) or "selected_situations" not in result:
logger.error("LLM返回格式错误")
logger.info(f"LLM返回结果: \n{content}")
return []
selected_indices = result["selected_situations"]
# 根据索引获取完整的表达方式
valid_expressions = []
for idx in selected_indices:
if isinstance(idx, int) and 1 <= idx <= len(all_expressions):
expression = all_expressions[idx - 1] # 索引从1开始
valid_expressions.append(expression)
# 对选中的所有表达方式一次性更新count数
if valid_expressions:
self.update_expressions_count_batch(valid_expressions, 0.006)
# logger.info(f"LLM从{len(all_expressions)}个情境中选择了{len(valid_expressions)}个")
return valid_expressions
except Exception as e:
logger.error(f"LLM处理表达方式选择时出错: {e}")
return []
init_prompt()
try:
expression_selector = ExpressionSelector()
except Exception as e:
print(f"ExpressionSelector初始化失败: {e}")