Merge branch 'dev' of https://github.com/MaiM-with-u/MaiBot into dev
This commit is contained in:
@@ -77,7 +77,6 @@ class DefaultExpressor:
|
||||
# TODO: API-Adapter修改标记
|
||||
self.express_model = LLMRequest(
|
||||
model=global_config.model.replyer_1,
|
||||
max_tokens=256,
|
||||
request_type="focus.expressor",
|
||||
)
|
||||
self.heart_fc_sender = HeartFCSender()
|
||||
|
||||
@@ -29,13 +29,13 @@ def init_prompt() -> None:
|
||||
4. 思考有没有特殊的梗,一并总结成语言风格
|
||||
5. 例子仅供参考,请严格根据群聊内容总结!!!
|
||||
注意:总结成如下格式的规律,总结的内容要详细,但具有概括性:
|
||||
当"xxx"时,可以"xxx", xxx不超过10个字
|
||||
当"xxxxxx"时,可以"xxxxxx", xxxxxx不超过20个字
|
||||
|
||||
例如:
|
||||
当"表示十分惊叹,有些意外"时,使用"我嘞个xxxx"
|
||||
当"对某件事表示十分惊叹,有些意外"时,使用"我嘞个xxxx"
|
||||
当"表示讽刺的赞同,不想讲道理"时,使用"对对对"
|
||||
当"想说明某个观点,但懒得明说,或者不便明说",使用"懂的都懂"
|
||||
当"表示意外的夸赞,略带戏谑意味"时,使用"这么强!"
|
||||
当"想说明某个具体的事实观点,但懒得明说,或者不便明说,或表达一种默契",使用"懂的都懂"
|
||||
当"当涉及游戏相关时,表示意外的夸赞,略带戏谑意味"时,使用"这么强!"
|
||||
|
||||
注意不要总结你自己(SELF)的发言
|
||||
现在请你概括
|
||||
@@ -70,7 +70,6 @@ class ExpressionLearner:
|
||||
self.express_learn_model: LLMRequest = LLMRequest(
|
||||
model=global_config.model.replyer_1,
|
||||
temperature=0.1,
|
||||
max_tokens=256,
|
||||
request_type="expressor.learner",
|
||||
)
|
||||
|
||||
@@ -280,6 +279,39 @@ class ExpressionLearner:
|
||||
new_expr["last_active_time"] = current_time
|
||||
old_data.append(new_expr)
|
||||
|
||||
# 处理超限问题
|
||||
if len(old_data) > MAX_EXPRESSION_COUNT:
|
||||
# 计算每个表达方式的权重(count的倒数,这样count越小的越容易被选中)
|
||||
weights = [1 / (expr.get("count", 1) + 0.1) for expr in old_data]
|
||||
|
||||
# 随机选择要移除的表达方式,避免重复索引
|
||||
remove_count = len(old_data) - MAX_EXPRESSION_COUNT
|
||||
|
||||
# 使用一种不会选到重复索引的方法
|
||||
indices = list(range(len(old_data)))
|
||||
|
||||
# 方法1:使用numpy.random.choice
|
||||
# 把列表转成一个映射字典,保证不会有重复
|
||||
remove_set = set()
|
||||
total_attempts = 0
|
||||
|
||||
# 尝试按权重随机选择,直到选够数量
|
||||
while len(remove_set) < remove_count and total_attempts < len(old_data) * 2:
|
||||
idx = random.choices(indices, weights=weights, k=1)[0]
|
||||
remove_set.add(idx)
|
||||
total_attempts += 1
|
||||
|
||||
# 如果没选够,随机补充
|
||||
if len(remove_set) < remove_count:
|
||||
remaining = set(indices) - remove_set
|
||||
remove_set.update(random.sample(list(remaining), remove_count - len(remove_set)))
|
||||
|
||||
remove_indices = list(remove_set)
|
||||
|
||||
# 从后往前删除,避免索引变化
|
||||
for idx in sorted(remove_indices, reverse=True):
|
||||
old_data.pop(idx)
|
||||
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
json.dump(old_data, f, ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
@@ -96,13 +96,14 @@ class CycleDetail:
|
||||
or "group"
|
||||
)
|
||||
|
||||
current_time_minute = time.strftime("%Y%m%d_%H%M", time.localtime())
|
||||
try:
|
||||
self.log_cycle_to_file(
|
||||
log_dir + self.prefix + f"/{current_time_minute}_cycle_" + str(self.cycle_id) + ".json"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"写入文件日志,可能是群名称包含非法字符: {e}")
|
||||
# current_time_minute = time.strftime("%Y%m%d_%H%M", time.localtime())
|
||||
|
||||
# try:
|
||||
# self.log_cycle_to_file(
|
||||
# log_dir + self.prefix + f"/{current_time_minute}_cycle_" + str(self.cycle_id) + ".json"
|
||||
# )
|
||||
# except Exception as e:
|
||||
# logger.warning(f"写入文件日志,可能是群名称包含非法字符: {e}")
|
||||
|
||||
def log_cycle_to_file(self, file_path: str):
|
||||
"""将循环信息写入文件"""
|
||||
|
||||
@@ -441,31 +441,33 @@ class HeartFChatting:
|
||||
"observations": self.observations,
|
||||
}
|
||||
|
||||
with Timer("调整动作", cycle_timers):
|
||||
# 处理特殊的观察
|
||||
await self.action_modifier.modify_actions(observations=self.observations)
|
||||
await self.action_observation.observe()
|
||||
self.observations.append(self.action_observation)
|
||||
# 根据配置决定是否并行执行调整动作、回忆和处理器阶段
|
||||
|
||||
# 根据配置决定是否并行执行回忆和处理器阶段
|
||||
# print(global_config.focus_chat.parallel_processing)
|
||||
if global_config.focus_chat.parallel_processing:
|
||||
# 并行执行回忆和处理器阶段
|
||||
with Timer("并行回忆和处理", cycle_timers):
|
||||
memory_task = asyncio.create_task(self.memory_activator.activate_memory(self.observations))
|
||||
processor_task = asyncio.create_task(self._process_processors(self.observations, []))
|
||||
|
||||
# 等待两个任务完成
|
||||
running_memorys, (all_plan_info, processor_time_costs) = await asyncio.gather(
|
||||
memory_task, processor_task
|
||||
# 并行执行调整动作、回忆和处理器阶段
|
||||
with Timer("并行调整动作、处理", cycle_timers):
|
||||
# 创建并行任务
|
||||
async def modify_actions_task():
|
||||
# 调用完整的动作修改流程
|
||||
await self.action_modifier.modify_actions(
|
||||
observations=self.observations,
|
||||
)
|
||||
else:
|
||||
# 串行执行
|
||||
with Timer("回忆", cycle_timers):
|
||||
running_memorys = await self.memory_activator.activate_memory(self.observations)
|
||||
|
||||
await self.action_observation.observe()
|
||||
self.observations.append(self.action_observation)
|
||||
return True
|
||||
|
||||
# 创建三个并行任务
|
||||
action_modify_task = asyncio.create_task(modify_actions_task())
|
||||
memory_task = asyncio.create_task(self.memory_activator.activate_memory(self.observations))
|
||||
processor_task = asyncio.create_task(self._process_processors(self.observations, []))
|
||||
|
||||
# 等待三个任务完成
|
||||
_, running_memorys, (all_plan_info, processor_time_costs) = await asyncio.gather(
|
||||
action_modify_task, memory_task, processor_task
|
||||
)
|
||||
|
||||
|
||||
|
||||
with Timer("执行 信息处理器", cycle_timers):
|
||||
all_plan_info, processor_time_costs = await self._process_processors(self.observations, running_memorys)
|
||||
|
||||
loop_processor_info = {
|
||||
"all_plan_info": all_plan_info,
|
||||
|
||||
@@ -106,7 +106,8 @@ class HeartFCSender:
|
||||
and not message.is_private_message()
|
||||
and message.reply.processed_plain_text != "[System Trigger Context]"
|
||||
):
|
||||
message.set_reply(message.reply)
|
||||
# message.set_reply(message.reply)
|
||||
message.set_reply()
|
||||
logger.debug(f"[{chat_id}] 应用 set_reply 逻辑: {message.processed_plain_text[:20]}...")
|
||||
|
||||
await message.process()
|
||||
|
||||
@@ -31,7 +31,6 @@ class ChattingInfoProcessor(BaseProcessor):
|
||||
self.model_summary = LLMRequest(
|
||||
model=global_config.model.utils_small,
|
||||
temperature=0.7,
|
||||
max_tokens=300,
|
||||
request_type="focus.observation.chat",
|
||||
)
|
||||
|
||||
@@ -64,7 +63,7 @@ class ChattingInfoProcessor(BaseProcessor):
|
||||
obs_info = ObsInfo()
|
||||
|
||||
# 改为异步任务,不阻塞主流程
|
||||
asyncio.create_task(self.chat_compress(obs))
|
||||
# asyncio.create_task(self.chat_compress(obs))
|
||||
|
||||
# 设置说话消息
|
||||
if hasattr(obs, "talking_message_str"):
|
||||
|
||||
@@ -69,7 +69,6 @@ class MindProcessor(BaseProcessor):
|
||||
|
||||
self.llm_model = LLMRequest(
|
||||
model=global_config.model.planner,
|
||||
max_tokens=800,
|
||||
request_type="focus.processor.chat_mind",
|
||||
)
|
||||
|
||||
|
||||
@@ -13,32 +13,71 @@ from typing import List, Optional
|
||||
from typing import Dict
|
||||
from src.chat.focus_chat.info.info_base import InfoBase
|
||||
from src.chat.focus_chat.info.relation_info import RelationInfo
|
||||
from json_repair import repair_json
|
||||
from src.person_info.person_info import person_info_manager
|
||||
import json
|
||||
import asyncio
|
||||
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat
|
||||
|
||||
logger = get_logger("processor")
|
||||
|
||||
|
||||
def init_prompt():
|
||||
relationship_prompt = """
|
||||
{name_block}
|
||||
|
||||
你和别人的关系信息是,请从这些信息中提取出你和别人的关系的原文:
|
||||
{relation_prompt}
|
||||
请只从上面这些信息中提取出。
|
||||
|
||||
|
||||
现在是{time_now},你正在上网,和qq群里的网友们聊天,以下是正在进行的聊天内容:
|
||||
<聊天记录>
|
||||
{chat_observe_info}
|
||||
</聊天记录>
|
||||
|
||||
现在请你根据现有的信息,总结你和群里的人的关系
|
||||
1. 根据聊天记录的需要,精简你和其他人的关系并输出
|
||||
2. 根据聊天记录,如果需要提及你和某个人的关系,请输出你和这个人之间的关系
|
||||
3. 如果没有特别需要提及的关系,就不用输出这个人的关系
|
||||
<调取记录>
|
||||
{info_cache_block}
|
||||
</调取记录>
|
||||
|
||||
输出内容平淡一些,说中文。
|
||||
请注意不要输出多余内容(包括前后缀,括号(),表情包,at或 @等 )。只输出关系内容,记得明确说明这是你的关系。
|
||||
{name_block}
|
||||
请你阅读聊天记录,查看是否需要调取某个人的信息,这个人可以是出现在聊天记录中的,也可以是记录中提到的人。
|
||||
你不同程度上认识群聊里的人,以及他们谈论到的人,你可以根据聊天记录,回忆起有关他们的信息,帮助你参与聊天
|
||||
1.你需要提供用户名,以及你想要提取的信息名称类型来进行调取
|
||||
2.你也可以完全不输出任何信息
|
||||
3.阅读调取记录,如果已经回忆过某个人的信息,请不要重复调取,除非你忘记了
|
||||
|
||||
请以json格式输出,例如:
|
||||
|
||||
{{
|
||||
"用户A": "昵称",
|
||||
"用户A": "性别",
|
||||
"用户B": "对你的态度",
|
||||
"用户C": "你和ta最近做的事",
|
||||
"用户D": "你对ta的印象",
|
||||
}}
|
||||
|
||||
|
||||
请严格按照以下输出格式,不要输出多余内容,person_name可以有多个:
|
||||
{{
|
||||
"person_name": "信息名称",
|
||||
"person_name": "信息名称",
|
||||
}}
|
||||
|
||||
"""
|
||||
Prompt(relationship_prompt, "relationship_prompt")
|
||||
|
||||
fetch_info_prompt = """
|
||||
|
||||
{name_block}
|
||||
以下是你对{person_name}的了解,请你从中提取用户的有关"{info_type}"的信息,如果用户没有相关信息,请输出none:
|
||||
<对{person_name}的总体了解>
|
||||
{person_impression}
|
||||
</对{person_name}的总体了解>
|
||||
|
||||
<你记得{person_name}最近的事>
|
||||
{points_text}
|
||||
</你记得{person_name}最近的事>
|
||||
|
||||
请严格按照以下json输出格式,不要输出多余内容:
|
||||
{{
|
||||
{info_json_str}
|
||||
}}
|
||||
"""
|
||||
Prompt(fetch_info_prompt, "fetch_info_prompt")
|
||||
|
||||
|
||||
|
||||
class RelationshipProcessor(BaseProcessor):
|
||||
@@ -48,11 +87,14 @@ class RelationshipProcessor(BaseProcessor):
|
||||
super().__init__()
|
||||
|
||||
self.subheartflow_id = subheartflow_id
|
||||
self.info_fetching_cache: List[Dict[str, any]] = []
|
||||
self.info_fetched_cache: Dict[str, Dict[str, any]] = {} # {person_id: {"info": str, "ttl": int, "start_time": float}}
|
||||
self.person_engaged_cache: List[Dict[str, any]] = [] # [{person_id: str, start_time: float, rounds: int}]
|
||||
self.grace_period_rounds = 5
|
||||
|
||||
self.llm_model = LLMRequest(
|
||||
model=global_config.model.relation,
|
||||
max_tokens=800,
|
||||
request_type="relation",
|
||||
request_type="focus.relationship",
|
||||
)
|
||||
|
||||
name = chat_manager.get_stream_name(self.subheartflow_id)
|
||||
@@ -81,84 +123,288 @@ class RelationshipProcessor(BaseProcessor):
|
||||
return [relation_info]
|
||||
|
||||
async def relation_identify(
|
||||
self, observations: Optional[List[Observation]] = None,
|
||||
self,
|
||||
observations: Optional[List[Observation]] = None,
|
||||
):
|
||||
"""
|
||||
在回复前进行思考,生成内心想法并收集工具调用结果
|
||||
|
||||
参数:
|
||||
observations: 观察信息
|
||||
|
||||
返回:
|
||||
如果return_prompt为False:
|
||||
tuple: (current_mind, past_mind) 当前想法和过去的想法列表
|
||||
如果return_prompt为True:
|
||||
tuple: (current_mind, past_mind, prompt) 当前想法、过去的想法列表和使用的prompt
|
||||
"""
|
||||
# 0. 从观察信息中提取所需数据
|
||||
# 需要兼容私聊
|
||||
|
||||
if observations is None:
|
||||
observations = []
|
||||
for observation in observations:
|
||||
if isinstance(observation, ChattingObservation):
|
||||
# 获取聊天元信息
|
||||
is_group_chat = observation.is_group_chat
|
||||
chat_target_info = observation.chat_target_info
|
||||
chat_target_name = "对方" # 私聊默认名称
|
||||
if not is_group_chat and chat_target_info:
|
||||
# 优先使用person_name,其次user_nickname,最后回退到默认值
|
||||
chat_target_name = (
|
||||
chat_target_info.get("person_name") or chat_target_info.get("user_nickname") or chat_target_name
|
||||
)
|
||||
# 获取聊天内容
|
||||
chat_observe_info = observation.get_observe_info()
|
||||
person_list = observation.person_list
|
||||
chat_observe_info = ""
|
||||
current_time = time.time()
|
||||
if observations:
|
||||
for observation in observations:
|
||||
if isinstance(observation, ChattingObservation):
|
||||
chat_observe_info = observation.get_observe_info()
|
||||
break
|
||||
|
||||
nickname_str = ""
|
||||
for nicknames in global_config.bot.alias_names:
|
||||
nickname_str += f"{nicknames},"
|
||||
name_block = f"你的名字是{global_config.bot.nickname},你的昵称有{nickname_str},有人也会用这些昵称称呼你。"
|
||||
|
||||
if is_group_chat:
|
||||
relation_prompt_init = "你对群聊里的人的印象是:\n"
|
||||
else:
|
||||
relation_prompt_init = "你对对方的印象是:\n"
|
||||
|
||||
relation_prompt = ""
|
||||
for person in person_list:
|
||||
relation_prompt += f"{await relationship_manager.build_relationship_info(person, is_id=True)}\n"
|
||||
# 1. 处理person_engaged_cache
|
||||
for record in list(self.person_engaged_cache):
|
||||
record["rounds"] += 1
|
||||
time_elapsed = current_time - record["start_time"]
|
||||
message_count = len(get_raw_msg_by_timestamp_with_chat(self.subheartflow_id, record["start_time"], current_time))
|
||||
|
||||
if relation_prompt:
|
||||
relation_prompt = relation_prompt_init + relation_prompt
|
||||
else:
|
||||
relation_prompt = relation_prompt_init + "没有特别在意的人\n"
|
||||
if (record["rounds"] > 50 or
|
||||
time_elapsed > 1800 or # 30分钟
|
||||
message_count > 75):
|
||||
logger.info(f"{self.log_prefix} 用户 {record['person_id']} 满足关系构建条件,开始构建关系。")
|
||||
asyncio.create_task(
|
||||
self.update_impression_on_cache_expiry(
|
||||
record["person_id"],
|
||||
self.subheartflow_id,
|
||||
record["start_time"],
|
||||
current_time
|
||||
)
|
||||
)
|
||||
self.person_engaged_cache.remove(record)
|
||||
|
||||
# 2. 减少info_fetched_cache中所有信息的TTL
|
||||
for person_id in list(self.info_fetched_cache.keys()):
|
||||
for info_type in list(self.info_fetched_cache[person_id].keys()):
|
||||
self.info_fetched_cache[person_id][info_type]["ttl"] -= 1
|
||||
if self.info_fetched_cache[person_id][info_type]["ttl"] <= 0:
|
||||
# 在删除前查找匹配的info_fetching_cache记录
|
||||
matched_record = None
|
||||
min_time_diff = float('inf')
|
||||
for record in self.info_fetching_cache:
|
||||
if (record["person_id"] == person_id and
|
||||
record["info_type"] == info_type and
|
||||
not record["forget"]):
|
||||
time_diff = abs(record["start_time"] - self.info_fetched_cache[person_id][info_type]["start_time"])
|
||||
if time_diff < min_time_diff:
|
||||
min_time_diff = time_diff
|
||||
matched_record = record
|
||||
|
||||
if matched_record:
|
||||
matched_record["forget"] = True
|
||||
logger.info(f"{self.log_prefix} 用户 {person_id} 的 {info_type} 信息已过期,标记为遗忘。")
|
||||
|
||||
del self.info_fetched_cache[person_id][info_type]
|
||||
if not self.info_fetched_cache[person_id]:
|
||||
del self.info_fetched_cache[person_id]
|
||||
|
||||
# 5. 为需要处理的人员准备LLM prompt
|
||||
nickname_str = ",".join(global_config.bot.alias_names)
|
||||
name_block = f"你的名字是{global_config.bot.nickname},你的昵称有{nickname_str},有人也会用这些昵称称呼你。"
|
||||
|
||||
info_cache_block = ""
|
||||
if self.info_fetching_cache:
|
||||
for info_fetching in self.info_fetching_cache:
|
||||
if info_fetching["forget"]:
|
||||
info_cache_block += f"在{time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(info_fetching['start_time']))},你回忆了[{info_fetching['person_name']}]的[{info_fetching['info_type']}],但是现在你忘记了\n"
|
||||
else:
|
||||
info_cache_block += f"在{time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(info_fetching['start_time']))},你回忆了[{info_fetching['person_name']}]的[{info_fetching['info_type']}],还记着呢\n"
|
||||
|
||||
prompt = (await global_prompt_manager.get_prompt_async("relationship_prompt")).format(
|
||||
name_block=name_block,
|
||||
relation_prompt=relation_prompt,
|
||||
time_now=time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),
|
||||
chat_observe_info=chat_observe_info,
|
||||
info_cache_block=info_cache_block,
|
||||
)
|
||||
|
||||
# print(prompt)
|
||||
|
||||
content = ""
|
||||
|
||||
try:
|
||||
logger.info(f"{self.log_prefix} 人物信息prompt: \n{prompt}\n")
|
||||
content, _ = await self.llm_model.generate_response_async(prompt=prompt)
|
||||
if not content:
|
||||
if content:
|
||||
print(f"content: {content}")
|
||||
content_json = json.loads(repair_json(content))
|
||||
|
||||
for person_name, info_type in content_json.items():
|
||||
person_id = person_info_manager.get_person_id_by_person_name(person_name)
|
||||
if person_id:
|
||||
self.info_fetching_cache.append({
|
||||
"person_id": person_id,
|
||||
"person_name": person_name,
|
||||
"info_type": info_type,
|
||||
"start_time": time.time(),
|
||||
"forget": False,
|
||||
})
|
||||
if len(self.info_fetching_cache) > 20:
|
||||
self.info_fetching_cache.pop(0)
|
||||
else:
|
||||
logger.warning(f"{self.log_prefix} 未找到用户 {person_name} 的ID,跳过调取信息。")
|
||||
|
||||
logger.info(f"{self.log_prefix} 调取用户 {person_name} 的 {info_type} 信息。")
|
||||
|
||||
self.person_engaged_cache.append({
|
||||
"person_id": person_id,
|
||||
"start_time": time.time(),
|
||||
"rounds": 0
|
||||
})
|
||||
asyncio.create_task(self.fetch_person_info(person_id, [info_type], start_time=time.time()))
|
||||
|
||||
else:
|
||||
logger.warning(f"{self.log_prefix} LLM返回空结果,关系识别失败。")
|
||||
|
||||
except Exception as e:
|
||||
# 处理总体异常
|
||||
logger.error(f"{self.log_prefix} 执行LLM请求或处理响应时出错: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
content = "关系识别过程中出现错误"
|
||||
|
||||
if content == "None":
|
||||
content = ""
|
||||
# 记录初步思考结果
|
||||
logger.info(f"{self.log_prefix} 关系识别prompt: \n{prompt}\n")
|
||||
logger.info(f"{self.log_prefix} 关系识别: {content}")
|
||||
# 7. 合并缓存和新处理的信息
|
||||
persons_infos_str = ""
|
||||
# 处理已获取到的信息
|
||||
if self.info_fetched_cache:
|
||||
for person_id in self.info_fetched_cache:
|
||||
person_infos_str = ""
|
||||
for info_type in self.info_fetched_cache[person_id]:
|
||||
person_name = self.info_fetched_cache[person_id][info_type]["person_name"]
|
||||
if not self.info_fetched_cache[person_id][info_type]["unknow"]:
|
||||
info_content = self.info_fetched_cache[person_id][info_type]["info"]
|
||||
person_infos_str += f"[{info_type}]:{info_content};"
|
||||
else:
|
||||
person_infos_str += f"你不了解{person_name}有关[{info_type}]的信息,不要胡乱回答;"
|
||||
if person_infos_str:
|
||||
persons_infos_str += f"你对 {person_name} 的了解:{person_infos_str}\n"
|
||||
|
||||
# 处理正在调取但还没有结果的项目
|
||||
pending_info_dict = {}
|
||||
for record in self.info_fetching_cache:
|
||||
if not record["forget"]:
|
||||
current_time = time.time()
|
||||
# 只处理不超过2分钟的调取请求,避免过期请求一直显示
|
||||
if current_time - record["start_time"] <= 120: # 10分钟内的请求
|
||||
person_id = record["person_id"]
|
||||
person_name = record["person_name"]
|
||||
info_type = record["info_type"]
|
||||
|
||||
# 检查是否已经在info_fetched_cache中有结果
|
||||
if (person_id in self.info_fetched_cache and
|
||||
info_type in self.info_fetched_cache[person_id]):
|
||||
continue
|
||||
|
||||
# 按人物组织正在调取的信息
|
||||
if person_name not in pending_info_dict:
|
||||
pending_info_dict[person_name] = []
|
||||
pending_info_dict[person_name].append(info_type)
|
||||
|
||||
# 添加正在调取的信息到返回字符串
|
||||
for person_name, info_types in pending_info_dict.items():
|
||||
info_types_str = "、".join(info_types)
|
||||
persons_infos_str += f"你正在识图回忆有关 {person_name} 的 {info_types_str} 信息,稍等一下再回答...\n"
|
||||
|
||||
return content
|
||||
return persons_infos_str
|
||||
|
||||
async def fetch_person_info(self, person_id: str, info_types: list[str], start_time: float):
|
||||
"""
|
||||
获取某个人的信息
|
||||
"""
|
||||
# 检查缓存中是否已存在且未过期的信息
|
||||
info_types_to_fetch = []
|
||||
|
||||
for info_type in info_types:
|
||||
if (person_id in self.info_fetched_cache and
|
||||
info_type in self.info_fetched_cache[person_id]):
|
||||
logger.info(f"{self.log_prefix} 用户 {person_id} 的 {info_type} 信息已存在且未过期,跳过调取。")
|
||||
continue
|
||||
info_types_to_fetch.append(info_type)
|
||||
|
||||
if not info_types_to_fetch:
|
||||
return
|
||||
|
||||
nickname_str = ",".join(global_config.bot.alias_names)
|
||||
name_block = f"你的名字是{global_config.bot.nickname},你的昵称有{nickname_str},有人也会用这些昵称称呼你。"
|
||||
|
||||
person_name = await person_info_manager.get_value(person_id, "person_name")
|
||||
|
||||
info_type_str = ""
|
||||
info_json_str = ""
|
||||
for info_type in info_types_to_fetch:
|
||||
info_type_str += f"{info_type},"
|
||||
info_json_str += f"\"{info_type}\": \"信息内容\","
|
||||
info_type_str = info_type_str[:-1]
|
||||
info_json_str = info_json_str[:-1]
|
||||
|
||||
person_impression = await person_info_manager.get_value(person_id, "impression")
|
||||
if not person_impression:
|
||||
impression_block = "你对ta没有什么深刻的印象"
|
||||
else:
|
||||
impression_block = f"{person_impression}"
|
||||
|
||||
|
||||
points = await person_info_manager.get_value(person_id, "points")
|
||||
|
||||
if points:
|
||||
points_text = "\n".join([
|
||||
f"{point[2]}:{point[0]}"
|
||||
for point in points
|
||||
])
|
||||
else:
|
||||
points_text = "你不记得ta最近发生了什么"
|
||||
|
||||
|
||||
prompt = (await global_prompt_manager.get_prompt_async("fetch_info_prompt")).format(
|
||||
name_block=name_block,
|
||||
info_type=info_type_str,
|
||||
person_impression=impression_block,
|
||||
person_name=person_name,
|
||||
info_json_str=info_json_str,
|
||||
points_text=points_text,
|
||||
)
|
||||
|
||||
try:
|
||||
content, _ = await self.llm_model.generate_response_async(prompt=prompt)
|
||||
|
||||
# logger.info(f"{self.log_prefix} fetch_person_info prompt: \n{prompt}\n")
|
||||
logger.info(f"{self.log_prefix} fetch_person_info 结果: {content}")
|
||||
|
||||
if content:
|
||||
try:
|
||||
content_json = json.loads(repair_json(content))
|
||||
for info_type, info_content in content_json.items():
|
||||
if info_content != "none" and info_content:
|
||||
if person_id not in self.info_fetched_cache:
|
||||
self.info_fetched_cache[person_id] = {}
|
||||
self.info_fetched_cache[person_id][info_type] = {
|
||||
"info": info_content,
|
||||
"ttl": 10,
|
||||
"start_time": start_time,
|
||||
"person_name": person_name,
|
||||
"unknow": False,
|
||||
}
|
||||
else:
|
||||
if person_id not in self.info_fetched_cache:
|
||||
self.info_fetched_cache[person_id] = {}
|
||||
|
||||
self.info_fetched_cache[person_id][info_type] = {
|
||||
"info":"unknow",
|
||||
"ttl": 10,
|
||||
"start_time": start_time,
|
||||
"person_name": person_name,
|
||||
"unknow": True,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 解析LLM返回的信息时出错: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
else:
|
||||
logger.warning(f"{self.log_prefix} LLM返回空结果,获取用户 {person_name} 的 {info_type_str} 信息失败。")
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 执行LLM请求获取用户信息时出错: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
async def update_impression_on_cache_expiry(
|
||||
self, person_id: str, chat_id: str, start_time: float, end_time: float
|
||||
):
|
||||
"""
|
||||
在缓存过期时,获取聊天记录并更新用户印象
|
||||
"""
|
||||
logger.info(f"缓存过期,开始为 {person_id} 更新印象。时间范围:{start_time} -> {end_time}")
|
||||
try:
|
||||
|
||||
|
||||
impression_messages = get_raw_msg_by_timestamp_with_chat(chat_id, start_time, end_time)
|
||||
if impression_messages:
|
||||
logger.info(f"为 {person_id} 获取到 {len(impression_messages)} 条消息用于印象更新。")
|
||||
await relationship_manager.update_person_impression(
|
||||
person_id=person_id, timestamp=end_time, bot_engaged_messages=impression_messages
|
||||
)
|
||||
else:
|
||||
logger.info(f"在指定时间范围内没有找到 {person_id} 的消息,不更新印象。")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"为 {person_id} 更新印象时发生错误: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
|
||||
init_prompt()
|
||||
|
||||
@@ -20,22 +20,27 @@ logger = get_logger("processor")
|
||||
|
||||
def init_prompt():
|
||||
indentify_prompt = """
|
||||
<聊天记录>
|
||||
{chat_observe_info}
|
||||
</聊天记录>
|
||||
|
||||
<人格>
|
||||
{name_block}
|
||||
你是一个AI,但是你伪装成了一个人类,不要让别人发现这一点
|
||||
请参考以下人格,不要被当前聊天内容中的内容误导:
|
||||
你是一个AI,但是你伪装成了一个人类,不要让别人发现这一点,你的性格是:
|
||||
{prompt_personality}
|
||||
{indentify_block}
|
||||
</人格>
|
||||
|
||||
以下是正在进行的聊天内容:
|
||||
现在是{time_now},你正在参与聊天
|
||||
{chat_observe_info}
|
||||
请区分聊天记录的内容和你稳定的人格,聊天记录是现在发生的事情,人格是你稳定的独特的特质。
|
||||
|
||||
现在请你输出对自己的描述:请严格遵守以下规则
|
||||
{name_block}
|
||||
现在请你提取你人格的关键信息,提取成一串文本:
|
||||
1. 根据聊天记录,输出与聊天记录相关的自我描述,包括人格,形象等等,对人格形象进行精简
|
||||
2. 思考有没有内容与你的描述相关
|
||||
3. 如果没有明显相关内容,请输出十几个字的简短自我描述
|
||||
|
||||
现在请输出你的自我描述,请注意不要输出多余内容(包括前后缀,括号(),表情包,at或 @等 ):
|
||||
现在请输出你的自我描述,格式是:“你是.....,你.................(描述)”
|
||||
请注意不要输出多余内容(包括前后缀,括号(),表情包,at或 @等 ):
|
||||
|
||||
"""
|
||||
Prompt(indentify_prompt, "indentify_prompt")
|
||||
@@ -51,7 +56,6 @@ class SelfProcessor(BaseProcessor):
|
||||
|
||||
self.llm_model = LLMRequest(
|
||||
model=global_config.model.relation,
|
||||
max_tokens=800,
|
||||
request_type="focus.processor.self_identify",
|
||||
)
|
||||
|
||||
|
||||
@@ -43,7 +43,6 @@ class ToolProcessor(BaseProcessor):
|
||||
self.log_prefix = f"[{subheartflow_id}:ToolExecutor] "
|
||||
self.llm_model = LLMRequest(
|
||||
model=global_config.model.focus_tool_use,
|
||||
max_tokens=500,
|
||||
request_type="focus.processor.tool",
|
||||
)
|
||||
self.structured_info = []
|
||||
|
||||
@@ -61,7 +61,6 @@ class WorkingMemoryProcessor(BaseProcessor):
|
||||
|
||||
self.llm_model = LLMRequest(
|
||||
model=global_config.model.planner,
|
||||
max_tokens=800,
|
||||
request_type="focus.processor.working_memory",
|
||||
)
|
||||
|
||||
|
||||
@@ -72,7 +72,6 @@ class MemoryActivator:
|
||||
self.summary_model = LLMRequest(
|
||||
model=global_config.model.memory_summary,
|
||||
temperature=0.7,
|
||||
max_tokens=50,
|
||||
request_type="focus.memory_activator",
|
||||
)
|
||||
self.running_memory = []
|
||||
|
||||
@@ -41,6 +41,9 @@ class ActionManager:
|
||||
|
||||
# 初始化时将默认动作加载到使用中的动作
|
||||
self._using_actions = self._default_actions.copy()
|
||||
|
||||
# 添加系统核心动作
|
||||
self._add_system_core_actions()
|
||||
|
||||
def _load_registered_actions(self) -> None:
|
||||
"""
|
||||
@@ -59,7 +62,22 @@ class ActionManager:
|
||||
action_parameters: dict[str:str] = getattr(action_class, "action_parameters", {})
|
||||
action_require: list[str] = getattr(action_class, "action_require", [])
|
||||
associated_types: list[str] = getattr(action_class, "associated_types", [])
|
||||
is_default: bool = getattr(action_class, "default", False)
|
||||
is_enabled: bool = getattr(action_class, "enable_plugin", True)
|
||||
|
||||
# 获取激活类型相关属性
|
||||
focus_activation_type: str = getattr(action_class, "focus_activation_type", "always")
|
||||
normal_activation_type: str = getattr(action_class, "normal_activation_type", "always")
|
||||
|
||||
random_probability: float = getattr(action_class, "random_activation_probability", 0.3)
|
||||
llm_judge_prompt: str = getattr(action_class, "llm_judge_prompt", "")
|
||||
activation_keywords: list[str] = getattr(action_class, "activation_keywords", [])
|
||||
keyword_case_sensitive: bool = getattr(action_class, "keyword_case_sensitive", False)
|
||||
|
||||
# 获取模式启用属性
|
||||
mode_enable: str = getattr(action_class, "mode_enable", "all")
|
||||
|
||||
# 获取并行执行属性
|
||||
parallel_action: bool = getattr(action_class, "parallel_action", False)
|
||||
|
||||
if action_name and action_description:
|
||||
# 创建动作信息字典
|
||||
@@ -68,13 +86,21 @@ class ActionManager:
|
||||
"parameters": action_parameters,
|
||||
"require": action_require,
|
||||
"associated_types": associated_types,
|
||||
"focus_activation_type": focus_activation_type,
|
||||
"normal_activation_type": normal_activation_type,
|
||||
"random_probability": random_probability,
|
||||
"llm_judge_prompt": llm_judge_prompt,
|
||||
"activation_keywords": activation_keywords,
|
||||
"keyword_case_sensitive": keyword_case_sensitive,
|
||||
"mode_enable": mode_enable,
|
||||
"parallel_action": parallel_action,
|
||||
}
|
||||
|
||||
# 添加到所有已注册的动作
|
||||
self._registered_actions[action_name] = action_info
|
||||
|
||||
# 添加到默认动作(如果是默认动作)
|
||||
if is_default:
|
||||
# 添加到默认动作(如果启用插件)
|
||||
if is_enabled:
|
||||
self._default_actions[action_name] = action_info
|
||||
|
||||
# logger.info(f"所有注册动作: {list(self._registered_actions.keys())}")
|
||||
@@ -200,9 +226,34 @@ class ActionManager:
|
||||
return self._default_actions.copy()
|
||||
|
||||
def get_using_actions(self) -> Dict[str, ActionInfo]:
|
||||
"""获取当前正在使用的动作集"""
|
||||
"""获取当前正在使用的动作集合"""
|
||||
return self._using_actions.copy()
|
||||
|
||||
def get_using_actions_for_mode(self, mode: str) -> Dict[str, ActionInfo]:
|
||||
"""
|
||||
根据聊天模式获取可用的动作集合
|
||||
|
||||
Args:
|
||||
mode: 聊天模式 ("focus", "normal", "all")
|
||||
|
||||
Returns:
|
||||
Dict[str, ActionInfo]: 在指定模式下可用的动作集合
|
||||
"""
|
||||
filtered_actions = {}
|
||||
|
||||
for action_name, action_info in self._using_actions.items():
|
||||
action_mode = action_info.get("mode_enable", "all")
|
||||
|
||||
# 检查动作是否在当前模式下启用
|
||||
if action_mode == "all" or action_mode == mode:
|
||||
filtered_actions[action_name] = action_info
|
||||
logger.debug(f"动作 {action_name} 在模式 {mode} 下可用 (mode_enable: {action_mode})")
|
||||
else:
|
||||
logger.debug(f"动作 {action_name} 在模式 {mode} 下不可用 (mode_enable: {action_mode})")
|
||||
|
||||
logger.info(f"模式 {mode} 下可用动作: {list(filtered_actions.keys())}")
|
||||
return filtered_actions
|
||||
|
||||
def add_action_to_using(self, action_name: str) -> bool:
|
||||
"""
|
||||
添加已注册的动作到当前使用的动作集
|
||||
@@ -294,6 +345,36 @@ class ActionManager:
|
||||
def restore_default_actions(self) -> None:
|
||||
"""恢复默认动作集到使用集"""
|
||||
self._using_actions = self._default_actions.copy()
|
||||
# 添加系统核心动作(即使enable_plugin为False的系统动作)
|
||||
self._add_system_core_actions()
|
||||
|
||||
def _add_system_core_actions(self) -> None:
|
||||
"""
|
||||
添加系统核心动作到使用集
|
||||
系统核心动作是那些enable_plugin为False但是系统必需的动作
|
||||
"""
|
||||
system_core_actions = ["exit_focus_chat"] # 可以根据需要扩展
|
||||
|
||||
for action_name in system_core_actions:
|
||||
if action_name in self._registered_actions and action_name not in self._using_actions:
|
||||
self._using_actions[action_name] = self._registered_actions[action_name]
|
||||
logger.info(f"添加系统核心动作到使用集: {action_name}")
|
||||
|
||||
def add_system_action_if_needed(self, action_name: str) -> bool:
|
||||
"""
|
||||
根据需要添加系统动作到使用集
|
||||
|
||||
Args:
|
||||
action_name: 动作名称
|
||||
|
||||
Returns:
|
||||
bool: 是否成功添加
|
||||
"""
|
||||
if action_name in self._registered_actions and action_name not in self._using_actions:
|
||||
self._using_actions[action_name] = self._registered_actions[action_name]
|
||||
logger.info(f"临时添加系统动作到使用集: {action_name}")
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_action(self, action_name: str) -> Optional[Type[BaseAction]]:
|
||||
"""
|
||||
|
||||
@@ -8,6 +8,18 @@ logger = get_logger("base_action")
|
||||
_ACTION_REGISTRY: Dict[str, Type["BaseAction"]] = {}
|
||||
_DEFAULT_ACTIONS: Dict[str, str] = {}
|
||||
|
||||
# 动作激活类型枚举
|
||||
class ActionActivationType:
|
||||
ALWAYS = "always" # 默认参与到planner
|
||||
LLM_JUDGE = "llm_judge" # LLM判定是否启动该action到planner
|
||||
RANDOM = "random" # 随机启用action到planner
|
||||
KEYWORD = "keyword" # 关键词触发启用action到planner
|
||||
|
||||
# 聊天模式枚举
|
||||
class ChatMode:
|
||||
FOCUS = "focus" # Focus聊天模式
|
||||
NORMAL = "normal" # Normal聊天模式
|
||||
ALL = "all" # 所有聊天模式
|
||||
|
||||
def register_action(cls):
|
||||
"""
|
||||
@@ -18,6 +30,10 @@ def register_action(cls):
|
||||
class MyAction(BaseAction):
|
||||
action_name = "my_action"
|
||||
action_description = "我的动作"
|
||||
focus_activation_type = ActionActivationType.ALWAYS
|
||||
normal_activation_type = ActionActivationType.ALWAYS
|
||||
mode_enable = ChatMode.ALL
|
||||
parallel_action = False
|
||||
...
|
||||
"""
|
||||
# 检查类是否有必要的属性
|
||||
@@ -27,7 +43,7 @@ def register_action(cls):
|
||||
|
||||
action_name = cls.action_name
|
||||
action_description = cls.action_description
|
||||
is_default = getattr(cls, "default", False)
|
||||
is_enabled = getattr(cls, "enable_plugin", True) # 默认启用插件
|
||||
|
||||
if not action_name or not action_description:
|
||||
logger.error(f"动作类 {cls.__name__} 的 action_name 或 action_description 为空")
|
||||
@@ -36,11 +52,11 @@ def register_action(cls):
|
||||
# 将动作类注册到全局注册表
|
||||
_ACTION_REGISTRY[action_name] = cls
|
||||
|
||||
# 如果是默认动作,添加到默认动作集
|
||||
if is_default:
|
||||
# 如果启用插件,添加到默认动作集
|
||||
if is_enabled:
|
||||
_DEFAULT_ACTIONS[action_name] = action_description
|
||||
|
||||
logger.info(f"已注册动作: {action_name} -> {cls.__name__},默认: {is_default}")
|
||||
logger.info(f"已注册动作: {action_name} -> {cls.__name__},插件启用: {is_enabled}")
|
||||
return cls
|
||||
|
||||
|
||||
@@ -65,10 +81,33 @@ class BaseAction(ABC):
|
||||
self.action_description: str = "基础动作"
|
||||
self.action_parameters: dict = {}
|
||||
self.action_require: list[str] = []
|
||||
|
||||
# 动作激活类型设置
|
||||
# Focus模式下的激活类型,默认为always
|
||||
self.focus_activation_type: str = ActionActivationType.ALWAYS
|
||||
# Normal模式下的激活类型,默认为always
|
||||
self.normal_activation_type: str = ActionActivationType.ALWAYS
|
||||
|
||||
# 随机激活的概率(0.0-1.0),用于RANDOM激活类型
|
||||
self.random_activation_probability: float = 0.3
|
||||
# LLM判定的提示词,用于LLM_JUDGE激活类型
|
||||
self.llm_judge_prompt: str = ""
|
||||
# 关键词触发列表,用于KEYWORD激活类型
|
||||
self.activation_keywords: list[str] = []
|
||||
# 关键词匹配是否区分大小写
|
||||
self.keyword_case_sensitive: bool = False
|
||||
|
||||
# 模式启用设置:指定在哪些聊天模式下启用此动作
|
||||
# 可选值: "focus"(仅Focus模式), "normal"(仅Normal模式), "all"(所有模式)
|
||||
self.mode_enable: str = ChatMode.ALL
|
||||
|
||||
# 并行执行设置:仅在Normal模式下生效,设置为True的动作可以与回复动作并行执行
|
||||
# 而不是替代回复动作,适用于图片生成、TTS、禁言等不需要覆盖回复的动作
|
||||
self.parallel_action: bool = False
|
||||
|
||||
self.associated_types: list[str] = []
|
||||
|
||||
self.default: bool = False
|
||||
self.enable_plugin: bool = True # 是否启用插件,默认启用
|
||||
|
||||
self.action_data = action_data
|
||||
self.reasoning = reasoning
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
from src.common.logger_manager import get_logger
|
||||
from src.chat.focus_chat.planners.actions.base_action import BaseAction, register_action
|
||||
from src.chat.focus_chat.planners.actions.base_action import BaseAction, register_action, ActionActivationType, ChatMode
|
||||
from typing import Tuple, List
|
||||
from src.chat.heart_flow.observation.observation import Observation
|
||||
from src.chat.focus_chat.replyer.default_replyer import DefaultReplyer
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
from src.chat.focus_chat.hfc_utils import create_empty_anchor_message
|
||||
from src.config.config import global_config
|
||||
|
||||
logger = get_logger("action_taken")
|
||||
|
||||
@@ -29,7 +28,25 @@ class EmojiAction(BaseAction):
|
||||
|
||||
associated_types: list[str] = ["emoji"]
|
||||
|
||||
default = True
|
||||
enable_plugin = True
|
||||
|
||||
focus_activation_type = ActionActivationType.LLM_JUDGE
|
||||
normal_activation_type = ActionActivationType.RANDOM
|
||||
|
||||
random_activation_probability = global_config.normal_chat.emoji_chance
|
||||
|
||||
parallel_action = True
|
||||
|
||||
|
||||
llm_judge_prompt = """
|
||||
判定是否需要使用表情动作的条件:
|
||||
1. 用户明确要求使用表情包
|
||||
2. 这是一个适合表达强烈情绪的场合
|
||||
3. 不要发送太多表情包,如果你已经发送过多个表情包
|
||||
"""
|
||||
|
||||
# 模式启用设置 - 表情动作只在Focus模式下使用
|
||||
mode_enable = ChatMode.ALL
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -130,4 +147,4 @@ class EmojiAction(BaseAction):
|
||||
elif type == "emoji":
|
||||
reply_text += data
|
||||
|
||||
return success, reply_text
|
||||
return success, reply_text
|
||||
@@ -1,7 +1,7 @@
|
||||
import asyncio
|
||||
import traceback
|
||||
from src.common.logger_manager import get_logger
|
||||
from src.chat.focus_chat.planners.actions.base_action import BaseAction, register_action
|
||||
from src.chat.focus_chat.planners.actions.base_action import BaseAction, register_action, ChatMode
|
||||
from typing import Tuple, List
|
||||
from src.chat.heart_flow.observation.observation import Observation
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
@@ -25,7 +25,11 @@ class ExitFocusChatAction(BaseAction):
|
||||
"当前内容不需要持续专注关注,你决定退出专注聊天",
|
||||
"聊天内容已经完成,你决定退出专注聊天",
|
||||
]
|
||||
default = False
|
||||
# 退出专注聊天是系统核心功能,不是插件,但默认不启用(需要特定条件触发)
|
||||
enable_plugin = False
|
||||
|
||||
# 模式启用设置 - 退出专注聊天动作只在Focus模式下使用
|
||||
mode_enable = ChatMode.FOCUS
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -2,7 +2,7 @@ import asyncio
|
||||
import traceback
|
||||
from src.common.logger_manager import get_logger
|
||||
from src.chat.utils.timer_calculator import Timer
|
||||
from src.chat.focus_chat.planners.actions.base_action import BaseAction, register_action
|
||||
from src.chat.focus_chat.planners.actions.base_action import BaseAction, register_action, ActionActivationType, ChatMode
|
||||
from typing import Tuple, List
|
||||
from src.chat.heart_flow.observation.observation import Observation
|
||||
from src.chat.heart_flow.observation.chatting_observation import ChattingObservation
|
||||
@@ -28,7 +28,13 @@ class NoReplyAction(BaseAction):
|
||||
"你连续发送了太多消息,且无人回复",
|
||||
"想要休息一下",
|
||||
]
|
||||
default = True
|
||||
enable_plugin = True
|
||||
|
||||
# 激活类型设置
|
||||
focus_activation_type = ActionActivationType.ALWAYS
|
||||
|
||||
# 模式启用设置 - no_reply动作只在Focus模式下使用
|
||||
mode_enable = ChatMode.FOCUS
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -1,134 +0,0 @@
|
||||
import asyncio
|
||||
import traceback
|
||||
from src.common.logger_manager import get_logger
|
||||
from src.chat.utils.timer_calculator import Timer
|
||||
from src.chat.focus_chat.planners.actions.base_action import BaseAction, register_action
|
||||
from typing import Tuple, List
|
||||
from src.chat.heart_flow.observation.observation import Observation
|
||||
from src.chat.heart_flow.observation.chatting_observation import ChattingObservation
|
||||
from src.chat.focus_chat.hfc_utils import parse_thinking_id_to_timestamp
|
||||
|
||||
logger = get_logger("action_taken")
|
||||
|
||||
# 常量定义
|
||||
WAITING_TIME_THRESHOLD = 1200 # 等待新消息时间阈值,单位秒
|
||||
|
||||
|
||||
@register_action
|
||||
class NoReplyAction(BaseAction):
|
||||
"""不回复动作处理类
|
||||
|
||||
处理决定不回复的动作。
|
||||
"""
|
||||
|
||||
action_name = "no_reply"
|
||||
action_description = "不回复"
|
||||
action_parameters = {}
|
||||
action_require = [
|
||||
"话题无关/无聊/不感兴趣/不懂",
|
||||
"聊天记录中最新一条消息是你自己发的且无人回应你",
|
||||
"你连续发送了太多消息,且无人回复",
|
||||
]
|
||||
default = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
action_data: dict,
|
||||
reasoning: str,
|
||||
cycle_timers: dict,
|
||||
thinking_id: str,
|
||||
observations: List[Observation],
|
||||
log_prefix: str,
|
||||
shutting_down: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""初始化不回复动作处理器
|
||||
|
||||
Args:
|
||||
action_name: 动作名称
|
||||
action_data: 动作数据
|
||||
reasoning: 执行该动作的理由
|
||||
cycle_timers: 计时器字典
|
||||
thinking_id: 思考ID
|
||||
observations: 观察列表
|
||||
log_prefix: 日志前缀
|
||||
shutting_down: 是否正在关闭
|
||||
"""
|
||||
super().__init__(action_data, reasoning, cycle_timers, thinking_id)
|
||||
self.observations = observations
|
||||
self.log_prefix = log_prefix
|
||||
self._shutting_down = shutting_down
|
||||
|
||||
async def handle_action(self) -> Tuple[bool, str]:
|
||||
"""
|
||||
处理不回复的情况
|
||||
|
||||
工作流程:
|
||||
1. 等待新消息、超时或关闭信号
|
||||
2. 根据等待结果更新连续不回复计数
|
||||
3. 如果达到阈值,触发回调
|
||||
|
||||
Returns:
|
||||
Tuple[bool, str]: (是否执行成功, 空字符串)
|
||||
"""
|
||||
logger.info(f"{self.log_prefix} 决定不回复: {self.reasoning}")
|
||||
|
||||
observation = self.observations[0] if self.observations else None
|
||||
|
||||
try:
|
||||
with Timer("等待新消息", self.cycle_timers):
|
||||
# 等待新消息、超时或关闭信号,并获取结果
|
||||
await self._wait_for_new_message(observation, self.thinking_id, self.log_prefix)
|
||||
|
||||
return True, "" # 不回复动作没有回复文本
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"{self.log_prefix} 处理 'no_reply' 时等待被中断 (CancelledError)")
|
||||
raise
|
||||
except Exception as e: # 捕获调用管理器或其他地方可能发生的错误
|
||||
logger.error(f"{self.log_prefix} 处理 'no_reply' 时发生错误: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
return False, ""
|
||||
|
||||
async def _wait_for_new_message(self, observation: ChattingObservation, thinking_id: str, log_prefix: str) -> bool:
|
||||
"""
|
||||
等待新消息 或 检测到关闭信号
|
||||
|
||||
参数:
|
||||
observation: 观察实例
|
||||
thinking_id: 思考ID
|
||||
log_prefix: 日志前缀
|
||||
|
||||
返回:
|
||||
bool: 是否检测到新消息 (如果因关闭信号退出则返回 False)
|
||||
"""
|
||||
wait_start_time = asyncio.get_event_loop().time()
|
||||
while True:
|
||||
# --- 在每次循环开始时检查关闭标志 ---
|
||||
if self._shutting_down:
|
||||
logger.info(f"{log_prefix} 等待新消息时检测到关闭信号,中断等待。")
|
||||
return False # 表示因为关闭而退出
|
||||
# -----------------------------------
|
||||
|
||||
thinking_id_timestamp = parse_thinking_id_to_timestamp(thinking_id)
|
||||
|
||||
# 检查新消息
|
||||
if await observation.has_new_messages_since(thinking_id_timestamp):
|
||||
logger.info(f"{log_prefix} 检测到新消息")
|
||||
return True
|
||||
|
||||
# 检查超时 (放在检查新消息和关闭之后)
|
||||
if asyncio.get_event_loop().time() - wait_start_time > WAITING_TIME_THRESHOLD:
|
||||
logger.warning(f"{log_prefix} 等待新消息超时({WAITING_TIME_THRESHOLD}秒)")
|
||||
return False
|
||||
|
||||
try:
|
||||
# 短暂休眠,让其他任务有机会运行,并能更快响应取消或关闭
|
||||
await asyncio.sleep(0.5) # 缩短休眠时间
|
||||
except asyncio.CancelledError:
|
||||
# 如果在休眠时被取消,再次检查关闭标志
|
||||
# 如果是正常关闭,则不需要警告
|
||||
if not self._shutting_down:
|
||||
logger.warning(f"{log_prefix} _wait_for_new_message 的休眠被意外取消")
|
||||
# 无论如何,重新抛出异常,让上层处理
|
||||
raise
|
||||
@@ -1,6 +1,6 @@
|
||||
import traceback
|
||||
from typing import Tuple, Dict, List, Any, Optional
|
||||
from src.chat.focus_chat.planners.actions.base_action import BaseAction, register_action # noqa F401
|
||||
from typing import Tuple, Dict, List, Any, Optional, Union, Type
|
||||
from src.chat.focus_chat.planners.actions.base_action import BaseAction, register_action, ActionActivationType, ChatMode # noqa F401
|
||||
from src.chat.heart_flow.observation.chatting_observation import ChattingObservation
|
||||
from src.chat.focus_chat.hfc_utils import create_empty_anchor_message
|
||||
from src.common.logger_manager import get_logger
|
||||
@@ -12,6 +12,9 @@ import os
|
||||
import inspect
|
||||
import toml # 导入 toml 库
|
||||
from src.common.database.database_model import ActionRecords
|
||||
from src.common.database.database import db
|
||||
from peewee import Model, DoesNotExist
|
||||
import json
|
||||
import time
|
||||
|
||||
# 以下为类型注解需要
|
||||
@@ -30,6 +33,17 @@ class PluginAction(BaseAction):
|
||||
"""
|
||||
|
||||
action_config_file_name: Optional[str] = None # 插件可以覆盖此属性来指定配置文件名
|
||||
|
||||
# 默认激活类型设置,插件可以覆盖
|
||||
focus_activation_type = ActionActivationType.ALWAYS
|
||||
normal_activation_type = ActionActivationType.ALWAYS
|
||||
random_activation_probability: float = 0.3
|
||||
llm_judge_prompt: str = ""
|
||||
activation_keywords: list[str] = []
|
||||
keyword_case_sensitive: bool = False
|
||||
|
||||
# 默认模式启用设置 - 插件动作默认在所有模式下可用,插件可以覆盖
|
||||
mode_enable = ChatMode.ALL
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -348,7 +362,6 @@ class PluginAction(BaseAction):
|
||||
self,
|
||||
prompt: str,
|
||||
model_config: Dict[str, Any],
|
||||
max_tokens: int = 2000,
|
||||
request_type: str = "plugin.generate",
|
||||
**kwargs
|
||||
) -> Tuple[bool, str]:
|
||||
@@ -372,7 +385,6 @@ class PluginAction(BaseAction):
|
||||
|
||||
llm_request = LLMRequest(
|
||||
model=model_config,
|
||||
max_tokens=max_tokens,
|
||||
request_type=request_type,
|
||||
**kwargs
|
||||
)
|
||||
@@ -436,3 +448,332 @@ class PluginAction(BaseAction):
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 存储action信息时出错: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
async def db_query(
|
||||
self,
|
||||
model_class: Type[Model],
|
||||
query_type: str = "get",
|
||||
filters: Dict[str, Any] = None,
|
||||
data: Dict[str, Any] = None,
|
||||
limit: int = None,
|
||||
order_by: List[str] = None,
|
||||
single_result: bool = False
|
||||
) -> Union[List[Dict[str, Any]], Dict[str, Any], None]:
|
||||
"""执行数据库查询操作
|
||||
|
||||
这个方法提供了一个通用接口来执行数据库操作,包括查询、创建、更新和删除记录。
|
||||
|
||||
Args:
|
||||
model_class: Peewee 模型类,例如 ActionRecords, Messages 等
|
||||
query_type: 查询类型,可选值: "get", "create", "update", "delete", "count"
|
||||
filters: 过滤条件字典,键为字段名,值为要匹配的值
|
||||
data: 用于创建或更新的数据字典
|
||||
limit: 限制结果数量
|
||||
order_by: 排序字段列表,使用字段名,前缀'-'表示降序
|
||||
single_result: 是否只返回单个结果
|
||||
|
||||
Returns:
|
||||
根据查询类型返回不同的结果:
|
||||
- "get": 返回查询结果列表或单个结果(如果 single_result=True)
|
||||
- "create": 返回创建的记录
|
||||
- "update": 返回受影响的行数
|
||||
- "delete": 返回受影响的行数
|
||||
- "count": 返回记录数量
|
||||
|
||||
示例:
|
||||
# 查询最近10条消息
|
||||
messages = await self.db_query(
|
||||
Messages,
|
||||
query_type="get",
|
||||
filters={"chat_id": chat_stream.stream_id},
|
||||
limit=10,
|
||||
order_by=["-time"]
|
||||
)
|
||||
|
||||
# 创建一条记录
|
||||
new_record = await self.db_query(
|
||||
ActionRecords,
|
||||
query_type="create",
|
||||
data={"action_id": "123", "time": time.time(), "action_name": "TestAction"}
|
||||
)
|
||||
|
||||
# 更新记录
|
||||
updated_count = await self.db_query(
|
||||
ActionRecords,
|
||||
query_type="update",
|
||||
filters={"action_id": "123"},
|
||||
data={"action_done": True}
|
||||
)
|
||||
|
||||
# 删除记录
|
||||
deleted_count = await self.db_query(
|
||||
ActionRecords,
|
||||
query_type="delete",
|
||||
filters={"action_id": "123"}
|
||||
)
|
||||
|
||||
# 计数
|
||||
count = await self.db_query(
|
||||
Messages,
|
||||
query_type="count",
|
||||
filters={"chat_id": chat_stream.stream_id}
|
||||
)
|
||||
"""
|
||||
try:
|
||||
# 构建基本查询
|
||||
if query_type in ["get", "update", "delete", "count"]:
|
||||
query = model_class.select()
|
||||
|
||||
# 应用过滤条件
|
||||
if filters:
|
||||
for field, value in filters.items():
|
||||
query = query.where(getattr(model_class, field) == value)
|
||||
|
||||
# 执行查询
|
||||
if query_type == "get":
|
||||
# 应用排序
|
||||
if order_by:
|
||||
for field in order_by:
|
||||
if field.startswith("-"):
|
||||
query = query.order_by(getattr(model_class, field[1:]).desc())
|
||||
else:
|
||||
query = query.order_by(getattr(model_class, field))
|
||||
|
||||
# 应用限制
|
||||
if limit:
|
||||
query = query.limit(limit)
|
||||
|
||||
# 执行查询
|
||||
results = list(query.dicts())
|
||||
|
||||
# 返回结果
|
||||
if single_result:
|
||||
return results[0] if results else None
|
||||
return results
|
||||
|
||||
elif query_type == "create":
|
||||
if not data:
|
||||
raise ValueError("创建记录需要提供data参数")
|
||||
|
||||
# 创建记录
|
||||
record = model_class.create(**data)
|
||||
# 返回创建的记录
|
||||
return model_class.select().where(model_class.id == record.id).dicts().get()
|
||||
|
||||
elif query_type == "update":
|
||||
if not data:
|
||||
raise ValueError("更新记录需要提供data参数")
|
||||
|
||||
# 更新记录
|
||||
return query.update(**data).execute()
|
||||
|
||||
elif query_type == "delete":
|
||||
# 删除记录
|
||||
return query.delete().execute()
|
||||
|
||||
elif query_type == "count":
|
||||
# 计数
|
||||
return query.count()
|
||||
|
||||
else:
|
||||
raise ValueError(f"不支持的查询类型: {query_type}")
|
||||
|
||||
except DoesNotExist:
|
||||
# 记录不存在
|
||||
if query_type == "get" and single_result:
|
||||
return None
|
||||
return []
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 数据库操作出错: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
# 根据查询类型返回合适的默认值
|
||||
if query_type == "get":
|
||||
return None if single_result else []
|
||||
elif query_type in ["create", "update", "delete", "count"]:
|
||||
return None
|
||||
|
||||
async def db_raw_query(
|
||||
self,
|
||||
sql: str,
|
||||
params: List[Any] = None,
|
||||
fetch_results: bool = True
|
||||
) -> Union[List[Dict[str, Any]], int, None]:
|
||||
"""执行原始SQL查询
|
||||
|
||||
警告: 使用此方法需要小心,确保SQL语句已正确构造以避免SQL注入风险。
|
||||
|
||||
Args:
|
||||
sql: 原始SQL查询字符串
|
||||
params: 查询参数列表,用于替换SQL中的占位符
|
||||
fetch_results: 是否获取查询结果,对于SELECT查询设为True,对于
|
||||
UPDATE/INSERT/DELETE等操作设为False
|
||||
|
||||
Returns:
|
||||
如果fetch_results为True,返回查询结果列表;
|
||||
如果fetch_results为False,返回受影响的行数;
|
||||
如果出错,返回None
|
||||
"""
|
||||
try:
|
||||
cursor = db.execute_sql(sql, params or [])
|
||||
|
||||
if fetch_results:
|
||||
# 获取列名
|
||||
columns = [col[0] for col in cursor.description]
|
||||
|
||||
# 构建结果字典列表
|
||||
results = []
|
||||
for row in cursor.fetchall():
|
||||
results.append(dict(zip(columns, row)))
|
||||
|
||||
return results
|
||||
else:
|
||||
# 返回受影响的行数
|
||||
return cursor.rowcount
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 执行原始SQL查询出错: {e}")
|
||||
traceback.print_exc()
|
||||
return None
|
||||
|
||||
async def db_save(
|
||||
self,
|
||||
model_class: Type[Model],
|
||||
data: Dict[str, Any],
|
||||
key_field: str = None,
|
||||
key_value: Any = None
|
||||
) -> Union[Dict[str, Any], None]:
|
||||
"""保存数据到数据库(创建或更新)
|
||||
|
||||
如果提供了key_field和key_value,会先尝试查找匹配的记录进行更新;
|
||||
如果没有找到匹配记录,或未提供key_field和key_value,则创建新记录。
|
||||
|
||||
Args:
|
||||
model_class: Peewee模型类,如ActionRecords, Messages等
|
||||
data: 要保存的数据字典
|
||||
key_field: 用于查找现有记录的字段名,例如"action_id"
|
||||
key_value: 用于查找现有记录的字段值
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 保存后的记录数据
|
||||
None: 如果操作失败
|
||||
|
||||
示例:
|
||||
# 创建或更新一条记录
|
||||
record = await self.db_save(
|
||||
ActionRecords,
|
||||
{
|
||||
"action_id": "123",
|
||||
"time": time.time(),
|
||||
"action_name": "TestAction",
|
||||
"action_done": True
|
||||
},
|
||||
key_field="action_id",
|
||||
key_value="123"
|
||||
)
|
||||
"""
|
||||
try:
|
||||
# 如果提供了key_field和key_value,尝试更新现有记录
|
||||
if key_field and key_value is not None:
|
||||
# 查找现有记录
|
||||
existing_records = list(model_class.select().where(
|
||||
getattr(model_class, key_field) == key_value
|
||||
).limit(1))
|
||||
|
||||
if existing_records:
|
||||
# 更新现有记录
|
||||
existing_record = existing_records[0]
|
||||
for field, value in data.items():
|
||||
setattr(existing_record, field, value)
|
||||
existing_record.save()
|
||||
|
||||
# 返回更新后的记录
|
||||
updated_record = model_class.select().where(
|
||||
model_class.id == existing_record.id
|
||||
).dicts().get()
|
||||
return updated_record
|
||||
|
||||
# 如果没有找到现有记录或未提供key_field和key_value,创建新记录
|
||||
new_record = model_class.create(**data)
|
||||
|
||||
# 返回创建的记录
|
||||
created_record = model_class.select().where(
|
||||
model_class.id == new_record.id
|
||||
).dicts().get()
|
||||
return created_record
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 保存数据库记录出错: {e}")
|
||||
traceback.print_exc()
|
||||
return None
|
||||
|
||||
async def db_get(
|
||||
self,
|
||||
model_class: Type[Model],
|
||||
filters: Dict[str, Any] = None,
|
||||
order_by: str = None,
|
||||
limit: int = None
|
||||
) -> Union[List[Dict[str, Any]], Dict[str, Any], None]:
|
||||
"""从数据库获取记录
|
||||
|
||||
这是db_query方法的简化版本,专注于数据检索操作。
|
||||
|
||||
Args:
|
||||
model_class: Peewee模型类
|
||||
filters: 过滤条件,字段名和值的字典
|
||||
order_by: 排序字段,前缀'-'表示降序,例如'-time'表示按时间降序
|
||||
limit: 结果数量限制,如果为1则返回单个记录而不是列表
|
||||
|
||||
Returns:
|
||||
如果limit=1,返回单个记录字典或None;
|
||||
否则返回记录字典列表或空列表。
|
||||
|
||||
示例:
|
||||
# 获取单个记录
|
||||
record = await self.db_get(
|
||||
ActionRecords,
|
||||
filters={"action_id": "123"},
|
||||
limit=1
|
||||
)
|
||||
|
||||
# 获取最近10条记录
|
||||
records = await self.db_get(
|
||||
Messages,
|
||||
filters={"chat_id": chat_stream.stream_id},
|
||||
order_by="-time",
|
||||
limit=10
|
||||
)
|
||||
"""
|
||||
try:
|
||||
# 构建查询
|
||||
query = model_class.select()
|
||||
|
||||
# 应用过滤条件
|
||||
if filters:
|
||||
for field, value in filters.items():
|
||||
query = query.where(getattr(model_class, field) == value)
|
||||
|
||||
# 应用排序
|
||||
if order_by:
|
||||
if order_by.startswith("-"):
|
||||
query = query.order_by(getattr(model_class, order_by[1:]).desc())
|
||||
else:
|
||||
query = query.order_by(getattr(model_class, order_by))
|
||||
|
||||
# 应用限制
|
||||
if limit:
|
||||
query = query.limit(limit)
|
||||
|
||||
# 执行查询
|
||||
results = list(query.dicts())
|
||||
|
||||
# 返回结果
|
||||
if limit == 1:
|
||||
return results[0] if results else None
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 获取数据库记录出错: {e}")
|
||||
traceback.print_exc()
|
||||
return None if limit == 1 else []
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
from src.common.logger_manager import get_logger
|
||||
from src.chat.focus_chat.planners.actions.base_action import BaseAction, register_action
|
||||
from src.chat.focus_chat.planners.actions.base_action import BaseAction, register_action, ActionActivationType, ChatMode
|
||||
from typing import Tuple, List
|
||||
from src.chat.heart_flow.observation.observation import Observation
|
||||
from src.chat.focus_chat.replyer.default_replyer import DefaultReplyer
|
||||
@@ -11,6 +11,7 @@ from src.chat.focus_chat.hfc_utils import create_empty_anchor_message
|
||||
import time
|
||||
import traceback
|
||||
from src.common.database.database_model import ActionRecords
|
||||
import re
|
||||
|
||||
logger = get_logger("action_taken")
|
||||
|
||||
@@ -25,16 +26,23 @@ class ReplyAction(BaseAction):
|
||||
action_name: str = "reply"
|
||||
action_description: str = "当你想要参与回复或者聊天"
|
||||
action_parameters: dict[str:str] = {
|
||||
"target": "如果你要明确回复特定某人的某句话,请在target参数中中指定那句话的原始文本(非必须,仅文本,不包含发送者)(可选)",
|
||||
"reply_to": "如果是明确回复某个人的发言,请在reply_to参数中指定,格式:(用户名:发言内容),如果不是,reply_to的值设为none"
|
||||
}
|
||||
action_require: list[str] = [
|
||||
"你想要闲聊或者随便附和",
|
||||
"有人提到你",
|
||||
"如果你刚刚进行了回复,不要对同一个话题重复回应"
|
||||
]
|
||||
|
||||
associated_types: list[str] = ["text", "emoji"]
|
||||
associated_types: list[str] = ["text"]
|
||||
|
||||
default = True
|
||||
enable_plugin = True
|
||||
|
||||
# 激活类型设置
|
||||
focus_activation_type = ActionActivationType.ALWAYS
|
||||
|
||||
# 模式启用设置 - 回复动作只在Focus模式下使用
|
||||
mode_enable = ChatMode.FOCUS
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -99,7 +107,6 @@ class ReplyAction(BaseAction):
|
||||
{
|
||||
"text": "你好啊" # 文本内容列表(可选)
|
||||
"target": "锚定消息", # 锚定消息的文本内容
|
||||
"emojis": "微笑" # 表情关键词列表(可选)
|
||||
}
|
||||
"""
|
||||
logger.info(f"{self.log_prefix} 决定回复: {self.reasoning}")
|
||||
@@ -108,19 +115,29 @@ class ReplyAction(BaseAction):
|
||||
chatting_observation: ChattingObservation = next(
|
||||
obs for obs in self.observations if isinstance(obs, ChattingObservation)
|
||||
)
|
||||
if reply_data.get("target"):
|
||||
anchor_message = chatting_observation.search_message_by_text(reply_data["target"])
|
||||
|
||||
reply_to = reply_data.get("reply_to", "none")
|
||||
|
||||
# sender = ""
|
||||
target = ""
|
||||
if ":" in reply_to or ":" in reply_to:
|
||||
# 使用正则表达式匹配中文或英文冒号
|
||||
parts = re.split(pattern=r'[::]', string=reply_to, maxsplit=1)
|
||||
if len(parts) == 2:
|
||||
# sender = parts[0].strip()
|
||||
target = parts[1].strip()
|
||||
anchor_message = chatting_observation.search_message_by_text(target)
|
||||
else:
|
||||
anchor_message = None
|
||||
|
||||
# 如果没有找到锚点消息,创建一个占位符
|
||||
if not anchor_message:
|
||||
|
||||
if anchor_message:
|
||||
anchor_message.update_chat_stream(self.chat_stream)
|
||||
else:
|
||||
logger.info(f"{self.log_prefix} 未找到锚点消息,创建占位符")
|
||||
anchor_message = await create_empty_anchor_message(
|
||||
self.chat_stream.platform, self.chat_stream.group_info, self.chat_stream
|
||||
)
|
||||
else:
|
||||
anchor_message.update_chat_stream(self.chat_stream)
|
||||
|
||||
|
||||
success, reply_set = await self.replyer.deal_reply(
|
||||
cycle_timers=cycle_timers,
|
||||
|
||||
@@ -1,12 +1,16 @@
|
||||
from typing import List, Optional, Any
|
||||
from typing import List, Optional, Any, Dict
|
||||
from src.chat.heart_flow.observation.observation import Observation
|
||||
from src.common.logger_manager import get_logger
|
||||
from src.chat.heart_flow.observation.hfcloop_observation import HFCloopObservation
|
||||
from src.chat.heart_flow.observation.chatting_observation import ChattingObservation
|
||||
from src.chat.message_receive.chat_stream import chat_manager
|
||||
from typing import Dict
|
||||
from src.config.config import global_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.chat.focus_chat.planners.actions.base_action import ActionActivationType, ChatMode
|
||||
import random
|
||||
import asyncio
|
||||
import hashlib
|
||||
import time
|
||||
from src.chat.focus_chat.planners.action_manager import ActionManager
|
||||
|
||||
logger = get_logger("action_manager")
|
||||
@@ -15,25 +19,47 @@ logger = get_logger("action_manager")
|
||||
class ActionModifier:
|
||||
"""动作处理器
|
||||
|
||||
用于处理Observation对象,将其转换为ObsInfo对象。
|
||||
用于处理Observation对象和根据激活类型处理actions。
|
||||
集成了原有的modify_actions功能和新的激活类型处理功能。
|
||||
支持并行判定和智能缓存优化。
|
||||
"""
|
||||
|
||||
log_prefix = "动作处理"
|
||||
|
||||
def __init__(self, action_manager: ActionManager):
|
||||
"""初始化观察处理器"""
|
||||
"""初始化动作处理器"""
|
||||
self.action_manager = action_manager
|
||||
self.all_actions = self.action_manager.get_registered_actions()
|
||||
self.all_actions = self.action_manager.get_using_actions_for_mode(ChatMode.FOCUS)
|
||||
|
||||
# 用于LLM判定的小模型
|
||||
self.llm_judge = LLMRequest(
|
||||
model=global_config.model.utils_small,
|
||||
request_type="action.judge",
|
||||
)
|
||||
|
||||
# 缓存相关属性
|
||||
self._llm_judge_cache = {} # 缓存LLM判定结果
|
||||
self._cache_expiry_time = 30 # 缓存过期时间(秒)
|
||||
self._last_context_hash = None # 上次上下文的哈希值
|
||||
|
||||
async def modify_actions(
|
||||
self,
|
||||
observations: Optional[List[Observation]] = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
# 处理Observation对象
|
||||
"""
|
||||
完整的动作修改流程,整合传统观察处理和新的激活类型判定
|
||||
|
||||
这个方法处理完整的动作管理流程:
|
||||
1. 基于观察的传统动作修改(循环历史分析、类型匹配等)
|
||||
2. 基于激活类型的智能动作判定,最终确定可用动作集
|
||||
|
||||
处理后,ActionManager 将包含最终的可用动作集,供规划器直接使用
|
||||
"""
|
||||
logger.debug(f"{self.log_prefix}开始完整动作修改流程")
|
||||
|
||||
# === 第一阶段:传统观察处理 ===
|
||||
if observations:
|
||||
# action_info = ActionInfo()
|
||||
# all_actions = None
|
||||
hfc_obs = None
|
||||
chat_obs = None
|
||||
|
||||
@@ -43,28 +69,31 @@ class ActionModifier:
|
||||
hfc_obs = obs
|
||||
if isinstance(obs, ChattingObservation):
|
||||
chat_obs = obs
|
||||
chat_content = obs.talking_message_str_truncate
|
||||
|
||||
# 合并所有动作变更
|
||||
merged_action_changes = {"add": [], "remove": []}
|
||||
reasons = []
|
||||
|
||||
# 处理HFCloopObservation
|
||||
# 处理HFCloopObservation - 传统的循环历史分析
|
||||
if hfc_obs:
|
||||
obs = hfc_obs
|
||||
all_actions = self.all_actions
|
||||
# 获取适用于FOCUS模式的动作
|
||||
all_actions = self.action_manager.get_using_actions_for_mode(ChatMode.FOCUS)
|
||||
action_changes = await self.analyze_loop_actions(obs)
|
||||
if action_changes["add"] or action_changes["remove"]:
|
||||
# 合并动作变更
|
||||
merged_action_changes["add"].extend(action_changes["add"])
|
||||
merged_action_changes["remove"].extend(action_changes["remove"])
|
||||
reasons.append("基于循环历史分析")
|
||||
|
||||
# 详细记录循环历史分析的变更原因
|
||||
for action_name in action_changes["add"]:
|
||||
logger.info(f"{self.log_prefix}添加动作: {action_name},原因: 循环历史分析建议添加")
|
||||
for action_name in action_changes["remove"]:
|
||||
logger.info(f"{self.log_prefix}移除动作: {action_name},原因: 循环历史分析建议移除")
|
||||
|
||||
# 收集变更原因
|
||||
# if action_changes["add"]:
|
||||
# reasons.append(f"添加动作{action_changes['add']}因为检测到大量无回复")
|
||||
# if action_changes["remove"]:
|
||||
# reasons.append(f"移除动作{action_changes['remove']}因为检测到连续回复")
|
||||
|
||||
# 处理ChattingObservation
|
||||
# 处理ChattingObservation - 传统的类型匹配检查
|
||||
if chat_obs:
|
||||
obs = chat_obs
|
||||
# 检查动作的关联类型
|
||||
@@ -76,30 +105,432 @@ class ActionModifier:
|
||||
if data.get("associated_types"):
|
||||
if not chat_context.check_types(data["associated_types"]):
|
||||
type_mismatched_actions.append(action_name)
|
||||
logger.debug(f"{self.log_prefix} 动作 {action_name} 关联类型不匹配,移除该动作")
|
||||
associated_types_str = ", ".join(data["associated_types"])
|
||||
logger.info(f"{self.log_prefix}移除动作: {action_name},原因: 关联类型不匹配(需要: {associated_types_str})")
|
||||
|
||||
if type_mismatched_actions:
|
||||
# 合并到移除列表中
|
||||
merged_action_changes["remove"].extend(type_mismatched_actions)
|
||||
reasons.append(f"移除动作{type_mismatched_actions}因为关联类型不匹配")
|
||||
reasons.append("基于关联类型检查")
|
||||
|
||||
# 应用传统的动作变更到ActionManager
|
||||
for action_name in merged_action_changes["add"]:
|
||||
if action_name in self.action_manager.get_registered_actions():
|
||||
self.action_manager.add_action_to_using(action_name)
|
||||
logger.debug(f"{self.log_prefix} 添加动作: {action_name}, 原因: {reasons}")
|
||||
logger.debug(f"{self.log_prefix}应用添加动作: {action_name},原因集合: {reasons}")
|
||||
|
||||
for action_name in merged_action_changes["remove"]:
|
||||
self.action_manager.remove_action_from_using(action_name)
|
||||
logger.debug(f"{self.log_prefix} 移除动作: {action_name}, 原因: {reasons}")
|
||||
logger.debug(f"{self.log_prefix}应用移除动作: {action_name},原因集合: {reasons}")
|
||||
|
||||
# 如果有任何动作变更,设置到action_info中
|
||||
# if merged_action_changes["add"] or merged_action_changes["remove"]:
|
||||
# action_info.set_action_changes(merged_action_changes)
|
||||
# action_info.set_reason(" | ".join(reasons))
|
||||
logger.info(f"{self.log_prefix}传统动作修改完成,当前使用动作: {list(self.action_manager.get_using_actions().keys())}")
|
||||
|
||||
# processed_infos.append(action_info)
|
||||
# === 第二阶段:激活类型判定 ===
|
||||
# 如果提供了聊天上下文,则进行激活类型判定
|
||||
if chat_content is not None:
|
||||
logger.debug(f"{self.log_prefix}开始激活类型判定阶段")
|
||||
|
||||
# 获取当前使用的动作集(经过第一阶段处理,且适用于FOCUS模式)
|
||||
current_using_actions = self.action_manager.get_using_actions()
|
||||
all_registered_actions = self.action_manager.get_using_actions_for_mode(ChatMode.FOCUS)
|
||||
|
||||
# 构建完整的动作信息
|
||||
current_actions_with_info = {}
|
||||
for action_name in current_using_actions.keys():
|
||||
if action_name in all_registered_actions:
|
||||
current_actions_with_info[action_name] = all_registered_actions[action_name]
|
||||
else:
|
||||
logger.warning(f"{self.log_prefix}使用中的动作 {action_name} 未在已注册动作中找到")
|
||||
|
||||
# 应用激活类型判定
|
||||
final_activated_actions = await self._apply_activation_type_filtering(
|
||||
current_actions_with_info,
|
||||
chat_content,
|
||||
)
|
||||
|
||||
# 更新ActionManager,移除未激活的动作
|
||||
actions_to_remove = []
|
||||
removal_reasons = {}
|
||||
|
||||
for action_name in current_using_actions.keys():
|
||||
if action_name not in final_activated_actions:
|
||||
actions_to_remove.append(action_name)
|
||||
# 确定移除原因
|
||||
if action_name in all_registered_actions:
|
||||
action_info = all_registered_actions[action_name]
|
||||
activation_type = action_info.get("focus_activation_type", ActionActivationType.ALWAYS)
|
||||
|
||||
if activation_type == ActionActivationType.RANDOM:
|
||||
probability = action_info.get("random_probability", 0.3)
|
||||
removal_reasons[action_name] = f"RANDOM类型未触发(概率{probability})"
|
||||
elif activation_type == ActionActivationType.LLM_JUDGE:
|
||||
removal_reasons[action_name] = "LLM判定未激活"
|
||||
elif activation_type == ActionActivationType.KEYWORD:
|
||||
keywords = action_info.get("activation_keywords", [])
|
||||
removal_reasons[action_name] = f"关键词未匹配(关键词: {keywords})"
|
||||
else:
|
||||
removal_reasons[action_name] = "激活判定未通过"
|
||||
else:
|
||||
removal_reasons[action_name] = "动作信息不完整"
|
||||
|
||||
for action_name in actions_to_remove:
|
||||
self.action_manager.remove_action_from_using(action_name)
|
||||
reason = removal_reasons.get(action_name, "未知原因")
|
||||
logger.info(f"{self.log_prefix}移除动作: {action_name},原因: {reason}")
|
||||
|
||||
logger.info(f"{self.log_prefix}激活类型判定完成,最终可用动作: {list(final_activated_actions.keys())}")
|
||||
|
||||
logger.info(f"{self.log_prefix}完整动作修改流程结束,最终动作集: {list(self.action_manager.get_using_actions().keys())}")
|
||||
|
||||
# return processed_infos
|
||||
async def _apply_activation_type_filtering(
|
||||
self,
|
||||
actions_with_info: Dict[str, Any],
|
||||
chat_content: str = "",
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
应用激活类型过滤逻辑,支持四种激活类型的并行处理
|
||||
|
||||
Args:
|
||||
actions_with_info: 带完整信息的动作字典
|
||||
observed_messages_str: 观察到的聊天消息
|
||||
chat_context: 聊天上下文信息
|
||||
extra_context: 额外的上下文信息
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 过滤后激活的actions字典
|
||||
"""
|
||||
activated_actions = {}
|
||||
|
||||
# 分类处理不同激活类型的actions
|
||||
always_actions = {}
|
||||
random_actions = {}
|
||||
llm_judge_actions = {}
|
||||
keyword_actions = {}
|
||||
|
||||
for action_name, action_info in actions_with_info.items():
|
||||
activation_type = action_info.get("focus_activation_type", ActionActivationType.ALWAYS)
|
||||
|
||||
if activation_type == ActionActivationType.ALWAYS:
|
||||
always_actions[action_name] = action_info
|
||||
elif activation_type == ActionActivationType.RANDOM:
|
||||
random_actions[action_name] = action_info
|
||||
elif activation_type == ActionActivationType.LLM_JUDGE:
|
||||
llm_judge_actions[action_name] = action_info
|
||||
elif activation_type == ActionActivationType.KEYWORD:
|
||||
keyword_actions[action_name] = action_info
|
||||
else:
|
||||
logger.warning(f"{self.log_prefix}未知的激活类型: {activation_type},跳过处理")
|
||||
|
||||
# 1. 处理ALWAYS类型(直接激活)
|
||||
for action_name, action_info in always_actions.items():
|
||||
activated_actions[action_name] = action_info
|
||||
logger.debug(f"{self.log_prefix}激活动作: {action_name},原因: ALWAYS类型直接激活")
|
||||
|
||||
# 2. 处理RANDOM类型
|
||||
for action_name, action_info in random_actions.items():
|
||||
probability = action_info.get("random_probability", 0.3)
|
||||
should_activate = random.random() < probability
|
||||
if should_activate:
|
||||
activated_actions[action_name] = action_info
|
||||
logger.debug(f"{self.log_prefix}激活动作: {action_name},原因: RANDOM类型触发(概率{probability})")
|
||||
else:
|
||||
logger.debug(f"{self.log_prefix}未激活动作: {action_name},原因: RANDOM类型未触发(概率{probability})")
|
||||
|
||||
# 3. 处理KEYWORD类型(快速判定)
|
||||
for action_name, action_info in keyword_actions.items():
|
||||
should_activate = self._check_keyword_activation(
|
||||
action_name,
|
||||
action_info,
|
||||
chat_content,
|
||||
)
|
||||
if should_activate:
|
||||
activated_actions[action_name] = action_info
|
||||
keywords = action_info.get("activation_keywords", [])
|
||||
logger.debug(f"{self.log_prefix}激活动作: {action_name},原因: KEYWORD类型匹配关键词({keywords})")
|
||||
else:
|
||||
keywords = action_info.get("activation_keywords", [])
|
||||
logger.debug(f"{self.log_prefix}未激活动作: {action_name},原因: KEYWORD类型未匹配关键词({keywords})")
|
||||
|
||||
# 4. 处理LLM_JUDGE类型(并行判定)
|
||||
if llm_judge_actions:
|
||||
# 直接并行处理所有LLM判定actions
|
||||
llm_results = await self._process_llm_judge_actions_parallel(
|
||||
llm_judge_actions,
|
||||
chat_content,
|
||||
)
|
||||
|
||||
# 添加激活的LLM判定actions
|
||||
for action_name, should_activate in llm_results.items():
|
||||
if should_activate:
|
||||
activated_actions[action_name] = llm_judge_actions[action_name]
|
||||
logger.debug(f"{self.log_prefix}激活动作: {action_name},原因: LLM_JUDGE类型判定通过")
|
||||
else:
|
||||
logger.debug(f"{self.log_prefix}未激活动作: {action_name},原因: LLM_JUDGE类型判定未通过")
|
||||
|
||||
logger.debug(f"{self.log_prefix}激活类型过滤完成: {list(activated_actions.keys())}")
|
||||
return activated_actions
|
||||
|
||||
async def process_actions_for_planner(
|
||||
self,
|
||||
observed_messages_str: str = "",
|
||||
chat_context: Optional[str] = None,
|
||||
extra_context: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
[已废弃] 此方法现在已被整合到 modify_actions() 中
|
||||
|
||||
为了保持向后兼容性而保留,但建议直接使用 ActionManager.get_using_actions()
|
||||
规划器应该直接从 ActionManager 获取最终的可用动作集,而不是调用此方法
|
||||
|
||||
新的架构:
|
||||
1. 主循环调用 modify_actions() 处理完整的动作管理流程
|
||||
2. 规划器直接使用 ActionManager.get_using_actions() 获取最终动作集
|
||||
"""
|
||||
logger.warning(f"{self.log_prefix}process_actions_for_planner() 已废弃,建议规划器直接使用 ActionManager.get_using_actions()")
|
||||
|
||||
# 为了向后兼容,仍然返回当前使用的动作集
|
||||
current_using_actions = self.action_manager.get_using_actions()
|
||||
all_registered_actions = self.action_manager.get_registered_actions()
|
||||
|
||||
# 构建完整的动作信息
|
||||
result = {}
|
||||
for action_name in current_using_actions.keys():
|
||||
if action_name in all_registered_actions:
|
||||
result[action_name] = all_registered_actions[action_name]
|
||||
|
||||
return result
|
||||
|
||||
def _generate_context_hash(self, chat_content: str) -> str:
|
||||
"""生成上下文的哈希值用于缓存"""
|
||||
context_content = f"{chat_content}"
|
||||
return hashlib.md5(context_content.encode('utf-8')).hexdigest()
|
||||
|
||||
|
||||
|
||||
async def _process_llm_judge_actions_parallel(
|
||||
self,
|
||||
llm_judge_actions: Dict[str, Any],
|
||||
chat_content: str = "",
|
||||
) -> Dict[str, bool]:
|
||||
"""
|
||||
并行处理LLM判定actions,支持智能缓存
|
||||
|
||||
Args:
|
||||
llm_judge_actions: 需要LLM判定的actions
|
||||
observed_messages_str: 观察到的聊天消息
|
||||
chat_context: 聊天上下文
|
||||
extra_context: 额外上下文
|
||||
|
||||
Returns:
|
||||
Dict[str, bool]: action名称到激活结果的映射
|
||||
"""
|
||||
|
||||
# 生成当前上下文的哈希值
|
||||
current_context_hash = self._generate_context_hash(chat_content)
|
||||
current_time = time.time()
|
||||
|
||||
results = {}
|
||||
tasks_to_run = {}
|
||||
|
||||
# 检查缓存
|
||||
for action_name, action_info in llm_judge_actions.items():
|
||||
cache_key = f"{action_name}_{current_context_hash}"
|
||||
|
||||
# 检查是否有有效的缓存
|
||||
if (cache_key in self._llm_judge_cache and
|
||||
current_time - self._llm_judge_cache[cache_key]["timestamp"] < self._cache_expiry_time):
|
||||
|
||||
results[action_name] = self._llm_judge_cache[cache_key]["result"]
|
||||
logger.debug(f"{self.log_prefix}使用缓存结果 {action_name}: {'激活' if results[action_name] else '未激活'}")
|
||||
else:
|
||||
# 需要进行LLM判定
|
||||
tasks_to_run[action_name] = action_info
|
||||
|
||||
# 如果有需要运行的任务,并行执行
|
||||
if tasks_to_run:
|
||||
logger.debug(f"{self.log_prefix}并行执行LLM判定,任务数: {len(tasks_to_run)}")
|
||||
|
||||
# 创建并行任务
|
||||
tasks = []
|
||||
task_names = []
|
||||
|
||||
for action_name, action_info in tasks_to_run.items():
|
||||
task = self._llm_judge_action(
|
||||
action_name,
|
||||
action_info,
|
||||
chat_content,
|
||||
)
|
||||
tasks.append(task)
|
||||
task_names.append(action_name)
|
||||
|
||||
# 并行执行所有任务
|
||||
try:
|
||||
task_results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# 处理结果并更新缓存
|
||||
for i, (action_name, result) in enumerate(zip(task_names, task_results)):
|
||||
if isinstance(result, Exception):
|
||||
logger.error(f"{self.log_prefix}LLM判定action {action_name} 时出错: {result}")
|
||||
results[action_name] = False
|
||||
else:
|
||||
results[action_name] = result
|
||||
|
||||
# 更新缓存
|
||||
cache_key = f"{action_name}_{current_context_hash}"
|
||||
self._llm_judge_cache[cache_key] = {
|
||||
"result": result,
|
||||
"timestamp": current_time
|
||||
}
|
||||
|
||||
logger.debug(f"{self.log_prefix}并行LLM判定完成,耗时: {time.time() - current_time:.2f}s")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix}并行LLM判定失败: {e}")
|
||||
# 如果并行执行失败,为所有任务返回False
|
||||
for action_name in tasks_to_run.keys():
|
||||
results[action_name] = False
|
||||
|
||||
# 清理过期缓存
|
||||
self._cleanup_expired_cache(current_time)
|
||||
|
||||
return results
|
||||
|
||||
def _cleanup_expired_cache(self, current_time: float):
|
||||
"""清理过期的缓存条目"""
|
||||
expired_keys = []
|
||||
for cache_key, cache_data in self._llm_judge_cache.items():
|
||||
if current_time - cache_data["timestamp"] > self._cache_expiry_time:
|
||||
expired_keys.append(cache_key)
|
||||
|
||||
for key in expired_keys:
|
||||
del self._llm_judge_cache[key]
|
||||
|
||||
if expired_keys:
|
||||
logger.debug(f"{self.log_prefix}清理了 {len(expired_keys)} 个过期缓存条目")
|
||||
|
||||
async def _llm_judge_action(
|
||||
self,
|
||||
action_name: str,
|
||||
action_info: Dict[str, Any],
|
||||
chat_content: str = "",
|
||||
) -> bool:
|
||||
"""
|
||||
使用LLM判定是否应该激活某个action
|
||||
|
||||
Args:
|
||||
action_name: 动作名称
|
||||
action_info: 动作信息
|
||||
observed_messages_str: 观察到的聊天消息
|
||||
chat_context: 聊天上下文
|
||||
extra_context: 额外上下文
|
||||
|
||||
Returns:
|
||||
bool: 是否应该激活此action
|
||||
"""
|
||||
|
||||
try:
|
||||
# 构建判定提示词
|
||||
action_description = action_info.get("description", "")
|
||||
action_require = action_info.get("require", [])
|
||||
custom_prompt = action_info.get("llm_judge_prompt", "")
|
||||
|
||||
|
||||
# 构建基础判定提示词
|
||||
base_prompt = f"""
|
||||
你需要判断在当前聊天情况下,是否应该激活名为"{action_name}"的动作。
|
||||
|
||||
动作描述:{action_description}
|
||||
|
||||
动作使用场景:
|
||||
"""
|
||||
for req in action_require:
|
||||
base_prompt += f"- {req}\n"
|
||||
|
||||
if custom_prompt:
|
||||
base_prompt += f"\n额外判定条件:\n{custom_prompt}\n"
|
||||
|
||||
if chat_content:
|
||||
base_prompt += f"\n当前聊天记录:\n{chat_content}\n"
|
||||
|
||||
|
||||
base_prompt += """
|
||||
请根据以上信息判断是否应该激活这个动作。
|
||||
只需要回答"是"或"否",不要有其他内容。
|
||||
"""
|
||||
|
||||
# 调用LLM进行判定
|
||||
response, _ = await self.llm_judge.generate_response_async(prompt=base_prompt)
|
||||
|
||||
# 解析响应
|
||||
response = response.strip().lower()
|
||||
|
||||
# print(base_prompt)
|
||||
print(f"LLM判定动作 {action_name}:响应='{response}'")
|
||||
|
||||
|
||||
should_activate = "是" in response or "yes" in response or "true" in response
|
||||
|
||||
logger.debug(f"{self.log_prefix}LLM判定动作 {action_name}:响应='{response}',结果={'激活' if should_activate else '不激活'}")
|
||||
return should_activate
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix}LLM判定动作 {action_name} 时出错: {e}")
|
||||
# 出错时默认不激活
|
||||
return False
|
||||
|
||||
def _check_keyword_activation(
|
||||
self,
|
||||
action_name: str,
|
||||
action_info: Dict[str, Any],
|
||||
chat_content: str = "",
|
||||
) -> bool:
|
||||
"""
|
||||
检查是否匹配关键词触发条件
|
||||
|
||||
Args:
|
||||
action_name: 动作名称
|
||||
action_info: 动作信息
|
||||
observed_messages_str: 观察到的聊天消息
|
||||
chat_context: 聊天上下文
|
||||
extra_context: 额外上下文
|
||||
|
||||
Returns:
|
||||
bool: 是否应该激活此action
|
||||
"""
|
||||
|
||||
activation_keywords = action_info.get("activation_keywords", [])
|
||||
case_sensitive = action_info.get("keyword_case_sensitive", False)
|
||||
|
||||
if not activation_keywords:
|
||||
logger.warning(f"{self.log_prefix}动作 {action_name} 设置为关键词触发但未配置关键词")
|
||||
return False
|
||||
|
||||
# 构建检索文本
|
||||
search_text = ""
|
||||
if chat_content:
|
||||
search_text += chat_content
|
||||
# if chat_context:
|
||||
# search_text += f" {chat_context}"
|
||||
# if extra_context:
|
||||
# search_text += f" {extra_context}"
|
||||
|
||||
# 如果不区分大小写,转换为小写
|
||||
if not case_sensitive:
|
||||
search_text = search_text.lower()
|
||||
|
||||
# 检查每个关键词
|
||||
matched_keywords = []
|
||||
for keyword in activation_keywords:
|
||||
check_keyword = keyword if case_sensitive else keyword.lower()
|
||||
if check_keyword in search_text:
|
||||
matched_keywords.append(keyword)
|
||||
|
||||
if matched_keywords:
|
||||
logger.debug(f"{self.log_prefix}动作 {action_name} 匹配到关键词: {matched_keywords}")
|
||||
return True
|
||||
else:
|
||||
logger.debug(f"{self.log_prefix}动作 {action_name} 未匹配到任何关键词: {activation_keywords}")
|
||||
return False
|
||||
|
||||
async def analyze_loop_actions(self, obs: HFCloopObservation) -> Dict[str, List[str]]:
|
||||
"""分析最近的循环内容并决定动作的增减
|
||||
@@ -129,8 +560,6 @@ class ActionModifier:
|
||||
reply_sequence.append(action_type == "reply")
|
||||
|
||||
# 检查no_reply比例
|
||||
# print(f"no_reply_count: {no_reply_count}, len(recent_cycles): {len(recent_cycles)}")
|
||||
# print(1111111111111111111111111111111111111111111111111111111111111111111111111111111111111111)
|
||||
if len(recent_cycles) >= (5 * global_config.chat.exit_focus_threshold) and (
|
||||
no_reply_count / len(recent_cycles)
|
||||
) >= (0.8 * global_config.chat.exit_focus_threshold):
|
||||
@@ -138,6 +567,8 @@ class ActionModifier:
|
||||
result["add"].append("exit_focus_chat")
|
||||
result["remove"].append("no_reply")
|
||||
result["remove"].append("reply")
|
||||
no_reply_ratio = no_reply_count / len(recent_cycles)
|
||||
logger.info(f"{self.log_prefix}检测到高no_reply比例: {no_reply_ratio:.2f},达到退出聊天阈值,将添加exit_focus_chat并移除no_reply/reply动作")
|
||||
|
||||
# 计算连续回复的相关阈值
|
||||
|
||||
@@ -162,34 +593,37 @@ class ActionModifier:
|
||||
if len(last_max_reply_num) >= max_reply_num and all(last_max_reply_num):
|
||||
# 如果最近max_reply_num次都是reply,直接移除
|
||||
result["remove"].append("reply")
|
||||
reply_count = len(last_max_reply_num) - no_reply_count
|
||||
logger.info(
|
||||
f"最近{len(last_max_reply_num)}次回复中,有{no_reply_count}次no_reply,{len(last_max_reply_num) - no_reply_count}次reply,直接移除"
|
||||
f"{self.log_prefix}移除reply动作,原因: 连续回复过多(最近{len(last_max_reply_num)}次全是reply,超过阈值{max_reply_num})"
|
||||
)
|
||||
elif len(last_max_reply_num) >= sec_thres_reply_num and all(last_max_reply_num[-sec_thres_reply_num:]):
|
||||
# 如果最近sec_thres_reply_num次都是reply,40%概率移除
|
||||
if random.random() < 0.4 / global_config.focus_chat.consecutive_replies:
|
||||
removal_probability = 0.4 / global_config.focus_chat.consecutive_replies
|
||||
if random.random() < removal_probability:
|
||||
result["remove"].append("reply")
|
||||
logger.info(
|
||||
f"最近{len(last_max_reply_num)}次回复中,有{no_reply_count}次no_reply,{len(last_max_reply_num) - no_reply_count}次reply,{0.4 / global_config.focus_chat.consecutive_replies}概率移除,移除"
|
||||
f"{self.log_prefix}移除reply动作,原因: 连续回复较多(最近{sec_thres_reply_num}次全是reply,{removal_probability:.2f}概率移除,触发移除)"
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
f"最近{len(last_max_reply_num)}次回复中,有{no_reply_count}次no_reply,{len(last_max_reply_num) - no_reply_count}次reply,{0.4 / global_config.focus_chat.consecutive_replies}概率移除,不移除"
|
||||
f"{self.log_prefix}连续回复检测:最近{sec_thres_reply_num}次全是reply,{removal_probability:.2f}概率移除,未触发"
|
||||
)
|
||||
elif len(last_max_reply_num) >= one_thres_reply_num and all(last_max_reply_num[-one_thres_reply_num:]):
|
||||
# 如果最近one_thres_reply_num次都是reply,20%概率移除
|
||||
if random.random() < 0.2 / global_config.focus_chat.consecutive_replies:
|
||||
removal_probability = 0.2 / global_config.focus_chat.consecutive_replies
|
||||
if random.random() < removal_probability:
|
||||
result["remove"].append("reply")
|
||||
logger.info(
|
||||
f"最近{len(last_max_reply_num)}次回复中,有{no_reply_count}次no_reply,{len(last_max_reply_num) - no_reply_count}次reply,{0.2 / global_config.focus_chat.consecutive_replies}概率移除,移除"
|
||||
f"{self.log_prefix}移除reply动作,原因: 连续回复检测(最近{one_thres_reply_num}次全是reply,{removal_probability:.2f}概率移除,触发移除)"
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
f"最近{len(last_max_reply_num)}次回复中,有{no_reply_count}次no_reply,{len(last_max_reply_num) - no_reply_count}次reply,{0.2 / global_config.focus_chat.consecutive_replies}概率移除,不移除"
|
||||
f"{self.log_prefix}连续回复检测:最近{one_thres_reply_num}次全是reply,{removal_probability:.2f}概率移除,未触发"
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
f"最近{len(last_max_reply_num)}次回复中,有{no_reply_count}次no_reply,{len(last_max_reply_num) - no_reply_count}次reply,无需移除"
|
||||
f"{self.log_prefix}连续回复检测:无需移除reply动作,最近回复模式正常"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@@ -15,6 +15,8 @@ from src.common.logger_manager import get_logger
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.individuality.individuality import individuality
|
||||
from src.chat.focus_chat.planners.action_manager import ActionManager
|
||||
from src.chat.focus_chat.planners.modify_actions import ActionModifier
|
||||
from src.chat.focus_chat.planners.actions.base_action import ChatMode
|
||||
from json_repair import repair_json
|
||||
from src.chat.focus_chat.planners.base_planner import BasePlanner
|
||||
from datetime import datetime
|
||||
@@ -31,8 +33,6 @@ def init_prompt():
|
||||
{self_info_block}
|
||||
请记住你的性格,身份和特点。
|
||||
|
||||
{relation_info_block}
|
||||
|
||||
{extra_info_block}
|
||||
{memory_str}
|
||||
|
||||
@@ -42,6 +42,8 @@ def init_prompt():
|
||||
|
||||
{chat_content_block}
|
||||
|
||||
{relation_info_block}
|
||||
|
||||
{cycle_info_block}
|
||||
|
||||
{moderation_prompt}
|
||||
@@ -141,8 +143,19 @@ class ActionPlanner(BasePlanner):
|
||||
# elif not isinstance(info, ActionInfo): # 跳过已处理的ActionInfo
|
||||
# extra_info.append(info.get_processed_info())
|
||||
|
||||
# 获取当前可用的动作
|
||||
current_available_actions = self.action_manager.get_using_actions()
|
||||
# 获取经过modify_actions处理后的最终可用动作集
|
||||
# 注意:动作的激活判定现在在主循环的modify_actions中完成
|
||||
# 使用Focus模式过滤动作
|
||||
current_available_actions_dict = self.action_manager.get_using_actions_for_mode(ChatMode.FOCUS)
|
||||
|
||||
# 获取完整的动作信息
|
||||
all_registered_actions = self.action_manager.get_registered_actions()
|
||||
current_available_actions = {}
|
||||
for action_name in current_available_actions_dict.keys():
|
||||
if action_name in all_registered_actions:
|
||||
current_available_actions[action_name] = all_registered_actions[action_name]
|
||||
else:
|
||||
logger.warning(f"{self.log_prefix}使用中的动作 {action_name} 未在已注册动作中找到")
|
||||
|
||||
# 如果没有可用动作或只有no_reply动作,直接返回no_reply
|
||||
if not current_available_actions or (
|
||||
@@ -181,7 +194,7 @@ class ActionPlanner(BasePlanner):
|
||||
prompt = f"{prompt}"
|
||||
llm_content, (reasoning_content, _) = await self.planner_llm.generate_response_async(prompt=prompt)
|
||||
|
||||
logger.info(f"{self.log_prefix}规划器原始提示词: {prompt}")
|
||||
# logger.info(f"{self.log_prefix}规划器原始提示词: {prompt}")
|
||||
logger.info(f"{self.log_prefix}规划器原始响应: {llm_content}")
|
||||
logger.info(f"{self.log_prefix}规划器推理: {reasoning_content}")
|
||||
|
||||
@@ -225,7 +238,10 @@ class ActionPlanner(BasePlanner):
|
||||
extra_info_block = ""
|
||||
|
||||
action_data["extra_info_block"] = extra_info_block
|
||||
|
||||
|
||||
if relation_info:
|
||||
action_data["relation_info_block"] = relation_info
|
||||
|
||||
# 对于reply动作不需要额外处理,因为相关字段已经在上面的循环中添加到action_data
|
||||
|
||||
if extracted_action not in current_available_actions:
|
||||
|
||||
@@ -23,6 +23,9 @@ from src.chat.focus_chat.expressors.exprssion_learner import expression_learner
|
||||
import random
|
||||
from datetime import datetime
|
||||
import re
|
||||
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
import numpy as np
|
||||
|
||||
logger = get_logger("replyer")
|
||||
|
||||
@@ -32,19 +35,19 @@ def init_prompt():
|
||||
"""
|
||||
你可以参考以下的语言习惯,如果情景合适就使用,不要盲目使用,不要生硬使用,而是结合到表达中:
|
||||
{style_habbits}
|
||||
|
||||
请你根据情景使用以下句法:
|
||||
{grammar_habbits}
|
||||
|
||||
{extra_info_block}
|
||||
|
||||
{relation_info_block}
|
||||
|
||||
{time_block}
|
||||
你现在正在群里聊天,以下是群里正在进行的聊天内容:
|
||||
{chat_info}
|
||||
|
||||
以上是聊天内容,你需要了解聊天记录中的内容
|
||||
|
||||
{chat_target}
|
||||
{identity},在这聊天中,"{target_message}"引起了你的注意,你想要在群里发言或者回复这条消息。
|
||||
{chat_info}
|
||||
{reply_target_block}
|
||||
{identity}
|
||||
你需要使用合适的语言习惯和句法,参考聊天内容,组织一条日常且口语化的回复。注意不要复读你说过的话。
|
||||
{config_expression_style},请注意不要输出多余内容(包括前后缀,冒号和引号,括号(),表情包,at或 @等 )。只输出回复内容。
|
||||
{keywords_reaction_prompt}
|
||||
@@ -57,20 +60,17 @@ def init_prompt():
|
||||
|
||||
Prompt(
|
||||
"""
|
||||
{extra_info_block}
|
||||
|
||||
{time_block}
|
||||
你现在正在聊天,以下是你和对方正在进行的聊天内容:
|
||||
{chat_info}
|
||||
|
||||
以上是聊天内容,你需要了解聊天记录中的内容
|
||||
|
||||
{chat_target}
|
||||
{identity},在这聊天中,"{target_message}"引起了你的注意,你想要发言或者回复这条消息。
|
||||
你需要使用合适的语法和句法,参考聊天内容,组织一条日常且口语化的回复。注意不要复读你说过的话。
|
||||
你可以参考以下的语言习惯和句法,如果情景合适就使用,不要盲目使用,不要生硬使用,而是结合到表达中:
|
||||
{style_habbits}
|
||||
{grammar_habbits}
|
||||
{extra_info_block}
|
||||
{time_block}
|
||||
{chat_target}
|
||||
{chat_info}
|
||||
现在"{sender_name}"说的:{target_message}。引起了你的注意,你想要发言或者回复这条消息。
|
||||
{identity},
|
||||
你需要使用合适的语法和句法,参考聊天内容,组织一条日常且口语化的回复。注意不要复读你说过的话。
|
||||
你可以参考以下的语言习惯和句法,如果情景合适就使用,不要盲目使用,不要生硬使用,而是结合到表达中:
|
||||
|
||||
|
||||
{config_expression_style},请注意不要输出多余内容(包括前后缀,冒号和引号,括号(),表情包,at或 @等 )。只输出回复内容。
|
||||
{keywords_reaction_prompt}
|
||||
@@ -88,8 +88,7 @@ class DefaultReplyer:
|
||||
# TODO: API-Adapter修改标记
|
||||
self.express_model = LLMRequest(
|
||||
model=global_config.model.replyer_1,
|
||||
max_tokens=256,
|
||||
request_type="focus.expressor",
|
||||
request_type="focus.replyer",
|
||||
)
|
||||
self.heart_fc_sender = HeartFCSender()
|
||||
|
||||
@@ -151,12 +150,6 @@ class DefaultReplyer:
|
||||
action_data=action_data,
|
||||
)
|
||||
|
||||
# with Timer("选择表情", cycle_timers):
|
||||
# emoji_keyword = action_data.get("emojis", [])
|
||||
# emoji_base64 = await self._choose_emoji(emoji_keyword)
|
||||
# if emoji_base64:
|
||||
# reply.append(("emoji", emoji_base64))
|
||||
|
||||
if reply:
|
||||
with Timer("发送消息", cycle_timers):
|
||||
sent_msg_list = await self.send_response_messages(
|
||||
@@ -247,22 +240,22 @@ class DefaultReplyer:
|
||||
|
||||
# 2. 获取信息捕捉器
|
||||
info_catcher = info_catcher_manager.get_info_catcher(thinking_id)
|
||||
|
||||
# --- Determine sender_name for private chat ---
|
||||
sender_name_for_prompt = "某人" # Default for group or if info unavailable
|
||||
if not self.is_group_chat and self.chat_target_info:
|
||||
# Prioritize person_name, then nickname
|
||||
sender_name_for_prompt = (
|
||||
self.chat_target_info.get("person_name")
|
||||
or self.chat_target_info.get("user_nickname")
|
||||
or sender_name_for_prompt
|
||||
)
|
||||
# --- End determining sender_name ---
|
||||
|
||||
target_message = action_data.get("target", "")
|
||||
|
||||
reply_to = action_data.get("reply_to", "none")
|
||||
|
||||
sender = ""
|
||||
targer = ""
|
||||
if ":" in reply_to or ":" in reply_to:
|
||||
# 使用正则表达式匹配中文或英文冒号
|
||||
parts = re.split(pattern=r'[::]', string=reply_to, maxsplit=1)
|
||||
if len(parts) == 2:
|
||||
sender = parts[0].strip()
|
||||
targer = parts[1].strip()
|
||||
|
||||
identity = action_data.get("identity", "")
|
||||
extra_info_block = action_data.get("extra_info_block", "")
|
||||
|
||||
relation_info_block = action_data.get("relation_info_block", "")
|
||||
|
||||
# 3. 构建 Prompt
|
||||
with Timer("构建Prompt", {}): # 内部计时器,可选保留
|
||||
prompt = await self.build_prompt_focus(
|
||||
@@ -270,9 +263,10 @@ class DefaultReplyer:
|
||||
# in_mind_reply=in_mind_reply,
|
||||
identity=identity,
|
||||
extra_info_block=extra_info_block,
|
||||
relation_info_block=relation_info_block,
|
||||
reason=reason,
|
||||
sender_name=sender_name_for_prompt, # Pass determined name
|
||||
target_message=target_message,
|
||||
sender_name=sender, # Pass determined name
|
||||
target_message=targer,
|
||||
config_expression_style=global_config.expression.expression_style,
|
||||
)
|
||||
|
||||
@@ -286,8 +280,7 @@ class DefaultReplyer:
|
||||
|
||||
try:
|
||||
with Timer("LLM生成", {}): # 内部计时器,可选保留
|
||||
# TODO: API-Adapter修改标记
|
||||
# logger.info(f"{self.log_prefix}[Replier-{thinking_id}]\nPrompt:\n{prompt}\n")
|
||||
logger.info(f"{self.log_prefix}Prompt:\n{prompt}\n")
|
||||
content, (reasoning_content, model_name) = await self.express_model.generate_response_async(prompt)
|
||||
|
||||
# logger.info(f"prompt: {prompt}")
|
||||
@@ -331,9 +324,11 @@ class DefaultReplyer:
|
||||
sender_name,
|
||||
# in_mind_reply,
|
||||
extra_info_block,
|
||||
relation_info_block,
|
||||
identity,
|
||||
target_message,
|
||||
config_expression_style,
|
||||
# stuation,
|
||||
) -> str:
|
||||
is_group_chat = bool(chat_stream.group_info)
|
||||
|
||||
@@ -362,15 +357,16 @@ class DefaultReplyer:
|
||||
grammar_habbits = []
|
||||
# 1. learnt_expressions加权随机选3条
|
||||
if learnt_style_expressions:
|
||||
weights = [expr["count"] for expr in learnt_style_expressions]
|
||||
selected_learnt = weighted_sample_no_replacement(learnt_style_expressions, weights, 4)
|
||||
for expr in selected_learnt:
|
||||
# 使用相似度匹配选择最相似的表达
|
||||
similar_exprs = find_similar_expressions(target_message, learnt_style_expressions, 3)
|
||||
for expr in similar_exprs:
|
||||
# print(f"expr: {expr}")
|
||||
if isinstance(expr, dict) and "situation" in expr and "style" in expr:
|
||||
style_habbits.append(f"当{expr['situation']}时,使用 {expr['style']}")
|
||||
# 2. learnt_grammar_expressions加权随机选3条
|
||||
# 2. learnt_grammar_expressions加权随机选2条
|
||||
if learnt_grammar_expressions:
|
||||
weights = [expr["count"] for expr in learnt_grammar_expressions]
|
||||
selected_learnt = weighted_sample_no_replacement(learnt_grammar_expressions, weights, 4)
|
||||
selected_learnt = weighted_sample_no_replacement(learnt_grammar_expressions, weights, 2)
|
||||
for expr in selected_learnt:
|
||||
if isinstance(expr, dict) and "situation" in expr and "style" in expr:
|
||||
grammar_habbits.append(f"当{expr['situation']}时,使用 {expr['style']}")
|
||||
@@ -382,6 +378,8 @@ class DefaultReplyer:
|
||||
|
||||
style_habbits_str = "\n".join(style_habbits)
|
||||
grammar_habbits_str = "\n".join(grammar_habbits)
|
||||
|
||||
|
||||
|
||||
# 关键词检测与反应
|
||||
keywords_reaction_prompt = ""
|
||||
@@ -413,6 +411,16 @@ class DefaultReplyer:
|
||||
time_block = f"当前时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
|
||||
|
||||
# logger.debug("开始构建 focus prompt")
|
||||
|
||||
if sender_name:
|
||||
reply_target_block = f"现在{sender_name}说的:{target_message}。引起了你的注意,你想要在群里发言或者回复这条消息。"
|
||||
elif target_message:
|
||||
reply_target_block = f"现在{target_message}引起了你的注意,你想要在群里发言或者回复这条消息。"
|
||||
else:
|
||||
reply_target_block = "现在,你想要在群里发言或者回复消息。"
|
||||
|
||||
|
||||
|
||||
|
||||
# --- Choose template based on chat type ---
|
||||
if is_group_chat:
|
||||
@@ -428,7 +436,9 @@ class DefaultReplyer:
|
||||
chat_target=chat_target_1,
|
||||
chat_info=chat_talking_prompt,
|
||||
extra_info_block=extra_info_block,
|
||||
relation_info_block=relation_info_block,
|
||||
time_block=time_block,
|
||||
reply_target_block=reply_target_block,
|
||||
# bot_name=global_config.bot.nickname,
|
||||
# prompt_personality="",
|
||||
# reason=reason,
|
||||
@@ -436,6 +446,7 @@ class DefaultReplyer:
|
||||
keywords_reaction_prompt=keywords_reaction_prompt,
|
||||
identity=identity,
|
||||
target_message=target_message,
|
||||
sender_name=sender_name,
|
||||
config_expression_style=config_expression_style,
|
||||
)
|
||||
else: # Private chat
|
||||
@@ -448,7 +459,9 @@ class DefaultReplyer:
|
||||
chat_target=chat_target_1,
|
||||
chat_info=chat_talking_prompt,
|
||||
extra_info_block=extra_info_block,
|
||||
relation_info_block=relation_info_block,
|
||||
time_block=time_block,
|
||||
reply_target_block=reply_target_block,
|
||||
# bot_name=global_config.bot.nickname,
|
||||
# prompt_personality="",
|
||||
# reason=reason,
|
||||
@@ -456,6 +469,7 @@ class DefaultReplyer:
|
||||
keywords_reaction_prompt=keywords_reaction_prompt,
|
||||
identity=identity,
|
||||
target_message=target_message,
|
||||
sender_name=sender_name,
|
||||
config_expression_style=config_expression_style,
|
||||
)
|
||||
|
||||
@@ -599,6 +613,8 @@ class DefaultReplyer:
|
||||
platform=self.chat_stream.platform,
|
||||
)
|
||||
|
||||
# await anchor_message.process()
|
||||
|
||||
bot_message = MessageSending(
|
||||
message_id=message_id, # 使用片段的唯一ID
|
||||
chat_stream=self.chat_stream,
|
||||
@@ -649,4 +665,35 @@ def weighted_sample_no_replacement(items, weights, k) -> list:
|
||||
return selected
|
||||
|
||||
|
||||
def find_similar_expressions(input_text: str, expressions: List[Dict], top_k: int = 3) -> List[Dict]:
|
||||
"""使用TF-IDF和余弦相似度找出与输入文本最相似的top_k个表达方式"""
|
||||
if not expressions:
|
||||
return []
|
||||
|
||||
# 准备文本数据
|
||||
texts = [expr['situation'] for expr in expressions]
|
||||
texts.append(input_text) # 添加输入文本
|
||||
|
||||
# 使用TF-IDF向量化
|
||||
vectorizer = TfidfVectorizer()
|
||||
tfidf_matrix = vectorizer.fit_transform(texts)
|
||||
|
||||
# 计算余弦相似度
|
||||
similarity_matrix = cosine_similarity(tfidf_matrix)
|
||||
|
||||
# 获取输入文本的相似度分数(最后一行)
|
||||
scores = similarity_matrix[-1][:-1] # 排除与自身的相似度
|
||||
|
||||
# 获取top_k的索引
|
||||
top_indices = np.argsort(scores)[::-1][:top_k]
|
||||
|
||||
# 获取相似表达
|
||||
similar_exprs = []
|
||||
for idx in top_indices:
|
||||
if scores[idx] > 0: # 只保留有相似度的
|
||||
similar_exprs.append(expressions[idx])
|
||||
|
||||
return similar_exprs
|
||||
|
||||
|
||||
init_prompt()
|
||||
|
||||
@@ -35,7 +35,6 @@ class MemoryManager:
|
||||
self.llm_summarizer = LLMRequest(
|
||||
model=global_config.model.focus_working_memory,
|
||||
temperature=0.3,
|
||||
max_tokens=512,
|
||||
request_type="focus.processor.working_memory",
|
||||
)
|
||||
|
||||
|
||||
@@ -132,13 +132,17 @@ class ChattingObservation(Observation):
|
||||
# logger.debug(f"找到的锚定消息:find_msg: {find_msg}")
|
||||
break
|
||||
else:
|
||||
similarity = difflib.SequenceMatcher(None, text, message["processed_plain_text"]).ratio()
|
||||
raw_message = message.get("raw_message")
|
||||
if raw_message:
|
||||
similarity = difflib.SequenceMatcher(None, text, raw_message).ratio()
|
||||
else:
|
||||
similarity = difflib.SequenceMatcher(None, text, message.get("processed_plain_text", "")).ratio()
|
||||
msg_list.append({"message": message, "similarity": similarity})
|
||||
# logger.debug(f"对锚定消息检查:message: {message['processed_plain_text']},similarity: {similarity}")
|
||||
if not find_msg:
|
||||
if msg_list:
|
||||
msg_list.sort(key=lambda x: x["similarity"], reverse=True)
|
||||
if msg_list[0]["similarity"] >= 0.5: # 只返回相似度大于等于0.5的消息
|
||||
if msg_list[0]["similarity"] >= 0.9: # 只返回相似度大于等于0.5的消息
|
||||
find_msg = msg_list[0]["message"]
|
||||
else:
|
||||
logger.debug("没有找到锚定消息,相似度低")
|
||||
@@ -191,6 +195,7 @@ class ChattingObservation(Observation):
|
||||
"detailed_plain_text": find_msg.get("processed_plain_text"),
|
||||
"processed_plain_text": find_msg.get("processed_plain_text"),
|
||||
}
|
||||
# print(f"message_dict: {message_dict}")
|
||||
find_rec_msg = MessageRecv(message_dict)
|
||||
# logger.debug(f"锚定消息处理后:find_rec_msg: {find_rec_msg}")
|
||||
return find_rec_msg
|
||||
|
||||
@@ -27,7 +27,7 @@ from rich.progress import (
|
||||
)
|
||||
|
||||
install(extra_lines=3)
|
||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..", ".."))
|
||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
||||
EMBEDDING_DATA_DIR = (
|
||||
os.path.join(ROOT_PATH, "data", "embedding")
|
||||
if global_config["persistence"]["embedding_data_dir"] is None
|
||||
@@ -6,7 +6,7 @@ from .global_logger import logger
|
||||
from . import prompt_template
|
||||
from .lpmmconfig import global_config, INVALID_ENTITY
|
||||
from .llm_client import LLMClient
|
||||
from .utils.json_fix import new_fix_broken_generated_json
|
||||
from src.chat.knowledge.utils.json_fix import new_fix_broken_generated_json
|
||||
|
||||
|
||||
def _entity_extract(llm_client: LLMClient, paragraph: str) -> List[str]:
|
||||
@@ -31,7 +31,7 @@ from .lpmmconfig import (
|
||||
|
||||
from .global_logger import logger
|
||||
|
||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..", ".."))
|
||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
||||
KG_DIR = (
|
||||
os.path.join(ROOT_PATH, "data/rag")
|
||||
if global_config["persistence"]["rag_data_dir"] is None
|
||||
@@ -1,10 +1,10 @@
|
||||
from .src.lpmmconfig import PG_NAMESPACE, global_config
|
||||
from .src.embedding_store import EmbeddingManager
|
||||
from .src.llm_client import LLMClient
|
||||
from .src.mem_active_manager import MemoryActiveManager
|
||||
from .src.qa_manager import QAManager
|
||||
from .src.kg_manager import KGManager
|
||||
from .src.global_logger import logger
|
||||
from src.chat.knowledge.lpmmconfig import PG_NAMESPACE, global_config
|
||||
from src.chat.knowledge.embedding_store import EmbeddingManager
|
||||
from src.chat.knowledge.llm_client import LLMClient
|
||||
from src.chat.knowledge.mem_active_manager import MemoryActiveManager
|
||||
from src.chat.knowledge.qa_manager import QAManager
|
||||
from src.chat.knowledge.kg_manager import KGManager
|
||||
from src.chat.knowledge.global_logger import logger
|
||||
# try:
|
||||
# import quick_algo
|
||||
# except ImportError:
|
||||
|
||||
@@ -45,7 +45,7 @@ def _load_config(config, config_file_path):
|
||||
if "llm_providers" in file_config:
|
||||
for provider in file_config["llm_providers"]:
|
||||
if provider["name"] not in config["llm_providers"]:
|
||||
config["llm_providers"][provider["name"]] = dict()
|
||||
config["llm_providers"][provider["name"]] = {}
|
||||
config["llm_providers"][provider["name"]]["base_url"] = provider["base_url"]
|
||||
config["llm_providers"][provider["name"]]["api_key"] = provider["api_key"]
|
||||
|
||||
@@ -135,6 +135,6 @@ global_config = dict(
|
||||
# _load_config(global_config, parser.parse_args().config_path)
|
||||
# file_path = os.path.abspath(__file__)
|
||||
# dir_path = os.path.dirname(file_path)
|
||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..", ".."))
|
||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
||||
config_path = os.path.join(ROOT_PATH, "config", "lpmm_config.toml")
|
||||
_load_config(global_config, config_path)
|
||||
@@ -3,7 +3,7 @@ import os
|
||||
|
||||
from .global_logger import logger
|
||||
from .lpmmconfig import global_config
|
||||
from .utils.hash import get_sha256
|
||||
from src.chat.knowledge.utils import get_sha256
|
||||
|
||||
|
||||
def load_raw_data(path: str = None) -> tuple[list[str], list[str]]:
|
||||
@@ -108,8 +108,8 @@ class MessageRecv(Message):
|
||||
self.raw_message = message_dict.get("raw_message")
|
||||
|
||||
# 处理消息内容
|
||||
self.processed_plain_text = "" # 初始化为空字符串
|
||||
self.detailed_plain_text = "" # 初始化为空字符串
|
||||
self.processed_plain_text = message_dict.get("processed_plain_text", "") # 初始化为空字符串
|
||||
self.detailed_plain_text = message_dict.get("detailed_plain_text", "") # 初始化为空字符串
|
||||
self.is_emoji = False
|
||||
|
||||
def update_chat_stream(self, chat_stream: "ChatStream"):
|
||||
@@ -217,7 +217,9 @@ class MessageProcessBase(Message):
|
||||
return f"[@{seg.data}]"
|
||||
elif seg.type == "reply":
|
||||
if self.reply and hasattr(self.reply, "processed_plain_text"):
|
||||
return f"[回复:{self.reply.processed_plain_text}]"
|
||||
# print(f"self.reply.processed_plain_text: {self.reply.processed_plain_text}")
|
||||
# print(f"reply: {self.reply}")
|
||||
return f"[回复<{self.reply.message_info.user_info.user_nickname}:{self.reply.message_info.user_info.user_id}> 的消息:{self.reply.processed_plain_text}]"
|
||||
return None
|
||||
else:
|
||||
return f"[{seg.type}:{str(seg.data)}]"
|
||||
|
||||
@@ -301,28 +301,26 @@ class NormalChat:
|
||||
info_catcher = info_catcher_manager.get_info_catcher(thinking_id)
|
||||
info_catcher.catch_decide_to_response(message)
|
||||
|
||||
# 如果启用planner,预先修改可用actions(避免在并行任务中重复调用)
|
||||
available_actions = None
|
||||
if self.enable_planner:
|
||||
try:
|
||||
await self.action_modifier.modify_actions_for_normal_chat(
|
||||
self.chat_stream, self.recent_replies, message.processed_plain_text
|
||||
)
|
||||
available_actions = self.action_manager.get_using_actions()
|
||||
except Exception as e:
|
||||
logger.warning(f"[{self.stream_name}] 获取available_actions失败: {e}")
|
||||
available_actions = None
|
||||
|
||||
# 定义并行执行的任务
|
||||
async def generate_normal_response():
|
||||
"""生成普通回复"""
|
||||
try:
|
||||
# 如果启用planner,获取可用actions
|
||||
enable_planner = self.enable_planner
|
||||
available_actions = None
|
||||
|
||||
if enable_planner:
|
||||
try:
|
||||
await self.action_modifier.modify_actions_for_normal_chat(
|
||||
self.chat_stream, self.recent_replies
|
||||
)
|
||||
available_actions = self.action_manager.get_using_actions()
|
||||
except Exception as e:
|
||||
logger.warning(f"[{self.stream_name}] 获取available_actions失败: {e}")
|
||||
available_actions = None
|
||||
|
||||
return await self.gpt.generate_response(
|
||||
message=message,
|
||||
thinking_id=thinking_id,
|
||||
enable_planner=enable_planner,
|
||||
enable_planner=self.enable_planner,
|
||||
available_actions=available_actions,
|
||||
)
|
||||
except Exception as e:
|
||||
@@ -336,38 +334,37 @@ class NormalChat:
|
||||
return None
|
||||
|
||||
try:
|
||||
# 并行执行动作修改和规划准备
|
||||
async def modify_actions():
|
||||
"""修改可用动作集合"""
|
||||
return await self.action_modifier.modify_actions_for_normal_chat(
|
||||
self.chat_stream, self.recent_replies
|
||||
)
|
||||
|
||||
async def prepare_planning():
|
||||
"""准备规划所需的信息"""
|
||||
return self._get_sender_name(message)
|
||||
|
||||
# 并行执行动作修改和准备工作
|
||||
_, sender_name = await asyncio.gather(modify_actions(), prepare_planning())
|
||||
# 获取发送者名称(动作修改已在并行执行前完成)
|
||||
sender_name = self._get_sender_name(message)
|
||||
|
||||
no_action = {
|
||||
"action_result": {"action_type": "no_action", "action_data": {}, "reasoning": "规划器初始化默认", "is_parallel": True},
|
||||
"chat_context": "",
|
||||
"action_prompt": "",
|
||||
}
|
||||
|
||||
|
||||
# 检查是否应该跳过规划
|
||||
if self.action_modifier.should_skip_planning():
|
||||
logger.debug(f"[{self.stream_name}] 没有可用动作,跳过规划")
|
||||
return None
|
||||
self.action_type = "no_action"
|
||||
return no_action
|
||||
|
||||
# 执行规划
|
||||
plan_result = await self.planner.plan(message, sender_name)
|
||||
action_type = plan_result["action_result"]["action_type"]
|
||||
action_data = plan_result["action_result"]["action_data"]
|
||||
reasoning = plan_result["action_result"]["reasoning"]
|
||||
is_parallel = plan_result["action_result"].get("is_parallel", False)
|
||||
|
||||
logger.info(f"[{self.stream_name}] Planner决策: {action_type}, 理由: {reasoning}")
|
||||
logger.info(f"[{self.stream_name}] Planner决策: {action_type}, 理由: {reasoning}, 并行执行: {is_parallel}")
|
||||
self.action_type = action_type # 更新实例属性
|
||||
self.is_parallel_action = is_parallel # 新增:保存并行执行标志
|
||||
|
||||
# 如果规划器决定不执行任何动作
|
||||
if action_type == "no_action":
|
||||
logger.debug(f"[{self.stream_name}] Planner决定不执行任何额外动作")
|
||||
return None
|
||||
return no_action
|
||||
elif action_type == "change_to_focus_chat":
|
||||
logger.info(f"[{self.stream_name}] Planner决定切换到focus聊天模式")
|
||||
return None
|
||||
@@ -379,14 +376,15 @@ class NormalChat:
|
||||
else:
|
||||
logger.warning(f"[{self.stream_name}] 额外动作 {action_type} 执行失败")
|
||||
|
||||
return {"action_type": action_type, "action_data": action_data, "reasoning": reasoning}
|
||||
return {"action_type": action_type, "action_data": action_data, "reasoning": reasoning, "is_parallel": is_parallel}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.stream_name}] Planner执行失败: {e}")
|
||||
return None
|
||||
return no_action
|
||||
|
||||
# 并行执行回复生成和动作规划
|
||||
self.action_type = None # 初始化动作类型
|
||||
self.is_parallel_action = False # 初始化并行动作标志
|
||||
with Timer("并行生成回复和规划", timing_results):
|
||||
response_set, plan_result = await asyncio.gather(
|
||||
generate_normal_response(), plan_and_execute_actions(), return_exceptions=True
|
||||
@@ -403,12 +401,15 @@ class NormalChat:
|
||||
if isinstance(plan_result, Exception):
|
||||
logger.error(f"[{self.stream_name}] 动作规划异常: {plan_result}")
|
||||
elif plan_result:
|
||||
logger.debug(f"[{self.stream_name}] 额外动作处理完成: {plan_result['action_type']}")
|
||||
|
||||
logger.debug(f"[{self.stream_name}] 额外动作处理完成: {self.action_type}")
|
||||
|
||||
if not response_set or (
|
||||
self.enable_planner and self.action_type not in ["no_action", "change_to_focus_chat"]
|
||||
self.enable_planner and self.action_type not in ["no_action", "change_to_focus_chat"] and not self.is_parallel_action
|
||||
):
|
||||
logger.info(f"[{self.stream_name}] 模型未生成回复内容")
|
||||
if not response_set:
|
||||
logger.info(f"[{self.stream_name}] 模型未生成回复内容")
|
||||
elif self.enable_planner and self.action_type not in ["no_action", "change_to_focus_chat"] and not self.is_parallel_action:
|
||||
logger.info(f"[{self.stream_name}] 模型选择其他动作(非并行动作)")
|
||||
# 如果模型未生成回复,移除思考消息
|
||||
container = await message_manager.get_container(self.stream_id) # 使用 self.stream_id
|
||||
for msg in container.messages[:]:
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
from typing import List, Any
|
||||
from typing import List, Any, Dict
|
||||
from src.common.logger_manager import get_logger
|
||||
from src.chat.focus_chat.planners.action_manager import ActionManager
|
||||
from src.chat.focus_chat.planners.actions.base_action import ActionActivationType, ChatMode
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat
|
||||
from src.config.config import global_config
|
||||
import random
|
||||
import time
|
||||
|
||||
logger = get_logger("normal_chat_action_modifier")
|
||||
|
||||
@@ -9,6 +14,7 @@ class NormalChatActionModifier:
|
||||
"""Normal Chat动作修改器
|
||||
|
||||
负责根据Normal Chat的上下文和状态动态调整可用的动作集合
|
||||
实现与Focus Chat类似的动作激活策略,但将LLM_JUDGE转换为概率激活以提升性能
|
||||
"""
|
||||
|
||||
def __init__(self, action_manager: ActionManager, stream_id: str, stream_name: str):
|
||||
@@ -25,9 +31,14 @@ class NormalChatActionModifier:
|
||||
self,
|
||||
chat_stream,
|
||||
recent_replies: List[dict],
|
||||
message_content: str,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""为Normal Chat修改可用动作集合
|
||||
|
||||
实现动作激活策略:
|
||||
1. 基于关联类型的动态过滤
|
||||
2. 基于激活类型的智能判定(LLM_JUDGE转为概率激活)
|
||||
|
||||
Args:
|
||||
chat_stream: 聊天流对象
|
||||
@@ -35,24 +46,19 @@ class NormalChatActionModifier:
|
||||
**kwargs: 其他参数
|
||||
"""
|
||||
|
||||
# 合并所有动作变更
|
||||
merged_action_changes = {"add": [], "remove": []}
|
||||
reasons = []
|
||||
merged_action_changes = {"add": [], "remove": []}
|
||||
type_mismatched_actions = [] # 在外层定义避免作用域问题
|
||||
|
||||
self.action_manager.restore_default_actions()
|
||||
|
||||
# 1. 移除Normal Chat不适用的动作
|
||||
excluded_actions = ["exit_focus_chat_action", "no_reply", "reply"]
|
||||
for action_name in excluded_actions:
|
||||
if action_name in self.action_manager.get_using_actions():
|
||||
merged_action_changes["remove"].append(action_name)
|
||||
reasons.append(f"移除{action_name}(Normal Chat不适用)")
|
||||
|
||||
# 2. 检查动作的关联类型
|
||||
# 第一阶段:基于关联类型的动态过滤
|
||||
if chat_stream:
|
||||
chat_context = chat_stream.context if hasattr(chat_stream, "context") else None
|
||||
if chat_context:
|
||||
type_mismatched_actions = []
|
||||
|
||||
current_using_actions = self.action_manager.get_using_actions()
|
||||
# 获取Normal模式下的可用动作(已经过滤了mode_enable)
|
||||
current_using_actions = self.action_manager.get_using_actions_for_mode(ChatMode.NORMAL)
|
||||
# print(f"current_using_actions: {current_using_actions}")
|
||||
for action_name in current_using_actions.keys():
|
||||
if action_name in self.all_actions:
|
||||
data = self.all_actions[action_name]
|
||||
@@ -65,26 +71,218 @@ class NormalChatActionModifier:
|
||||
merged_action_changes["remove"].extend(type_mismatched_actions)
|
||||
reasons.append(f"移除{type_mismatched_actions}(关联类型不匹配)")
|
||||
|
||||
# 应用动作变更
|
||||
# 第二阶段:应用激活类型判定
|
||||
# 构建聊天内容 - 使用与planner一致的方式
|
||||
chat_content = ""
|
||||
if chat_stream and hasattr(chat_stream, 'stream_id'):
|
||||
try:
|
||||
# 获取消息历史,使用与normal_chat_planner相同的方法
|
||||
message_list_before_now = get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=chat_stream.stream_id,
|
||||
timestamp=time.time(),
|
||||
limit=global_config.focus_chat.observation_context_size, # 使用相同的配置
|
||||
)
|
||||
|
||||
# 构建可读的聊天上下文
|
||||
chat_content = build_readable_messages(
|
||||
message_list_before_now,
|
||||
replace_bot_name=True,
|
||||
merge_messages=False,
|
||||
timestamp_mode="relative",
|
||||
read_mark=0.0,
|
||||
show_actions=True,
|
||||
)
|
||||
|
||||
logger.debug(f"{self.log_prefix} 成功构建聊天内容,长度: {len(chat_content)}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"{self.log_prefix} 构建聊天内容失败: {e}")
|
||||
chat_content = ""
|
||||
|
||||
# 获取当前Normal模式下的动作集进行激活判定
|
||||
current_actions = self.action_manager.get_using_actions_for_mode(ChatMode.NORMAL)
|
||||
|
||||
# print(f"current_actions: {current_actions}")
|
||||
# print(f"chat_content: {chat_content}")
|
||||
final_activated_actions = await self._apply_normal_activation_filtering(
|
||||
current_actions,
|
||||
chat_content,
|
||||
message_content
|
||||
)
|
||||
# print(f"final_activated_actions: {final_activated_actions}")
|
||||
|
||||
# 统一处理所有需要移除的动作,避免重复移除
|
||||
all_actions_to_remove = set() # 使用set避免重复
|
||||
|
||||
# 添加关联类型不匹配的动作
|
||||
if type_mismatched_actions:
|
||||
all_actions_to_remove.update(type_mismatched_actions)
|
||||
|
||||
# 添加激活类型判定未通过的动作
|
||||
for action_name in current_actions.keys():
|
||||
if action_name not in final_activated_actions:
|
||||
all_actions_to_remove.add(action_name)
|
||||
|
||||
# 统计移除原因(避免重复)
|
||||
activation_failed_actions = [name for name in current_actions.keys() if name not in final_activated_actions and name not in type_mismatched_actions]
|
||||
if activation_failed_actions:
|
||||
reasons.append(f"移除{activation_failed_actions}(激活类型判定未通过)")
|
||||
|
||||
# 统一执行移除操作
|
||||
for action_name in all_actions_to_remove:
|
||||
success = self.action_manager.remove_action_from_using(action_name)
|
||||
if success:
|
||||
logger.debug(f"{self.log_prefix} 移除动作: {action_name}")
|
||||
else:
|
||||
logger.debug(f"{self.log_prefix} 动作 {action_name} 已经不在使用集中,跳过移除")
|
||||
|
||||
# 应用动作添加(如果有的话)
|
||||
for action_name in merged_action_changes["add"]:
|
||||
if action_name in self.all_actions and action_name not in excluded_actions:
|
||||
if action_name in self.all_actions:
|
||||
success = self.action_manager.add_action_to_using(action_name)
|
||||
if success:
|
||||
logger.debug(f"{self.log_prefix} 添加动作: {action_name}")
|
||||
|
||||
for action_name in merged_action_changes["remove"]:
|
||||
success = self.action_manager.remove_action_from_using(action_name)
|
||||
if success:
|
||||
logger.debug(f"{self.log_prefix} 移除动作: {action_name}")
|
||||
|
||||
# 记录变更原因
|
||||
if merged_action_changes["add"] or merged_action_changes["remove"]:
|
||||
if reasons:
|
||||
logger.info(f"{self.log_prefix} 动作调整完成: {' | '.join(reasons)}")
|
||||
logger.debug(f"{self.log_prefix} 当前可用动作: {list(self.action_manager.get_using_actions().keys())}")
|
||||
|
||||
# 获取最终的Normal模式可用动作并记录
|
||||
final_actions = self.action_manager.get_using_actions_for_mode(ChatMode.NORMAL)
|
||||
logger.debug(f"{self.log_prefix} 当前Normal模式可用动作: {list(final_actions.keys())}")
|
||||
|
||||
async def _apply_normal_activation_filtering(
|
||||
self,
|
||||
actions_with_info: Dict[str, Any],
|
||||
chat_content: str = "",
|
||||
message_content: str = "",
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
应用Normal模式的激活类型过滤逻辑
|
||||
|
||||
与Focus模式的区别:
|
||||
1. LLM_JUDGE类型转换为概率激活(避免LLM调用)
|
||||
2. RANDOM类型保持概率激活
|
||||
3. KEYWORD类型保持关键词匹配
|
||||
4. ALWAYS类型直接激活
|
||||
|
||||
Args:
|
||||
actions_with_info: 带完整信息的动作字典
|
||||
chat_content: 聊天内容
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 过滤后激活的actions字典
|
||||
"""
|
||||
activated_actions = {}
|
||||
|
||||
# 分类处理不同激活类型的actions
|
||||
always_actions = {}
|
||||
random_actions = {}
|
||||
keyword_actions = {}
|
||||
|
||||
for action_name, action_info in actions_with_info.items():
|
||||
# 使用normal_activation_type
|
||||
activation_type = action_info.get("normal_activation_type", ActionActivationType.ALWAYS)
|
||||
|
||||
if activation_type == ActionActivationType.ALWAYS:
|
||||
always_actions[action_name] = action_info
|
||||
elif activation_type == ActionActivationType.RANDOM or activation_type == ActionActivationType.LLM_JUDGE:
|
||||
random_actions[action_name] = action_info
|
||||
elif activation_type == ActionActivationType.KEYWORD:
|
||||
keyword_actions[action_name] = action_info
|
||||
else:
|
||||
logger.warning(f"{self.log_prefix}未知的激活类型: {activation_type},跳过处理")
|
||||
|
||||
# 1. 处理ALWAYS类型(直接激活)
|
||||
for action_name, action_info in always_actions.items():
|
||||
activated_actions[action_name] = action_info
|
||||
logger.debug(f"{self.log_prefix}激活动作: {action_name},原因: ALWAYS类型直接激活")
|
||||
|
||||
# 2. 处理RANDOM类型(概率激活)
|
||||
for action_name, action_info in random_actions.items():
|
||||
probability = action_info.get("random_probability", 0.3)
|
||||
should_activate = random.random() < probability
|
||||
if should_activate:
|
||||
activated_actions[action_name] = action_info
|
||||
logger.info(f"{self.log_prefix}激活动作: {action_name},原因: RANDOM类型触发(概率{probability})")
|
||||
else:
|
||||
logger.debug(f"{self.log_prefix}未激活动作: {action_name},原因: RANDOM类型未触发(概率{probability})")
|
||||
|
||||
# 3. 处理KEYWORD类型(关键词匹配)
|
||||
for action_name, action_info in keyword_actions.items():
|
||||
should_activate = self._check_keyword_activation(
|
||||
action_name,
|
||||
action_info,
|
||||
chat_content,
|
||||
message_content
|
||||
)
|
||||
if should_activate:
|
||||
activated_actions[action_name] = action_info
|
||||
keywords = action_info.get("activation_keywords", [])
|
||||
logger.info(f"{self.log_prefix}激活动作: {action_name},原因: KEYWORD类型匹配关键词({keywords})")
|
||||
else:
|
||||
keywords = action_info.get("activation_keywords", [])
|
||||
logger.info(f"{self.log_prefix}未激活动作: {action_name},原因: KEYWORD类型未匹配关键词({keywords})")
|
||||
# print(f"keywords: {keywords}")
|
||||
# print(f"chat_content: {chat_content}")
|
||||
|
||||
logger.debug(f"{self.log_prefix}Normal模式激活类型过滤完成: {list(activated_actions.keys())}")
|
||||
return activated_actions
|
||||
|
||||
def _check_keyword_activation(
|
||||
self,
|
||||
action_name: str,
|
||||
action_info: Dict[str, Any],
|
||||
chat_content: str = "",
|
||||
message_content: str = "",
|
||||
) -> bool:
|
||||
"""
|
||||
检查是否匹配关键词触发条件
|
||||
|
||||
Args:
|
||||
action_name: 动作名称
|
||||
action_info: 动作信息
|
||||
chat_content: 聊天内容(已经是格式化后的可读消息)
|
||||
|
||||
Returns:
|
||||
bool: 是否应该激活此action
|
||||
"""
|
||||
|
||||
activation_keywords = action_info.get("activation_keywords", [])
|
||||
case_sensitive = action_info.get("keyword_case_sensitive", False)
|
||||
|
||||
if not activation_keywords:
|
||||
logger.warning(f"{self.log_prefix}动作 {action_name} 设置为关键词触发但未配置关键词")
|
||||
return False
|
||||
|
||||
# 使用构建好的聊天内容作为检索文本
|
||||
search_text = chat_content +message_content
|
||||
|
||||
# 如果不区分大小写,转换为小写
|
||||
if not case_sensitive:
|
||||
search_text = search_text.lower()
|
||||
|
||||
# 检查每个关键词
|
||||
matched_keywords = []
|
||||
for keyword in activation_keywords:
|
||||
check_keyword = keyword if case_sensitive else keyword.lower()
|
||||
if check_keyword in search_text:
|
||||
matched_keywords.append(keyword)
|
||||
|
||||
|
||||
# print(f"search_text: {search_text}")
|
||||
# print(f"activation_keywords: {activation_keywords}")
|
||||
|
||||
if matched_keywords:
|
||||
logger.info(f"{self.log_prefix}动作 {action_name} 匹配到关键词: {matched_keywords}")
|
||||
return True
|
||||
else:
|
||||
logger.info(f"{self.log_prefix}动作 {action_name} 未匹配到任何关键词: {activation_keywords}")
|
||||
return False
|
||||
|
||||
def get_available_actions_count(self) -> int:
|
||||
"""获取当前可用动作数量(排除默认的no_action)"""
|
||||
current_actions = self.action_manager.get_using_actions()
|
||||
current_actions = self.action_manager.get_using_actions_for_mode(ChatMode.NORMAL)
|
||||
# 排除no_action(如果存在)
|
||||
filtered_actions = {k: v for k, v in current_actions.items() if k != "no_action"}
|
||||
return len(filtered_actions)
|
||||
|
||||
@@ -19,19 +19,15 @@ class NormalChatGenerator:
|
||||
# TODO: API-Adapter修改标记
|
||||
self.model_reasoning = LLMRequest(
|
||||
model=global_config.model.replyer_1,
|
||||
# temperature=0.7,
|
||||
max_tokens=3000,
|
||||
request_type="normal.chat_1",
|
||||
)
|
||||
self.model_normal = LLMRequest(
|
||||
model=global_config.model.replyer_2,
|
||||
# temperature=global_config.model.replyer_2["temp"],
|
||||
max_tokens=256,
|
||||
request_type="normal.chat_2",
|
||||
)
|
||||
|
||||
self.model_sum = LLMRequest(
|
||||
model=global_config.model.memory_summary, temperature=0.7, max_tokens=3000, request_type="relation"
|
||||
model=global_config.model.memory_summary, temperature=0.7, request_type="relation"
|
||||
)
|
||||
self.current_model_type = "r1" # 默认使用 R1
|
||||
self.current_model_name = "unknown model"
|
||||
@@ -57,7 +53,7 @@ class NormalChatGenerator:
|
||||
)
|
||||
|
||||
if model_response:
|
||||
logger.debug(f"{global_config.bot.nickname}的原始回复是:{model_response}")
|
||||
logger.debug(f"{global_config.bot.nickname}的备选回复是:{model_response}")
|
||||
model_response = process_llm_response(model_response)
|
||||
|
||||
return model_response
|
||||
|
||||
@@ -7,6 +7,7 @@ from src.common.logger_manager import get_logger
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.individuality.individuality import individuality
|
||||
from src.chat.focus_chat.planners.action_manager import ActionManager
|
||||
from src.chat.focus_chat.planners.actions.base_action import ChatMode
|
||||
from src.chat.message_receive.message import MessageThinking
|
||||
from json_repair import repair_json
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat
|
||||
@@ -98,16 +99,18 @@ class NormalChatPlanner:
|
||||
|
||||
self_info = name_block + personality_block + identity_block
|
||||
|
||||
# 获取当前可用的动作
|
||||
current_available_actions = self.action_manager.get_using_actions()
|
||||
# 获取当前可用的动作,使用Normal模式过滤
|
||||
current_available_actions = self.action_manager.get_using_actions_for_mode(ChatMode.NORMAL)
|
||||
|
||||
# 注意:动作的激活判定现在在 normal_chat_action_modifier 中完成
|
||||
# 这里直接使用经过 action_modifier 处理后的最终动作集
|
||||
# 符合职责分离原则:ActionModifier负责动作管理,Planner专注于决策
|
||||
|
||||
# 如果没有可用动作或只有no_action动作,直接返回no_action
|
||||
if not current_available_actions or (
|
||||
len(current_available_actions) == 1 and "no_action" in current_available_actions
|
||||
):
|
||||
logger.debug(f"{self.log_prefix}规划器: 没有可用动作或只有no_action动作,返回no_action")
|
||||
# 如果没有可用动作,直接返回no_action
|
||||
if not current_available_actions:
|
||||
logger.debug(f"{self.log_prefix}规划器: 没有可用动作,返回no_action")
|
||||
return {
|
||||
"action_result": {"action_type": action, "action_data": action_data, "reasoning": reasoning},
|
||||
"action_result": {"action_type": action, "action_data": action_data, "reasoning": reasoning, "is_parallel": True},
|
||||
"chat_context": "",
|
||||
"action_prompt": "",
|
||||
}
|
||||
@@ -138,7 +141,7 @@ class NormalChatPlanner:
|
||||
if not prompt:
|
||||
logger.warning(f"{self.log_prefix}规划器: 构建提示词失败")
|
||||
return {
|
||||
"action_result": {"action_type": action, "action_data": action_data, "reasoning": reasoning},
|
||||
"action_result": {"action_type": action, "action_data": action_data, "reasoning": reasoning, "is_parallel": False},
|
||||
"chat_context": chat_context,
|
||||
"action_prompt": "",
|
||||
}
|
||||
@@ -185,13 +188,21 @@ class NormalChatPlanner:
|
||||
|
||||
except Exception as outer_e:
|
||||
logger.error(f"{self.log_prefix}规划器异常: {outer_e}")
|
||||
chat_context = "无法获取聊天上下文" # 设置默认值
|
||||
prompt = "" # 设置默认值
|
||||
# 设置异常时的默认值
|
||||
current_available_actions = {}
|
||||
chat_context = "无法获取聊天上下文"
|
||||
prompt = ""
|
||||
action = "no_action"
|
||||
reasoning = "规划器出现异常,使用默认动作"
|
||||
action_data = {}
|
||||
|
||||
logger.debug(f"{self.log_prefix}规划器决策动作:{action}, 动作信息: '{action_data}', 理由: {reasoning}")
|
||||
# 检查动作是否支持并行执行
|
||||
is_parallel = False
|
||||
if action in current_available_actions:
|
||||
action_info = current_available_actions[action]
|
||||
is_parallel = action_info.get("parallel_action", False)
|
||||
|
||||
logger.debug(f"{self.log_prefix}规划器决策动作:{action}, 动作信息: '{action_data}', 理由: {reasoning}, 并行执行: {is_parallel}")
|
||||
|
||||
# 恢复到默认动作集
|
||||
self.action_manager.restore_actions()
|
||||
@@ -212,6 +223,7 @@ class NormalChatPlanner:
|
||||
"action_type": action,
|
||||
"action_data": action_data,
|
||||
"reasoning": reasoning,
|
||||
"is_parallel": is_parallel,
|
||||
"action_record": json.dumps(action_record, ensure_ascii=False)
|
||||
}
|
||||
|
||||
@@ -304,4 +316,6 @@ class NormalChatPlanner:
|
||||
return ""
|
||||
|
||||
|
||||
|
||||
|
||||
init_prompt()
|
||||
|
||||
@@ -184,7 +184,7 @@ class ImageManager:
|
||||
return f"[图片:{cached_description}]"
|
||||
|
||||
# 调用AI获取描述
|
||||
prompt = "请用中文描述这张图片的内容。如果有文字,请把文字都描述出来,请留意其主题,直观感受,以及是否有擦边色情内容。最多100个字。"
|
||||
prompt = "请用中文描述这张图片的内容。如果有文字,请把文字都描述出来,请留意其主题,直观感受,输出为一段平文本,最多50字"
|
||||
description, _ = await self._llm.generate_response_for_image(prompt, image_base64, image_format)
|
||||
|
||||
if description is None:
|
||||
|
||||
Reference in New Issue
Block a user