feat:记录使用的表达方式
This commit is contained in:
@@ -233,9 +233,9 @@ class HeartFChatting:
|
|||||||
modified_exit_interest_threshold = 1.5 / talk_frequency
|
modified_exit_interest_threshold = 1.5 / talk_frequency
|
||||||
total_interest = 0.0
|
total_interest = 0.0
|
||||||
for msg_dict in new_message:
|
for msg_dict in new_message:
|
||||||
interest_value = msg_dict.get("interest_value", 0.0)
|
interest_value = msg_dict.get("interest_value")
|
||||||
if msg_dict.get("processed_plain_text", ""):
|
if interest_value is not None and msg_dict.get("processed_plain_text", ""):
|
||||||
total_interest += interest_value
|
total_interest += float(interest_value)
|
||||||
|
|
||||||
if new_message_count >= modified_exit_count_threshold:
|
if new_message_count >= modified_exit_count_threshold:
|
||||||
self.recent_interest_records.append(total_interest)
|
self.recent_interest_records.append(total_interest)
|
||||||
@@ -244,7 +244,7 @@ class HeartFChatting:
|
|||||||
)
|
)
|
||||||
# logger.info(self.last_read_time)
|
# logger.info(self.last_read_time)
|
||||||
# logger.info(new_message)
|
# logger.info(new_message)
|
||||||
return True,total_interest/new_message_count
|
return True, total_interest / new_message_count if new_message_count > 0 else 0.0
|
||||||
|
|
||||||
# 检查累计兴趣值
|
# 检查累计兴趣值
|
||||||
if new_message_count > 0:
|
if new_message_count > 0:
|
||||||
@@ -259,7 +259,7 @@ class HeartFChatting:
|
|||||||
logger.info(
|
logger.info(
|
||||||
f"{self.log_prefix} 累计兴趣值达到{total_interest:.2f}(>{modified_exit_interest_threshold:.1f}),结束等待"
|
f"{self.log_prefix} 累计兴趣值达到{total_interest:.2f}(>{modified_exit_interest_threshold:.1f}),结束等待"
|
||||||
)
|
)
|
||||||
return True,total_interest/new_message_count
|
return True, total_interest / new_message_count if new_message_count > 0 else 0.0
|
||||||
|
|
||||||
# 每10秒输出一次等待状态
|
# 每10秒输出一次等待状态
|
||||||
if int(time.time() - self.last_read_time) > 0 and int(time.time() - self.last_read_time) % 15 == 0:
|
if int(time.time() - self.last_read_time) > 0 and int(time.time() - self.last_read_time) % 15 == 0:
|
||||||
@@ -302,10 +302,15 @@ class HeartFChatting:
|
|||||||
cycle_timers: Dict[str, float],
|
cycle_timers: Dict[str, float],
|
||||||
thinking_id,
|
thinking_id,
|
||||||
actions,
|
actions,
|
||||||
|
selected_expressions:List[int] = None,
|
||||||
) -> Tuple[Dict[str, Any], str, Dict[str, float]]:
|
) -> Tuple[Dict[str, Any], str, Dict[str, float]]:
|
||||||
|
|
||||||
with Timer("回复发送", cycle_timers):
|
with Timer("回复发送", cycle_timers):
|
||||||
reply_text = await self._send_response(response_set, action_message)
|
reply_text = await self._send_response(
|
||||||
|
reply_set=response_set,
|
||||||
|
message_data=action_message,
|
||||||
|
selected_expressions=selected_expressions,
|
||||||
|
)
|
||||||
|
|
||||||
# 获取 platform,如果不存在则从 chat_stream 获取,如果还是 None 则使用默认值
|
# 获取 platform,如果不存在则从 chat_stream 获取,如果还是 None 则使用默认值
|
||||||
platform = action_message.get("chat_info_platform")
|
platform = action_message.get("chat_info_platform")
|
||||||
@@ -474,7 +479,7 @@ class HeartFChatting:
|
|||||||
else:
|
else:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
success, response_set, _ = await generator_api.generate_reply(
|
success, response_set, prompt_selected_expressions = await generator_api.generate_reply(
|
||||||
chat_stream=self.chat_stream,
|
chat_stream=self.chat_stream,
|
||||||
reply_message = action_info["action_message"],
|
reply_message = action_info["action_message"],
|
||||||
available_actions=available_actions,
|
available_actions=available_actions,
|
||||||
@@ -483,7 +488,13 @@ class HeartFChatting:
|
|||||||
enable_tool=global_config.tool.enable_tool,
|
enable_tool=global_config.tool.enable_tool,
|
||||||
request_type="replyer",
|
request_type="replyer",
|
||||||
from_plugin=False,
|
from_plugin=False,
|
||||||
|
return_expressions=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if prompt_selected_expressions and len(prompt_selected_expressions) > 1:
|
||||||
|
_,selected_expressions = prompt_selected_expressions
|
||||||
|
else:
|
||||||
|
selected_expressions = []
|
||||||
|
|
||||||
if not success or not response_set:
|
if not success or not response_set:
|
||||||
logger.info(f"对 {action_info['action_message'].get('processed_plain_text')} 的回复生成失败")
|
logger.info(f"对 {action_info['action_message'].get('processed_plain_text')} 的回复生成失败")
|
||||||
@@ -504,11 +515,12 @@ class HeartFChatting:
|
|||||||
}
|
}
|
||||||
|
|
||||||
loop_info, reply_text, cycle_timers_reply = await self._send_and_store_reply(
|
loop_info, reply_text, cycle_timers_reply = await self._send_and_store_reply(
|
||||||
response_set,
|
response_set=response_set,
|
||||||
action_info["action_message"],
|
action_message=action_info["action_message"],
|
||||||
cycle_timers,
|
cycle_timers=cycle_timers,
|
||||||
thinking_id,
|
thinking_id=thinking_id,
|
||||||
actions,
|
actions=actions,
|
||||||
|
selected_expressions=selected_expressions,
|
||||||
)
|
)
|
||||||
return {
|
return {
|
||||||
"action_type": "reply",
|
"action_type": "reply",
|
||||||
@@ -685,7 +697,11 @@ class HeartFChatting:
|
|||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
return False, "", ""
|
return False, "", ""
|
||||||
|
|
||||||
async def _send_response(self, reply_set, message_data) -> str:
|
async def _send_response(self,
|
||||||
|
reply_set,
|
||||||
|
message_data,
|
||||||
|
selected_expressions:List[int] = None,
|
||||||
|
) -> str:
|
||||||
new_message_count = message_api.count_new_messages(
|
new_message_count = message_api.count_new_messages(
|
||||||
chat_id=self.chat_stream.stream_id, start_time=self.last_read_time, end_time=time.time()
|
chat_id=self.chat_stream.stream_id, start_time=self.last_read_time, end_time=time.time()
|
||||||
)
|
)
|
||||||
@@ -706,6 +722,7 @@ class HeartFChatting:
|
|||||||
reply_message = message_data,
|
reply_message = message_data,
|
||||||
set_reply=need_reply,
|
set_reply=need_reply,
|
||||||
typing=False,
|
typing=False,
|
||||||
|
selected_expressions=selected_expressions,
|
||||||
)
|
)
|
||||||
first_replied = True
|
first_replied = True
|
||||||
else:
|
else:
|
||||||
@@ -715,6 +732,7 @@ class HeartFChatting:
|
|||||||
reply_message = message_data,
|
reply_message = message_data,
|
||||||
set_reply=False,
|
set_reply=False,
|
||||||
typing=True,
|
typing=True,
|
||||||
|
selected_expressions=selected_expressions,
|
||||||
)
|
)
|
||||||
reply_text += data
|
reply_text += data
|
||||||
|
|
||||||
|
|||||||
@@ -137,6 +137,7 @@ class ExpressionSelector:
|
|||||||
|
|
||||||
style_exprs = [
|
style_exprs = [
|
||||||
{
|
{
|
||||||
|
"id": expr.id,
|
||||||
"situation": expr.situation,
|
"situation": expr.situation,
|
||||||
"style": expr.style,
|
"style": expr.style,
|
||||||
"count": expr.count,
|
"count": expr.count,
|
||||||
@@ -203,14 +204,14 @@ class ExpressionSelector:
|
|||||||
# 检查是否允许在此聊天流中使用表达
|
# 检查是否允许在此聊天流中使用表达
|
||||||
if not self.can_use_expression_for_chat(chat_id):
|
if not self.can_use_expression_for_chat(chat_id):
|
||||||
logger.debug(f"聊天流 {chat_id} 不允许使用表达,返回空列表")
|
logger.debug(f"聊天流 {chat_id} 不允许使用表达,返回空列表")
|
||||||
return []
|
return [], []
|
||||||
|
|
||||||
# 1. 获取20个随机表达方式(现在按权重抽取)
|
# 1. 获取20个随机表达方式(现在按权重抽取)
|
||||||
style_exprs = self.get_random_expressions(chat_id, 10)
|
style_exprs = self.get_random_expressions(chat_id, 10)
|
||||||
|
|
||||||
if len(style_exprs) < 20:
|
if len(style_exprs) < 10:
|
||||||
logger.info(f"聊天流 {chat_id} 表达方式正在积累中")
|
logger.info(f"聊天流 {chat_id} 表达方式正在积累中")
|
||||||
return []
|
return [], []
|
||||||
|
|
||||||
# 2. 构建所有表达方式的索引和情境列表
|
# 2. 构建所有表达方式的索引和情境列表
|
||||||
all_expressions = []
|
all_expressions = []
|
||||||
@@ -218,15 +219,13 @@ class ExpressionSelector:
|
|||||||
|
|
||||||
# 添加style表达方式
|
# 添加style表达方式
|
||||||
for expr in style_exprs:
|
for expr in style_exprs:
|
||||||
if isinstance(expr, dict) and "situation" in expr and "style" in expr:
|
expr = expr.copy()
|
||||||
expr_with_type = expr.copy()
|
all_expressions.append(expr)
|
||||||
expr_with_type["type"] = "style"
|
all_situations.append(f"{len(all_expressions)}.当 {expr['situation']} 时,使用 {expr['style']}")
|
||||||
all_expressions.append(expr_with_type)
|
|
||||||
all_situations.append(f"{len(all_expressions)}.当 {expr['situation']} 时,使用 {expr['style']}")
|
|
||||||
|
|
||||||
if not all_expressions:
|
if not all_expressions:
|
||||||
logger.warning("没有找到可用的表达方式")
|
logger.warning("没有找到可用的表达方式")
|
||||||
return []
|
return [], []
|
||||||
|
|
||||||
all_situations_str = "\n".join(all_situations)
|
all_situations_str = "\n".join(all_situations)
|
||||||
|
|
||||||
@@ -247,8 +246,6 @@ class ExpressionSelector:
|
|||||||
target_message_extra_block=target_message_extra_block,
|
target_message_extra_block=target_message_extra_block,
|
||||||
)
|
)
|
||||||
|
|
||||||
print(prompt)
|
|
||||||
|
|
||||||
# 4. 调用LLM
|
# 4. 调用LLM
|
||||||
try:
|
try:
|
||||||
|
|
||||||
@@ -265,7 +262,7 @@ class ExpressionSelector:
|
|||||||
|
|
||||||
if not content:
|
if not content:
|
||||||
logger.warning("LLM返回空结果")
|
logger.warning("LLM返回空结果")
|
||||||
return []
|
return [], []
|
||||||
|
|
||||||
# 5. 解析结果
|
# 5. 解析结果
|
||||||
result = repair_json(content)
|
result = repair_json(content)
|
||||||
@@ -275,15 +272,17 @@ class ExpressionSelector:
|
|||||||
if not isinstance(result, dict) or "selected_situations" not in result:
|
if not isinstance(result, dict) or "selected_situations" not in result:
|
||||||
logger.error("LLM返回格式错误")
|
logger.error("LLM返回格式错误")
|
||||||
logger.info(f"LLM返回结果: \n{content}")
|
logger.info(f"LLM返回结果: \n{content}")
|
||||||
return []
|
return [], []
|
||||||
|
|
||||||
selected_indices = result["selected_situations"]
|
selected_indices = result["selected_situations"]
|
||||||
|
|
||||||
# 根据索引获取完整的表达方式
|
# 根据索引获取完整的表达方式
|
||||||
valid_expressions = []
|
valid_expressions = []
|
||||||
|
selected_ids = []
|
||||||
for idx in selected_indices:
|
for idx in selected_indices:
|
||||||
if isinstance(idx, int) and 1 <= idx <= len(all_expressions):
|
if isinstance(idx, int) and 1 <= idx <= len(all_expressions):
|
||||||
expression = all_expressions[idx - 1] # 索引从1开始
|
expression = all_expressions[idx - 1] # 索引从1开始
|
||||||
|
selected_ids.append(expression["id"])
|
||||||
valid_expressions.append(expression)
|
valid_expressions.append(expression)
|
||||||
|
|
||||||
# 对选中的所有表达方式,一次性更新count数
|
# 对选中的所有表达方式,一次性更新count数
|
||||||
@@ -291,11 +290,11 @@ class ExpressionSelector:
|
|||||||
self.update_expressions_count_batch(valid_expressions, 0.006)
|
self.update_expressions_count_batch(valid_expressions, 0.006)
|
||||||
|
|
||||||
# logger.info(f"LLM从{len(all_expressions)}个情境中选择了{len(valid_expressions)}个")
|
# logger.info(f"LLM从{len(all_expressions)}个情境中选择了{len(valid_expressions)}个")
|
||||||
return valid_expressions
|
return valid_expressions , selected_ids
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"LLM处理表达方式选择时出错: {e}")
|
logger.error(f"LLM处理表达方式选择时出错: {e}")
|
||||||
return []
|
return [], []
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,298 +0,0 @@
|
|||||||
import json
|
|
||||||
import time
|
|
||||||
import random
|
|
||||||
import hashlib
|
|
||||||
|
|
||||||
from typing import List, Dict, Tuple, Optional, Any
|
|
||||||
from json_repair import repair_json
|
|
||||||
|
|
||||||
from src.llm_models.utils_model import LLMRequest
|
|
||||||
from src.config.config import global_config, model_config
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
from src.common.database.database_model import Expression
|
|
||||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
|
||||||
|
|
||||||
logger = get_logger("expression_selector")
|
|
||||||
|
|
||||||
|
|
||||||
def init_prompt():
|
|
||||||
expression_evaluation_prompt = """
|
|
||||||
以下是正在进行的聊天内容:
|
|
||||||
{chat_observe_info}
|
|
||||||
|
|
||||||
你的名字是{bot_name}{target_message}
|
|
||||||
|
|
||||||
你知道以下这些表达方式,梗和说话方式:
|
|
||||||
{all_situations}
|
|
||||||
|
|
||||||
现在,请你根据聊天记录从中挑选合适的表达方式,梗和说话方式,组织一条回复风格指导,指导的目的是在组织回复的时候提供一些语言风格和梗上的参考。
|
|
||||||
请在reply_style_guide中以平文本输出指导,不要浮夸,并在selected_expressions中说明在指导中你挑选了哪些表达方式,梗和说话方式,以json格式输出:
|
|
||||||
例子:
|
|
||||||
{{
|
|
||||||
"reply_style_guide": "...",
|
|
||||||
"selected_expressions": [2, 3, 4, 7]
|
|
||||||
}}
|
|
||||||
请严格按照JSON格式输出,不要包含其他内容:
|
|
||||||
"""
|
|
||||||
Prompt(expression_evaluation_prompt, "expression_evaluation_prompt")
|
|
||||||
|
|
||||||
|
|
||||||
def weighted_sample(population: List[Dict], weights: List[float], k: int) -> List[Dict]:
|
|
||||||
"""按权重随机抽样"""
|
|
||||||
if not population or not weights or k <= 0:
|
|
||||||
return []
|
|
||||||
|
|
||||||
if len(population) <= k:
|
|
||||||
return population.copy()
|
|
||||||
|
|
||||||
# 使用累积权重的方法进行加权抽样
|
|
||||||
selected = []
|
|
||||||
population_copy = population.copy()
|
|
||||||
weights_copy = weights.copy()
|
|
||||||
|
|
||||||
for _ in range(k):
|
|
||||||
if not population_copy:
|
|
||||||
break
|
|
||||||
|
|
||||||
# 选择一个元素
|
|
||||||
chosen_idx = random.choices(range(len(population_copy)), weights=weights_copy)[0]
|
|
||||||
selected.append(population_copy.pop(chosen_idx))
|
|
||||||
weights_copy.pop(chosen_idx)
|
|
||||||
|
|
||||||
return selected
|
|
||||||
|
|
||||||
|
|
||||||
class ExpressionSelector:
|
|
||||||
def __init__(self):
|
|
||||||
self.llm_model = LLMRequest(
|
|
||||||
model_set=model_config.model_task_config.utils_small, request_type="expression.selector"
|
|
||||||
)
|
|
||||||
|
|
||||||
def can_use_expression_for_chat(self, chat_id: str) -> bool:
|
|
||||||
"""
|
|
||||||
检查指定聊天流是否允许使用表达
|
|
||||||
|
|
||||||
Args:
|
|
||||||
chat_id: 聊天流ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: 是否允许使用表达
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
use_expression, _, _ = global_config.expression.get_expression_config_for_chat(chat_id)
|
|
||||||
return use_expression
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"检查表达使用权限失败: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _parse_stream_config_to_chat_id(stream_config_str: str) -> Optional[str]:
|
|
||||||
"""解析'platform:id:type'为chat_id(与get_stream_id一致)"""
|
|
||||||
try:
|
|
||||||
parts = stream_config_str.split(":")
|
|
||||||
if len(parts) != 3:
|
|
||||||
return None
|
|
||||||
platform = parts[0]
|
|
||||||
id_str = parts[1]
|
|
||||||
stream_type = parts[2]
|
|
||||||
is_group = stream_type == "group"
|
|
||||||
if is_group:
|
|
||||||
components = [platform, str(id_str)]
|
|
||||||
else:
|
|
||||||
components = [platform, str(id_str), "private"]
|
|
||||||
key = "_".join(components)
|
|
||||||
return hashlib.md5(key.encode()).hexdigest()
|
|
||||||
except Exception:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def get_related_chat_ids(self, chat_id: str) -> List[str]:
|
|
||||||
"""根据expression_groups配置,获取与当前chat_id相关的所有chat_id(包括自身)"""
|
|
||||||
groups = global_config.expression.expression_groups
|
|
||||||
for group in groups:
|
|
||||||
group_chat_ids = []
|
|
||||||
for stream_config_str in group:
|
|
||||||
if chat_id_candidate := self._parse_stream_config_to_chat_id(stream_config_str):
|
|
||||||
group_chat_ids.append(chat_id_candidate)
|
|
||||||
if chat_id in group_chat_ids:
|
|
||||||
return group_chat_ids
|
|
||||||
return [chat_id]
|
|
||||||
|
|
||||||
def get_random_expressions(
|
|
||||||
self, chat_id: str, total_num: int
|
|
||||||
) -> List[Dict[str, Any]]:
|
|
||||||
# sourcery skip: extract-duplicate-method, move-assign
|
|
||||||
# 支持多chat_id合并抽选
|
|
||||||
related_chat_ids = self.get_related_chat_ids(chat_id)
|
|
||||||
|
|
||||||
# 优化:一次性查询所有相关chat_id的表达方式
|
|
||||||
style_query = Expression.select().where(
|
|
||||||
(Expression.chat_id.in_(related_chat_ids)) & (Expression.type == "style")
|
|
||||||
)
|
|
||||||
|
|
||||||
style_exprs = [
|
|
||||||
{
|
|
||||||
"situation": expr.situation,
|
|
||||||
"style": expr.style,
|
|
||||||
"count": expr.count,
|
|
||||||
"last_active_time": expr.last_active_time,
|
|
||||||
"source_id": expr.chat_id,
|
|
||||||
"type": "style",
|
|
||||||
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
|
|
||||||
}
|
|
||||||
for expr in style_query
|
|
||||||
]
|
|
||||||
|
|
||||||
# 按权重抽样(使用count作为权重)
|
|
||||||
if style_exprs:
|
|
||||||
style_weights = [expr.get("count", 1) for expr in style_exprs]
|
|
||||||
selected_style = weighted_sample(style_exprs, style_weights, total_num)
|
|
||||||
else:
|
|
||||||
selected_style = []
|
|
||||||
return selected_style
|
|
||||||
|
|
||||||
def update_expressions_count_batch(self, expressions_to_update: List[Dict[str, Any]], increment: float = 0.1):
|
|
||||||
"""对一批表达方式更新count值,按chat_id+type分组后一次性写入数据库"""
|
|
||||||
if not expressions_to_update:
|
|
||||||
return
|
|
||||||
updates_by_key = {}
|
|
||||||
for expr in expressions_to_update:
|
|
||||||
source_id: str = expr.get("source_id") # type: ignore
|
|
||||||
expr_type: str = expr.get("type", "style")
|
|
||||||
situation: str = expr.get("situation") # type: ignore
|
|
||||||
style: str = expr.get("style") # type: ignore
|
|
||||||
if not source_id or not situation or not style:
|
|
||||||
logger.warning(f"表达方式缺少必要字段,无法更新: {expr}")
|
|
||||||
continue
|
|
||||||
key = (source_id, expr_type, situation, style)
|
|
||||||
if key not in updates_by_key:
|
|
||||||
updates_by_key[key] = expr
|
|
||||||
for chat_id, expr_type, situation, style in updates_by_key:
|
|
||||||
query = Expression.select().where(
|
|
||||||
(Expression.chat_id == chat_id)
|
|
||||||
& (Expression.type == expr_type)
|
|
||||||
& (Expression.situation == situation)
|
|
||||||
& (Expression.style == style)
|
|
||||||
)
|
|
||||||
if query.exists():
|
|
||||||
expr_obj = query.get()
|
|
||||||
current_count = expr_obj.count
|
|
||||||
new_count = min(current_count + increment, 5.0)
|
|
||||||
expr_obj.count = new_count
|
|
||||||
expr_obj.last_active_time = time.time()
|
|
||||||
expr_obj.save()
|
|
||||||
logger.debug(
|
|
||||||
f"表达方式激活: 原count={current_count:.3f}, 增量={increment}, 新count={new_count:.3f} in db"
|
|
||||||
)
|
|
||||||
|
|
||||||
async def select_suitable_expressions_llm(
|
|
||||||
self,
|
|
||||||
chat_id: str,
|
|
||||||
chat_info: str,
|
|
||||||
max_num: int = 10,
|
|
||||||
target_message: Optional[str] = None,
|
|
||||||
) -> Tuple[str, List[Dict[str, Any]]]:
|
|
||||||
# sourcery skip: inline-variable, list-comprehension
|
|
||||||
"""使用LLM选择适合的表达方式"""
|
|
||||||
|
|
||||||
# 检查是否允许在此聊天流中使用表达
|
|
||||||
if not self.can_use_expression_for_chat(chat_id):
|
|
||||||
logger.debug(f"聊天流 {chat_id} 不允许使用表达,返回空列表")
|
|
||||||
return "", []
|
|
||||||
|
|
||||||
# 1. 获取20个随机表达方式(现在按权重抽取)
|
|
||||||
style_exprs = self.get_random_expressions(chat_id, 10)
|
|
||||||
|
|
||||||
# 2. 构建所有表达方式的索引和情境列表
|
|
||||||
all_expressions = []
|
|
||||||
all_situations = []
|
|
||||||
|
|
||||||
# 添加style表达方式
|
|
||||||
for expr in style_exprs:
|
|
||||||
if isinstance(expr, dict) and "situation" in expr and "style" in expr:
|
|
||||||
expr_with_type = expr.copy()
|
|
||||||
expr_with_type["type"] = "style"
|
|
||||||
all_expressions.append(expr_with_type)
|
|
||||||
all_situations.append(f"{len(all_expressions)}.当 {expr['situation']} 时,使用 {expr['style']}")
|
|
||||||
|
|
||||||
if not all_expressions:
|
|
||||||
logger.warning("没有找到可用的表达方式")
|
|
||||||
return "", []
|
|
||||||
|
|
||||||
all_situations_str = "\n".join(all_situations)
|
|
||||||
|
|
||||||
if target_message:
|
|
||||||
target_message_str = f",现在你想要回复消息:{target_message}"
|
|
||||||
target_message_extra_block = "4.考虑你要回复的目标消息"
|
|
||||||
else:
|
|
||||||
target_message_str = ""
|
|
||||||
target_message_extra_block = ""
|
|
||||||
|
|
||||||
# 3. 构建prompt(只包含情境,不包含完整的表达方式)
|
|
||||||
prompt = (await global_prompt_manager.get_prompt_async("expression_evaluation_prompt")).format(
|
|
||||||
bot_name=global_config.bot.nickname,
|
|
||||||
chat_observe_info=chat_info,
|
|
||||||
all_situations=all_situations_str,
|
|
||||||
max_num=max_num,
|
|
||||||
target_message=target_message_str,
|
|
||||||
target_message_extra_block=target_message_extra_block,
|
|
||||||
)
|
|
||||||
|
|
||||||
print(prompt)
|
|
||||||
|
|
||||||
# 4. 调用LLM
|
|
||||||
try:
|
|
||||||
|
|
||||||
# start_time = time.time()
|
|
||||||
content, (reasoning_content, model_name, _) = await self.llm_model.generate_response_async(prompt=prompt)
|
|
||||||
# logger.info(f"LLM请求时间: {model_name} {time.time() - start_time} \n{prompt}")
|
|
||||||
|
|
||||||
# logger.info(f"模型名称: {model_name}")
|
|
||||||
logger.info(f"LLM返回结果: {content}")
|
|
||||||
# if reasoning_content:
|
|
||||||
# logger.info(f"LLM推理: {reasoning_content}")
|
|
||||||
# else:
|
|
||||||
# logger.info(f"LLM推理: 无")
|
|
||||||
|
|
||||||
if not content:
|
|
||||||
logger.warning("LLM返回空结果")
|
|
||||||
return "", []
|
|
||||||
|
|
||||||
# 5. 解析结果
|
|
||||||
result = repair_json(content)
|
|
||||||
if isinstance(result, str):
|
|
||||||
result = json.loads(result)
|
|
||||||
|
|
||||||
if not isinstance(result, dict) or "reply_style_guide" not in result or "selected_expressions" not in result:
|
|
||||||
logger.error("LLM返回格式错误")
|
|
||||||
logger.info(f"LLM返回结果: \n{content}")
|
|
||||||
return "", []
|
|
||||||
|
|
||||||
reply_style_guide = result["reply_style_guide"]
|
|
||||||
selected_expressions = result["selected_expressions"]
|
|
||||||
|
|
||||||
# 根据索引获取完整的表达方式
|
|
||||||
valid_expressions = []
|
|
||||||
for idx in selected_expressions:
|
|
||||||
if isinstance(idx, int) and 1 <= idx <= len(all_expressions):
|
|
||||||
expression = all_expressions[idx - 1] # 索引从1开始
|
|
||||||
valid_expressions.append(expression)
|
|
||||||
|
|
||||||
# 对选中的所有表达方式,一次性更新count数
|
|
||||||
if valid_expressions:
|
|
||||||
self.update_expressions_count_batch(valid_expressions, 0.006)
|
|
||||||
|
|
||||||
# logger.info(f"LLM从{len(all_expressions)}个情境中选择了{len(valid_expressions)}个")
|
|
||||||
return reply_style_guide, valid_expressions
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"LLM处理表达方式选择时出错: {e}")
|
|
||||||
return "", []
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
init_prompt()
|
|
||||||
|
|
||||||
try:
|
|
||||||
expression_selector = ExpressionSelector()
|
|
||||||
except Exception as e:
|
|
||||||
print(f"ExpressionSelector初始化失败: {e}")
|
|
||||||
@@ -4,7 +4,7 @@ import urllib3
|
|||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
from typing import Optional, Any
|
from typing import Optional, Any, List
|
||||||
from maim_message import Seg, UserInfo, BaseMessageInfo, MessageBase
|
from maim_message import Seg, UserInfo, BaseMessageInfo, MessageBase
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
@@ -421,6 +421,7 @@ class MessageSending(MessageProcessBase):
|
|||||||
thinking_start_time: float = 0,
|
thinking_start_time: float = 0,
|
||||||
apply_set_reply_logic: bool = False,
|
apply_set_reply_logic: bool = False,
|
||||||
reply_to: Optional[str] = None,
|
reply_to: Optional[str] = None,
|
||||||
|
selected_expressions:List[int] = None,
|
||||||
):
|
):
|
||||||
# 调用父类初始化
|
# 调用父类初始化
|
||||||
super().__init__(
|
super().__init__(
|
||||||
@@ -445,6 +446,8 @@ class MessageSending(MessageProcessBase):
|
|||||||
self.display_message = display_message
|
self.display_message = display_message
|
||||||
|
|
||||||
self.interest_value = 0.0
|
self.interest_value = 0.0
|
||||||
|
|
||||||
|
self.selected_expressions = selected_expressions
|
||||||
|
|
||||||
def build_reply(self):
|
def build_reply(self):
|
||||||
"""设置回复消息"""
|
"""设置回复消息"""
|
||||||
|
|||||||
@@ -65,6 +65,7 @@ class MessageStorage:
|
|||||||
is_command = False
|
is_command = False
|
||||||
key_words = ""
|
key_words = ""
|
||||||
key_words_lite = ""
|
key_words_lite = ""
|
||||||
|
selected_expressions = message.selected_expressions
|
||||||
else:
|
else:
|
||||||
filtered_display_message = ""
|
filtered_display_message = ""
|
||||||
interest_value = message.interest_value
|
interest_value = message.interest_value
|
||||||
@@ -79,6 +80,7 @@ class MessageStorage:
|
|||||||
# 序列化关键词列表为JSON字符串
|
# 序列化关键词列表为JSON字符串
|
||||||
key_words = MessageStorage._serialize_keywords(message.key_words)
|
key_words = MessageStorage._serialize_keywords(message.key_words)
|
||||||
key_words_lite = MessageStorage._serialize_keywords(message.key_words_lite)
|
key_words_lite = MessageStorage._serialize_keywords(message.key_words_lite)
|
||||||
|
selected_expressions = ""
|
||||||
|
|
||||||
chat_info_dict = chat_stream.to_dict()
|
chat_info_dict = chat_stream.to_dict()
|
||||||
user_info_dict = message.message_info.user_info.to_dict() # type: ignore
|
user_info_dict = message.message_info.user_info.to_dict() # type: ignore
|
||||||
@@ -127,6 +129,7 @@ class MessageStorage:
|
|||||||
is_command=is_command,
|
is_command=is_command,
|
||||||
key_words=key_words,
|
key_words=key_words,
|
||||||
key_words_lite=key_words_lite,
|
key_words_lite=key_words_lite,
|
||||||
|
selected_expressions=selected_expressions,
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("存储消息失败")
|
logger.exception("存储消息失败")
|
||||||
|
|||||||
@@ -162,7 +162,7 @@ class DefaultReplyer:
|
|||||||
from_plugin: bool = True,
|
from_plugin: bool = True,
|
||||||
stream_id: Optional[str] = None,
|
stream_id: Optional[str] = None,
|
||||||
reply_message: Optional[Dict[str, Any]] = None,
|
reply_message: Optional[Dict[str, Any]] = None,
|
||||||
) -> Tuple[bool, Optional[Dict[str, Any]], Optional[str]]:
|
) -> Tuple[bool, Optional[Dict[str, Any]], Optional[str], List[Dict[str, Any]]]:
|
||||||
# sourcery skip: merge-nested-ifs
|
# sourcery skip: merge-nested-ifs
|
||||||
"""
|
"""
|
||||||
回复器 (Replier): 负责生成回复文本的核心逻辑。
|
回复器 (Replier): 负责生成回复文本的核心逻辑。
|
||||||
@@ -186,7 +186,7 @@ class DefaultReplyer:
|
|||||||
try:
|
try:
|
||||||
# 3. 构建 Prompt
|
# 3. 构建 Prompt
|
||||||
with Timer("构建Prompt", {}): # 内部计时器,可选保留
|
with Timer("构建Prompt", {}): # 内部计时器,可选保留
|
||||||
prompt = await self.build_prompt_reply_context(
|
prompt,selected_expressions = await self.build_prompt_reply_context(
|
||||||
extra_info=extra_info,
|
extra_info=extra_info,
|
||||||
available_actions=available_actions,
|
available_actions=available_actions,
|
||||||
choosen_actions=choosen_actions,
|
choosen_actions=choosen_actions,
|
||||||
@@ -197,7 +197,7 @@ class DefaultReplyer:
|
|||||||
|
|
||||||
if not prompt:
|
if not prompt:
|
||||||
logger.warning("构建prompt失败,跳过回复生成")
|
logger.warning("构建prompt失败,跳过回复生成")
|
||||||
return False, None, None
|
return False, None, None, []
|
||||||
from src.plugin_system.core.events_manager import events_manager
|
from src.plugin_system.core.events_manager import events_manager
|
||||||
|
|
||||||
if not from_plugin:
|
if not from_plugin:
|
||||||
@@ -229,16 +229,16 @@ class DefaultReplyer:
|
|||||||
except Exception as llm_e:
|
except Exception as llm_e:
|
||||||
# 精简报错信息
|
# 精简报错信息
|
||||||
logger.error(f"LLM 生成失败: {llm_e}")
|
logger.error(f"LLM 生成失败: {llm_e}")
|
||||||
return False, None, prompt # LLM 调用失败则无法生成回复
|
return False, None, prompt, selected_expressions # LLM 调用失败则无法生成回复
|
||||||
|
|
||||||
return True, llm_response, prompt
|
return True, llm_response, prompt, selected_expressions
|
||||||
|
|
||||||
except UserWarning as uw:
|
except UserWarning as uw:
|
||||||
raise uw
|
raise uw
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"回复生成意外失败: {e}")
|
logger.error(f"回复生成意外失败: {e}")
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
return False, None, prompt
|
return False, None, prompt, selected_expressions
|
||||||
|
|
||||||
async def rewrite_reply_with_context(
|
async def rewrite_reply_with_context(
|
||||||
self,
|
self,
|
||||||
@@ -302,7 +302,7 @@ class DefaultReplyer:
|
|||||||
|
|
||||||
return person.build_relationship(points_num=5)
|
return person.build_relationship(points_num=5)
|
||||||
|
|
||||||
async def build_expression_habits(self, chat_history: str, target: str) -> str:
|
async def build_expression_habits(self, chat_history: str, target: str) -> Tuple[str, List[int]]:
|
||||||
"""构建表达习惯块
|
"""构建表达习惯块
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -315,11 +315,11 @@ class DefaultReplyer:
|
|||||||
# 检查是否允许在此聊天流中使用表达
|
# 检查是否允许在此聊天流中使用表达
|
||||||
use_expression, _, _ = global_config.expression.get_expression_config_for_chat(self.chat_stream.stream_id)
|
use_expression, _, _ = global_config.expression.get_expression_config_for_chat(self.chat_stream.stream_id)
|
||||||
if not use_expression:
|
if not use_expression:
|
||||||
return ""
|
return "", []
|
||||||
style_habits = []
|
style_habits = []
|
||||||
# 使用从处理器传来的选中表达方式
|
# 使用从处理器传来的选中表达方式
|
||||||
# LLM模式:调用LLM选择5-10个,然后随机选5个
|
# LLM模式:调用LLM选择5-10个,然后随机选5个
|
||||||
selected_expressions = await expression_selector.select_suitable_expressions_llm(
|
selected_expressions, selected_ids = await expression_selector.select_suitable_expressions_llm(
|
||||||
self.chat_stream.stream_id, chat_history, max_num=8, target_message=target
|
self.chat_stream.stream_id, chat_history, max_num=8, target_message=target
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -343,7 +343,7 @@ class DefaultReplyer:
|
|||||||
)
|
)
|
||||||
expression_habits_block += f"{style_habits_str}\n"
|
expression_habits_block += f"{style_habits_str}\n"
|
||||||
|
|
||||||
return f"{expression_habits_title}\n{expression_habits_block}"
|
return f"{expression_habits_title}\n{expression_habits_block}", selected_ids
|
||||||
|
|
||||||
async def build_memory_block(self, chat_history: str, target: str) -> str:
|
async def build_memory_block(self, chat_history: str, target: str) -> str:
|
||||||
"""构建记忆块
|
"""构建记忆块
|
||||||
@@ -636,9 +636,8 @@ class DefaultReplyer:
|
|||||||
action_descriptions += f"- {action_name}: {action_description}\n"
|
action_descriptions += f"- {action_name}: {action_description}\n"
|
||||||
action_descriptions += "\n"
|
action_descriptions += "\n"
|
||||||
|
|
||||||
|
choosen_action_descriptions = ""
|
||||||
if choosen_actions:
|
if choosen_actions:
|
||||||
action_descriptions += "根据聊天情况,你决定在回复的同时做以下这些动作:\n"
|
|
||||||
|
|
||||||
for action in choosen_actions:
|
for action in choosen_actions:
|
||||||
action_name = action.get('action_type', 'unknown_action')
|
action_name = action.get('action_type', 'unknown_action')
|
||||||
if action_name =="reply":
|
if action_name =="reply":
|
||||||
@@ -646,9 +645,11 @@ class DefaultReplyer:
|
|||||||
action_description = action.get('reason', '无描述')
|
action_description = action.get('reason', '无描述')
|
||||||
reasoning = action.get('reasoning', '无原因')
|
reasoning = action.get('reasoning', '无原因')
|
||||||
|
|
||||||
|
choosen_action_descriptions += f"- {action_name}: {action_description},原因:{reasoning}\n"
|
||||||
action_descriptions += f"- {action_name}: {action_description},原因:{reasoning}\n"
|
|
||||||
|
|
||||||
|
if choosen_action_descriptions:
|
||||||
|
action_descriptions += "根据聊天情况,你决定在回复的同时做以下这些动作:\n"
|
||||||
|
action_descriptions += choosen_action_descriptions
|
||||||
|
|
||||||
return action_descriptions
|
return action_descriptions
|
||||||
|
|
||||||
@@ -661,7 +662,7 @@ class DefaultReplyer:
|
|||||||
choosen_actions: Optional[List[Dict[str, Any]]] = None,
|
choosen_actions: Optional[List[Dict[str, Any]]] = None,
|
||||||
enable_tool: bool = True,
|
enable_tool: bool = True,
|
||||||
reply_message: Optional[Dict[str, Any]] = None,
|
reply_message: Optional[Dict[str, Any]] = None,
|
||||||
) -> str:
|
) -> Tuple[str, List[int]]:
|
||||||
"""
|
"""
|
||||||
构建回复器上下文
|
构建回复器上下文
|
||||||
|
|
||||||
@@ -759,7 +760,7 @@ class DefaultReplyer:
|
|||||||
logger.warning(f"回复生成前信息获取耗时过长: {chinese_name} 耗时: {duration:.1f}s,请使用更快的模型")
|
logger.warning(f"回复生成前信息获取耗时过长: {chinese_name} 耗时: {duration:.1f}s,请使用更快的模型")
|
||||||
logger.info(f"在回复前的步骤耗时: {'; '.join(timing_logs)}")
|
logger.info(f"在回复前的步骤耗时: {'; '.join(timing_logs)}")
|
||||||
|
|
||||||
expression_habits_block = results_dict["expression_habits"]
|
expression_habits_block, selected_expressions = results_dict["expression_habits"]
|
||||||
relation_info = results_dict["relation_info"]
|
relation_info = results_dict["relation_info"]
|
||||||
memory_block = results_dict["memory_block"]
|
memory_block = results_dict["memory_block"]
|
||||||
tool_info = results_dict["tool_info"]
|
tool_info = results_dict["tool_info"]
|
||||||
@@ -831,7 +832,7 @@ class DefaultReplyer:
|
|||||||
reply_style=global_config.personality.reply_style,
|
reply_style=global_config.personality.reply_style,
|
||||||
keywords_reaction_prompt=keywords_reaction_prompt,
|
keywords_reaction_prompt=keywords_reaction_prompt,
|
||||||
moderation_prompt=moderation_prompt_block,
|
moderation_prompt=moderation_prompt_block,
|
||||||
)
|
),selected_expressions
|
||||||
else:
|
else:
|
||||||
return await global_prompt_manager.format_prompt(
|
return await global_prompt_manager.format_prompt(
|
||||||
"replyer_prompt",
|
"replyer_prompt",
|
||||||
@@ -852,7 +853,7 @@ class DefaultReplyer:
|
|||||||
reply_style=global_config.personality.reply_style,
|
reply_style=global_config.personality.reply_style,
|
||||||
keywords_reaction_prompt=keywords_reaction_prompt,
|
keywords_reaction_prompt=keywords_reaction_prompt,
|
||||||
moderation_prompt=moderation_prompt_block,
|
moderation_prompt=moderation_prompt_block,
|
||||||
)
|
),selected_expressions
|
||||||
|
|
||||||
async def build_prompt_rewrite_context(
|
async def build_prompt_rewrite_context(
|
||||||
self,
|
self,
|
||||||
@@ -860,7 +861,7 @@ class DefaultReplyer:
|
|||||||
reason: str,
|
reason: str,
|
||||||
reply_to: str,
|
reply_to: str,
|
||||||
reply_message: Optional[Dict[str, Any]] = None,
|
reply_message: Optional[Dict[str, Any]] = None,
|
||||||
) -> str: # sourcery skip: merge-else-if-into-elif, remove-redundant-if
|
) -> Tuple[str, List[int]]: # sourcery skip: merge-else-if-into-elif, remove-redundant-if
|
||||||
chat_stream = self.chat_stream
|
chat_stream = self.chat_stream
|
||||||
chat_id = chat_stream.stream_id
|
chat_id = chat_stream.stream_id
|
||||||
is_group_chat = bool(chat_stream.group_info)
|
is_group_chat = bool(chat_stream.group_info)
|
||||||
@@ -893,7 +894,7 @@ class DefaultReplyer:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 并行执行2个构建任务
|
# 并行执行2个构建任务
|
||||||
expression_habits_block, relation_info = await asyncio.gather(
|
(expression_habits_block, selected_expressions), relation_info = await asyncio.gather(
|
||||||
self.build_expression_habits(chat_talking_prompt_half, target),
|
self.build_expression_habits(chat_talking_prompt_half, target),
|
||||||
self.build_relation_info(sender, target),
|
self.build_relation_info(sender, target),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -169,6 +169,8 @@ class Messages(BaseModel):
|
|||||||
is_picid = BooleanField(default=False)
|
is_picid = BooleanField(default=False)
|
||||||
is_command = BooleanField(default=False)
|
is_command = BooleanField(default=False)
|
||||||
is_notify = BooleanField(default=False)
|
is_notify = BooleanField(default=False)
|
||||||
|
|
||||||
|
selected_expressions = TextField(null=True)
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
# database = db # 继承自 BaseModel
|
# database = db # 继承自 BaseModel
|
||||||
|
|||||||
@@ -102,7 +102,7 @@ class PromptBuilder:
|
|||||||
|
|
||||||
# 使用从处理器传来的选中表达方式
|
# 使用从处理器传来的选中表达方式
|
||||||
# LLM模式:调用LLM选择5-10个,然后随机选5个
|
# LLM模式:调用LLM选择5-10个,然后随机选5个
|
||||||
selected_expressions = await expression_selector.select_suitable_expressions_llm(
|
selected_expressions ,_ = await expression_selector.select_suitable_expressions_llm(
|
||||||
chat_stream.stream_id, chat_history, max_num=12, target_message=target
|
chat_stream.stream_id, chat_history, max_num=12, target_message=target
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -192,7 +192,7 @@ class ChatMood:
|
|||||||
|
|
||||||
class MoodRegressionTask(AsyncTask):
|
class MoodRegressionTask(AsyncTask):
|
||||||
def __init__(self, mood_manager: "MoodManager"):
|
def __init__(self, mood_manager: "MoodManager"):
|
||||||
super().__init__(task_name="MoodRegressionTask", run_interval=30)
|
super().__init__(task_name="MoodRegressionTask", run_interval=45)
|
||||||
self.mood_manager = mood_manager
|
self.mood_manager = mood_manager
|
||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
@@ -202,8 +202,8 @@ class MoodRegressionTask(AsyncTask):
|
|||||||
if mood.last_change_time == 0:
|
if mood.last_change_time == 0:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if now - mood.last_change_time > 180:
|
if now - mood.last_change_time > 200:
|
||||||
if mood.regression_count >= 3:
|
if mood.regression_count >= 2:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
logger.debug(f"{mood.log_prefix} 开始情绪回归, 第 {mood.regression_count + 1} 次")
|
logger.debug(f"{mood.log_prefix} 开始情绪回归, 第 {mood.regression_count + 1} 次")
|
||||||
|
|||||||
@@ -47,7 +47,7 @@ def init_prompt():
|
|||||||
}}
|
}}
|
||||||
]
|
]
|
||||||
|
|
||||||
如果没有,就只输出空数组:[]
|
如果没有,就只输出空json:{{}}
|
||||||
""",
|
""",
|
||||||
"relation_points",
|
"relation_points",
|
||||||
)
|
)
|
||||||
@@ -77,7 +77,7 @@ def init_prompt():
|
|||||||
"attitude": 0,
|
"attitude": 0,
|
||||||
"confidence": 0.5
|
"confidence": 0.5
|
||||||
}}
|
}}
|
||||||
如果无法看出对方对你的态度,就只输出空数组:[]
|
如果无法看出对方对你的态度,就只输出空数组:{{}}
|
||||||
|
|
||||||
现在,请你输出:
|
现在,请你输出:
|
||||||
""",
|
""",
|
||||||
@@ -111,7 +111,7 @@ def init_prompt():
|
|||||||
"neuroticism": 0,
|
"neuroticism": 0,
|
||||||
"confidence": 0.5
|
"confidence": 0.5
|
||||||
}}
|
}}
|
||||||
如果无法看出对方的神经质程度,就只输出空数组:[]
|
如果无法看出对方的神经质程度,就只输出空数组:{{}}
|
||||||
|
|
||||||
现在,请你输出:
|
现在,请你输出:
|
||||||
""",
|
""",
|
||||||
@@ -163,7 +163,7 @@ class RelationshipManager:
|
|||||||
points_data = json.loads(points)
|
points_data = json.loads(points)
|
||||||
|
|
||||||
# 只处理正确的格式,错误格式直接跳过
|
# 只处理正确的格式,错误格式直接跳过
|
||||||
if points_data == "none" or not points_data or (isinstance(points_data, str) and points_data.lower() == "none") or (isinstance(points_data, list) and len(points_data) == 0):
|
if not points_data or (isinstance(points_data, list) and len(points_data) == 0):
|
||||||
points_list = []
|
points_list = []
|
||||||
elif isinstance(points_data, list):
|
elif isinstance(points_data, list):
|
||||||
points_list = [(item["point"], float(item["weight"]), current_time) for item in points_data]
|
points_list = [(item["point"], float(item["weight"]), current_time) for item in points_data]
|
||||||
@@ -263,7 +263,7 @@ class RelationshipManager:
|
|||||||
attitude = repair_json(attitude)
|
attitude = repair_json(attitude)
|
||||||
attitude_data = json.loads(attitude)
|
attitude_data = json.loads(attitude)
|
||||||
|
|
||||||
if attitude_data == "none" or not attitude_data or (isinstance(attitude_data, str) and attitude_data.lower() == "none") or (isinstance(attitude_data, list) and len(attitude_data) == 0):
|
if not attitude_data or (isinstance(attitude_data, list) and len(attitude_data) == 0):
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
# 确保 attitude_data 是字典格式
|
# 确保 attitude_data 是字典格式
|
||||||
@@ -309,7 +309,7 @@ class RelationshipManager:
|
|||||||
neuroticism = repair_json(neuroticism)
|
neuroticism = repair_json(neuroticism)
|
||||||
neuroticism_data = json.loads(neuroticism)
|
neuroticism_data = json.loads(neuroticism)
|
||||||
|
|
||||||
if neuroticism_data == "none" or not neuroticism_data or (isinstance(neuroticism_data, str) and neuroticism_data.lower() == "none") or (isinstance(neuroticism_data, list) and len(neuroticism_data) == 0):
|
if not neuroticism_data or (isinstance(neuroticism_data, list) and len(neuroticism_data) == 0):
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
# 确保 neuroticism_data 是字典格式
|
# 确保 neuroticism_data 是字典格式
|
||||||
|
|||||||
@@ -84,7 +84,8 @@ async def generate_reply(
|
|||||||
return_prompt: bool = False,
|
return_prompt: bool = False,
|
||||||
request_type: str = "generator_api",
|
request_type: str = "generator_api",
|
||||||
from_plugin: bool = True,
|
from_plugin: bool = True,
|
||||||
) -> Tuple[bool, List[Tuple[str, Any]], Optional[str]]:
|
return_expressions: bool = False,
|
||||||
|
) -> Tuple[bool, List[Tuple[str, Any]], Optional[Tuple[str, List[Dict[str, Any]]]]]:
|
||||||
"""生成回复
|
"""生成回复
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -123,7 +124,7 @@ async def generate_reply(
|
|||||||
reply_reason = action_data.get("reason", "")
|
reply_reason = action_data.get("reason", "")
|
||||||
|
|
||||||
# 调用回复器生成回复
|
# 调用回复器生成回复
|
||||||
success, llm_response_dict, prompt = await replyer.generate_reply_with_context(
|
success, llm_response_dict, prompt, selected_expressions = await replyer.generate_reply_with_context(
|
||||||
extra_info=extra_info,
|
extra_info=extra_info,
|
||||||
available_actions=available_actions,
|
available_actions=available_actions,
|
||||||
choosen_actions=choosen_actions,
|
choosen_actions=choosen_actions,
|
||||||
@@ -144,10 +145,16 @@ async def generate_reply(
|
|||||||
logger.debug(f"[GeneratorAPI] 回复生成成功,生成了 {len(reply_set)} 个回复项")
|
logger.debug(f"[GeneratorAPI] 回复生成成功,生成了 {len(reply_set)} 个回复项")
|
||||||
|
|
||||||
if return_prompt:
|
if return_prompt:
|
||||||
return success, reply_set, prompt
|
if return_expressions:
|
||||||
|
return success, reply_set, (prompt, selected_expressions)
|
||||||
|
else:
|
||||||
|
return success, reply_set, prompt
|
||||||
else:
|
else:
|
||||||
return success, reply_set, None
|
if return_expressions:
|
||||||
|
return success, reply_set, (None, selected_expressions)
|
||||||
|
else:
|
||||||
|
return success, reply_set, None
|
||||||
|
|
||||||
except ValueError as ve:
|
except ValueError as ve:
|
||||||
raise ve
|
raise ve
|
||||||
|
|
||||||
|
|||||||
@@ -21,7 +21,7 @@
|
|||||||
|
|
||||||
import traceback
|
import traceback
|
||||||
import time
|
import time
|
||||||
from typing import Optional, Union, Dict, Any
|
from typing import Optional, Union, Dict, Any, List
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
# 导入依赖
|
# 导入依赖
|
||||||
@@ -49,6 +49,7 @@ async def _send_to_target(
|
|||||||
reply_message: Optional[Dict[str, Any]] = None,
|
reply_message: Optional[Dict[str, Any]] = None,
|
||||||
storage_message: bool = True,
|
storage_message: bool = True,
|
||||||
show_log: bool = True,
|
show_log: bool = True,
|
||||||
|
selected_expressions:List[int] = None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""向指定目标发送消息的内部实现
|
"""向指定目标发送消息的内部实现
|
||||||
|
|
||||||
@@ -121,6 +122,7 @@ async def _send_to_target(
|
|||||||
is_emoji=(message_type == "emoji"),
|
is_emoji=(message_type == "emoji"),
|
||||||
thinking_start_time=current_time,
|
thinking_start_time=current_time,
|
||||||
reply_to=reply_to_platform_id,
|
reply_to=reply_to_platform_id,
|
||||||
|
selected_expressions=selected_expressions,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 发送消息
|
# 发送消息
|
||||||
@@ -208,6 +210,7 @@ async def text_to_stream(
|
|||||||
set_reply: bool = False,
|
set_reply: bool = False,
|
||||||
reply_message: Optional[Dict[str, Any]] = None,
|
reply_message: Optional[Dict[str, Any]] = None,
|
||||||
storage_message: bool = True,
|
storage_message: bool = True,
|
||||||
|
selected_expressions:List[int] = None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""向指定流发送文本消息
|
"""向指定流发送文本消息
|
||||||
|
|
||||||
@@ -230,6 +233,7 @@ async def text_to_stream(
|
|||||||
set_reply=set_reply,
|
set_reply=set_reply,
|
||||||
reply_message=reply_message,
|
reply_message=reply_message,
|
||||||
storage_message=storage_message,
|
storage_message=storage_message,
|
||||||
|
selected_expressions=selected_expressions,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user