This commit is contained in:
春河晴
2025-06-10 16:13:31 +09:00
parent 440e8bf7f3
commit 8d9a88a903
70 changed files with 1598 additions and 1642 deletions

View File

@@ -8,19 +8,22 @@ logger = get_logger("base_action")
_ACTION_REGISTRY: Dict[str, Type["BaseAction"]] = {}
_DEFAULT_ACTIONS: Dict[str, str] = {}
# 动作激活类型枚举
class ActionActivationType:
ALWAYS = "always" # 默认参与到planner
LLM_JUDGE = "llm_judge" # LLM判定是否启动该action到planner
LLM_JUDGE = "llm_judge" # LLM判定是否启动该action到planner
RANDOM = "random" # 随机启用action到planner
KEYWORD = "keyword" # 关键词触发启用action到planner
# 聊天模式枚举
class ChatMode:
FOCUS = "focus" # Focus聊天模式
NORMAL = "normal" # Normal聊天模式
ALL = "all" # 所有聊天模式
def register_action(cls):
"""
动作注册装饰器
@@ -81,13 +84,13 @@ class BaseAction(ABC):
self.action_description: str = "基础动作"
self.action_parameters: dict = {}
self.action_require: list[str] = []
# 动作激活类型设置
# Focus模式下的激活类型默认为always
self.focus_activation_type: str = ActionActivationType.ALWAYS
# Normal模式下的激活类型默认为always
# Normal模式下的激活类型默认为always
self.normal_activation_type: str = ActionActivationType.ALWAYS
# 随机激活的概率(0.0-1.0)用于RANDOM激活类型
self.random_activation_probability: float = 0.3
# LLM判定的提示词用于LLM_JUDGE激活类型

View File

@@ -4,4 +4,4 @@ from . import no_reply_action # noqa
from . import exit_focus_chat_action # noqa
from . import emoji_action # noqa
# 在此处添加更多动作模块导入
# 在此处添加更多动作模块导入

View File

@@ -22,29 +22,26 @@ class EmojiAction(BaseAction):
action_parameters: dict[str:str] = {
"description": "文字描述你想要发送的表情包内容",
}
action_require: list[str] = [
"表达情绪时可以选择使用",
"重点:不要连续发,如果你已经发过[表情包],就不要选择此动作"]
action_require: list[str] = ["表达情绪时可以选择使用", "重点:不要连续发,如果你已经发过[表情包],就不要选择此动作"]
associated_types: list[str] = ["emoji"]
enable_plugin = True
focus_activation_type = ActionActivationType.LLM_JUDGE
normal_activation_type = ActionActivationType.RANDOM
random_activation_probability = global_config.normal_chat.emoji_chance
parallel_action = True
llm_judge_prompt = """
判定是否需要使用表情动作的条件:
1. 用户明确要求使用表情包
2. 这是一个适合表达强烈情绪的场合
3. 不要发送太多表情包,如果你已经发送过多个表情包
"""
# 模式启用设置 - 表情动作只在Focus模式下使用
mode_enable = ChatMode.ALL
@@ -147,4 +144,4 @@ class EmojiAction(BaseAction):
elif type == "emoji":
reply_text += data
return success, reply_text
return success, reply_text

View File

@@ -27,7 +27,7 @@ class ExitFocusChatAction(BaseAction):
]
# 退出专注聊天是系统核心功能,不是插件,但默认不启用(需要特定条件触发)
enable_plugin = False
# 模式启用设置 - 退出专注聊天动作只在Focus模式下使用
mode_enable = ChatMode.FOCUS

View File

@@ -29,10 +29,10 @@ class NoReplyAction(BaseAction):
"想要休息一下",
]
enable_plugin = True
# 激活类型设置
focus_activation_type = ActionActivationType.ALWAYS
# 模式启用设置 - no_reply动作只在Focus模式下使用
mode_enable = ChatMode.FOCUS

View File

@@ -31,16 +31,16 @@ class ReplyAction(BaseAction):
action_require: list[str] = [
"你想要闲聊或者随便附和",
"有人提到你",
"如果你刚刚进行了回复,不要对同一个话题重复回应"
"如果你刚刚进行了回复,不要对同一个话题重复回应",
]
associated_types: list[str] = ["text"]
enable_plugin = True
# 激活类型设置
focus_activation_type = ActionActivationType.ALWAYS
# 模式启用设置 - 回复动作只在Focus模式下使用
mode_enable = ChatMode.FOCUS
@@ -89,12 +89,12 @@ class ReplyAction(BaseAction):
cycle_timers=self.cycle_timers,
thinking_id=self.thinking_id,
)
await self.store_action_info(
action_build_into_prompt=False,
action_prompt_display=f"{reply_text}",
)
return success, reply_text
async def _handle_reply(
@@ -115,22 +115,22 @@ class ReplyAction(BaseAction):
chatting_observation: ChattingObservation = next(
obs for obs in self.observations if isinstance(obs, ChattingObservation)
)
reply_to = reply_data.get("reply_to", "none")
# sender = ""
target = ""
if ":" in reply_to or "" in reply_to:
# 使用正则表达式匹配中文或英文冒号
parts = re.split(pattern=r'[:]', string=reply_to, maxsplit=1)
parts = re.split(pattern=r"[:]", string=reply_to, maxsplit=1)
if len(parts) == 2:
# sender = parts[0].strip()
target = parts[1].strip()
anchor_message = chatting_observation.search_message_by_text(target)
else:
anchor_message = None
if anchor_message:
if anchor_message:
anchor_message.update_chat_stream(self.chat_stream)
else:
logger.info(f"{self.log_prefix} 未找到锚点消息,创建占位符")
@@ -138,7 +138,6 @@ class ReplyAction(BaseAction):
self.chat_stream.platform, self.chat_stream.group_info, self.chat_stream
)
success, reply_set = await self.replyer.deal_reply(
cycle_timers=cycle_timers,
action_data=reply_data,
@@ -158,8 +157,9 @@ class ReplyAction(BaseAction):
return success, reply_text
async def store_action_info(self, action_build_into_prompt: bool = False, action_prompt_display: str = "", action_done: bool = True) -> None:
async def store_action_info(
self, action_build_into_prompt: bool = False, action_prompt_display: str = "", action_done: bool = True
) -> None:
"""存储action执行信息到数据库
Args:
@@ -188,9 +188,9 @@ class ReplyAction(BaseAction):
chat_info_platform=chat_stream.platform,
user_id=chat_stream.user_info.user_id if chat_stream.user_info else "",
user_nickname=chat_stream.user_info.user_nickname if chat_stream.user_info else "",
user_cardname=chat_stream.user_info.user_cardname if chat_stream.user_info else ""
user_cardname=chat_stream.user_info.user_cardname if chat_stream.user_info else "",
)
logger.debug(f"{self.log_prefix} 已存储action信息: {action_prompt_display}")
except Exception as e:
logger.error(f"{self.log_prefix} 存储action信息时出错: {e}")
traceback.print_exc()
traceback.print_exc()

View File

@@ -1,10 +1,6 @@
import traceback
from typing import Tuple, Dict, List, Any, Optional, Union, Type
from typing import Tuple, Dict, Any, Optional
from src.chat.actions.base_action import BaseAction, register_action, ActionActivationType, ChatMode # noqa F401
from src.chat.heart_flow.observation.chatting_observation import ChattingObservation
from src.chat.focus_chat.hfc_utils import create_empty_anchor_message
from src.common.logger_manager import get_logger
from src.config.config import global_config
import os
import inspect
import toml # 导入 toml 库
@@ -20,10 +16,10 @@ from src.chat.actions.plugin_api.stream_api import StreamAPI
from src.chat.actions.plugin_api.hearflow_api import HearflowAPI
# 以下为类型注解需要
from src.chat.message_receive.chat_stream import ChatStream # noqa
from src.chat.focus_chat.expressors.default_expressor import DefaultExpressor # noqa
from src.chat.focus_chat.replyer.default_replyer import DefaultReplyer # noqa
from src.chat.focus_chat.info.obs_info import ObsInfo # noqa
from src.chat.message_receive.chat_stream import ChatStream # noqa
from src.chat.focus_chat.expressors.default_expressor import DefaultExpressor # noqa
from src.chat.focus_chat.replyer.default_replyer import DefaultReplyer # noqa
from src.chat.focus_chat.info.obs_info import ObsInfo # noqa
logger = get_logger("plugin_action")
@@ -35,7 +31,7 @@ class PluginAction(BaseAction, MessageAPI, LLMAPI, DatabaseAPI, ConfigAPI, Utils
"""
action_config_file_name: Optional[str] = None # 插件可以覆盖此属性来指定配置文件名
# 默认激活类型设置,插件可以覆盖
focus_activation_type = ActionActivationType.ALWAYS
normal_activation_type = ActionActivationType.ALWAYS
@@ -43,7 +39,7 @@ class PluginAction(BaseAction, MessageAPI, LLMAPI, DatabaseAPI, ConfigAPI, Utils
llm_judge_prompt: str = ""
activation_keywords: list[str] = []
keyword_case_sensitive: bool = False
# 默认模式启用设置 - 插件动作默认在所有模式下可用,插件可以覆盖
mode_enable = ChatMode.ALL

View File

@@ -7,11 +7,11 @@ from src.chat.actions.plugin_api.stream_api import StreamAPI
from src.chat.actions.plugin_api.hearflow_api import HearflowAPI
__all__ = [
'MessageAPI',
'LLMAPI',
'DatabaseAPI',
'ConfigAPI',
'UtilsAPI',
'StreamAPI',
'HearflowAPI',
]
"MessageAPI",
"LLMAPI",
"DatabaseAPI",
"ConfigAPI",
"UtilsAPI",
"StreamAPI",
"HearflowAPI",
]

View File

@@ -5,32 +5,33 @@ from src.person_info.person_info import person_info_manager
logger = get_logger("config_api")
class ConfigAPI:
"""配置API模块
提供了配置读取和用户信息获取等功能
"""
def get_global_config(self, key: str, default: Any = None) -> Any:
"""
安全地从全局配置中获取一个值。
插件应使用此方法读取全局配置,以保证只读和隔离性。
Args:
key: 配置键名
default: 如果配置不存在时返回的默认值
Returns:
Any: 配置值或默认值
"""
return global_config.get(key, default)
async def get_user_id_by_person_name(self, person_name: str) -> tuple[str, str]:
"""根据用户名获取用户ID
Args:
person_name: 用户名
Returns:
tuple[str, str]: (平台, 用户ID)
"""
@@ -38,16 +39,16 @@ class ConfigAPI:
user_id = await person_info_manager.get_value(person_id, "user_id")
platform = await person_info_manager.get_value(person_id, "platform")
return platform, user_id
async def get_person_info(self, person_id: str, key: str, default: Any = None) -> Any:
"""获取用户信息
Args:
person_id: 用户ID
key: 信息键名
default: 默认值
Returns:
Any: 用户信息值或默认值
"""
return await person_info_manager.get_value(person_id, key, default)
return await person_info_manager.get_value(person_id, key, default)

View File

@@ -8,13 +8,16 @@ from peewee import Model, DoesNotExist
logger = get_logger("database_api")
class DatabaseAPI:
"""数据库API模块
提供了数据库操作相关的功能
"""
async def store_action_info(self, action_build_into_prompt: bool = False, action_prompt_display: str = "", action_done: bool = True) -> None:
async def store_action_info(
self, action_build_into_prompt: bool = False, action_prompt_display: str = "", action_done: bool = True
) -> None:
"""存储action执行信息到数据库
Args:
@@ -44,13 +47,13 @@ class DatabaseAPI:
chat_info_platform=chat_stream.platform,
user_id=chat_stream.user_info.user_id if chat_stream.user_info else "",
user_nickname=chat_stream.user_info.user_nickname if chat_stream.user_info else "",
user_cardname=chat_stream.user_info.user_cardname if chat_stream.user_info else ""
user_cardname=chat_stream.user_info.user_cardname if chat_stream.user_info else "",
)
logger.debug(f"{self.log_prefix} 已存储action信息: {action_prompt_display}")
except Exception as e:
logger.error(f"{self.log_prefix} 存储action信息时出错: {e}")
traceback.print_exc()
async def db_query(
self,
model_class: Type[Model],
@@ -59,12 +62,12 @@ class DatabaseAPI:
data: Dict[str, Any] = None,
limit: int = None,
order_by: List[str] = None,
single_result: bool = False
single_result: bool = False,
) -> Union[List[Dict[str, Any]], Dict[str, Any], None]:
"""执行数据库查询操作
这个方法提供了一个通用接口来执行数据库操作,包括查询、创建、更新和删除记录。
Args:
model_class: Peewee 模型类,例如 ActionRecords, Messages 等
query_type: 查询类型,可选值: "get", "create", "update", "delete", "count"
@@ -73,7 +76,7 @@ class DatabaseAPI:
limit: 限制结果数量
order_by: 排序字段列表,使用字段名,前缀'-'表示降序
single_result: 是否只返回单个结果
Returns:
根据查询类型返回不同的结果:
- "get": 返回查询结果列表或单个结果(如果 single_result=True
@@ -81,24 +84,24 @@ class DatabaseAPI:
- "update": 返回受影响的行数
- "delete": 返回受影响的行数
- "count": 返回记录数量
示例:
# 查询最近10条消息
messages = await self.db_query(
Messages,
Messages,
query_type="get",
filters={"chat_id": chat_stream.stream_id},
limit=10,
order_by=["-time"]
)
# 创建一条记录
new_record = await self.db_query(
ActionRecords,
query_type="create",
data={"action_id": "123", "time": time.time(), "action_name": "TestAction"}
)
# 更新记录
updated_count = await self.db_query(
ActionRecords,
@@ -106,14 +109,14 @@ class DatabaseAPI:
filters={"action_id": "123"},
data={"action_done": True}
)
# 删除记录
deleted_count = await self.db_query(
ActionRecords,
query_type="delete",
filters={"action_id": "123"}
)
# 计数
count = await self.db_query(
Messages,
@@ -125,12 +128,12 @@ class DatabaseAPI:
# 构建基本查询
if query_type in ["get", "update", "delete", "count"]:
query = model_class.select()
# 应用过滤条件
if filters:
for field, value in filters.items():
query = query.where(getattr(model_class, field) == value)
# 执行查询
if query_type == "get":
# 应用排序
@@ -140,56 +143,56 @@ class DatabaseAPI:
query = query.order_by(getattr(model_class, field[1:]).desc())
else:
query = query.order_by(getattr(model_class, field))
# 应用限制
if limit:
query = query.limit(limit)
# 执行查询
results = list(query.dicts())
# 返回结果
if single_result:
return results[0] if results else None
return results
elif query_type == "create":
if not data:
raise ValueError("创建记录需要提供data参数")
# 创建记录
record = model_class.create(**data)
# 返回创建的记录
return model_class.select().where(model_class.id == record.id).dicts().get()
elif query_type == "update":
if not data:
raise ValueError("更新记录需要提供data参数")
# 更新记录
return query.update(**data).execute()
elif query_type == "delete":
# 删除记录
return query.delete().execute()
elif query_type == "count":
# 计数
return query.count()
else:
raise ValueError(f"不支持的查询类型: {query_type}")
except DoesNotExist:
# 记录不存在
if query_type == "get" and single_result:
return None
return []
except Exception as e:
logger.error(f"{self.log_prefix} 数据库操作出错: {e}")
traceback.print_exc()
# 根据查询类型返回合适的默认值
if query_type == "get":
return None if single_result else []
@@ -198,21 +201,18 @@ class DatabaseAPI:
raise "unknown query type"
async def db_raw_query(
self,
sql: str,
params: List[Any] = None,
fetch_results: bool = True
self, sql: str, params: List[Any] = None, fetch_results: bool = True
) -> Union[List[Dict[str, Any]], int, None]:
"""执行原始SQL查询
警告: 使用此方法需要小心确保SQL语句已正确构造以避免SQL注入风险。
Args:
sql: 原始SQL查询字符串
params: 查询参数列表用于替换SQL中的占位符
fetch_results: 是否获取查询结果对于SELECT查询设为True对于
UPDATE/INSERT/DELETE等操作设为False
Returns:
如果fetch_results为True返回查询结果列表
如果fetch_results为False返回受影响的行数
@@ -220,55 +220,51 @@ class DatabaseAPI:
"""
try:
cursor = db.execute_sql(sql, params or [])
if fetch_results:
# 获取列名
columns = [col[0] for col in cursor.description]
# 构建结果字典列表
results = []
for row in cursor.fetchall():
results.append(dict(zip(columns, row)))
return results
else:
# 返回受影响的行数
return cursor.rowcount
except Exception as e:
logger.error(f"{self.log_prefix} 执行原始SQL查询出错: {e}")
traceback.print_exc()
return None
async def db_save(
self,
model_class: Type[Model],
data: Dict[str, Any],
key_field: str = None,
key_value: Any = None
self, model_class: Type[Model], data: Dict[str, Any], key_field: str = None, key_value: Any = None
) -> Union[Dict[str, Any], None]:
"""保存数据到数据库(创建或更新)
如果提供了key_field和key_value会先尝试查找匹配的记录进行更新
如果没有找到匹配记录或未提供key_field和key_value则创建新记录。
Args:
model_class: Peewee模型类如ActionRecords, Messages等
data: 要保存的数据字典
key_field: 用于查找现有记录的字段名,例如"action_id"
key_value: 用于查找现有记录的字段值
Returns:
Dict[str, Any]: 保存后的记录数据
None: 如果操作失败
示例:
# 创建或更新一条记录
record = await self.db_save(
ActionRecords,
{
"action_id": "123",
"time": time.time(),
"action_id": "123",
"time": time.time(),
"action_name": "TestAction",
"action_done": True
},
@@ -280,58 +276,50 @@ class DatabaseAPI:
# 如果提供了key_field和key_value尝试更新现有记录
if key_field and key_value is not None:
# 查找现有记录
existing_records = list(model_class.select().where(
getattr(model_class, key_field) == key_value
).limit(1))
existing_records = list(
model_class.select().where(getattr(model_class, key_field) == key_value).limit(1)
)
if existing_records:
# 更新现有记录
existing_record = existing_records[0]
for field, value in data.items():
setattr(existing_record, field, value)
existing_record.save()
# 返回更新后的记录
updated_record = model_class.select().where(
model_class.id == existing_record.id
).dicts().get()
updated_record = model_class.select().where(model_class.id == existing_record.id).dicts().get()
return updated_record
# 如果没有找到现有记录或未提供key_field和key_value创建新记录
new_record = model_class.create(**data)
# 返回创建的记录
created_record = model_class.select().where(
model_class.id == new_record.id
).dicts().get()
created_record = model_class.select().where(model_class.id == new_record.id).dicts().get()
return created_record
except Exception as e:
logger.error(f"{self.log_prefix} 保存数据库记录出错: {e}")
traceback.print_exc()
return None
async def db_get(
self,
model_class: Type[Model],
filters: Dict[str, Any] = None,
order_by: str = None,
limit: int = None
self, model_class: Type[Model], filters: Dict[str, Any] = None, order_by: str = None, limit: int = None
) -> Union[List[Dict[str, Any]], Dict[str, Any], None]:
"""从数据库获取记录
这是db_query方法的简化版本专注于数据检索操作。
Args:
model_class: Peewee模型类
filters: 过滤条件,字段名和值的字典
order_by: 排序字段,前缀'-'表示降序,例如'-time'表示按时间降序
limit: 结果数量限制如果为1则返回单个记录而不是列表
Returns:
如果limit=1返回单个记录字典或None
否则返回记录字典列表或空列表。
示例:
# 获取单个记录
record = await self.db_get(
@@ -339,7 +327,7 @@ class DatabaseAPI:
filters={"action_id": "123"},
limit=1
)
# 获取最近10条记录
records = await self.db_get(
Messages,
@@ -351,32 +339,32 @@ class DatabaseAPI:
try:
# 构建查询
query = model_class.select()
# 应用过滤条件
if filters:
for field, value in filters.items():
query = query.where(getattr(model_class, field) == value)
# 应用排序
if order_by:
if order_by.startswith("-"):
query = query.order_by(getattr(model_class, order_by[1:]).desc())
else:
query = query.order_by(getattr(model_class, order_by))
# 应用限制
if limit:
query = query.limit(limit)
# 执行查询
results = list(query.dicts())
# 返回结果
if limit == 1:
return results[0] if results else None
return results
except Exception as e:
logger.error(f"{self.log_prefix} 获取数据库记录出错: {e}")
traceback.print_exc()
return None if limit == 1 else []
return None if limit == 1 else []

View File

@@ -1,4 +1,4 @@
from typing import Optional, List, Any, Tuple
from typing import Optional, List, Any
from src.common.logger_manager import get_logger
from src.chat.heart_flow.heartflow import heartflow
from src.chat.heart_flow.sub_heartflow import SubHeartflow, ChatState
@@ -8,16 +8,16 @@ logger = get_logger("hearflow_api")
class HearflowAPI:
"""心流API模块
提供与心流和子心流相关的操作接口
"""
async def get_sub_hearflow_by_chat_id(self, chat_id: str) -> Optional[SubHeartflow]:
"""根据chat_id获取指定的sub_hearflow实例
Args:
chat_id: 聊天ID与sub_hearflow的subheartflow_id相同
Returns:
Optional[SubHeartflow]: sub_hearflow实例如果不存在则返回None
"""
@@ -35,11 +35,10 @@ class HearflowAPI:
except Exception as e:
logger.error(f"{self.log_prefix} 获取子心流实例时出错: {e}")
return None
def get_all_sub_hearflow_ids(self) -> List[str]:
"""获取所有子心流的ID列表
Returns:
List[str]: 所有子心流的ID列表
"""
@@ -51,10 +50,10 @@ class HearflowAPI:
except Exception as e:
logger.error(f"{self.log_prefix} 获取子心流ID列表时出错: {e}")
return []
def get_all_sub_hearflows(self) -> List[SubHeartflow]:
"""获取所有子心流实例
Returns:
List[SubHeartflow]: 所有活跃的子心流实例列表
"""
@@ -66,13 +65,13 @@ class HearflowAPI:
except Exception as e:
logger.error(f"{self.log_prefix} 获取子心流实例列表时出错: {e}")
return []
async def get_sub_hearflow_chat_state(self, chat_id: str) -> Optional[ChatState]:
"""获取指定子心流的聊天状态
Args:
chat_id: 聊天ID
Returns:
Optional[ChatState]: 聊天状态如果子心流不存在则返回None
"""
@@ -84,14 +83,14 @@ class HearflowAPI:
except Exception as e:
logger.error(f"{self.log_prefix} 获取子心流聊天状态时出错: {e}")
return None
async def set_sub_hearflow_chat_state(self, chat_id: str, target_state: ChatState) -> bool:
"""设置指定子心流的聊天状态
Args:
chat_id: 聊天ID
target_state: 目标状态
Returns:
bool: 是否设置成功
"""
@@ -100,13 +99,13 @@ class HearflowAPI:
except Exception as e:
logger.error(f"{self.log_prefix} 设置子心流聊天状态时出错: {e}")
return False
async def get_sub_hearflow_replyer(self, chat_id: str) -> Optional[Any]:
"""根据chat_id获取指定子心流的replyer实例
Args:
chat_id: 聊天ID
Returns:
Optional[Any]: replyer实例如果不存在则返回None
"""
@@ -116,13 +115,13 @@ class HearflowAPI:
except Exception as e:
logger.error(f"{self.log_prefix} 获取子心流replyer时出错: {e}")
return None
async def get_sub_hearflow_expressor(self, chat_id: str) -> Optional[Any]:
"""根据chat_id获取指定子心流的expressor实例
Args:
chat_id: 聊天ID
Returns:
Optional[Any]: expressor实例如果不存在则返回None
"""
@@ -131,4 +130,4 @@ class HearflowAPI:
return expressor
except Exception as e:
logger.error(f"{self.log_prefix} 获取子心流expressor时出错: {e}")
return None
return None

View File

@@ -5,12 +5,13 @@ from src.config.config import global_config
logger = get_logger("llm_api")
class LLMAPI:
"""LLM API模块
提供了与LLM模型交互的功能
"""
def get_available_models(self) -> Dict[str, Any]:
"""获取所有可用的模型配置
@@ -20,17 +21,13 @@ class LLMAPI:
if not hasattr(global_config, "model"):
logger.error(f"{self.log_prefix} 无法获取模型列表:全局配置中未找到 model 配置")
return {}
models = global_config.model
return models
async def generate_with_model(
self,
prompt: str,
model_config: Dict[str, Any],
request_type: str = "plugin.generate",
**kwargs
self, prompt: str, model_config: Dict[str, Any], request_type: str = "plugin.generate", **kwargs
) -> Tuple[bool, str, str, str]:
"""使用指定模型生成内容
@@ -45,17 +42,13 @@ class LLMAPI:
"""
try:
logger.info(f"{self.log_prefix} 使用模型生成内容,提示词: {prompt[:100]}...")
llm_request = LLMRequest(
model=model_config,
request_type=request_type,
**kwargs
)
llm_request = LLMRequest(model=model_config, request_type=request_type, **kwargs)
response, (reasoning, model_name) = await llm_request.generate_response_async(prompt)
return True, response, reasoning, model_name
except Exception as e:
error_msg = f"生成内容时出错: {str(e)}"
logger.error(f"{self.log_prefix} {error_msg}")
return False, error_msg, "", ""
return False, error_msg, "", ""

View File

@@ -14,17 +14,18 @@ from src.chat.focus_chat.info.obs_info import ObsInfo
# 新增导入
from src.chat.focus_chat.heartFC_sender import HeartFCSender
from src.chat.message_receive.message import MessageSending
from maim_message import Seg, UserInfo, GroupInfo
from maim_message import Seg, UserInfo
from src.config.config import global_config
logger = get_logger("message_api")
class MessageAPI:
"""消息API模块
提供了发送消息、获取消息历史等功能
"""
async def send_message_to_target(
self,
message_type: str,
@@ -35,7 +36,7 @@ class MessageAPI:
display_message: str = "",
) -> bool:
"""直接向指定目标发送消息
Args:
message_type: 消息类型,如"text""image""emoji"
content: 消息内容
@@ -43,7 +44,7 @@ class MessageAPI:
target_id: 目标ID群ID或用户ID
is_group: 是否为群聊True为群聊False为私聊
display_message: 显示消息(可选)
Returns:
bool: 是否发送成功
"""
@@ -53,12 +54,14 @@ class MessageAPI:
# 群聊:从数据库查找对应的聊天流
target_stream = None
for stream_id, stream in chat_manager.streams.items():
if (stream.group_info and
str(stream.group_info.group_id) == str(target_id) and
stream.platform == platform):
if (
stream.group_info
and str(stream.group_info.group_id) == str(target_id)
and stream.platform == platform
):
target_stream = stream
break
if not target_stream:
logger.error(f"{getattr(self, 'log_prefix', '')} 未找到群ID为 {target_id} 的聊天流")
return False
@@ -66,39 +69,39 @@ class MessageAPI:
# 私聊:从数据库查找对应的聊天流
target_stream = None
for stream_id, stream in chat_manager.streams.items():
if (not stream.group_info and
str(stream.user_info.user_id) == str(target_id) and
stream.platform == platform):
if (
not stream.group_info
and str(stream.user_info.user_id) == str(target_id)
and stream.platform == platform
):
target_stream = stream
break
if not target_stream:
logger.error(f"{getattr(self, 'log_prefix', '')} 未找到用户ID为 {target_id} 的私聊流")
return False
# 创建HeartFCSender实例
heart_fc_sender = HeartFCSender()
# 生成消息ID和thinking_id
current_time = time.time()
message_id = f"plugin_msg_{int(current_time * 1000)}"
thinking_id = f"plugin_thinking_{int(current_time * 1000)}"
# 构建机器人用户信息
bot_user_info = UserInfo(
user_id=global_config.bot.qq_account,
user_nickname=global_config.bot.nickname,
platform=platform,
)
# 创建消息段
message_segment = Seg(type=message_type, data=content)
# 创建空锚点消息(用于回复)
anchor_message = await create_empty_anchor_message(
platform, target_stream.group_info, target_stream
)
anchor_message = await create_empty_anchor_message(platform, target_stream.group_info, target_stream)
# 构建发送消息对象
bot_message = MessageSending(
message_id=message_id,
@@ -112,22 +115,17 @@ class MessageAPI:
is_emoji=(message_type == "emoji"),
thinking_start_time=current_time,
)
# 发送消息
sent_msg = await heart_fc_sender.send_message(
bot_message,
has_thinking=True,
typing=False,
set_reply=False
)
sent_msg = await heart_fc_sender.send_message(bot_message, has_thinking=True, typing=False, set_reply=False)
if sent_msg:
logger.info(f"{getattr(self, 'log_prefix', '')} 成功发送消息到 {platform}:{target_id}")
return True
else:
logger.error(f"{getattr(self, 'log_prefix', '')} 发送消息失败")
return False
except Exception as e:
logger.error(f"{getattr(self, 'log_prefix', '')} 向目标发送消息时出错: {e}")
traceback.print_exc()
@@ -135,42 +133,34 @@ class MessageAPI:
async def send_text_to_group(self, text: str, group_id: str, platform: str = "qq") -> bool:
"""便捷方法:向指定群聊发送文本消息
Args:
text: 要发送的文本内容
group_id: 群聊ID
platform: 平台,默认为"qq"
Returns:
bool: 是否发送成功
"""
return await self.send_message_to_target(
message_type="text",
content=text,
platform=platform,
target_id=group_id,
is_group=True
message_type="text", content=text, platform=platform, target_id=group_id, is_group=True
)
async def send_text_to_user(self, text: str, user_id: str, platform: str = "qq") -> bool:
"""便捷方法:向指定用户发送私聊文本消息
Args:
text: 要发送的文本内容
user_id: 用户ID
platform: 平台,默认为"qq"
Returns:
bool: 是否发送成功
"""
return await self.send_message_to_target(
message_type="text",
content=text,
platform=platform,
target_id=user_id,
is_group=False
message_type="text", content=text, platform=platform, target_id=user_id, is_group=False
)
async def send_message(self, type: str, data: str, target: Optional[str] = "", display_message: str = "") -> bool:
"""发送消息的简化方法
@@ -288,7 +278,9 @@ class MessageAPI:
return success
async def send_message_by_replyer(self, target: Optional[str] = None, extra_info_block: Optional[str] = None) -> bool:
async def send_message_by_replyer(
self, target: Optional[str] = None, extra_info_block: Optional[str] = None
) -> bool:
"""通过replyer发送消息的简化方法
Args:
@@ -381,4 +373,4 @@ class MessageAPI:
}
messages.append(simple_msg)
return messages
return messages

View File

@@ -1,147 +1,142 @@
import hashlib
from typing import Optional, List, Dict, Any
from src.common.logger_manager import get_logger
from src.chat.message_receive.chat_stream import ChatManager, ChatStream
from maim_message import GroupInfo, UserInfo
logger = get_logger("stream_api")
class StreamAPI:
"""聊天流API模块
提供了获取聊天流、通过群ID查找聊天流等功能
"""
def get_chat_stream_by_group_id(self, group_id: str, platform: str = "qq") -> Optional[ChatStream]:
"""通过QQ群ID获取聊天流
Args:
group_id: QQ群ID
platform: 平台标识,默认为"qq"
Returns:
Optional[ChatStream]: 找到的聊天流对象如果未找到则返回None
"""
try:
chat_manager = ChatManager()
# 遍历所有已加载的聊天流查找匹配的群ID
for stream_id, stream in chat_manager.streams.items():
if (stream.group_info and
str(stream.group_info.group_id) == str(group_id) and
stream.platform == platform):
if (
stream.group_info
and str(stream.group_info.group_id) == str(group_id)
and stream.platform == platform
):
logger.info(f"{self.log_prefix} 通过群ID {group_id} 找到聊天流: {stream_id}")
return stream
logger.warning(f"{self.log_prefix} 未找到群ID为 {group_id} 的聊天流")
return None
except Exception as e:
logger.error(f"{self.log_prefix} 通过群ID获取聊天流时出错: {e}")
return None
def get_all_group_chat_streams(self, platform: str = "qq") -> List[ChatStream]:
"""获取所有群聊的聊天流
Args:
platform: 平台标识,默认为"qq"
Returns:
List[ChatStream]: 所有群聊的聊天流列表
"""
try:
chat_manager = ChatManager()
group_streams = []
for stream in chat_manager.streams.values():
if (stream.group_info and
stream.platform == platform):
if stream.group_info and stream.platform == platform:
group_streams.append(stream)
logger.info(f"{self.log_prefix} 找到 {len(group_streams)} 个群聊聊天流")
return group_streams
except Exception as e:
logger.error(f"{self.log_prefix} 获取所有群聊聊天流时出错: {e}")
return []
def get_chat_stream_by_user_id(self, user_id: str, platform: str = "qq") -> Optional[ChatStream]:
"""通过用户ID获取私聊聊天流
Args:
user_id: 用户ID
platform: 平台标识,默认为"qq"
Returns:
Optional[ChatStream]: 找到的私聊聊天流对象如果未找到则返回None
"""
try:
chat_manager = ChatManager()
# 遍历所有已加载的聊天流查找匹配的用户ID私聊
for stream_id, stream in chat_manager.streams.items():
if (not stream.group_info and # 私聊没有群信息
stream.user_info and
str(stream.user_info.user_id) == str(user_id) and
stream.platform == platform):
if (
not stream.group_info # 私聊没有群信息
and stream.user_info
and str(stream.user_info.user_id) == str(user_id)
and stream.platform == platform
):
logger.info(f"{self.log_prefix} 通过用户ID {user_id} 找到私聊聊天流: {stream_id}")
return stream
logger.warning(f"{self.log_prefix} 未找到用户ID为 {user_id} 的私聊聊天流")
return None
except Exception as e:
logger.error(f"{self.log_prefix} 通过用户ID获取私聊聊天流时出错: {e}")
return None
def get_chat_streams_info(self) -> List[Dict[str, Any]]:
"""获取所有聊天流的基本信息
Returns:
List[Dict[str, Any]]: 包含聊天流基本信息的字典列表
"""
try:
chat_manager = ChatManager()
streams_info = []
for stream_id, stream in chat_manager.streams.items():
info = {
"stream_id": stream_id,
"platform": stream.platform,
"chat_type": "group" if stream.group_info else "private",
"create_time": stream.create_time,
"last_active_time": stream.last_active_time
"last_active_time": stream.last_active_time,
}
if stream.group_info:
info.update({
"group_id": stream.group_info.group_id,
"group_name": stream.group_info.group_name
})
info.update({"group_id": stream.group_info.group_id, "group_name": stream.group_info.group_name})
if stream.user_info:
info.update({
"user_id": stream.user_info.user_id,
"user_nickname": stream.user_info.user_nickname
})
info.update({"user_id": stream.user_info.user_id, "user_nickname": stream.user_info.user_nickname})
streams_info.append(info)
logger.info(f"{self.log_prefix} 获取到 {len(streams_info)} 个聊天流信息")
return streams_info
except Exception as e:
logger.error(f"{self.log_prefix} 获取聊天流信息时出错: {e}")
return []
async def get_chat_stream_by_group_id_async(self, group_id: str, platform: str = "qq") -> Optional[ChatStream]:
"""异步通过QQ群ID获取聊天流包括从数据库搜索
Args:
group_id: QQ群ID
platform: 平台标识,默认为"qq"
Returns:
Optional[ChatStream]: 找到的聊天流对象如果未找到则返回None
"""
@@ -150,15 +145,15 @@ class StreamAPI:
stream = self.get_chat_stream_by_group_id(group_id, platform)
if stream:
return stream
# 如果内存中没有,尝试从数据库加载所有聊天流后再查找
chat_manager = ChatManager()
await chat_manager.load_all_streams()
# 再次尝试从内存中查找
stream = self.get_chat_stream_by_group_id(group_id, platform)
return stream
except Exception as e:
logger.error(f"{self.log_prefix} 异步通过群ID获取聊天流时出错: {e}")
return None
return None

View File

@@ -1,35 +1,37 @@
import os
import json
import time
from typing import Any, Dict, List, Optional
from typing import Any, Optional
from src.common.logger_manager import get_logger
logger = get_logger("utils_api")
class UtilsAPI:
"""工具类API模块
提供了各种辅助功能
"""
def get_plugin_path(self) -> str:
"""获取当前插件的路径
Returns:
str: 插件目录的绝对路径
"""
import inspect
plugin_module_path = inspect.getfile(self.__class__)
plugin_dir = os.path.dirname(plugin_module_path)
return plugin_dir
def read_json_file(self, file_path: str, default: Any = None) -> Any:
"""读取JSON文件
Args:
file_path: 文件路径,可以是相对于插件目录的路径
default: 如果文件不存在或读取失败时返回的默认值
Returns:
Any: JSON数据或默认值
"""
@@ -37,25 +39,25 @@ class UtilsAPI:
# 如果是相对路径,则相对于插件目录
if not os.path.isabs(file_path):
file_path = os.path.join(self.get_plugin_path(), file_path)
if not os.path.exists(file_path):
logger.warning(f"{self.log_prefix} 文件不存在: {file_path}")
return default
with open(file_path, 'r', encoding='utf-8') as f:
with open(file_path, "r", encoding="utf-8") as f:
return json.load(f)
except Exception as e:
logger.error(f"{self.log_prefix} 读取JSON文件出错: {e}")
return default
def write_json_file(self, file_path: str, data: Any, indent: int = 2) -> bool:
"""写入JSON文件
Args:
file_path: 文件路径,可以是相对于插件目录的路径
data: 要写入的数据
indent: JSON缩进
Returns:
bool: 是否写入成功
"""
@@ -63,59 +65,62 @@ class UtilsAPI:
# 如果是相对路径,则相对于插件目录
if not os.path.isabs(file_path):
file_path = os.path.join(self.get_plugin_path(), file_path)
# 确保目录存在
os.makedirs(os.path.dirname(file_path), exist_ok=True)
with open(file_path, 'w', encoding='utf-8') as f:
with open(file_path, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=indent)
return True
except Exception as e:
logger.error(f"{self.log_prefix} 写入JSON文件出错: {e}")
return False
def get_timestamp(self) -> int:
"""获取当前时间戳
Returns:
int: 当前时间戳(秒)
"""
return int(time.time())
def format_time(self, timestamp: Optional[int] = None, format_str: str = "%Y-%m-%d %H:%M:%S") -> str:
"""格式化时间
Args:
timestamp: 时间戳如果为None则使用当前时间
format_str: 时间格式字符串
Returns:
str: 格式化后的时间字符串
"""
import datetime
if timestamp is None:
timestamp = time.time()
return datetime.datetime.fromtimestamp(timestamp).strftime(format_str)
def parse_time(self, time_str: str, format_str: str = "%Y-%m-%d %H:%M:%S") -> int:
"""解析时间字符串为时间戳
Args:
time_str: 时间字符串
format_str: 时间格式字符串
Returns:
int: 时间戳(秒)
"""
import datetime
dt = datetime.datetime.strptime(time_str, format_str)
return int(dt.timestamp())
def generate_unique_id(self) -> str:
"""生成唯一ID
Returns:
str: 唯一ID
"""
import uuid
return str(uuid.uuid4())
return str(uuid.uuid4())

View File

@@ -3,7 +3,6 @@ from abc import ABC, abstractmethod
from typing import Dict, List, Type, Optional, Tuple, Pattern
from src.common.logger_manager import get_logger
from src.chat.message_receive.message import MessageRecv
from src.chat.actions.plugin_api.message_api import MessageAPI
from src.chat.focus_chat.hfc_utils import create_empty_anchor_message
from src.chat.focus_chat.expressors.default_expressor import DefaultExpressor
@@ -13,9 +12,10 @@ logger = get_logger("command_handler")
_COMMAND_REGISTRY: Dict[str, Type["BaseCommand"]] = {}
_COMMAND_PATTERNS: Dict[Pattern, Type["BaseCommand"]] = {}
class BaseCommand(ABC):
"""命令基类,所有自定义命令都应该继承这个类"""
# 命令的基本属性
command_name: str = "" # 命令名称
command_description: str = "" # 命令描述
@@ -23,43 +23,43 @@ class BaseCommand(ABC):
command_help: str = "" # 命令帮助信息
command_examples: List[str] = [] # 命令使用示例
enable_command: bool = True # 是否启用命令
def __init__(self, message: MessageRecv):
"""初始化命令处理器
Args:
message: 接收到的消息对象
"""
self.message = message
self.matched_groups: Dict[str, str] = {} # 存储正则表达式匹配的命名组
self._services = {} # 存储内部服务
# 设置服务
self._services["chat_stream"] = message.chat_stream
# 日志前缀
self.log_prefix = f"[Command:{self.command_name}]"
@abstractmethod
async def execute(self) -> Tuple[bool, Optional[str]]:
"""执行命令的抽象方法,需要被子类实现
Returns:
Tuple[bool, Optional[str]]: (是否执行成功, 可选的回复消息)
"""
pass
def set_matched_groups(self, groups: Dict[str, str]) -> None:
"""设置正则表达式匹配的命名组
Args:
groups: 正则表达式匹配的命名组
"""
self.matched_groups = groups
async def send_reply(self, content: str) -> None:
"""发送回复消息
Args:
content: 回复内容
"""
@@ -69,43 +69,42 @@ class BaseCommand(ABC):
if not chat_stream:
logger.error(f"{self.log_prefix} 无法发送消息缺少chat_stream")
return
# 创建空的锚定消息
anchor_message = await create_empty_anchor_message(
chat_stream.platform,
chat_stream.group_info,
chat_stream
chat_stream.platform, chat_stream.group_info, chat_stream
)
# 创建表达器传入chat_stream参数
expressor = DefaultExpressor(chat_stream)
# 设置服务
self._services["expressor"] = expressor
# 发送消息
response_set = [
("text", content),
]
# 调用表达器发送消息
await expressor.send_response_messages(
anchor_message=anchor_message,
response_set=response_set,
display_message="",
)
logger.info(f"{self.log_prefix} 命令回复消息发送成功: {content[:30]}...")
except Exception as e:
logger.error(f"{self.log_prefix} 发送命令回复消息失败: {e}")
import traceback
logger.error(traceback.format_exc())
def register_command(cls):
"""
命令注册装饰器
用法:
@register_command
class MyCommand(BaseCommand):
@@ -115,21 +114,25 @@ def register_command(cls):
...
"""
# 检查类是否有必要的属性
if not hasattr(cls, "command_name") or not hasattr(cls, "command_description") or not hasattr(cls, "command_pattern"):
if (
not hasattr(cls, "command_name")
or not hasattr(cls, "command_description")
or not hasattr(cls, "command_pattern")
):
logger.error(f"命令类 {cls.__name__} 缺少必要的属性: command_name, command_description 或 command_pattern")
return cls
command_name = cls.command_name
command_pattern = cls.command_pattern
is_enabled = getattr(cls, "enable_command", True) # 默认启用命令
if not command_name or not command_pattern:
logger.error(f"命令类 {cls.__name__} 的 command_name 或 command_pattern 为空")
return cls
# 将命令类注册到全局注册表
_COMMAND_REGISTRY[command_name] = cls
# 编译正则表达式并注册
try:
pattern = re.compile(command_pattern, re.IGNORECASE | re.DOTALL)
@@ -137,47 +140,47 @@ def register_command(cls):
logger.info(f"已注册命令: {command_name} -> {cls.__name__},命令启用: {is_enabled}")
except re.error as e:
logger.error(f"命令 {command_name} 的正则表达式编译失败: {e}")
return cls
class CommandManager:
"""命令管理器,负责处理命令(不再负责加载,加载由统一的插件加载器处理)"""
def __init__(self):
"""初始化命令管理器"""
# 命令加载现在由统一的插件加载器处理,这里只需要初始化
logger.info("命令管理器初始化完成")
async def process_command(self, message: MessageRecv) -> Tuple[bool, Optional[str], bool]:
"""处理消息中的命令
Args:
message: 接收到的消息对象
Returns:
Tuple[bool, Optional[str], bool]: (是否找到并执行了命令, 命令执行结果, 是否继续处理消息)
"""
if not message.processed_plain_text:
await message.process()
text = message.processed_plain_text
# 检查是否匹配任何命令模式
for pattern, command_cls in _COMMAND_PATTERNS.items():
match = pattern.match(text)
if match and getattr(command_cls, "enable_command", True):
# 创建命令实例
command_instance = command_cls(message)
# 提取命名组并设置
groups = match.groupdict()
command_instance.set_matched_groups(groups)
try:
# 执行命令
success, response = await command_instance.execute()
# 记录命令执行结果
if success:
logger.info(f"命令 {command_cls.command_name} 执行成功")
@@ -189,27 +192,28 @@ class CommandManager:
if response:
# 使用命令实例的send_reply方法发送错误信息
await command_instance.send_reply(f"命令执行失败: {response}")
# 命令执行后不再继续处理消息
return True, response, False
except Exception as e:
logger.error(f"执行命令 {command_cls.command_name} 时出错: {e}")
import traceback
logger.error(traceback.format_exc())
try:
# 使用命令实例的send_reply方法发送错误信息
await command_instance.send_reply(f"命令执行出错: {str(e)}")
except Exception as send_error:
logger.error(f"发送错误消息失败: {send_error}")
# 命令执行出错后不再继续处理消息
return True, str(e), False
# 没有匹配到任何命令,继续处理消息
return False, None, True
# 创建全局命令管理器实例
command_manager = CommandManager()
command_manager = CommandManager()

View File

@@ -227,8 +227,6 @@ class DefaultExpressor:
logger.info(f"想要表达:{in_mind_reply}||理由:{reason}")
logger.info(f"最终回复: {content}\n")
except Exception as llm_e:
# 精简报错信息
logger.error(f"{self.log_prefix}LLM 生成失败: {llm_e}")

View File

@@ -113,25 +113,25 @@ class ExpressionLearner:
同时对所有已存储的表达方式进行全局衰减
"""
current_time = time.time()
# 全局衰减所有已存储的表达方式
for type in ["style", "grammar"]:
base_dir = os.path.join("data", "expression", f"learnt_{type}")
if not os.path.exists(base_dir):
continue
for chat_id in os.listdir(base_dir):
file_path = os.path.join(base_dir, chat_id, "expressions.json")
if not os.path.exists(file_path):
continue
try:
with open(file_path, "r", encoding="utf-8") as f:
expressions = json.load(f)
# 应用全局衰减
decayed_expressions = self.apply_decay_to_expressions(expressions, current_time)
# 保存衰减后的结果
with open(file_path, "w", encoding="utf-8") as f:
json.dump(decayed_expressions, f, ensure_ascii=False, indent=2)
@@ -162,23 +162,25 @@ class ExpressionLearner:
"""
if time_diff_days <= 0 or time_diff_days >= DECAY_DAYS:
return 0.001
# 使用二次函数进行插值
# 将7天作为顶点0天和30天作为两个端点
# 使用顶点式y = a(x-h)^2 + k其中(h,k)为顶点
h = 7.0 # 顶点x坐标
k = 0.001 # 顶点y坐标
# 计算a值使得x=0和x=30时y=0.001
# 0.001 = a(0-7)^2 + 0.001
# 解得a = 0
a = 0
# 计算衰减值
decay = a * (time_diff_days - h) ** 2 + k
return min(0.001, decay)
def apply_decay_to_expressions(self, expressions: List[Dict[str, Any]], current_time: float) -> List[Dict[str, Any]]:
def apply_decay_to_expressions(
self, expressions: List[Dict[str, Any]], current_time: float
) -> List[Dict[str, Any]]:
"""
对表达式列表应用衰减
返回衰减后的表达式列表移除count小于0的项
@@ -188,16 +190,16 @@ class ExpressionLearner:
# 确保last_active_time存在如果不存在则使用current_time
if "last_active_time" not in expr:
expr["last_active_time"] = current_time
last_active = expr["last_active_time"]
time_diff_days = (current_time - last_active) / (24 * 3600) # 转换为天
decay_value = self.calculate_decay_factor(time_diff_days)
expr["count"] = max(0.01, expr.get("count", 1) - decay_value)
if expr["count"] > 0:
result.append(expr)
return result
async def learn_and_store(self, type: str, num: int = 10) -> List[Tuple[str, str, str]]:
@@ -211,7 +213,7 @@ class ExpressionLearner:
type_str = "句法特点"
else:
raise ValueError(f"Invalid type: {type}")
res = await self.learn_expression(type, num)
if res is None:
@@ -238,15 +240,15 @@ class ExpressionLearner:
if chat_id not in chat_dict:
chat_dict[chat_id] = []
chat_dict[chat_id].append({"situation": situation, "style": style})
current_time = time.time()
# 存储到/data/expression/对应chat_id/expressions.json
for chat_id, expr_list in chat_dict.items():
dir_path = os.path.join("data", "expression", f"learnt_{type}", str(chat_id))
os.makedirs(dir_path, exist_ok=True)
file_path = os.path.join(dir_path, "expressions.json")
# 若已存在,先读出合并
old_data: List[Dict[str, Any]] = []
if os.path.exists(file_path):
@@ -255,10 +257,10 @@ class ExpressionLearner:
old_data = json.load(f)
except Exception:
old_data = []
# 应用衰减
# old_data = self.apply_decay_to_expressions(old_data, current_time)
# 合并逻辑
for new_expr in expr_list:
found = False
@@ -278,43 +280,43 @@ class ExpressionLearner:
new_expr["count"] = 1
new_expr["last_active_time"] = current_time
old_data.append(new_expr)
# 处理超限问题
if len(old_data) > MAX_EXPRESSION_COUNT:
# 计算每个表达方式的权重count的倒数这样count越小的越容易被选中
weights = [1 / (expr.get("count", 1) + 0.1) for expr in old_data]
# 随机选择要移除的表达方式,避免重复索引
remove_count = len(old_data) - MAX_EXPRESSION_COUNT
# 使用一种不会选到重复索引的方法
indices = list(range(len(old_data)))
# 方法1使用numpy.random.choice
# 把列表转成一个映射字典,保证不会有重复
remove_set = set()
total_attempts = 0
# 尝试按权重随机选择,直到选够数量
while len(remove_set) < remove_count and total_attempts < len(old_data) * 2:
idx = random.choices(indices, weights=weights, k=1)[0]
remove_set.add(idx)
total_attempts += 1
# 如果没选够,随机补充
if len(remove_set) < remove_count:
remaining = set(indices) - remove_set
remove_set.update(random.sample(list(remaining), remove_count - len(remove_set)))
remove_indices = list(remove_set)
# 从后往前删除,避免索引变化
for idx in sorted(remove_indices, reverse=True):
old_data.pop(idx)
with open(file_path, "w", encoding="utf-8") as f:
json.dump(old_data, f, ensure_ascii=False, indent=2)
return learnt_expressions
async def learn_expression(self, type: str, num: int = 10) -> Optional[Tuple[List[Tuple[str, str, str]], str]]:

View File

@@ -97,7 +97,7 @@ class CycleDetail:
)
# current_time_minute = time.strftime("%Y%m%d_%H%M", time.localtime())
# try:
# self.log_cycle_to_file(
# log_dir + self.prefix + f"/{current_time_minute}_cycle_" + str(self.cycle_id) + ".json"
@@ -117,7 +117,6 @@ class CycleDetail:
if dir_name and not os.path.exists(dir_name):
os.makedirs(dir_name, exist_ok=True)
# 写入文件
file_path = os.path.join(dir_name, os.path.basename(file_path))
# print("file_path:", file_path)

View File

@@ -99,22 +99,23 @@ class HeartFChatting:
self.stream_id: str = chat_id # 聊天流ID
self.chat_stream = chat_manager.get_stream(self.stream_id)
self.log_prefix = f"[{chat_manager.get_stream_name(self.stream_id) or self.stream_id}]"
self.memory_activator = MemoryActivator()
# 初始化观察器
self.observations: List[Observation] = []
self._register_observations()
# 根据配置文件和默认规则确定启用的处理器
config_processor_settings = global_config.focus_chat_processor
self.enabled_processor_names = []
for proc_name, (_proc_class, config_key) in PROCESSOR_CLASSES.items():
# 对于关系处理器,需要同时检查两个配置项
if proc_name == "RelationshipProcessor":
if (global_config.relationship.enable_relationship and
getattr(config_processor_settings, config_key, True)):
if global_config.relationship.enable_relationship and getattr(
config_processor_settings, config_key, True
):
self.enabled_processor_names.append(proc_name)
else:
# 其他处理器的原有逻辑
@@ -122,14 +123,13 @@ class HeartFChatting:
self.enabled_processor_names.append(proc_name)
# logger.info(f"{self.log_prefix} 将启用的处理器: {self.enabled_processor_names}")
self.processors: List[BaseProcessor] = []
self._register_default_processors()
self.expressor = DefaultExpressor(chat_stream=self.chat_stream)
self.replyer = DefaultReplyer(chat_stream=self.chat_stream)
self.action_manager = ActionManager()
self.action_planner = PlannerFactory.create_planner(
log_prefix=self.log_prefix, action_manager=self.action_manager
@@ -138,7 +138,6 @@ class HeartFChatting:
self.action_observation = ActionObservation(observe_id=self.stream_id)
self.action_observation.set_action_manager(self.action_manager)
self._processing_lock = asyncio.Lock()
# 循环控制内部状态
@@ -182,7 +181,13 @@ class HeartFChatting:
if processor_info:
processor_actual_class = processor_info[0] # 获取实际的类定义
# 根据处理器类名判断是否需要 subheartflow_id
if name in ["MindProcessor", "ToolProcessor", "WorkingMemoryProcessor", "SelfProcessor", "RelationshipProcessor"]:
if name in [
"MindProcessor",
"ToolProcessor",
"WorkingMemoryProcessor",
"SelfProcessor",
"RelationshipProcessor",
]:
self.processors.append(processor_actual_class(subheartflow_id=self.stream_id))
elif name == "ChattingInfoProcessor":
self.processors.append(processor_actual_class())
@@ -203,9 +208,7 @@ class HeartFChatting:
)
if self.processors:
logger.info(
f"{self.log_prefix} 已注册处理器: {[p.__class__.__name__ for p in self.processors]}"
)
logger.info(f"{self.log_prefix} 已注册处理器: {[p.__class__.__name__ for p in self.processors]}")
else:
logger.warning(f"{self.log_prefix} 没有注册任何处理器。这可能是由于配置错误或所有处理器都被禁用了。")
@@ -292,7 +295,9 @@ class HeartFChatting:
self._current_cycle_detail.set_loop_info(loop_info)
# 从observations列表中获取HFCloopObservation
hfcloop_observation = next((obs for obs in self.observations if isinstance(obs, HFCloopObservation)), None)
hfcloop_observation = next(
(obs for obs in self.observations if isinstance(obs, HFCloopObservation)), None
)
if hfcloop_observation:
hfcloop_observation.add_loop_info(self._current_cycle_detail)
else:
@@ -451,19 +456,19 @@ class HeartFChatting:
# 根据配置决定是否并行执行调整动作、回忆和处理器阶段
# 并行执行调整动作、回忆和处理器阶段
# 并行执行调整动作、回忆和处理器阶段
with Timer("并行调整动作、处理", cycle_timers):
# 创建并行任务
async def modify_actions_task():
async def modify_actions_task():
# 调用完整的动作修改流程
await self.action_modifier.modify_actions(
observations=self.observations,
)
await self.action_observation.observe()
self.observations.append(self.action_observation)
return True
# 创建三个并行任务
action_modify_task = asyncio.create_task(modify_actions_task())
memory_task = asyncio.create_task(self.memory_activator.activate_memory(self.observations))
@@ -474,9 +479,6 @@ class HeartFChatting:
action_modify_task, memory_task, processor_task
)
loop_processor_info = {
"all_plan_info": all_plan_info,
"processor_time_costs": processor_time_costs,
@@ -594,9 +596,7 @@ class HeartFChatting:
else:
success, reply_text = result
command = ""
logger.debug(
f"{self.log_prefix} 麦麦执行了'{action}', 返回结果'{success}', '{reply_text}', '{command}'"
)
logger.debug(f"{self.log_prefix} 麦麦执行了'{action}', 返回结果'{success}', '{reply_text}', '{command}'")
return success, reply_text, command

View File

@@ -51,8 +51,8 @@ async def _process_relationship(message: MessageRecv) -> None:
logger.info(f"首次认识用户: {nickname}")
await relationship_manager.first_knowing_some_one(platform, user_id, nickname, cardname)
# elif not await relationship_manager.is_qved_name(platform, user_id):
# logger.info(f"给用户({nickname},{cardname})取名: {nickname}")
# await relationship_manager.first_knowing_some_one(platform, user_id, nickname, cardname, "")
# logger.info(f"给用户({nickname},{cardname})取名: {nickname}")
# await relationship_manager.first_knowing_some_one(platform, user_id, nickname, cardname, "")
async def _calculate_interest(message: MessageRecv) -> Tuple[float, bool]:
@@ -74,7 +74,7 @@ async def _calculate_interest(message: MessageRecv) -> Tuple[float, bool]:
fast_retrieval=True,
)
logger.trace(f"记忆激活率: {interested_rate:.2f}")
text_len = len(message.processed_plain_text)
# 根据文本长度调整兴趣度长度越大兴趣度越高但增长率递减最低0.01最高0.05
# 采用对数函数实现递减增长
@@ -181,7 +181,6 @@ class HeartFCMessageReceiver:
userinfo = message.message_info.user_info
messageinfo = message.message_info
chat = await chat_manager.get_or_create_stream(
platform=messageinfo.platform,
user_info=userinfo,

View File

@@ -11,7 +11,6 @@ from datetime import datetime
from typing import Dict
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config
import asyncio
logger = get_logger("processor")

View File

@@ -22,7 +22,7 @@ from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_ch
# 配置常量:是否启用小模型即时信息提取
# 开启时:使用小模型并行即时提取,速度更快,但精度可能略低
# 关闭时:使用原来的异步模式,精度更高但速度较慢
ENABLE_INSTANT_INFO_EXTRACTION = True
ENABLE_INSTANT_INFO_EXTRACTION = True
logger = get_logger("processor")
@@ -63,7 +63,7 @@ def init_prompt():
"""
Prompt(relationship_prompt, "relationship_prompt")
fetch_info_prompt = """
{name_block}
@@ -84,7 +84,6 @@ def init_prompt():
Prompt(fetch_info_prompt, "fetch_info_prompt")
class RelationshipProcessor(BaseProcessor):
log_prefix = "关系"
@@ -92,8 +91,10 @@ class RelationshipProcessor(BaseProcessor):
super().__init__()
self.subheartflow_id = subheartflow_id
self.info_fetching_cache: List[Dict[str, any]] = []
self.info_fetched_cache: Dict[str, Dict[str, any]] = {} # {person_id: {"info": str, "ttl": int, "start_time": float}}
self.info_fetching_cache: List[Dict[str, any]] = []
self.info_fetched_cache: Dict[
str, Dict[str, any]
] = {} # {person_id: {"info": str, "ttl": int, "start_time": float}}
self.person_engaged_cache: List[Dict[str, any]] = [] # [{person_id: str, start_time: float, rounds: int}]
self.grace_period_rounds = 5
@@ -101,7 +102,7 @@ class RelationshipProcessor(BaseProcessor):
model=global_config.model.relation,
request_type="focus.relationship",
)
# 小模型用于即时信息提取
if ENABLE_INSTANT_INFO_EXTRACTION:
self.instant_llm_model = LLMRequest(
@@ -156,26 +157,27 @@ class RelationshipProcessor(BaseProcessor):
for record in list(self.person_engaged_cache):
record["rounds"] += 1
time_elapsed = current_time - record["start_time"]
message_count = len(get_raw_msg_by_timestamp_with_chat(self.subheartflow_id, record["start_time"], current_time))
message_count = len(
get_raw_msg_by_timestamp_with_chat(self.subheartflow_id, record["start_time"], current_time)
)
print(record)
# 根据消息数量和时间设置不同的触发条件
should_trigger = (
message_count >= 50 or # 50条消息必定满足
(message_count >= 35 and time_elapsed >= 300) or # 35条且10分钟
(message_count >= 25 and time_elapsed >= 900) or # 25条且30分钟
(message_count >= 10 and time_elapsed >= 2000) # 10条且1小时
message_count >= 50 # 50条消息必定满足
or (message_count >= 35 and time_elapsed >= 300) # 35条且10分钟
or (message_count >= 25 and time_elapsed >= 900) # 25条且30分钟
or (message_count >= 10 and time_elapsed >= 2000) # 10条且1小时
)
if should_trigger:
logger.info(f"{self.log_prefix} 用户 {record['person_id']} 满足关系构建条件,开始构建关系。消息数:{message_count},时长:{time_elapsed:.0f}")
logger.info(
f"{self.log_prefix} 用户 {record['person_id']} 满足关系构建条件,开始构建关系。消息数:{message_count},时长:{time_elapsed:.0f}"
)
asyncio.create_task(
self.update_impression_on_cache_expiry(
record["person_id"],
self.subheartflow_id,
record["start_time"],
current_time
record["person_id"], self.subheartflow_id, record["start_time"], current_time
)
)
self.person_engaged_cache.remove(record)
@@ -187,20 +189,24 @@ class RelationshipProcessor(BaseProcessor):
if self.info_fetched_cache[person_id][info_type]["ttl"] <= 0:
# 在删除前查找匹配的info_fetching_cache记录
matched_record = None
min_time_diff = float('inf')
min_time_diff = float("inf")
for record in self.info_fetching_cache:
if (record["person_id"] == person_id and
record["info_type"] == info_type and
not record["forget"]):
time_diff = abs(record["start_time"] - self.info_fetched_cache[person_id][info_type]["start_time"])
if (
record["person_id"] == person_id
and record["info_type"] == info_type
and not record["forget"]
):
time_diff = abs(
record["start_time"] - self.info_fetched_cache[person_id][info_type]["start_time"]
)
if time_diff < min_time_diff:
min_time_diff = time_diff
matched_record = record
if matched_record:
matched_record["forget"] = True
logger.info(f"{self.log_prefix} 用户 {person_id}{info_type} 信息已过期,标记为遗忘。")
del self.info_fetched_cache[person_id][info_type]
if not self.info_fetched_cache[person_id]:
del self.info_fetched_cache[person_id]
@@ -208,7 +214,7 @@ class RelationshipProcessor(BaseProcessor):
# 5. 为需要处理的人员准备LLM prompt
nickname_str = ",".join(global_config.bot.alias_names)
name_block = f"你的名字是{global_config.bot.nickname},你的昵称有{nickname_str},有人也会用这些昵称称呼你。"
info_cache_block = ""
if self.info_fetching_cache:
for info_fetching in self.info_fetching_cache:
@@ -223,7 +229,7 @@ class RelationshipProcessor(BaseProcessor):
chat_observe_info=chat_observe_info,
info_cache_block=info_cache_block,
)
try:
logger.debug(f"{self.log_prefix} 人物信息prompt: \n{prompt}\n")
content, _ = await self.llm_model.generate_response_async(prompt=prompt)
@@ -234,45 +240,47 @@ class RelationshipProcessor(BaseProcessor):
# 收集即时提取任务
instant_tasks = []
async_tasks = []
for person_name, info_type in content_json.items():
person_id = person_info_manager.get_person_id_by_person_name(person_name)
if person_id:
self.info_fetching_cache.append({
"person_id": person_id,
"person_name": person_name,
"info_type": info_type,
"start_time": time.time(),
"forget": False,
})
self.info_fetching_cache.append(
{
"person_id": person_id,
"person_name": person_name,
"info_type": info_type,
"start_time": time.time(),
"forget": False,
}
)
if len(self.info_fetching_cache) > 20:
self.info_fetching_cache.pop(0)
else:
logger.warning(f"{self.log_prefix} 未找到用户 {person_name} 的ID跳过调取信息。")
continue
logger.info(f"{self.log_prefix} 调取用户 {person_name}{info_type} 信息。")
# 检查person_engaged_cache中是否已存在该person_id
person_exists = any(record["person_id"] == person_id for record in self.person_engaged_cache)
if not person_exists:
self.person_engaged_cache.append({
"person_id": person_id,
"start_time": time.time(),
"rounds": 0
})
self.person_engaged_cache.append(
{"person_id": person_id, "start_time": time.time(), "rounds": 0}
)
if ENABLE_INSTANT_INFO_EXTRACTION:
# 收集即时提取任务
instant_tasks.append((person_id, info_type, time.time()))
else:
# 使用原来的异步模式
async_tasks.append(asyncio.create_task(self.fetch_person_info(person_id, [info_type], start_time=time.time())))
async_tasks.append(
asyncio.create_task(self.fetch_person_info(person_id, [info_type], start_time=time.time()))
)
# 执行即时提取任务
if ENABLE_INSTANT_INFO_EXTRACTION and instant_tasks:
await self._execute_instant_extraction_batch(instant_tasks)
# 启动异步任务(如果不是即时模式)
if async_tasks:
# 异步任务不需要等待完成
@@ -300,7 +308,7 @@ class RelationshipProcessor(BaseProcessor):
person_infos_str += f"你不了解{person_name}有关[{info_type}]的信息,不要胡乱回答,你可以直接说你不知道,或者你忘记了;"
if person_infos_str:
persons_infos_str += f"你对 {person_name} 的了解:{person_infos_str}\n"
# 处理正在调取但还没有结果的项目(只在非即时提取模式下显示)
if not ENABLE_INSTANT_INFO_EXTRACTION:
pending_info_dict = {}
@@ -312,50 +320,47 @@ class RelationshipProcessor(BaseProcessor):
person_id = record["person_id"]
person_name = record["person_name"]
info_type = record["info_type"]
# 检查是否已经在info_fetched_cache中有结果
if (person_id in self.info_fetched_cache and
info_type in self.info_fetched_cache[person_id]):
if person_id in self.info_fetched_cache and info_type in self.info_fetched_cache[person_id]:
continue
# 按人物组织正在调取的信息
if person_name not in pending_info_dict:
pending_info_dict[person_name] = []
pending_info_dict[person_name].append(info_type)
# 添加正在调取的信息到返回字符串
for person_name, info_types in pending_info_dict.items():
info_types_str = "".join(info_types)
persons_infos_str += f"你正在识图回忆有关 {person_name}{info_types_str} 信息,稍等一下再回答...\n"
return persons_infos_str
async def _execute_instant_extraction_batch(self, instant_tasks: list):
"""
批量执行即时提取任务
"""
if not instant_tasks:
return
logger.info(f"{self.log_prefix} [即时提取] 开始批量提取 {len(instant_tasks)} 个信息")
# 创建所有提取任务
extraction_tasks = []
for person_id, info_type, start_time in instant_tasks:
# 检查缓存中是否已存在且未过期的信息
if (person_id in self.info_fetched_cache and
info_type in self.info_fetched_cache[person_id]):
if person_id in self.info_fetched_cache and info_type in self.info_fetched_cache[person_id]:
logger.info(f"{self.log_prefix} 用户 {person_id}{info_type} 信息已存在且未过期,跳过调取。")
continue
task = asyncio.create_task(self._fetch_single_info_instant(person_id, info_type, start_time))
extraction_tasks.append(task)
# 并行执行所有提取任务并等待完成
if extraction_tasks:
await asyncio.gather(*extraction_tasks, return_exceptions=True)
logger.info(f"{self.log_prefix} [即时提取] 批量提取完成")
async def _fetch_single_info_instant(self, person_id: str, info_type: str, start_time: float):
"""
@@ -363,24 +368,21 @@ class RelationshipProcessor(BaseProcessor):
"""
nickname_str = ",".join(global_config.bot.alias_names)
name_block = f"你的名字是{global_config.bot.nickname},你的昵称有{nickname_str},有人也会用这些昵称称呼你。"
person_name = await person_info_manager.get_value(person_id, "person_name")
person_impression = await person_info_manager.get_value(person_id, "impression")
if not person_impression:
impression_block = "你对ta没有什么深刻的印象"
else:
impression_block = f"{person_impression}"
points = await person_info_manager.get_value(person_id, "points")
if points:
points_text = "\n".join([
f"{point[2]}:{point[0]}"
for point in points
])
points_text = "\n".join([f"{point[2]}:{point[0]}" for point in points])
else:
points_text = "你不记得ta最近发生了什么"
prompt = (await global_prompt_manager.get_prompt_async("fetch_info_prompt")).format(
name_block=name_block,
info_type=info_type,
@@ -393,9 +395,9 @@ class RelationshipProcessor(BaseProcessor):
try:
# 使用小模型进行即时提取
content, _ = await self.instant_llm_model.generate_response_async(prompt=prompt)
logger.info(f"{self.log_prefix} [即时提取] {person_name}{info_type} 结果: {content}")
if content:
content_json = json.loads(repair_json(content))
if info_type in content_json:
@@ -410,7 +412,9 @@ class RelationshipProcessor(BaseProcessor):
"person_name": person_name,
"unknow": False,
}
logger.info(f"{self.log_prefix} [即时提取] 成功获取 {person_name}{info_type}: {info_content}")
logger.info(
f"{self.log_prefix} [即时提取] 成功获取 {person_name}{info_type}: {info_content}"
)
else:
if person_id not in self.info_fetched_cache:
self.info_fetched_cache[person_id] = {}
@@ -423,59 +427,55 @@ class RelationshipProcessor(BaseProcessor):
}
logger.info(f"{self.log_prefix} [即时提取] {person_name}{info_type} 信息不明确")
else:
logger.warning(f"{self.log_prefix} [即时提取] 小模型返回空结果,获取 {person_name}{info_type} 信息失败。")
logger.warning(
f"{self.log_prefix} [即时提取] 小模型返回空结果,获取 {person_name}{info_type} 信息失败。"
)
except Exception as e:
logger.error(f"{self.log_prefix} [即时提取] 执行小模型请求获取用户信息时出错: {e}")
logger.error(traceback.format_exc())
async def fetch_person_info(self, person_id: str, info_types: list[str], start_time: float):
"""
获取某个人的信息
"""
# 检查缓存中是否已存在且未过期的信息
info_types_to_fetch = []
for info_type in info_types:
if (person_id in self.info_fetched_cache and
info_type in self.info_fetched_cache[person_id]):
if person_id in self.info_fetched_cache and info_type in self.info_fetched_cache[person_id]:
logger.info(f"{self.log_prefix} 用户 {person_id}{info_type} 信息已存在且未过期,跳过调取。")
continue
info_types_to_fetch.append(info_type)
if not info_types_to_fetch:
return
nickname_str = ",".join(global_config.bot.alias_names)
name_block = f"你的名字是{global_config.bot.nickname},你的昵称有{nickname_str},有人也会用这些昵称称呼你。"
person_name = await person_info_manager.get_value(person_id, "person_name")
info_type_str = ""
info_json_str = ""
for info_type in info_types_to_fetch:
info_type_str += f"{info_type},"
info_json_str += f"\"{info_type}\": \"信息内容\","
info_json_str += f'"{info_type}": "信息内容",'
info_type_str = info_type_str[:-1]
info_json_str = info_json_str[:-1]
person_impression = await person_info_manager.get_value(person_id, "impression")
if not person_impression:
impression_block = "你对ta没有什么深刻的印象"
else:
impression_block = f"{person_impression}"
points = await person_info_manager.get_value(person_id, "points")
if points:
points_text = "\n".join([
f"{point[2]}:{point[0]}"
for point in points
])
points_text = "\n".join([f"{point[2]}:{point[0]}" for point in points])
else:
points_text = "你不记得ta最近发生了什么"
prompt = (await global_prompt_manager.get_prompt_async("fetch_info_prompt")).format(
name_block=name_block,
info_type=info_type_str,
@@ -487,10 +487,10 @@ class RelationshipProcessor(BaseProcessor):
try:
content, _ = await self.llm_model.generate_response_async(prompt=prompt)
# logger.info(f"{self.log_prefix} fetch_person_info prompt: \n{prompt}\n")
logger.info(f"{self.log_prefix} fetch_person_info 结果: {content}")
if content:
try:
content_json = json.loads(repair_json(content))
@@ -508,9 +508,9 @@ class RelationshipProcessor(BaseProcessor):
else:
if person_id not in self.info_fetched_cache:
self.info_fetched_cache[person_id] = {}
self.info_fetched_cache[person_id][info_type] = {
"info":"unknow",
"info": "unknow",
"ttl": 10,
"start_time": start_time,
"person_name": person_name,
@@ -525,16 +525,12 @@ class RelationshipProcessor(BaseProcessor):
logger.error(f"{self.log_prefix} 执行LLM请求获取用户信息时出错: {e}")
logger.error(traceback.format_exc())
async def update_impression_on_cache_expiry(
self, person_id: str, chat_id: str, start_time: float, end_time: float
):
async def update_impression_on_cache_expiry(self, person_id: str, chat_id: str, start_time: float, end_time: float):
"""
在缓存过期时,获取聊天记录并更新用户印象
"""
logger.info(f"缓存过期,开始为 {person_id} 更新印象。时间范围:{start_time} -> {end_time}")
try:
impression_messages = get_raw_msg_by_timestamp_with_chat(chat_id, start_time, end_time)
if impression_messages:
logger.info(f"{person_id} 获取到 {len(impression_messages)} 条消息用于印象更新。")

View File

@@ -122,9 +122,7 @@ class SelfProcessor(BaseProcessor):
)
# 获取聊天内容
chat_observe_info = observation.get_observe_info()
person_list = observation.person_list
if isinstance(observation, HFCloopObservation):
# hfcloop_observe_info = observation.get_observe_info()
pass
nickname_str = ""
@@ -133,9 +131,7 @@ class SelfProcessor(BaseProcessor):
name_block = f"你的名字是{global_config.bot.nickname},你的昵称有{nickname_str},有人也会用这些昵称称呼你。"
personality_block = individuality.get_personality_prompt(x_person=2, level=2)
identity_block = individuality.get_identity_prompt(x_person=2, level=2)
prompt = (await global_prompt_manager.get_prompt_async("indentify_prompt")).format(

View File

@@ -118,7 +118,7 @@ class ToolProcessor(BaseProcessor):
is_group_chat = observation.is_group_chat
chat_observe_info = observation.get_observe_info()
person_list = observation.person_list
# person_list = observation.person_list
memory_str = ""
if running_memorys:
@@ -141,9 +141,7 @@ class ToolProcessor(BaseProcessor):
# 调用LLM专注于工具使用
# logger.info(f"开始执行工具调用{prompt}")
response, other_info = await self.llm_model.generate_response_async(
prompt=prompt, tools=tools
)
response, other_info = await self.llm_model.generate_response_async(prompt=prompt, tools=tools)
if len(other_info) == 3:
reasoning_content, model_name, tool_calls = other_info

View File

@@ -118,9 +118,7 @@ class WorkingMemoryProcessor(BaseProcessor):
memory_str=memory_choose_str,
)
# print(f"prompt: {prompt}")
# 调用LLM处理记忆
content = ""

View File

@@ -90,7 +90,7 @@ class MemoryActivator:
# 如果记忆系统被禁用,直接返回空列表
if not global_config.memory.enable_memory:
return []
obs_info_text = ""
for observation in observations:
if isinstance(observation, ChattingObservation):

View File

@@ -5,9 +5,6 @@ from src.chat.focus_chat.replyer.default_replyer import DefaultReplyer
from src.chat.focus_chat.expressors.default_expressor import DefaultExpressor
from src.chat.message_receive.chat_stream import ChatStream
from src.common.logger_manager import get_logger
import importlib
import pkgutil
import os
# 不再需要导入动作类因为已经在main.py中导入
# import src.chat.actions.default_actions # noqa
@@ -41,7 +38,7 @@ class ActionManager:
# 初始化时将默认动作加载到使用中的动作
self._using_actions = self._default_actions.copy()
# 添加系统核心动作
self._add_system_core_actions()
@@ -63,19 +60,19 @@ class ActionManager:
action_require: list[str] = getattr(action_class, "action_require", [])
associated_types: list[str] = getattr(action_class, "associated_types", [])
is_enabled: bool = getattr(action_class, "enable_plugin", True)
# 获取激活类型相关属性
focus_activation_type: str = getattr(action_class, "focus_activation_type", "always")
normal_activation_type: str = getattr(action_class, "normal_activation_type", "always")
random_probability: float = getattr(action_class, "random_activation_probability", 0.3)
llm_judge_prompt: str = getattr(action_class, "llm_judge_prompt", "")
activation_keywords: list[str] = getattr(action_class, "activation_keywords", [])
keyword_case_sensitive: bool = getattr(action_class, "keyword_case_sensitive", False)
# 获取模式启用属性
mode_enable: str = getattr(action_class, "mode_enable", "all")
# 获取并行执行属性
parallel_action: bool = getattr(action_class, "parallel_action", False)
@@ -114,13 +111,13 @@ class ActionManager:
def _load_plugin_actions(self) -> None:
"""
加载所有插件目录中的动作
注意插件动作的实际导入已经在main.py中完成这里只需要从_ACTION_REGISTRY获取
"""
try:
# 插件动作已在main.py中加载这里只需要从_ACTION_REGISTRY获取
self._load_registered_actions()
logger.info(f"从注册表加载插件动作成功")
logger.info("从注册表加载插件动作成功")
except Exception as e:
logger.error(f"加载插件动作失败: {e}")
@@ -203,25 +200,25 @@ class ActionManager:
def get_using_actions_for_mode(self, mode: str) -> Dict[str, ActionInfo]:
"""
根据聊天模式获取可用的动作集合
Args:
mode: 聊天模式 ("focus", "normal", "all")
Returns:
Dict[str, ActionInfo]: 在指定模式下可用的动作集合
"""
filtered_actions = {}
for action_name, action_info in self._using_actions.items():
action_mode = action_info.get("mode_enable", "all")
# 检查动作是否在当前模式下启用
if action_mode == "all" or action_mode == mode:
filtered_actions[action_name] = action_info
logger.debug(f"动作 {action_name} 在模式 {mode} 下可用 (mode_enable: {action_mode})")
else:
logger.debug(f"动作 {action_name} 在模式 {mode} 下不可用 (mode_enable: {action_mode})")
logger.debug(f"模式 {mode} 下可用动作: {list(filtered_actions.keys())}")
return filtered_actions
@@ -325,7 +322,7 @@ class ActionManager:
系统核心动作是那些enable_plugin为False但是系统必需的动作
"""
system_core_actions = ["exit_focus_chat"] # 可以根据需要扩展
for action_name in system_core_actions:
if action_name in self._registered_actions and action_name not in self._using_actions:
self._using_actions[action_name] = self._registered_actions[action_name]
@@ -334,10 +331,10 @@ class ActionManager:
def add_system_action_if_needed(self, action_name: str) -> bool:
"""
根据需要添加系统动作到使用集
Args:
action_name: 动作名称
Returns:
bool: 是否成功添加
"""

View File

@@ -30,13 +30,13 @@ class ActionModifier:
"""初始化动作处理器"""
self.action_manager = action_manager
self.all_actions = self.action_manager.get_using_actions_for_mode(ChatMode.FOCUS)
# 用于LLM判定的小模型
self.llm_judge = LLMRequest(
model=global_config.model.utils_small,
request_type="action.judge",
)
# 缓存相关属性
self._llm_judge_cache = {} # 缓存LLM判定结果
self._cache_expiry_time = 30 # 缓存过期时间(秒)
@@ -49,15 +49,15 @@ class ActionModifier:
):
"""
完整的动作修改流程,整合传统观察处理和新的激活类型判定
这个方法处理完整的动作管理流程:
1. 基于观察的传统动作修改(循环历史分析、类型匹配等)
2. 基于激活类型的智能动作判定,最终确定可用动作集
处理后ActionManager 将包含最终的可用动作集,供规划器直接使用
"""
logger.debug(f"{self.log_prefix}开始完整动作修改流程")
# === 第一阶段:传统观察处理 ===
if observations:
hfc_obs = None
@@ -86,7 +86,7 @@ class ActionModifier:
merged_action_changes["add"].extend(action_changes["add"])
merged_action_changes["remove"].extend(action_changes["remove"])
reasons.append("基于循环历史分析")
# 详细记录循环历史分析的变更原因
for action_name in action_changes["add"]:
logger.info(f"{self.log_prefix}添加动作: {action_name},原因: 循环历史分析建议添加")
@@ -106,7 +106,9 @@ class ActionModifier:
if not chat_context.check_types(data["associated_types"]):
type_mismatched_actions.append(action_name)
associated_types_str = ", ".join(data["associated_types"])
logger.info(f"{self.log_prefix}移除动作: {action_name},原因: 关联类型不匹配(需要: {associated_types_str}")
logger.info(
f"{self.log_prefix}移除动作: {action_name},原因: 关联类型不匹配(需要: {associated_types_str}"
)
if type_mismatched_actions:
# 合并到移除列表中
@@ -123,17 +125,19 @@ class ActionModifier:
self.action_manager.remove_action_from_using(action_name)
logger.debug(f"{self.log_prefix}应用移除动作: {action_name},原因集合: {reasons}")
logger.info(f"{self.log_prefix}传统动作修改完成,当前使用动作: {list(self.action_manager.get_using_actions().keys())}")
logger.info(
f"{self.log_prefix}传统动作修改完成,当前使用动作: {list(self.action_manager.get_using_actions().keys())}"
)
# === 第二阶段:激活类型判定 ===
# 如果提供了聊天上下文,则进行激活类型判定
if chat_content is not None:
logger.debug(f"{self.log_prefix}开始激活类型判定阶段")
# 获取当前使用的动作集经过第一阶段处理且适用于FOCUS模式
current_using_actions = self.action_manager.get_using_actions()
all_registered_actions = self.action_manager.get_using_actions_for_mode(ChatMode.FOCUS)
# 构建完整的动作信息
current_actions_with_info = {}
for action_name in current_using_actions.keys():
@@ -141,17 +145,17 @@ class ActionModifier:
current_actions_with_info[action_name] = all_registered_actions[action_name]
else:
logger.warning(f"{self.log_prefix}使用中的动作 {action_name} 未在已注册动作中找到")
# 应用激活类型判定
final_activated_actions = await self._apply_activation_type_filtering(
current_actions_with_info,
chat_content,
)
# 更新ActionManager移除未激活的动作
actions_to_remove = []
removal_reasons = {}
for action_name in current_using_actions.keys():
if action_name not in final_activated_actions:
actions_to_remove.append(action_name)
@@ -159,7 +163,7 @@ class ActionModifier:
if action_name in all_registered_actions:
action_info = all_registered_actions[action_name]
activation_type = action_info.get("focus_activation_type", ActionActivationType.ALWAYS)
if activation_type == ActionActivationType.RANDOM:
probability = action_info.get("random_probability", 0.3)
removal_reasons[action_name] = f"RANDOM类型未触发概率{probability}"
@@ -172,15 +176,17 @@ class ActionModifier:
removal_reasons[action_name] = "激活判定未通过"
else:
removal_reasons[action_name] = "动作信息不完整"
for action_name in actions_to_remove:
self.action_manager.remove_action_from_using(action_name)
reason = removal_reasons.get(action_name, "未知原因")
logger.info(f"{self.log_prefix}移除动作: {action_name},原因: {reason}")
logger.info(f"{self.log_prefix}激活类型判定完成,最终可用动作: {list(final_activated_actions.keys())}")
logger.info(f"{self.log_prefix}完整动作修改流程结束,最终动作集: {list(self.action_manager.get_using_actions().keys())}")
logger.info(
f"{self.log_prefix}完整动作修改流程结束,最终动作集: {list(self.action_manager.get_using_actions().keys())}"
)
async def _apply_activation_type_filtering(
self,
@@ -189,27 +195,27 @@ class ActionModifier:
) -> Dict[str, Any]:
"""
应用激活类型过滤逻辑,支持四种激活类型的并行处理
Args:
actions_with_info: 带完整信息的动作字典
observed_messages_str: 观察到的聊天消息
chat_context: 聊天上下文信息
extra_context: 额外的上下文信息
Returns:
Dict[str, Any]: 过滤后激活的actions字典
"""
activated_actions = {}
# 分类处理不同激活类型的actions
always_actions = {}
random_actions = {}
llm_judge_actions = {}
keyword_actions = {}
for action_name, action_info in actions_with_info.items():
activation_type = action_info.get("focus_activation_type", ActionActivationType.ALWAYS)
if activation_type == ActionActivationType.ALWAYS:
always_actions[action_name] = action_info
elif activation_type == ActionActivationType.RANDOM:
@@ -220,12 +226,12 @@ class ActionModifier:
keyword_actions[action_name] = action_info
else:
logger.warning(f"{self.log_prefix}未知的激活类型: {activation_type},跳过处理")
# 1. 处理ALWAYS类型直接激活
for action_name, action_info in always_actions.items():
activated_actions[action_name] = action_info
logger.debug(f"{self.log_prefix}激活动作: {action_name},原因: ALWAYS类型直接激活")
# 2. 处理RANDOM类型
for action_name, action_info in random_actions.items():
probability = action_info.get("random_probability", 0.3)
@@ -235,7 +241,7 @@ class ActionModifier:
logger.debug(f"{self.log_prefix}激活动作: {action_name},原因: RANDOM类型触发概率{probability}")
else:
logger.debug(f"{self.log_prefix}未激活动作: {action_name},原因: RANDOM类型未触发概率{probability}")
# 3. 处理KEYWORD类型快速判定
for action_name, action_info in keyword_actions.items():
should_activate = self._check_keyword_activation(
@@ -250,7 +256,7 @@ class ActionModifier:
else:
keywords = action_info.get("activation_keywords", [])
logger.debug(f"{self.log_prefix}未激活动作: {action_name},原因: KEYWORD类型未匹配关键词{keywords}")
# 4. 处理LLM_JUDGE类型并行判定
if llm_judge_actions:
# 直接并行处理所有LLM判定actions
@@ -258,7 +264,7 @@ class ActionModifier:
llm_judge_actions,
chat_content,
)
# 添加激活的LLM判定actions
for action_name, should_activate in llm_results.items():
if should_activate:
@@ -266,46 +272,43 @@ class ActionModifier:
logger.debug(f"{self.log_prefix}激活动作: {action_name},原因: LLM_JUDGE类型判定通过")
else:
logger.debug(f"{self.log_prefix}未激活动作: {action_name},原因: LLM_JUDGE类型判定未通过")
logger.debug(f"{self.log_prefix}激活类型过滤完成: {list(activated_actions.keys())}")
return activated_actions
async def process_actions_for_planner(
self,
observed_messages_str: str = "",
chat_context: Optional[str] = None,
extra_context: Optional[str] = None
self, observed_messages_str: str = "", chat_context: Optional[str] = None, extra_context: Optional[str] = None
) -> Dict[str, Any]:
"""
[已废弃] 此方法现在已被整合到 modify_actions() 中
为了保持向后兼容性而保留,但建议直接使用 ActionManager.get_using_actions()
规划器应该直接从 ActionManager 获取最终的可用动作集,而不是调用此方法
新的架构:
1. 主循环调用 modify_actions() 处理完整的动作管理流程
2. 规划器直接使用 ActionManager.get_using_actions() 获取最终动作集
"""
logger.warning(f"{self.log_prefix}process_actions_for_planner() 已废弃,建议规划器直接使用 ActionManager.get_using_actions()")
logger.warning(
f"{self.log_prefix}process_actions_for_planner() 已废弃,建议规划器直接使用 ActionManager.get_using_actions()"
)
# 为了向后兼容,仍然返回当前使用的动作集
current_using_actions = self.action_manager.get_using_actions()
all_registered_actions = self.action_manager.get_registered_actions()
# 构建完整的动作信息
result = {}
for action_name in current_using_actions.keys():
if action_name in all_registered_actions:
result[action_name] = all_registered_actions[action_name]
return result
def _generate_context_hash(self, chat_content: str) -> str:
"""生成上下文的哈希值用于缓存"""
context_content = f"{chat_content}"
return hashlib.md5(context_content.encode('utf-8')).hexdigest()
return hashlib.md5(context_content.encode("utf-8")).hexdigest()
async def _process_llm_judge_actions_parallel(
self,
@@ -314,85 +317,85 @@ class ActionModifier:
) -> Dict[str, bool]:
"""
并行处理LLM判定actions支持智能缓存
Args:
llm_judge_actions: 需要LLM判定的actions
observed_messages_str: 观察到的聊天消息
chat_context: 聊天上下文
extra_context: 额外上下文
Returns:
Dict[str, bool]: action名称到激活结果的映射
"""
# 生成当前上下文的哈希值
current_context_hash = self._generate_context_hash(chat_content)
current_time = time.time()
results = {}
tasks_to_run = {}
# 检查缓存
for action_name, action_info in llm_judge_actions.items():
cache_key = f"{action_name}_{current_context_hash}"
# 检查是否有有效的缓存
if (cache_key in self._llm_judge_cache and
current_time - self._llm_judge_cache[cache_key]["timestamp"] < self._cache_expiry_time):
if (
cache_key in self._llm_judge_cache
and current_time - self._llm_judge_cache[cache_key]["timestamp"] < self._cache_expiry_time
):
results[action_name] = self._llm_judge_cache[cache_key]["result"]
logger.debug(f"{self.log_prefix}使用缓存结果 {action_name}: {'激活' if results[action_name] else '未激活'}")
logger.debug(
f"{self.log_prefix}使用缓存结果 {action_name}: {'激活' if results[action_name] else '未激活'}"
)
else:
# 需要进行LLM判定
tasks_to_run[action_name] = action_info
# 如果有需要运行的任务,并行执行
if tasks_to_run:
logger.debug(f"{self.log_prefix}并行执行LLM判定任务数: {len(tasks_to_run)}")
# 创建并行任务
tasks = []
task_names = []
for action_name, action_info in tasks_to_run.items():
task = self._llm_judge_action(
action_name,
action_info,
chat_content,
action_name,
action_info,
chat_content,
)
tasks.append(task)
task_names.append(action_name)
# 并行执行所有任务
try:
task_results = await asyncio.gather(*tasks, return_exceptions=True)
# 处理结果并更新缓存
for i, (action_name, result) in enumerate(zip(task_names, task_results)):
for _, (action_name, result) in enumerate(zip(task_names, task_results)):
if isinstance(result, Exception):
logger.error(f"{self.log_prefix}LLM判定action {action_name} 时出错: {result}")
results[action_name] = False
else:
results[action_name] = result
# 更新缓存
cache_key = f"{action_name}_{current_context_hash}"
self._llm_judge_cache[cache_key] = {
"result": result,
"timestamp": current_time
}
self._llm_judge_cache[cache_key] = {"result": result, "timestamp": current_time}
logger.debug(f"{self.log_prefix}并行LLM判定完成耗时: {time.time() - current_time:.2f}s")
except Exception as e:
logger.error(f"{self.log_prefix}并行LLM判定失败: {e}")
# 如果并行执行失败为所有任务返回False
for action_name in tasks_to_run.keys():
results[action_name] = False
# 清理过期缓存
self._cleanup_expired_cache(current_time)
return results
def _cleanup_expired_cache(self, current_time: float):
@@ -401,40 +404,39 @@ class ActionModifier:
for cache_key, cache_data in self._llm_judge_cache.items():
if current_time - cache_data["timestamp"] > self._cache_expiry_time:
expired_keys.append(cache_key)
for key in expired_keys:
del self._llm_judge_cache[key]
if expired_keys:
logger.debug(f"{self.log_prefix}清理了 {len(expired_keys)} 个过期缓存条目")
async def _llm_judge_action(
self,
action_name: str,
self,
action_name: str,
action_info: Dict[str, Any],
chat_content: str = "",
) -> bool:
"""
使用LLM判定是否应该激活某个action
Args:
action_name: 动作名称
action_info: 动作信息
observed_messages_str: 观察到的聊天消息
chat_context: 聊天上下文
extra_context: 额外上下文
Returns:
bool: 是否应该激活此action
"""
try:
# 构建判定提示词
action_description = action_info.get("description", "")
action_require = action_info.get("require", [])
custom_prompt = action_info.get("llm_judge_prompt", "")
# 构建基础判定提示词
base_prompt = f"""
你需要判断在当前聊天情况下,是否应该激活名为"{action_name}"的动作。
@@ -445,34 +447,34 @@ class ActionModifier:
"""
for req in action_require:
base_prompt += f"- {req}\n"
if custom_prompt:
base_prompt += f"\n额外判定条件:\n{custom_prompt}\n"
if chat_content:
base_prompt += f"\n当前聊天记录:\n{chat_content}\n"
base_prompt += """
请根据以上信息判断是否应该激活这个动作。
只需要回答"""",不要有其他内容。
"""
# 调用LLM进行判定
response, _ = await self.llm_judge.generate_response_async(prompt=base_prompt)
# 解析响应
response = response.strip().lower()
# print(base_prompt)
print(f"LLM判定动作 {action_name}:响应='{response}'")
should_activate = "" in response or "yes" in response or "true" in response
logger.debug(f"{self.log_prefix}LLM判定动作 {action_name}:响应='{response}',结果={'激活' if should_activate else '不激活'}")
logger.debug(
f"{self.log_prefix}LLM判定动作 {action_name}:响应='{response}',结果={'激活' if should_activate else '不激活'}"
)
return should_activate
except Exception as e:
logger.error(f"{self.log_prefix}LLM判定动作 {action_name} 时出错: {e}")
# 出错时默认不激活
@@ -486,45 +488,45 @@ class ActionModifier:
) -> bool:
"""
检查是否匹配关键词触发条件
Args:
action_name: 动作名称
action_info: 动作信息
observed_messages_str: 观察到的聊天消息
chat_context: 聊天上下文
extra_context: 额外上下文
Returns:
bool: 是否应该激活此action
"""
activation_keywords = action_info.get("activation_keywords", [])
case_sensitive = action_info.get("keyword_case_sensitive", False)
if not activation_keywords:
logger.warning(f"{self.log_prefix}动作 {action_name} 设置为关键词触发但未配置关键词")
return False
# 构建检索文本
search_text = ""
if chat_content:
search_text += chat_content
# if chat_context:
# search_text += f" {chat_context}"
# search_text += f" {chat_context}"
# if extra_context:
# search_text += f" {extra_context}"
# search_text += f" {extra_context}"
# 如果不区分大小写,转换为小写
if not case_sensitive:
search_text = search_text.lower()
# 检查每个关键词
matched_keywords = []
for keyword in activation_keywords:
check_keyword = keyword if case_sensitive else keyword.lower()
if check_keyword in search_text:
matched_keywords.append(keyword)
if matched_keywords:
logger.debug(f"{self.log_prefix}动作 {action_name} 匹配到关键词: {matched_keywords}")
return True
@@ -568,7 +570,9 @@ class ActionModifier:
result["remove"].append("no_reply")
result["remove"].append("reply")
no_reply_ratio = no_reply_count / len(recent_cycles)
logger.info(f"{self.log_prefix}检测到高no_reply比例: {no_reply_ratio:.2f}达到退出聊天阈值将添加exit_focus_chat并移除no_reply/reply动作")
logger.info(
f"{self.log_prefix}检测到高no_reply比例: {no_reply_ratio:.2f}达到退出聊天阈值将添加exit_focus_chat并移除no_reply/reply动作"
)
# 计算连续回复的相关阈值
@@ -593,7 +597,7 @@ class ActionModifier:
if len(last_max_reply_num) >= max_reply_num and all(last_max_reply_num):
# 如果最近max_reply_num次都是reply直接移除
result["remove"].append("reply")
reply_count = len(last_max_reply_num) - no_reply_count
# reply_count = len(last_max_reply_num) - no_reply_count
logger.info(
f"{self.log_prefix}移除reply动作原因: 连续回复过多(最近{len(last_max_reply_num)}次全是reply超过阈值{max_reply_num}"
)
@@ -622,8 +626,6 @@ class ActionModifier:
f"{self.log_prefix}连续回复检测:最近{one_thres_reply_num}次全是reply{removal_probability:.2f}概率移除,未触发"
)
else:
logger.debug(
f"{self.log_prefix}连续回复检测无需移除reply动作最近回复模式正常"
)
logger.debug(f"{self.log_prefix}连续回复检测无需移除reply动作最近回复模式正常")
return result

View File

@@ -146,7 +146,7 @@ class ActionPlanner(BasePlanner):
# 注意动作的激活判定现在在主循环的modify_actions中完成
# 使用Focus模式过滤动作
current_available_actions_dict = self.action_manager.get_using_actions_for_mode(ChatMode.FOCUS)
# 获取完整的动作信息
all_registered_actions = self.action_manager.get_registered_actions()
current_available_actions = {}
@@ -192,12 +192,11 @@ class ActionPlanner(BasePlanner):
try:
prompt = f"{prompt}"
llm_content, (reasoning_content, _) = await self.planner_llm.generate_response_async(prompt=prompt)
logger.info(f"{self.log_prefix}规划器原始提示词: {prompt}")
logger.info(f"{self.log_prefix}规划器原始响应: {llm_content}")
logger.info(f"{self.log_prefix}规划器推理: {reasoning_content}")
except Exception as req_e:
logger.error(f"{self.log_prefix}LLM 请求执行失败: {req_e}")
reasoning = f"LLM 请求失败,你的模型出现问题: {req_e}"
@@ -237,10 +236,10 @@ class ActionPlanner(BasePlanner):
extra_info_block = ""
action_data["extra_info_block"] = extra_info_block
if relation_info:
action_data["relation_info_block"] = relation_info
# 对于reply动作不需要额外处理因为相关字段已经在上面的循环中添加到action_data
if extracted_action not in current_available_actions:
@@ -303,12 +302,11 @@ class ActionPlanner(BasePlanner):
) -> str:
"""构建 Planner LLM 的提示词 (获取模板并填充数据)"""
try:
if relation_info_block:
relation_info_block = f"以下是你和别人的关系描述:\n{relation_info_block}"
else:
relation_info_block = ""
memory_str = ""
if running_memorys:
memory_str = "以下是当前在聊天中,你回忆起的记忆:\n"
@@ -331,9 +329,9 @@ class ActionPlanner(BasePlanner):
# mind_info_block = ""
# if current_mind:
# mind_info_block = f"对聊天的规划:{current_mind}"
# mind_info_block = f"对聊天的规划:{current_mind}"
# else:
# mind_info_block = "你刚参与聊天"
# mind_info_block = "你刚参与聊天"
personality_block = individuality.get_prompt(x_person=2, level=2)
@@ -351,16 +349,14 @@ class ActionPlanner(BasePlanner):
param_text = "\n"
for param_name, param_description in using_actions_info["parameters"].items():
param_text += f' "{param_name}":"{param_description}"\n'
param_text = param_text.rstrip('\n')
param_text = param_text.rstrip("\n")
else:
param_text = ""
require_text = ""
for require_item in using_actions_info["require"]:
require_text += f"- {require_item}\n"
require_text = require_text.rstrip('\n')
require_text = require_text.rstrip("\n")
using_action_prompt = using_action_prompt.format(
action_name=using_actions_name,

View File

@@ -93,7 +93,7 @@ class DefaultReplyer:
self.chat_id = chat_stream.stream_id
self.chat_stream = chat_stream
self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_id)
self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_id)
async def _create_thinking_message(self, anchor_message: Optional[MessageRecv], thinking_id: str):
"""创建思考消息 (尝试锚定到 anchor_message)"""
@@ -141,7 +141,7 @@ class DefaultReplyer:
# text_part = action_data.get("text", [])
# if text_part:
sent_msg_list = []
with Timer("生成回复", cycle_timers):
# 可以保留原有的文本处理逻辑或进行适当调整
reply = await self.reply(
@@ -240,22 +240,21 @@ class DefaultReplyer:
# current_temp = float(global_config.model.normal["temp"]) * arousal_multiplier
# self.express_model.params["temperature"] = current_temp # 动态调整温度
reply_to = action_data.get("reply_to", "none")
sender = ""
targer = ""
if ":" in reply_to or "" in reply_to:
# 使用正则表达式匹配中文或英文冒号
parts = re.split(pattern=r'[:]', string=reply_to, maxsplit=1)
parts = re.split(pattern=r"[:]", string=reply_to, maxsplit=1)
if len(parts) == 2:
sender = parts[0].strip()
targer = parts[1].strip()
identity = action_data.get("identity", "")
extra_info_block = action_data.get("extra_info_block", "")
relation_info_block = action_data.get("relation_info_block", "")
# 3. 构建 Prompt
with Timer("构建Prompt", {}): # 内部计时器,可选保留
prompt = await self.build_prompt_focus(
@@ -374,8 +373,6 @@ class DefaultReplyer:
style_habbits_str = "\n".join(style_habbits)
grammar_habbits_str = "\n".join(grammar_habbits)
# 关键词检测与反应
keywords_reaction_prompt = ""
@@ -407,16 +404,15 @@ class DefaultReplyer:
time_block = f"当前时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
# logger.debug("开始构建 focus prompt")
if sender_name:
reply_target_block = f"现在{sender_name}说的:{target_message}。引起了你的注意,你想要在群里发言或者回复这条消息。"
reply_target_block = (
f"现在{sender_name}说的:{target_message}。引起了你的注意,你想要在群里发言或者回复这条消息。"
)
elif target_message:
reply_target_block = f"现在{target_message}引起了你的注意,你想要在群里发言或者回复这条消息。"
else:
reply_target_block = "现在,你想要在群里发言或者回复消息。"
# --- Choose template based on chat type ---
if is_group_chat:
@@ -665,30 +661,30 @@ def find_similar_expressions(input_text: str, expressions: List[Dict], top_k: in
"""使用TF-IDF和余弦相似度找出与输入文本最相似的top_k个表达方式"""
if not expressions:
return []
# 准备文本数据
texts = [expr['situation'] for expr in expressions]
texts = [expr["situation"] for expr in expressions]
texts.append(input_text) # 添加输入文本
# 使用TF-IDF向量化
vectorizer = TfidfVectorizer()
tfidf_matrix = vectorizer.fit_transform(texts)
# 计算余弦相似度
similarity_matrix = cosine_similarity(tfidf_matrix)
# 获取输入文本的相似度分数(最后一行)
scores = similarity_matrix[-1][:-1] # 排除与自身的相似度
# 获取top_k的索引
top_indices = np.argsort(scores)[::-1][:top_k]
# 获取相似表达
similar_exprs = []
for idx in top_indices:
if scores[idx] > 0: # 只保留有相似度的
similar_exprs.append(expressions[idx])
return similar_exprs

View File

@@ -62,13 +62,12 @@ class ChattingObservation(Observation):
self.oldest_messages = []
self.oldest_messages_str = ""
self.compressor_prompt = ""
initial_messages = get_raw_msg_before_timestamp_with_chat(self.chat_id, self.last_observe_time, 10)
self.last_observe_time = initial_messages[-1]["time"] if initial_messages else self.last_observe_time
self.talking_message = initial_messages
self.talking_message_str = build_readable_messages(self.talking_message, show_actions=True)
def to_dict(self) -> dict:
"""将观察对象转换为可序列化的字典"""
return {
@@ -283,7 +282,7 @@ class ChattingObservation(Observation):
show_actions=True,
)
# print(f"构建中self.talking_message_str_truncate: {self.talking_message_str_truncate}")
self.person_list = await get_person_id_list(self.talking_message)
# print(f"构建中self.person_list: {self.person_list}")

View File

@@ -42,9 +42,7 @@ class SubHeartflow:
self.history_chat_state: List[Tuple[ChatState, float]] = []
self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_id)
self.log_prefix = (
chat_manager.get_stream_name(self.subheartflow_id) or self.subheartflow_id
)
self.log_prefix = chat_manager.get_stream_name(self.subheartflow_id) or self.subheartflow_id
# 兴趣消息集合
self.interest_dict: Dict[str, tuple[MessageRecv, float, bool]] = {}
@@ -199,7 +197,6 @@ class SubHeartflow:
# 如果实例不存在,则创建并启动
logger.info(f"{log_prefix} 麦麦准备开始专注聊天...")
try:
self.heart_fc_instance = HeartFChatting(
chat_id=self.subheartflow_id,
# observations=self.observations,

View File

@@ -23,7 +23,7 @@ def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Dict]]:
chat_target_info = None
try:
chat_stream = chat_manager.get_stream(chat_id)
chat_stream = chat_manager.get_stream(chat_id)
if chat_stream:
if chat_stream.group_info:

View File

@@ -3,7 +3,7 @@ import os
from .global_logger import logger
from .lpmmconfig import global_config
from src.chat.knowledge.utils import get_sha256
from src.chat.knowledge.utils.hash import get_sha256
def load_raw_data(path: str = None) -> tuple[list[str], list[str]]:

View File

@@ -346,7 +346,9 @@ class Hippocampus:
# 使用LLM提取关键词
topic_num = min(5, max(1, int(len(text) * 0.1))) # 根据文本长度动态调整关键词数量
# logger.info(f"提取关键词数量: {topic_num}")
topics_response, (reasoning_content, model_name) = await self.model_summary.generate_response_async(self.find_topic_llm(text, topic_num))
topics_response, (reasoning_content, model_name) = await self.model_summary.generate_response_async(
self.find_topic_llm(text, topic_num)
)
# 提取关键词
keywords = re.findall(r"<([^>]+)>", topics_response)
@@ -701,7 +703,9 @@ class Hippocampus:
# 使用LLM提取关键词
topic_num = min(5, max(1, int(len(text) * 0.1))) # 根据文本长度动态调整关键词数量
# logger.info(f"提取关键词数量: {topic_num}")
topics_response, (reasoning_content, model_name) = await self.model_summary.generate_response_async(self.find_topic_llm(text, topic_num))
topics_response, (reasoning_content, model_name) = await self.model_summary.generate_response_async(
self.find_topic_llm(text, topic_num)
)
# 提取关键词
keywords = re.findall(r"<([^>]+)>", topics_response)
@@ -893,7 +897,7 @@ class EntorhinalCortex:
# 获取数据库中所有节点和内存中所有节点
db_nodes = {node.concept: node for node in GraphNodes.select()}
memory_nodes = list(self.memory_graph.G.nodes(data=True))
# 批量准备节点数据
nodes_to_create = []
nodes_to_update = []
@@ -929,22 +933,26 @@ class EntorhinalCortex:
continue
if concept not in db_nodes:
nodes_to_create.append({
"concept": concept,
"memory_items": memory_items_json,
"hash": memory_hash,
"created_time": created_time,
"last_modified": last_modified,
})
else:
db_node = db_nodes[concept]
if db_node.hash != memory_hash:
nodes_to_update.append({
nodes_to_create.append(
{
"concept": concept,
"memory_items": memory_items_json,
"hash": memory_hash,
"created_time": created_time,
"last_modified": last_modified,
})
}
)
else:
db_node = db_nodes[concept]
if db_node.hash != memory_hash:
nodes_to_update.append(
{
"concept": concept,
"memory_items": memory_items_json,
"hash": memory_hash,
"last_modified": last_modified,
}
)
# 计算需要删除的节点
memory_concepts = {concept for concept, _ in memory_nodes}
@@ -954,13 +962,13 @@ class EntorhinalCortex:
if nodes_to_create:
batch_size = 100
for i in range(0, len(nodes_to_create), batch_size):
batch = nodes_to_create[i:i + batch_size]
batch = nodes_to_create[i : i + batch_size]
GraphNodes.insert_many(batch).execute()
if nodes_to_update:
batch_size = 100
for i in range(0, len(nodes_to_update), batch_size):
batch = nodes_to_update[i:i + batch_size]
batch = nodes_to_update[i : i + batch_size]
for node_data in batch:
GraphNodes.update(**{k: v for k, v in node_data.items() if k != "concept"}).where(
GraphNodes.concept == node_data["concept"]
@@ -992,22 +1000,26 @@ class EntorhinalCortex:
last_modified = data.get("last_modified", current_time)
if edge_key not in db_edge_dict:
edges_to_create.append({
"source": source,
"target": target,
"strength": strength,
"hash": edge_hash,
"created_time": created_time,
"last_modified": last_modified,
})
edges_to_create.append(
{
"source": source,
"target": target,
"strength": strength,
"hash": edge_hash,
"created_time": created_time,
"last_modified": last_modified,
}
)
elif db_edge_dict[edge_key]["hash"] != edge_hash:
edges_to_update.append({
"source": source,
"target": target,
"strength": strength,
"hash": edge_hash,
"last_modified": last_modified,
})
edges_to_update.append(
{
"source": source,
"target": target,
"strength": strength,
"hash": edge_hash,
"last_modified": last_modified,
}
)
# 计算需要删除的边
memory_edge_keys = {(source, target) for source, target, _ in memory_edges}
@@ -1017,13 +1029,13 @@ class EntorhinalCortex:
if edges_to_create:
batch_size = 100
for i in range(0, len(edges_to_create), batch_size):
batch = edges_to_create[i:i + batch_size]
batch = edges_to_create[i : i + batch_size]
GraphEdges.insert_many(batch).execute()
if edges_to_update:
batch_size = 100
for i in range(0, len(edges_to_update), batch_size):
batch = edges_to_update[i:i + batch_size]
batch = edges_to_update[i : i + batch_size]
for edge_data in batch:
GraphEdges.update(**{k: v for k, v in edge_data.items() if k not in ["source", "target"]}).where(
(GraphEdges.source == edge_data["source"]) & (GraphEdges.target == edge_data["target"])
@@ -1031,9 +1043,7 @@ class EntorhinalCortex:
if edges_to_delete:
for source, target in edges_to_delete:
GraphEdges.delete().where(
(GraphEdges.source == source) & (GraphEdges.target == target)
).execute()
GraphEdges.delete().where((GraphEdges.source == source) & (GraphEdges.target == target)).execute()
end_time = time.time()
logger.success(f"[同步] 总耗时: {end_time - start_time:.2f}")
@@ -1069,13 +1079,15 @@ class EntorhinalCortex:
if not memory_items_json:
continue
nodes_data.append({
"concept": concept,
"memory_items": memory_items_json,
"hash": self.hippocampus.calculate_node_hash(concept, memory_items),
"created_time": data.get("created_time", current_time),
"last_modified": data.get("last_modified", current_time),
})
nodes_data.append(
{
"concept": concept,
"memory_items": memory_items_json,
"hash": self.hippocampus.calculate_node_hash(concept, memory_items),
"created_time": data.get("created_time", current_time),
"last_modified": data.get("last_modified", current_time),
}
)
except Exception as e:
logger.error(f"准备节点 {concept} 数据时发生错误: {e}")
continue
@@ -1084,14 +1096,16 @@ class EntorhinalCortex:
edges_data = []
for source, target, data in memory_edges:
try:
edges_data.append({
"source": source,
"target": target,
"strength": data.get("strength", 1),
"hash": self.hippocampus.calculate_edge_hash(source, target),
"created_time": data.get("created_time", current_time),
"last_modified": data.get("last_modified", current_time),
})
edges_data.append(
{
"source": source,
"target": target,
"strength": data.get("strength", 1),
"hash": self.hippocampus.calculate_edge_hash(source, target),
"created_time": data.get("created_time", current_time),
"last_modified": data.get("last_modified", current_time),
}
)
except Exception as e:
logger.error(f"准备边 {source}-{target} 数据时发生错误: {e}")
continue
@@ -1102,7 +1116,7 @@ class EntorhinalCortex:
batch_size = 500 # 增加批量大小
with GraphNodes._meta.database.atomic():
for i in range(0, len(nodes_data), batch_size):
batch = nodes_data[i:i + batch_size]
batch = nodes_data[i : i + batch_size]
GraphNodes.insert_many(batch).execute()
node_end = time.time()
logger.info(f"[数据库] 写入 {len(nodes_data)} 个节点耗时: {node_end - node_start:.2f}")
@@ -1113,7 +1127,7 @@ class EntorhinalCortex:
batch_size = 500 # 增加批量大小
with GraphEdges._meta.database.atomic():
for i in range(0, len(edges_data), batch_size):
batch = edges_data[i:i + batch_size]
batch = edges_data[i : i + batch_size]
GraphEdges.insert_many(batch).execute()
edge_end = time.time()
logger.info(f"[数据库] 写入 {len(edges_data)} 条边耗时: {edge_end - edge_start:.2f}")

View File

@@ -79,7 +79,7 @@ class ChatBot:
group_info = message.message_info.group_info
user_info = message.message_info.user_info
chat_manager.register_message(message)
# 创建聊天流
chat = await chat_manager.get_or_create_stream(
platform=message.message_info.platform,
@@ -87,13 +87,13 @@ class ChatBot:
group_info=group_info,
)
message.update_chat_stream(chat)
# 处理消息内容,生成纯文本
await message.process()
# 命令处理 - 在消息处理的早期阶段检查并处理命令
is_command, cmd_result, continue_process = await command_manager.process_command(message)
# 如果是命令且不需要继续处理,则直接返回
if is_command and not continue_process:
logger.info(f"命令处理完成,跳过后续消息处理: {cmd_result}")

View File

@@ -24,7 +24,6 @@ class MessageStorage:
else:
filtered_processed_plain_text = ""
if isinstance(message, MessageSending):
display_message = message.display_message
if display_message:

View File

@@ -13,8 +13,6 @@ from src.chat.utils.prompt_builder import global_prompt_manager
from .normal_chat_generator import NormalChatGenerator
from ..message_receive.message import MessageSending, MessageRecv, MessageThinking, MessageSet
from src.chat.message_receive.message_sender import message_manager
from src.chat.utils.utils_image import image_path_to_base64
from src.chat.emoji_system.emoji_manager import emoji_manager
from src.chat.normal_chat.willing.willing_manager import willing_manager
from src.chat.normal_chat.normal_chat_utils import get_recent_message_stats
from src.config.config import global_config
@@ -69,7 +67,7 @@ class NormalChat:
self.on_switch_to_focus_callback = on_switch_to_focus_callback
self._disabled = False # 增加停用标志
logger.debug(f"[{self.stream_name}] NormalChat 初始化完成 (异步部分)。")
# 改为实例方法
@@ -193,7 +191,9 @@ class NormalChat:
return
timing_results = {}
reply_probability = 1.0 if is_mentioned and global_config.normal_chat.mentioned_bot_inevitable_reply else 0.0 # 如果被提及且开启了提及必回复则基础概率为1否则需要意愿判断
reply_probability = (
1.0 if is_mentioned and global_config.normal_chat.mentioned_bot_inevitable_reply else 0.0
) # 如果被提及且开启了提及必回复则基础概率为1否则需要意愿判断
# 意愿管理器设置当前message信息
willing_manager.setup(message, self.chat_stream, is_mentioned, interested_rate)
@@ -267,13 +267,17 @@ class NormalChat:
try:
# 获取发送者名称(动作修改已在并行执行前完成)
sender_name = self._get_sender_name(message)
no_action = {
"action_result": {"action_type": "no_action", "action_data": {}, "reasoning": "规划器初始化默认", "is_parallel": True},
"action_result": {
"action_type": "no_action",
"action_data": {},
"reasoning": "规划器初始化默认",
"is_parallel": True,
},
"chat_context": "",
"action_prompt": "",
}
# 检查是否应该跳过规划
if self.action_modifier.should_skip_planning():
@@ -288,7 +292,9 @@ class NormalChat:
reasoning = plan_result["action_result"]["reasoning"]
is_parallel = plan_result["action_result"].get("is_parallel", False)
logger.info(f"[{self.stream_name}] Planner决策: {action_type}, 理由: {reasoning}, 并行执行: {is_parallel}")
logger.info(
f"[{self.stream_name}] Planner决策: {action_type}, 理由: {reasoning}, 并行执行: {is_parallel}"
)
self.action_type = action_type # 更新实例属性
self.is_parallel_action = is_parallel # 新增:保存并行执行标志
@@ -307,7 +313,12 @@ class NormalChat:
else:
logger.warning(f"[{self.stream_name}] 额外动作 {action_type} 执行失败")
return {"action_type": action_type, "action_data": action_data, "reasoning": reasoning, "is_parallel": is_parallel}
return {
"action_type": action_type,
"action_data": action_data,
"reasoning": reasoning,
"is_parallel": is_parallel,
}
except Exception as e:
logger.error(f"[{self.stream_name}] Planner执行失败: {e}")
@@ -331,13 +342,19 @@ class NormalChat:
logger.error(f"[{self.stream_name}] 动作规划异常: {plan_result}")
elif plan_result:
logger.debug(f"[{self.stream_name}] 额外动作处理完成: {self.action_type}")
if not response_set or (
self.enable_planner and self.action_type not in ["no_action", "change_to_focus_chat"] and not self.is_parallel_action
self.enable_planner
and self.action_type not in ["no_action", "change_to_focus_chat"]
and not self.is_parallel_action
):
if not response_set:
logger.info(f"[{self.stream_name}] 模型未生成回复内容")
elif self.enable_planner and self.action_type not in ["no_action", "change_to_focus_chat"] and not self.is_parallel_action:
elif (
self.enable_planner
and self.action_type not in ["no_action", "change_to_focus_chat"]
and not self.is_parallel_action
):
logger.info(f"[{self.stream_name}] 模型选择其他动作(非并行动作)")
# 如果模型未生成回复,移除思考消息
container = await message_manager.get_container(self.stream_id) # 使用 self.stream_id
@@ -364,7 +381,6 @@ class NormalChat:
# 检查 first_bot_msg 是否为 None (例如思考消息已被移除的情况)
if first_bot_msg:
# 记录回复信息到最近回复列表中
reply_info = {
"time": time.time(),
@@ -396,7 +412,6 @@ class NormalChat:
await self._check_switch_to_focus()
pass
# with Timer("关系更新", timing_results):
# await self._update_relationship(message, response_set)
@@ -605,7 +620,7 @@ class NormalChat:
# 执行动作
result = await action_handler.handle_action()
success = False
if result and isinstance(result, tuple) and len(result) >= 2:
# handle_action返回 (success: bool, message: str)
success = result[0]

View File

@@ -35,7 +35,7 @@ class NormalChatActionModifier:
**kwargs: Any,
):
"""为Normal Chat修改可用动作集合
实现动作激活策略:
1. 基于关联类型的动态过滤
2. 基于激活类型的智能判定LLM_JUDGE转为概率激活
@@ -49,7 +49,7 @@ class NormalChatActionModifier:
reasons = []
merged_action_changes = {"add": [], "remove": []}
type_mismatched_actions = [] # 在外层定义避免作用域问题
self.action_manager.restore_default_actions()
# 第一阶段:基于关联类型的动态过滤
@@ -74,7 +74,7 @@ class NormalChatActionModifier:
# 第二阶段:应用激活类型判定
# 构建聊天内容 - 使用与planner一致的方式
chat_content = ""
if chat_stream and hasattr(chat_stream, 'stream_id'):
if chat_stream and hasattr(chat_stream, "stream_id"):
try:
# 获取消息历史使用与normal_chat_planner相同的方法
message_list_before_now = get_raw_msg_before_timestamp_with_chat(
@@ -82,7 +82,7 @@ class NormalChatActionModifier:
timestamp=time.time(),
limit=global_config.focus_chat.observation_context_size, # 使用相同的配置
)
# 构建可读的聊天上下文
chat_content = build_readable_messages(
message_list_before_now,
@@ -92,39 +92,41 @@ class NormalChatActionModifier:
read_mark=0.0,
show_actions=True,
)
logger.debug(f"{self.log_prefix} 成功构建聊天内容,长度: {len(chat_content)}")
except Exception as e:
logger.warning(f"{self.log_prefix} 构建聊天内容失败: {e}")
chat_content = ""
# 获取当前Normal模式下的动作集进行激活判定
current_actions = self.action_manager.get_using_actions_for_mode(ChatMode.NORMAL)
# print(f"current_actions: {current_actions}")
# print(f"chat_content: {chat_content}")
final_activated_actions = await self._apply_normal_activation_filtering(
current_actions,
chat_content,
message_content
current_actions, chat_content, message_content
)
# print(f"final_activated_actions: {final_activated_actions}")
# 统一处理所有需要移除的动作,避免重复移除
all_actions_to_remove = set() # 使用set避免重复
# 添加关联类型不匹配的动作
if type_mismatched_actions:
all_actions_to_remove.update(type_mismatched_actions)
# 添加激活类型判定未通过的动作
for action_name in current_actions.keys():
if action_name not in final_activated_actions:
all_actions_to_remove.add(action_name)
# 统计移除原因(避免重复)
activation_failed_actions = [name for name in current_actions.keys() if name not in final_activated_actions and name not in type_mismatched_actions]
activation_failed_actions = [
name
for name in current_actions.keys()
if name not in final_activated_actions and name not in type_mismatched_actions
]
if activation_failed_actions:
reasons.append(f"移除{activation_failed_actions}(激活类型判定未通过)")
@@ -146,7 +148,7 @@ class NormalChatActionModifier:
# 记录变更原因
if reasons:
logger.info(f"{self.log_prefix} 动作调整完成: {' | '.join(reasons)}")
# 获取最终的Normal模式可用动作并记录
final_actions = self.action_manager.get_using_actions_for_mode(ChatMode.NORMAL)
logger.debug(f"{self.log_prefix} 当前Normal模式可用动作: {list(final_actions.keys())}")
@@ -159,31 +161,31 @@ class NormalChatActionModifier:
) -> Dict[str, Any]:
"""
应用Normal模式的激活类型过滤逻辑
与Focus模式的区别
1. LLM_JUDGE类型转换为概率激活避免LLM调用
2. RANDOM类型保持概率激活
3. KEYWORD类型保持关键词匹配
4. ALWAYS类型直接激活
Args:
actions_with_info: 带完整信息的动作字典
chat_content: 聊天内容
Returns:
Dict[str, Any]: 过滤后激活的actions字典
"""
activated_actions = {}
# 分类处理不同激活类型的actions
always_actions = {}
random_actions = {}
keyword_actions = {}
for action_name, action_info in actions_with_info.items():
# 使用normal_activation_type
activation_type = action_info.get("normal_activation_type", ActionActivationType.ALWAYS)
if activation_type == ActionActivationType.ALWAYS:
always_actions[action_name] = action_info
elif activation_type == ActionActivationType.RANDOM or activation_type == ActionActivationType.LLM_JUDGE:
@@ -192,12 +194,12 @@ class NormalChatActionModifier:
keyword_actions[action_name] = action_info
else:
logger.warning(f"{self.log_prefix}未知的激活类型: {activation_type},跳过处理")
# 1. 处理ALWAYS类型直接激活
for action_name, action_info in always_actions.items():
activated_actions[action_name] = action_info
logger.debug(f"{self.log_prefix}激活动作: {action_name},原因: ALWAYS类型直接激活")
# 2. 处理RANDOM类型概率激活
for action_name, action_info in random_actions.items():
probability = action_info.get("random_probability", 0.3)
@@ -207,15 +209,10 @@ class NormalChatActionModifier:
logger.debug(f"{self.log_prefix}激活动作: {action_name},原因: RANDOM类型触发概率{probability}")
else:
logger.debug(f"{self.log_prefix}未激活动作: {action_name},原因: RANDOM类型未触发概率{probability}")
# 3. 处理KEYWORD类型关键词匹配
for action_name, action_info in keyword_actions.items():
should_activate = self._check_keyword_activation(
action_name,
action_info,
chat_content,
message_content
)
should_activate = self._check_keyword_activation(action_name, action_info, chat_content, message_content)
if should_activate:
activated_actions[action_name] = action_info
keywords = action_info.get("activation_keywords", [])
@@ -225,7 +222,7 @@ class NormalChatActionModifier:
logger.debug(f"{self.log_prefix}未激活动作: {action_name},原因: KEYWORD类型未匹配关键词{keywords}")
# print(f"keywords: {keywords}")
# print(f"chat_content: {chat_content}")
logger.debug(f"{self.log_prefix}Normal模式激活类型过滤完成: {list(activated_actions.keys())}")
return activated_actions
@@ -238,41 +235,40 @@ class NormalChatActionModifier:
) -> bool:
"""
检查是否匹配关键词触发条件
Args:
action_name: 动作名称
action_info: 动作信息
chat_content: 聊天内容(已经是格式化后的可读消息)
Returns:
bool: 是否应该激活此action
"""
activation_keywords = action_info.get("activation_keywords", [])
case_sensitive = action_info.get("keyword_case_sensitive", False)
if not activation_keywords:
logger.warning(f"{self.log_prefix}动作 {action_name} 设置为关键词触发但未配置关键词")
return False
# 使用构建好的聊天内容作为检索文本
search_text = chat_content +message_content
search_text = chat_content + message_content
# 如果不区分大小写,转换为小写
if not case_sensitive:
search_text = search_text.lower()
# 检查每个关键词
matched_keywords = []
for keyword in activation_keywords:
check_keyword = keyword if case_sensitive else keyword.lower()
if check_keyword in search_text:
matched_keywords.append(keyword)
# print(f"search_text: {search_text}")
# print(f"activation_keywords: {activation_keywords}")
if matched_keywords:
logger.debug(f"{self.log_prefix}动作 {action_name} 匹配到关键词: {matched_keywords}")
return True

View File

@@ -9,7 +9,7 @@ import time
from typing import List, Optional, Tuple, Dict, Any
from src.chat.message_receive.message import MessageRecv, MessageSending, MessageThinking, Seg
from src.chat.message_receive.message import UserInfo
from src.chat.message_receive.chat_stream import ChatStream,chat_manager
from src.chat.message_receive.chat_stream import ChatStream, chat_manager
from src.chat.message_receive.message_sender import message_manager
from src.config.config import global_config
from src.common.logger_manager import get_logger
@@ -37,7 +37,7 @@ class NormalChatExpressor:
self.chat_stream = chat_stream
self.stream_name = chat_manager.get_stream_name(self.chat_stream.stream_id) or self.chat_stream.stream_id
self.log_prefix = f"[{self.stream_name}]Normal表达器"
logger.debug(f"{self.log_prefix} 初始化完成")
async def create_thinking_message(

View File

@@ -25,9 +25,7 @@ class NormalChatGenerator:
request_type="normal.chat_2",
)
self.model_sum = LLMRequest(
model=global_config.model.memory_summary, temperature=0.7, request_type="relation"
)
self.model_sum = LLMRequest(model=global_config.model.memory_summary, temperature=0.7, request_type="relation")
self.current_model_type = "r1" # 默认使用 R1
self.current_model_name = "unknown model"
@@ -68,7 +66,6 @@ class NormalChatGenerator:
enable_planner: bool = False,
available_actions=None,
):
person_id = person_info_manager.get_person_id(
message.chat_stream.user_info.platform, message.chat_stream.user_info.user_id
)
@@ -103,7 +100,6 @@ class NormalChatGenerator:
logger.info(f"{message.processed_plain_text} 的回复:{content}")
except Exception:
logger.exception("生成回复时出错")
return None

View File

@@ -101,7 +101,7 @@ class NormalChatPlanner:
# 获取当前可用的动作使用Normal模式过滤
current_available_actions = self.action_manager.get_using_actions_for_mode(ChatMode.NORMAL)
# 注意:动作的激活判定现在在 normal_chat_action_modifier 中完成
# 这里直接使用经过 action_modifier 处理后的最终动作集
# 符合职责分离原则ActionModifier负责动作管理Planner专注于决策
@@ -110,7 +110,12 @@ class NormalChatPlanner:
if not current_available_actions:
logger.debug(f"{self.log_prefix}规划器: 没有可用动作返回no_action")
return {
"action_result": {"action_type": action, "action_data": action_data, "reasoning": reasoning, "is_parallel": True},
"action_result": {
"action_type": action,
"action_data": action_data,
"reasoning": reasoning,
"is_parallel": True,
},
"chat_context": "",
"action_prompt": "",
}
@@ -121,7 +126,7 @@ class NormalChatPlanner:
timestamp=time.time(),
limit=global_config.focus_chat.observation_context_size,
)
chat_context = build_readable_messages(
message_list_before_now,
replace_bot_name=True,
@@ -130,7 +135,7 @@ class NormalChatPlanner:
read_mark=0.0,
show_actions=True,
)
# 构建planner的prompt
prompt = await self.build_planner_prompt(
self_info_block=self_info,
@@ -141,7 +146,12 @@ class NormalChatPlanner:
if not prompt:
logger.warning(f"{self.log_prefix}规划器: 构建提示词失败")
return {
"action_result": {"action_type": action, "action_data": action_data, "reasoning": reasoning, "is_parallel": False},
"action_result": {
"action_type": action,
"action_data": action_data,
"reasoning": reasoning,
"is_parallel": False,
},
"chat_context": chat_context,
"action_prompt": "",
}
@@ -149,7 +159,7 @@ class NormalChatPlanner:
# 使用LLM生成动作决策
try:
content, (reasoning_content, model_name) = await self.planner_llm.generate_response_async(prompt)
# logger.info(f"{self.log_prefix}规划器原始提示词: {prompt}")
logger.info(f"{self.log_prefix}规划器原始响应: {content}")
logger.info(f"{self.log_prefix}规划器推理: {reasoning_content}")
@@ -201,8 +211,10 @@ class NormalChatPlanner:
if action in current_available_actions:
action_info = current_available_actions[action]
is_parallel = action_info.get("parallel_action", False)
logger.debug(f"{self.log_prefix}规划器决策动作:{action}, 动作信息: '{action_data}', 理由: {reasoning}, 并行执行: {is_parallel}")
logger.debug(
f"{self.log_prefix}规划器决策动作:{action}, 动作信息: '{action_data}', 理由: {reasoning}, 并行执行: {is_parallel}"
)
# 恢复到默认动作集
self.action_manager.restore_actions()
@@ -216,15 +228,15 @@ class NormalChatPlanner:
"action_data": action_data,
"reasoning": reasoning,
"timestamp": time.time(),
"model_name": model_name if 'model_name' in locals() else None
"model_name": model_name if "model_name" in locals() else None,
}
action_result = {
"action_type": action,
"action_data": action_data,
"action_type": action,
"action_data": action_data,
"reasoning": reasoning,
"is_parallel": is_parallel,
"action_record": json.dumps(action_record, ensure_ascii=False)
"action_record": json.dumps(action_record, ensure_ascii=False),
}
plan_result = {
@@ -248,24 +260,19 @@ class NormalChatPlanner:
# 添加特殊的change_to_focus_chat动作
action_options_text += "动作change_to_focus_chat\n"
action_options_text += (
"该动作的描述当聊天变得热烈、自己回复条数很多或需要深入交流时使用正常回复消息并切换到focus_chat模式\n"
)
action_options_text += "该动作的描述当聊天变得热烈、自己回复条数很多或需要深入交流时使用正常回复消息并切换到focus_chat模式\n"
action_options_text += "使用该动作的场景:\n"
action_options_text += "- 聊天上下文中自己的回复条数较多超过3-4条\n"
action_options_text += "- 对话进行得非常热烈活跃\n"
action_options_text += "- 用户表现出深入交流的意图\n"
action_options_text += "- 话题需要更专注和深入的讨论\n\n"
action_options_text += "输出要求:\n"
action_options_text += "{{"
action_options_text += " \"action\": \"change_to_focus_chat\""
action_options_text += ' "action": "change_to_focus_chat"'
action_options_text += "}}\n\n"
for action_name, action_info in current_available_actions.items():
action_description = action_info.get("description", "")
action_parameters = action_info.get("parameters", {})
@@ -276,15 +283,14 @@ class NormalChatPlanner:
print(action_parameters)
for param_name, param_description in action_parameters.items():
param_text += f' "{param_name}":"{param_description}"\n'
param_text = param_text.rstrip('\n')
param_text = param_text.rstrip("\n")
else:
param_text = ""
require_text = ""
for require_item in action_require:
require_text += f"- {require_item}\n"
require_text = require_text.rstrip('\n')
require_text = require_text.rstrip("\n")
# 构建单个动作的提示
action_prompt = await global_prompt_manager.format_prompt(
@@ -316,6 +322,4 @@ class NormalChatPlanner:
return ""
init_prompt()

View File

@@ -214,7 +214,6 @@ class PromptBuilder:
except Exception as e:
logger.error(f"关键词检测与反应时发生异常: {str(e)}", exc_info=True)
moderation_prompt_block = "请不要输出违法违规内容,不要输出色情,暴力,政治相关内容,如有敏感内容,请规避。"
# 构建action描述 (如果启用planner)

View File

@@ -42,9 +42,7 @@ class ClassicalWillingManager(BaseWillingManager):
self.chat_reply_willing[chat_id] = min(current_willing, 3.0)
reply_probability = min(
max((current_willing - 0.5), 0.01) * 2, 1
)
reply_probability = min(max((current_willing - 0.5), 0.01) * 2, 1)
# 检查群组权限(如果是群聊)
if (

View File

@@ -286,7 +286,7 @@ def _build_readable_messages_internal(
message_details_with_flags.append((timestamp, name, content, is_action))
# print(f"content:{content}")
# print(f"is_action:{is_action}")
# print(f"message_details_with_flags:{message_details_with_flags}")
# 应用截断逻辑 (如果 truncate 为 True)
@@ -324,7 +324,7 @@ def _build_readable_messages_internal(
else:
# 如果不截断,直接使用原始列表
message_details = message_details_with_flags
# print(f"message_details:{message_details}")
# 3: 合并连续消息 (如果 merge_messages 为 True)
@@ -336,12 +336,12 @@ def _build_readable_messages_internal(
"start_time": message_details[0][0],
"end_time": message_details[0][0],
"content": [message_details[0][2]],
"is_action": message_details[0][3]
"is_action": message_details[0][3],
}
for i in range(1, len(message_details)):
timestamp, name, content, is_action = message_details[i]
# 对于动作记录,不进行合并
if is_action or current_merge["is_action"]:
# 保存当前的合并块
@@ -352,7 +352,7 @@ def _build_readable_messages_internal(
"start_time": timestamp,
"end_time": timestamp,
"content": [content],
"is_action": is_action
"is_action": is_action,
}
continue
@@ -365,11 +365,11 @@ def _build_readable_messages_internal(
merged_messages.append(current_merge)
# 开始新的合并块
current_merge = {
"name": name,
"start_time": timestamp,
"end_time": timestamp,
"name": name,
"start_time": timestamp,
"end_time": timestamp,
"content": [content],
"is_action": is_action
"is_action": is_action,
}
# 添加最后一个合并块
merged_messages.append(current_merge)
@@ -381,10 +381,9 @@ def _build_readable_messages_internal(
"start_time": timestamp, # 起始和结束时间相同
"end_time": timestamp,
"content": [content], # 内容只有一个元素
"is_action": is_action
"is_action": is_action,
}
)
# 4 & 5: 格式化为字符串
output_lines = []
@@ -451,7 +450,7 @@ def build_readable_messages(
将消息列表转换为可读的文本格式。
如果提供了 read_mark则在相应位置插入已读标记。
允许通过参数控制格式化行为。
Args:
messages: 消息列表
replace_bot_name: 是否替换机器人名称为""
@@ -463,22 +462,24 @@ def build_readable_messages(
"""
# 创建messages的深拷贝避免修改原始列表
copy_messages = [msg.copy() for msg in messages]
if show_actions and copy_messages:
# 获取所有消息的时间范围
min_time = min(msg.get("time", 0) for msg in copy_messages)
max_time = max(msg.get("time", 0) for msg in copy_messages)
# 从第一条消息中获取chat_id
chat_id = copy_messages[0].get("chat_id") if copy_messages else None
# 获取这个时间范围内的动作记录并匹配chat_id
actions = ActionRecords.select().where(
(ActionRecords.time >= min_time) &
(ActionRecords.time <= max_time) &
(ActionRecords.chat_id == chat_id)
).order_by(ActionRecords.time)
actions = (
ActionRecords.select()
.where(
(ActionRecords.time >= min_time) & (ActionRecords.time <= max_time) & (ActionRecords.chat_id == chat_id)
)
.order_by(ActionRecords.time)
)
# 将动作记录转换为消息格式
for action in actions:
# 只有当build_into_prompt为True时才添加动作记录
@@ -495,25 +496,22 @@ def build_readable_messages(
"action_name": action.action_name, # 保存动作名称
}
copy_messages.append(action_msg)
# 重新按时间排序
copy_messages.sort(key=lambda x: x.get("time", 0))
if read_mark <= 0:
# 没有有效的 read_mark直接格式化所有消息
# for message in messages:
# print(f"message:{message}")
# print(f"message:{message}")
formatted_string, _ = _build_readable_messages_internal(
copy_messages, replace_bot_name, merge_messages, timestamp_mode, truncate
)
# print(f"formatted_string:{formatted_string}")
return formatted_string
else:
# 按 read_mark 分割消息
@@ -521,10 +519,10 @@ def build_readable_messages(
messages_after_mark = [msg for msg in copy_messages if msg.get("time", 0) > read_mark]
# for message in messages_before_mark:
# print(f"message:{message}")
# print(f"message:{message}")
# for message in messages_after_mark:
# print(f"message:{message}")
# print(f"message:{message}")
# 分别格式化
formatted_before, _ = _build_readable_messages_internal(
@@ -536,7 +534,7 @@ def build_readable_messages(
merge_messages,
timestamp_mode,
)
# print(f"formatted_before:{formatted_before}")
# print(f"formatted_after:{formatted_after}")

View File

@@ -154,7 +154,8 @@ class Messages(BaseModel):
class Meta:
# database = db # 继承自 BaseModel
table_name = "messages"
class ActionRecords(BaseModel):
"""
用于存储动作记录数据的模型。
@@ -162,11 +163,11 @@ class ActionRecords(BaseModel):
action_id = TextField(index=True) # 消息 ID (更改自 IntegerField)
time = DoubleField() # 消息时间戳
action_name = TextField()
action_data = TextField()
action_done = BooleanField(default=False)
action_build_into_prompt = BooleanField(default=False)
action_prompt_display = TextField()
@@ -241,11 +242,10 @@ class PersonInfo(BaseModel):
points = TextField(null=True) # 个人印象的点
forgotten_points = TextField(null=True) # 被遗忘的点
info_list = TextField(null=True) # 与Bot的互动
know_times = FloatField(null=True) # 认识时间 (时间戳)
know_since = FloatField(null=True) # 首次印象总结时间
last_know = FloatField(null=True) # 最后一次印象总结时间
class Meta:
# database = db # 继承自 BaseModel
@@ -403,20 +403,20 @@ def initialize_database():
logger.info(f"'{table_name}' 缺失字段 '{field_name}',正在添加...")
field_type = field_obj.__class__.__name__
sql_type = {
'TextField': 'TEXT',
'IntegerField': 'INTEGER',
'FloatField': 'FLOAT',
'DoubleField': 'DOUBLE',
'BooleanField': 'INTEGER',
'DateTimeField': 'DATETIME'
}.get(field_type, 'TEXT')
alter_sql = f'ALTER TABLE {table_name} ADD COLUMN {field_name} {sql_type}'
"TextField": "TEXT",
"IntegerField": "INTEGER",
"FloatField": "FLOAT",
"DoubleField": "DOUBLE",
"BooleanField": "INTEGER",
"DateTimeField": "DATETIME",
}.get(field_type, "TEXT")
alter_sql = f"ALTER TABLE {table_name} ADD COLUMN {field_name} {sql_type}"
if field_obj.null:
alter_sql += ' NULL'
alter_sql += " NULL"
else:
alter_sql += ' NOT NULL'
if hasattr(field_obj, 'default') and field_obj.default is not None:
alter_sql += f' DEFAULT {field_obj.default}'
alter_sql += " NOT NULL"
if hasattr(field_obj, "default") and field_obj.default is not None:
alter_sql += f" DEFAULT {field_obj.default}"
db.execute_sql(alter_sql)
logger.info(f"字段 '{field_name}' 添加成功")

View File

@@ -84,7 +84,7 @@ def update_config():
contains_regex = False
if value and isinstance(value[0], dict) and "regex" in value[0]:
contains_regex = True
if contains_regex:
target[key] = value
else:

View File

@@ -49,7 +49,7 @@ class IdentityConfig(ConfigBase):
@dataclass
class RelationshipConfig(ConfigBase):
"""关系配置类"""
enable_relationship: bool = True
give_name: bool = False
@@ -58,6 +58,7 @@ class RelationshipConfig(ConfigBase):
build_relationship_interval: int = 600
"""构建关系间隔 单位秒如果为0则不构建关系"""
@dataclass
class ChatConfig(ConfigBase):
"""聊天配置类"""
@@ -222,7 +223,7 @@ class EmojiConfig(ConfigBase):
@dataclass
class MemoryConfig(ConfigBase):
"""记忆配置类"""
enable_memory: bool = True
memory_build_interval: int = 600
@@ -329,6 +330,7 @@ class KeywordReactionConfig(ConfigBase):
if not isinstance(rule, KeywordRuleConfig):
raise ValueError(f"规则必须是KeywordRuleConfig类型而不是{type(rule).__name__}")
@dataclass
class ResponsePostProcessConfig(ConfigBase):
"""回复后处理配置类"""
@@ -461,7 +463,7 @@ class LPMMKnowledgeConfig(ConfigBase):
qa_res_top_k: int = 10
"""QA最终结果的Top K数量"""
@dataclass
class ModelConfig(ConfigBase):

View File

@@ -1,9 +1,12 @@
from abc import ABC, abstractmethod
from typing import List, Dict, Any
from typing import List, Dict, Any, Callable
# from src.common.database.database import db # Peewee db 导入
from playhouse import shortcuts
# from src.common.database.database import db # Peewee db 导入
from src.common.database.database_model import Messages # Peewee Messages 模型导入
from playhouse.shortcuts import model_to_dict # 用于将模型实例转换为字典
model_to_dict: Callable[..., dict] = shortcuts.model_to_dict # Peewee 模型转换为字典的快捷函数
class MessageStorage(ABC):

View File

@@ -90,7 +90,7 @@ class PersonalityExpression:
current_style_text = global_config.expression.expression_style
current_personality = global_config.personality.personality_core
meta_data = self._read_meta_data()
last_style_text = meta_data.get("last_style_text")
@@ -98,9 +98,10 @@ class PersonalityExpression:
count = meta_data.get("count", 0)
# 检查是否有任何变化
if (current_style_text != last_style_text or
current_personality != last_personality):
logger.info(f"检测到变化:\n风格: '{last_style_text}' -> '{current_style_text}'\n人格: '{last_personality}' -> '{current_personality}'")
if current_style_text != last_style_text or current_personality != last_personality:
logger.info(
f"检测到变化:\n风格: '{last_style_text}' -> '{current_style_text}'\n人格: '{last_personality}' -> '{current_personality}'"
)
count = 0
if os.path.exists(self.expressions_file_path):
try:
@@ -196,7 +197,7 @@ class PersonalityExpression:
"last_style_text": current_style_text,
"last_personality": current_personality,
"count": count,
"last_update_time": current_time
"last_update_time": current_time,
}
)
logger.info(f"成功处理。当前配置的计数现在是 {count},最后更新时间:{current_time}")

View File

@@ -19,6 +19,7 @@ from .common.server import global_server, Server
from rich.traceback import install
from .chat.focus_chat.expressors.exprssion_learner import expression_learner
from .api.main import start_api_server
# 导入actions模块确保装饰器被执行
import src.chat.actions.default_actions # noqa
@@ -40,7 +41,7 @@ class MainSystem:
self.hippocampus_manager = hippocampus_manager
else:
self.hippocampus_manager = None
self.individuality: Individuality = individuality
# 使用消息API替代直接的FastAPI实例
@@ -74,11 +75,11 @@ class MainSystem:
# 启动API服务器
start_api_server()
logger.success("API服务器启动成功")
# 加载所有actions包括默认的和插件的
self._load_all_actions()
logger.success("动作系统加载成功")
# 初始化表情管理器
emoji_manager.initialize()
logger.success("表情包管理器初始化成功")
@@ -137,23 +138,25 @@ class MainSystem:
try:
# 导入统一的插件加载器
from src.plugins.plugin_loader import plugin_loader
# 使用统一的插件加载器加载所有插件组件
loaded_actions, loaded_commands = plugin_loader.load_all_plugins()
# 加载命令处理系统
try:
# 导入命令处理系统
from src.chat.command.command_handler import command_manager
logger.success("命令处理系统加载成功")
except Exception as e:
logger.error(f"加载命令处理系统失败: {e}")
import traceback
logger.error(traceback.format_exc())
except Exception as e:
logger.error(f"加载插件失败: {e}")
import traceback
logger.error(traceback.format_exc())
async def schedule_tasks(self):
@@ -165,17 +168,19 @@ class MainSystem:
self.app.run(),
self.server.run(),
]
# 根据配置条件性地添加记忆系统相关任务
if global_config.memory.enable_memory and self.hippocampus_manager:
tasks.extend([
self.build_memory_task(),
self.forget_memory_task(),
self.consolidate_memory_task(),
])
tasks.extend(
[
self.build_memory_task(),
self.forget_memory_task(),
self.consolidate_memory_task(),
]
)
tasks.append(self.learn_and_store_expression_task())
await asyncio.gather(*tasks)
async def build_memory_task(self):
@@ -190,9 +195,7 @@ class MainSystem:
while True:
await asyncio.sleep(global_config.memory.forget_memory_interval)
logger.info("[记忆遗忘] 开始遗忘记忆...")
await self.hippocampus_manager.forget_memory(
percentage=global_config.memory.memory_forget_percentage
)
await self.hippocampus_manager.forget_memory(percentage=global_config.memory.memory_forget_percentage)
logger.info("[记忆遗忘] 记忆遗忘完成")
async def consolidate_memory_task(self):

View File

@@ -11,6 +11,7 @@ from collections import defaultdict
logger = get_logger("relation")
# 暂时弃用,改为实时更新
class ImpressionUpdateTask(AsyncTask):
def __init__(self):
@@ -25,10 +26,10 @@ class ImpressionUpdateTask(AsyncTask):
# 获取最近的消息
current_time = int(time.time())
start_time = current_time - global_config.relationship.build_relationship_interval # 100分钟前
# 获取所有消息
messages = get_raw_msg_by_timestamp(timestamp_start=start_time, timestamp_end=current_time)
if not messages:
logger.info("没有找到需要处理的消息")
return
@@ -48,7 +49,7 @@ class ImpressionUpdateTask(AsyncTask):
if len(msgs) < 30:
logger.info(f"聊天组 {chat_id} 消息数小于30跳过处理")
continue
chat_stream = chat_manager.get_stream(chat_id)
if not chat_stream:
logger.warning(f"未找到聊天组 {chat_id} 的chat_stream跳过处理")
@@ -56,26 +57,26 @@ class ImpressionUpdateTask(AsyncTask):
# 找到bot的消息
bot_messages = [msg for msg in msgs if msg["user_nickname"] == global_config.bot.nickname]
if not bot_messages:
logger.info(f"聊天组 {chat_id} 没有bot消息跳过处理")
continue
# 按时间排序所有消息
sorted_messages = sorted(msgs, key=lambda x: x["time"])
# 找到第一条和最后一条bot消息
first_bot_msg = bot_messages[0]
last_bot_msg = bot_messages[-1]
# 获取第一条bot消息前15条消息
first_bot_index = sorted_messages.index(first_bot_msg)
start_index = max(0, first_bot_index - 25)
# 获取最后一条bot消息后15条消息
last_bot_index = sorted_messages.index(last_bot_msg)
end_index = min(len(sorted_messages), last_bot_index + 26)
# 获取相关消息
relevant_messages = sorted_messages[start_index:end_index]
@@ -85,7 +86,9 @@ class ImpressionUpdateTask(AsyncTask):
# 计算权重
for bot_msg in bot_messages:
bot_time = bot_msg["time"]
context_messages = [msg for msg in relevant_messages if abs(msg["time"] - bot_time) <= 600] # 前后10分钟
context_messages = [
msg for msg in relevant_messages if abs(msg["time"] - bot_time) <= 600
] # 前后10分钟
logger.debug(f"Bot消息 {bot_time} 的上下文消息数: {len(context_messages)}")
for msg in context_messages:
@@ -121,7 +124,7 @@ class ImpressionUpdateTask(AsyncTask):
weights = [user[1]["weight"] for user in sorted_users]
total_weight = sum(weights)
# 计算每个用户的概率
probabilities = [w/total_weight for w in weights]
probabilities = [w / total_weight for w in weights]
# 使用累积概率进行选择
selected_indices = []
remaining_indices = list(range(len(sorted_users)))
@@ -131,12 +134,12 @@ class ImpressionUpdateTask(AsyncTask):
# 计算剩余索引的累积概率
remaining_probs = [probabilities[i] for i in remaining_indices]
# 归一化概率
remaining_probs = [p/sum(remaining_probs) for p in remaining_probs]
remaining_probs = [p / sum(remaining_probs) for p in remaining_probs]
# 选择索引
chosen_idx = random.choices(remaining_indices, weights=remaining_probs, k=1)[0]
selected_indices.append(chosen_idx)
remaining_indices.remove(chosen_idx)
selected_users = [sorted_users[i] for i in selected_indices]
logger.info(
f"开始进一步了解这些用户: {[msg[1]['messages'][0]['user_nickname'] for msg in selected_users]}"
@@ -153,19 +156,16 @@ class ImpressionUpdateTask(AsyncTask):
platform = data["messages"][0]["chat_info_platform"]
user_id = data["messages"][0]["user_id"]
cardname = data["messages"][0]["user_cardname"]
is_known = await relationship_manager.is_known_some_one(platform, user_id)
if not is_known:
logger.info(f"首次认识用户: {user_nickname}")
await relationship_manager.first_knowing_some_one(platform, user_id, user_nickname, cardname)
logger.info(f"开始更新用户 {user_nickname} 的印象")
await relationship_manager.update_person_impression(
person_id=person_id,
timestamp=last_bot_msg["time"],
bot_engaged_messages=relevant_messages
person_id=person_id, timestamp=last_bot_msg["time"], bot_engaged_messages=relevant_messages
)
logger.debug("印象更新任务执行完成")

View File

@@ -33,7 +33,7 @@ JSON_SERIALIZED_FIELDS = ["points", "forgotten_points", "info_list"]
person_info_default = {
"person_id": None,
"person_name": None,
"name_reason": None, # Corrected from person_name_reason to match common usage if intended
"name_reason": None, # Corrected from person_name_reason to match common usage if intended
"platform": "unknown",
"user_id": "unknown",
"nickname": "Unknown",
@@ -42,11 +42,10 @@ person_info_default = {
"last_know": None,
# "user_cardname": None, # This field is not in Peewee model PersonInfo
# "user_avatar": None, # This field is not in Peewee model PersonInfo
"impression": None, # Corrected from persion_impression
"impression": None, # Corrected from persion_impression
"info_list": None,
"points": None,
"forgotten_points": None,
}
@@ -126,7 +125,7 @@ class PersonInfoManager:
for key, default_value in _person_info_default.items():
if key in model_fields:
final_data[key] = default_value
# Override with provided data
if data:
for key, value in data.items():
@@ -141,7 +140,7 @@ class PersonInfoManager:
if key in final_data:
if isinstance(final_data[key], (list, dict)):
final_data[key] = json.dumps(final_data[key], ensure_ascii=False)
elif final_data[key] is None: # Default for lists is [], store as "[]"
elif final_data[key] is None: # Default for lists is [], store as "[]"
final_data[key] = json.dumps([], ensure_ascii=False)
# If it's already a string, assume it's valid JSON or a non-JSON string field
@@ -165,12 +164,12 @@ class PersonInfoManager:
return
print(f"更新字段: {field_name},值: {value}")
processed_value = value
if field_name in JSON_SERIALIZED_FIELDS:
if isinstance(value, (list, dict)):
processed_value = json.dumps(value, ensure_ascii=False, indent=None)
elif value is None: # Store None as "[]" for JSON list fields
elif value is None: # Store None as "[]" for JSON list fields
processed_value = json.dumps([], ensure_ascii=False, indent=None)
# If value is already a string, assume it's pre-serialized or a non-JSON string.
@@ -180,7 +179,7 @@ class PersonInfoManager:
setattr(record, f_name, val_to_set)
record.save()
return True, False # Found and updated, no creation needed
return False, True # Not found, needs creation
return False, True # Not found, needs creation
found, needs_creation = await asyncio.to_thread(_db_update_sync, person_id, field_name, processed_value)
@@ -190,15 +189,14 @@ class PersonInfoManager:
# Ensure platform and user_id are present for context if available from 'data'
# but primarily, set the field that triggered the update.
# The create_person_info will handle defaults and serialization.
creation_data[field_name] = value # Pass original value to create_person_info
creation_data[field_name] = value # Pass original value to create_person_info
# Ensure platform and user_id are in creation_data if available,
# otherwise create_person_info will use defaults.
if data and "platform" in data:
creation_data["platform"] = data["platform"]
creation_data["platform"] = data["platform"]
if data and "user_id" in data:
creation_data["user_id"] = data["user_id"]
creation_data["user_id"] = data["user_id"]
await self.create_person_info(person_id, creation_data)
@@ -233,7 +231,7 @@ class PersonInfoManager:
if isinstance(parsed_json, list) and parsed_json:
parsed_json = parsed_json[0]
if isinstance(parsed_json, dict):
return parsed_json
@@ -249,11 +247,11 @@ class PersonInfoManager:
# 处理空昵称的情况
if not base_name or base_name.isspace():
base_name = "空格"
# 检查基础名称是否已存在
if base_name not in self.person_name_list.values():
return base_name
# 如果存在,添加数字后缀
counter = 1
while True:
@@ -331,9 +329,11 @@ class PersonInfoManager:
if not is_duplicate:
await self.update_one_field(person_id, "person_name", generated_nickname)
await self.update_one_field(person_id, "name_reason", result.get("reason", "未提供理由"))
logger.info(f"成功给用户{user_nickname} {person_id} 取名 {generated_nickname},理由:{result.get('reason', '未提供理由')}")
logger.info(
f"成功给用户{user_nickname} {person_id} 取名 {generated_nickname},理由:{result.get('reason', '未提供理由')}"
)
self.person_name_list[person_id] = generated_nickname
return result
else:
@@ -379,7 +379,7 @@ class PersonInfoManager:
"""获取指定用户指定字段的值"""
default_value_for_field = person_info_default.get(field_name)
if field_name in JSON_SERIALIZED_FIELDS and default_value_for_field is None:
default_value_for_field = [] # Ensure JSON fields default to [] if not in DB
default_value_for_field = [] # Ensure JSON fields default to [] if not in DB
def _db_get_value_sync(p_id: str, f_name: str):
record = PersonInfo.get_or_none(PersonInfo.person_id == p_id)
@@ -391,32 +391,32 @@ class PersonInfoManager:
return json.loads(val)
except json.JSONDecodeError:
logger.warning(f"字段 {f_name} for {p_id} 包含无效JSON: {val}. 返回默认值.")
return [] # Default for JSON fields on error
elif val is None: # Field exists in DB but is None
return [] # Default for JSON fields
return [] # Default for JSON fields on error
elif val is None: # Field exists in DB but is None
return [] # Default for JSON fields
# If val is already a list/dict (e.g. if somehow set without serialization)
return val # Should ideally not happen if update_one_field is always used
return val # Should ideally not happen if update_one_field is always used
return val
return None # Record not found
return None # Record not found
try:
value_from_db = await asyncio.to_thread(_db_get_value_sync, person_id, field_name)
if value_from_db is not None:
return value_from_db
if field_name in person_info_default:
return default_value_for_field
return default_value_for_field
logger.warning(f"字段 {field_name} 在 person_info_default 中未定义,且在数据库中未找到。")
return None # Ultimate fallback
return None # Ultimate fallback
except Exception as e:
logger.error(f"获取字段 {field_name} for {person_id} 时出错 (Peewee): {e}")
# Fallback to default in case of any error during DB access
if field_name in person_info_default:
return default_value_for_field
return None
@staticmethod
def get_value_sync(person_id: str, field_name: str):
""" 同步获取指定用户指定字段的值 """
"""同步获取指定用户指定字段的值"""
default_value_for_field = person_info_default.get(field_name)
if field_name in JSON_SERIALIZED_FIELDS and default_value_for_field is None:
default_value_for_field = []
@@ -430,12 +430,12 @@ class PersonInfoManager:
return json.loads(val)
except json.JSONDecodeError:
logger.warning(f"字段 {field_name} for {person_id} 包含无效JSON: {val}. 返回默认值.")
return []
return []
elif val is None:
return []
return val
return val
return val
if field_name in person_info_default:
return default_value_for_field
logger.warning(f"字段 {field_name} 在 person_info_default 中未定义,且在数据库中未找到。")
@@ -534,7 +534,7 @@ class PersonInfoManager:
"last_know": int(datetime.datetime.now().timestamp()),
"impression": None,
"points": [],
"forgotten_points": []
"forgotten_points": [],
}
model_fields = PersonInfo._meta.fields.keys()
filtered_initial_data = {k: v for k, v in initial_data.items() if v is not None and k in model_fields}

View File

@@ -7,7 +7,6 @@ from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config
from src.chat.utils.chat_message_builder import build_readable_messages
from src.manager.mood_manager import mood_manager
from src.individuality.individuality import individuality
import json
from json_repair import repair_json
from datetime import datetime
@@ -90,9 +89,7 @@ class RelationshipManager:
return is_known
@staticmethod
async def first_knowing_some_one(
platform: str, user_id: str, user_nickname: str, user_cardname: str
):
async def first_knowing_some_one(platform: str, user_id: str, user_nickname: str, user_cardname: str):
"""判断是否认识某人"""
person_id = person_info_manager.get_person_id(platform, user_id)
# 生成唯一的 person_name
@@ -112,7 +109,7 @@ class RelationshipManager:
)
# 尝试生成更好的名字
# await person_info_manager.qv_person_name(
# person_id=person_id, user_nickname=user_nickname, user_cardname=user_cardname, user_avatar=user_avatar
# person_id=person_id, user_nickname=user_nickname, user_cardname=user_cardname, user_avatar=user_avatar
# )
async def build_relationship_info(self, person, is_id: bool = False) -> str:
@@ -124,26 +121,24 @@ class RelationshipManager:
person_name = await person_info_manager.get_value(person_id, "person_name")
if not person_name or person_name == "none":
return ""
impression = await person_info_manager.get_value(person_id, "impression")
# impression = await person_info_manager.get_value(person_id, "impression")
points = await person_info_manager.get_value(person_id, "points") or []
if isinstance(points, str):
try:
points = ast.literal_eval(points)
except (SyntaxError, ValueError):
points = []
random_points = random.sample(points, min(5, len(points))) if points else []
nickname_str = await person_info_manager.get_value(person_id, "nickname")
platform = await person_info_manager.get_value(person_id, "platform")
relation_prompt = f"'{person_name}' ta在{platform}上的昵称是{nickname_str}"
# if impression:
# relation_prompt += f"你对ta的印象是{impression}。"
# relation_prompt += f"你对ta的印象是{impression}。"
if random_points:
for point in random_points:
# print(f"point: {point}")
@@ -151,13 +146,12 @@ class RelationshipManager:
# print(f"point[0]: {point[0]}")
point_str = f"时间:{point[2]}。内容:{point[0]}"
relation_prompt += f"你记得{person_name}最近的点是:{point_str}"
return relation_prompt
async def _update_list_field(self, person_id: str, field_name: str, new_items: list) -> None:
"""更新列表类型的字段,将新项目添加到现有列表中
Args:
person_id: 用户ID
field_name: 字段名称
@@ -179,21 +173,21 @@ class RelationshipManager:
"""
person_name = await person_info_manager.get_value(person_id, "person_name")
nickname = await person_info_manager.get_value(person_id, "nickname")
alias_str = ", ".join(global_config.bot.alias_names)
personality_block = individuality.get_personality_prompt(x_person=2, level=2)
identity_block = individuality.get_identity_prompt(x_person=2, level=2)
# personality_block = individuality.get_personality_prompt(x_person=2, level=2)
# identity_block = individuality.get_identity_prompt(x_person=2, level=2)
user_messages = bot_engaged_messages
current_time = datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d %H:%M:%S")
# 匿名化消息
# 创建用户名称映射
name_mapping = {}
current_user = "A"
user_count = 1
# 遍历消息,构建映射
for msg in user_messages:
await person_info_manager.get_or_create_person(
@@ -206,37 +200,31 @@ class RelationshipManager:
replace_platform = msg.get("chat_info_platform")
replace_person_id = person_info_manager.get_person_id(replace_platform, replace_user_id)
replace_person_name = await person_info_manager.get_value(replace_person_id, "person_name")
# 跳过机器人自己
if replace_user_id == global_config.bot.qq_account:
name_mapping[f"{global_config.bot.nickname}"] = f"{global_config.bot.nickname}"
continue
# 跳过目标用户
if replace_person_name == person_name:
name_mapping[replace_person_name] = f"{person_name}"
continue
# 其他用户映射
if replace_person_name not in name_mapping:
if current_user > 'Z':
current_user = 'A'
if current_user > "Z":
current_user = "A"
user_count += 1
name_mapping[replace_person_name] = f"用户{current_user}{user_count if user_count > 1 else ''}"
current_user = chr(ord(current_user) + 1)
readable_messages = self.build_focus_readable_messages(
messages=user_messages,
target_person_id=person_id
)
readable_messages = self.build_focus_readable_messages(messages=user_messages, target_person_id=person_id)
for original_name, mapped_name in name_mapping.items():
# print(f"original_name: {original_name}, mapped_name: {mapped_name}")
readable_messages = readable_messages.replace(f"{original_name}", f"{mapped_name}")
prompt = f"""
你的名字是{global_config.bot.nickname}{global_config.bot.nickname}的别名是{alias_str}
请不要混淆你自己和{global_config.bot.nickname}{person_name}
@@ -271,22 +259,22 @@ class RelationshipManager:
"weight": 0
}}
"""
# 调用LLM生成印象
points, _ = await self.relationship_llm.generate_response_async(prompt=prompt)
points = points.strip()
# 还原用户名称
for original_name, mapped_name in name_mapping.items():
points = points.replace(mapped_name, original_name)
# logger.info(f"prompt: {prompt}")
# logger.info(f"points: {points}")
if not points:
logger.warning(f"未能从LLM获取 {person_name} 的新印象")
return
# 解析JSON并转换为元组列表
try:
points = repair_json(points)
@@ -307,7 +295,7 @@ class RelationshipManager:
except (KeyError, TypeError) as e:
logger.error(f"处理points数据失败: {e}, points: {points}")
return
current_points = await person_info_manager.get_value(person_id, "points") or []
if isinstance(current_points, str):
try:
@@ -318,7 +306,9 @@ class RelationshipManager:
elif not isinstance(current_points, list):
current_points = []
current_points.extend(points_list)
await person_info_manager.update_one_field(person_id, "points", json.dumps(current_points, ensure_ascii=False, indent=None))
await person_info_manager.update_one_field(
person_id, "points", json.dumps(current_points, ensure_ascii=False, indent=None)
)
# 将新记录添加到现有记录中
if isinstance(current_points, list):
@@ -326,14 +316,14 @@ class RelationshipManager:
for new_point in points_list:
similar_points = []
similar_indices = []
# 在现有points中查找相似的点
for i, existing_point in enumerate(current_points):
# 使用组合的相似度检查方法
if self.check_similarity(new_point[0], existing_point[0]):
similar_points.append(existing_point)
similar_indices.append(i)
if similar_points:
# 合并相似的点
all_points = [new_point] + similar_points
@@ -343,14 +333,14 @@ class RelationshipManager:
total_weight = sum(p[1] for p in all_points)
# 使用最长的描述
longest_desc = max(all_points, key=lambda x: len(x[0]))[0]
# 创建合并后的点
merged_point = (longest_desc, total_weight, latest_time)
# 从现有points中移除已合并的点
for idx in sorted(similar_indices, reverse=True):
current_points.pop(idx)
# 添加合并后的点
current_points.append(merged_point)
else:
@@ -359,7 +349,7 @@ class RelationshipManager:
else:
current_points = points_list
# 如果points超过10条按权重随机选择多余的条目移动到forgotten_points
# 如果points超过10条按权重随机选择多余的条目移动到forgotten_points
if len(current_points) > 10:
# 获取现有forgotten_points
forgotten_points = await person_info_manager.get_value(person_id, "forgotten_points") or []
@@ -371,29 +361,29 @@ class RelationshipManager:
forgotten_points = []
elif not isinstance(forgotten_points, list):
forgotten_points = []
# 计算当前时间
current_time = datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d %H:%M:%S")
# 计算每个点的最终权重(原始权重 * 时间权重)
weighted_points = []
for point in current_points:
time_weight = self.calculate_time_weight(point[2], current_time)
final_weight = point[1] * time_weight
weighted_points.append((point, final_weight))
# 计算总权重
total_weight = sum(w for _, w in weighted_points)
# 按权重随机选择要保留的点
remaining_points = []
points_to_move = []
# 对每个点进行随机选择
for point, weight in weighted_points:
# 计算保留概率(权重越高越可能保留)
keep_probability = weight / total_weight
if len(remaining_points) < 10:
# 如果还没达到30条直接保留
remaining_points.append(point)
@@ -407,28 +397,26 @@ class RelationshipManager:
else:
# 不保留这个点
points_to_move.append(point)
# 更新points和forgotten_points
current_points = remaining_points
forgotten_points.extend(points_to_move)
# 检查forgotten_points是否达到5条
if len(forgotten_points) >= 10:
# 构建压缩总结提示词
alias_str = ", ".join(global_config.bot.alias_names)
# 按时间排序forgotten_points
forgotten_points.sort(key=lambda x: x[2])
# 构建points文本
points_text = "\n".join([
f"时间:{point[2]}\n权重:{point[1]}\n内容:{point[0]}"
for point in forgotten_points
])
points_text = "\n".join(
[f"时间:{point[2]}\n权重:{point[1]}\n内容:{point[0]}" for point in forgotten_points]
)
impression = await person_info_manager.get_value(person_id, "impression") or ""
compress_prompt = f"""
你的名字是{global_config.bot.nickname}{global_config.bot.nickname}的别名是{alias_str}
请不要混淆你自己和{global_config.bot.nickname}{person_name}
@@ -449,88 +437,85 @@ class RelationshipManager:
"""
# 调用LLM生成压缩总结
compressed_summary, _ = await self.relationship_llm.generate_response_async(prompt=compress_prompt)
current_time = datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d %H:%M:%S")
compressed_summary = f"截至{current_time},你对{person_name}的了解:{compressed_summary}"
await person_info_manager.update_one_field(person_id, "impression", compressed_summary)
forgotten_points = []
# 这句代码的作用是:将更新后的 forgotten_points遗忘的记忆点列表序列化为 JSON 字符串后,写回到数据库中的 forgotten_points 字段
await person_info_manager.update_one_field(person_id, "forgotten_points", json.dumps(forgotten_points, ensure_ascii=False, indent=None))
await person_info_manager.update_one_field(
person_id, "forgotten_points", json.dumps(forgotten_points, ensure_ascii=False, indent=None)
)
# 更新数据库
await person_info_manager.update_one_field(person_id, "points", json.dumps(current_points, ensure_ascii=False, indent=None))
await person_info_manager.update_one_field(
person_id, "points", json.dumps(current_points, ensure_ascii=False, indent=None)
)
know_times = await person_info_manager.get_value(person_id, "know_times") or 0
await person_info_manager.update_one_field(person_id, "know_times", know_times + 1)
await person_info_manager.update_one_field(person_id, "last_know", timestamp)
logger.info(f"印象更新完成 for {person_name}")
def build_focus_readable_messages(self, messages: list, target_person_id: str = None) -> str:
"""格式化消息只保留目标用户和bot消息附近的内容"""
# 找到目标用户和bot的消息索引
target_indices = []
for i, msg in enumerate(messages):
user_id = msg.get("user_id")
platform = msg.get("chat_info_platform")
person_id = person_info_manager.get_person_id(platform, user_id)
if person_id == target_person_id:
target_indices.append(i)
if not target_indices:
return ""
# 获取需要保留的消息索引
keep_indices = set()
for idx in target_indices:
# 获取前后5条消息的索引
start_idx = max(0, idx - 5)
end_idx = min(len(messages), idx + 6)
keep_indices.update(range(start_idx, end_idx))
print(keep_indices)
# 将索引排序
keep_indices = sorted(list(keep_indices))
# 按顺序构建消息组
message_groups = []
current_group = []
for i in range(len(messages)):
if i in keep_indices:
current_group.append(messages[i])
elif current_group:
# 如果当前组不为空,且遇到不保留的消息,则结束当前组
if current_group:
message_groups.append(current_group)
current_group = []
# 添加最后一组
if current_group:
message_groups.append(current_group)
# 构建最终的消息文本
result = []
for i, group in enumerate(message_groups):
if i > 0:
result.append("...")
group_text = build_readable_messages(
messages=group,
replace_bot_name=True,
timestamp_mode="normal_no_YMD",
truncate=False
)
result.append(group_text)
return "\n".join(result)
"""格式化消息只保留目标用户和bot消息附近的内容"""
# 找到目标用户和bot的消息索引
target_indices = []
for i, msg in enumerate(messages):
user_id = msg.get("user_id")
platform = msg.get("chat_info_platform")
person_id = person_info_manager.get_person_id(platform, user_id)
if person_id == target_person_id:
target_indices.append(i)
if not target_indices:
return ""
# 获取需要保留的消息索引
keep_indices = set()
for idx in target_indices:
# 获取前后5条消息的索引
start_idx = max(0, idx - 5)
end_idx = min(len(messages), idx + 6)
keep_indices.update(range(start_idx, end_idx))
print(keep_indices)
# 将索引排序
keep_indices = sorted(list(keep_indices))
# 按顺序构建消息组
message_groups = []
current_group = []
for i in range(len(messages)):
if i in keep_indices:
current_group.append(messages[i])
elif current_group:
# 如果当前组不为空,且遇到不保留的消息,则结束当前组
if current_group:
message_groups.append(current_group)
current_group = []
# 添加最后一组
if current_group:
message_groups.append(current_group)
# 构建最终的消息文本
result = []
for i, group in enumerate(message_groups):
if i > 0:
result.append("...")
group_text = build_readable_messages(
messages=group, replace_bot_name=True, timestamp_mode="normal_no_YMD", truncate=False
)
result.append(group_text)
return "\n".join(result)
def calculate_time_weight(self, point_time: str, current_time: str) -> float:
"""计算基于时间的权重系数"""
try:
@@ -538,7 +523,7 @@ class RelationshipManager:
current_timestamp = datetime.strptime(current_time, "%Y-%m-%d %H:%M:%S")
time_diff = current_timestamp - point_timestamp
hours_diff = time_diff.total_seconds() / 3600
if hours_diff <= 1: # 1小时内
return 1.0
elif hours_diff <= 24: # 1-24小时
@@ -564,18 +549,18 @@ class RelationshipManager:
s1 = " ".join(str(x) for x in s1)
if isinstance(s2, list):
s2 = " ".join(str(x) for x in s2)
# 转换为字符串类型
s1 = str(s1)
s2 = str(s2)
# 1. 使用 jieba 进行分词
s1_words = " ".join(jieba.cut(s1))
s2_words = " ".join(jieba.cut(s2))
# 2. 将两句话放入一个列表中
corpus = [s1_words, s2_words]
# 3. 创建 TF-IDF 向量化器并进行计算
try:
vectorizer = TfidfVectorizer()
@@ -586,7 +571,7 @@ class RelationshipManager:
# 4. 计算余弦相似度
similarity_matrix = cosine_similarity(tfidf_matrix)
# 返回 s1 和 s2 的相似度
return similarity_matrix[0, 1]
@@ -599,20 +584,20 @@ class RelationshipManager:
def check_similarity(self, text1, text2, tfidf_threshold=0.5, seq_threshold=0.6):
"""
使用两种方法检查文本相似度,只要其中一种方法达到阈值就认为是相似的。
Args:
text1: 第一个文本
text2: 第二个文本
tfidf_threshold: TF-IDF相似度阈值
seq_threshold: SequenceMatcher相似度阈值
Returns:
bool: 如果任一方法达到阈值则返回True
"""
# 计算两种相似度
tfidf_sim = self.tfidf_similarity(text1, text2)
seq_sim = self.sequence_similarity(text1, text2)
# 只要其中一种方法达到阈值就认为是相似的
return tfidf_sim > tfidf_threshold or seq_sim > seq_threshold

View File

@@ -40,7 +40,7 @@ DEFAULT_CONFIG = {
"default_guidance_scale": 2.5,
"default_seed": 42,
"cache_enabled": True,
"cache_max_size": 10
"cache_max_size": 10,
}
@@ -49,37 +49,37 @@ def validate_and_fix_config(config_path: str) -> bool:
try:
with open(config_path, "r", encoding="utf-8") as f:
config = toml.load(f)
# 检查缺失的配置项
missing_keys = []
fixed = False
for key, default_value in DEFAULT_CONFIG.items():
if key not in config:
missing_keys.append(key)
config[key] = default_value
fixed = True
logger.info(f"添加缺失的配置项: {key} = {default_value}")
# 验证配置值的类型和范围
if isinstance(config.get("default_guidance_scale"), (int, float)):
if not 0.1 <= config["default_guidance_scale"] <= 20.0:
config["default_guidance_scale"] = 2.5
fixed = True
logger.info("修复无效的 default_guidance_scale 值")
if isinstance(config.get("default_seed"), (int, float)):
config["default_seed"] = int(config["default_seed"])
else:
config["default_seed"] = 42
fixed = True
logger.info("修复无效的 default_seed 值")
if config.get("cache_max_size") and not isinstance(config["cache_max_size"], int):
config["cache_max_size"] = 10
fixed = True
logger.info("修复无效的 cache_max_size 值")
# 如果有修复,写回文件
if fixed:
# 创建备份
@@ -87,14 +87,14 @@ def validate_and_fix_config(config_path: str) -> bool:
if os.path.exists(config_path):
os.rename(config_path, backup_path)
logger.info(f"已创建配置备份: {backup_path}")
# 写入修复后的配置
with open(config_path, "w", encoding="utf-8") as f:
toml.dump(config, f)
logger.info(f"配置文件已修复: {config_path}")
return True
except Exception as e:
logger.error(f"验证配置文件时出错: {e}")
return False

View File

@@ -37,15 +37,15 @@ class PicAction(PluginAction):
]
enable_plugin = False
action_config_file_name = "pic_action_config.toml"
# 激活类型设置
focus_activation_type = ActionActivationType.LLM_JUDGE # Focus模式使用LLM判定精确理解需求
normal_activation_type = ActionActivationType.KEYWORD # Normal模式使用关键词激活快速响应
normal_activation_type = ActionActivationType.KEYWORD # Normal模式使用关键词激活快速响应
# 关键词设置用于Normal模式
activation_keywords = ["", "绘制", "生成图片", "画图", "draw", "paint", "图片生成"]
keyword_case_sensitive = False
# LLM判定提示词用于Focus模式
llm_judge_prompt = """
判定是否需要使用图片生成动作的条件:
@@ -67,31 +67,31 @@ class PicAction(PluginAction):
4. 技术讨论中提到绘图概念但无生成需求
5. 用户明确表示不需要图片时
"""
# Random激活概率备用
random_activation_probability = 0.15 # 适中概率,图片生成比较有趣
# 简单的请求缓存,避免短时间内重复请求
_request_cache = {}
_cache_max_size = 10
# 模式启用设置 - 图片生成在所有模式下可用
mode_enable = ChatMode.ALL
# 并行执行设置 - 图片生成可以与回复并行执行,不覆盖回复内容
parallel_action = False
@classmethod
def _get_cache_key(cls, description: str, model: str, size: str) -> str:
"""生成缓存键"""
return f"{description[:100]}|{model}|{size}" # 限制描述长度避免键过长
@classmethod
def _cleanup_cache(cls):
"""清理缓存,保持大小在限制内"""
if len(cls._request_cache) > cls._cache_max_size:
# 简单的FIFO策略移除最旧的条目
keys_to_remove = list(cls._request_cache.keys())[:-cls._cache_max_size//2]
keys_to_remove = list(cls._request_cache.keys())[: -cls._cache_max_size // 2]
for key in keys_to_remove:
del cls._request_cache[key]
@@ -169,7 +169,7 @@ class PicAction(PluginAction):
cached_result = self._request_cache[cache_key]
logger.info(f"{self.log_prefix} 使用缓存的图片结果")
await self.send_message_by_expressor("我之前画过类似的图片,用之前的结果~")
# 直接发送缓存的结果
send_success = await self.send_message(type="image", data=cached_result)
if send_success:
@@ -258,7 +258,7 @@ class PicAction(PluginAction):
# 缓存成功的结果
self._request_cache[cache_key] = base64_image_string
self._cleanup_cache()
await self.send_message_by_expressor("图片表情已发送!")
return True, "图片表情已发送"
else:
@@ -370,7 +370,7 @@ class PicAction(PluginAction):
def _validate_image_size(self, image_size: str) -> bool:
"""验证图片尺寸格式"""
try:
width, height = map(int, image_size.split('x'))
width, height = map(int, image_size.split("x"))
return 100 <= width <= 10000 and 100 <= height <= 10000
except (ValueError, TypeError):
return False

View File

@@ -11,4 +11,4 @@
- 用户输入特定格式的命令时触发
- 通过命令前缀(如/)快速执行特定功能
- 提供快速响应的交互方式
"""
"""

View File

@@ -1,4 +1,4 @@
"""示例命令包
包含示例命令的实现
"""
"""

View File

@@ -5,27 +5,28 @@ import random
logger = get_logger("custom_prefix_command")
@register_command
class DiceCommand(BaseCommand):
"""骰子命令,使用!前缀而不是/前缀"""
command_name = "dice"
command_description = "骰子命令随机生成1-6的数字"
command_pattern = r"^[!](?:dice|骰子)(?:\s+(?P<count>\d+))?$" # 匹配 !dice 或 !骰子,可选参数为骰子数量
command_help = "使用方法: !dice [数量] 或 !骰子 [数量] - 掷骰子默认掷1个"
command_examples = ["!dice", "!骰子", "!dice 3", "!骰子 5"]
enable_command = True
async def execute(self) -> Tuple[bool, Optional[str]]:
"""执行骰子命令
Returns:
Tuple[bool, Optional[str]]: (是否执行成功, 回复消息)
"""
try:
# 获取骰子数量默认为1
count_str = self.matched_groups.get("count")
# 确保count_str不为None
if count_str is None:
count = 1 # 默认值
@@ -38,10 +39,10 @@ class DiceCommand(BaseCommand):
return False, "一次最多只能掷10个骰子"
except ValueError:
return False, "骰子数量必须是整数"
# 生成随机数
results = [random.randint(1, 6) for _ in range(count)]
# 构建回复消息
if count == 1:
message = f"🎲 掷出了 {results[0]}"
@@ -49,10 +50,10 @@ class DiceCommand(BaseCommand):
dice_results = ", ".join(map(str, results))
total = sum(results)
message = f"🎲 掷出了 {count} 个骰子: [{dice_results}],总点数: {total}"
logger.info(f"{self.log_prefix} 执行骰子命令: {message}")
return True, message
except Exception as e:
logger.error(f"{self.log_prefix} 执行骰子命令时出错: {e}")
return False, f"执行命令时出错: {str(e)}"
return False, f"执行命令时出错: {str(e)}"

View File

@@ -4,90 +4,86 @@ from typing import Tuple, Optional
logger = get_logger("help_command")
@register_command
class HelpCommand(BaseCommand):
"""帮助命令,显示所有可用命令的帮助信息"""
command_name = "help"
command_description = "显示所有可用命令的帮助信息"
command_pattern = r"^/help(?:\s+(?P<command>\w+))?$" # 匹配 /help 或 /help 命令名
command_help = "使用方法: /help [命令名] - 显示所有命令或特定命令的帮助信息"
command_examples = ["/help", "/help echo"]
enable_command = True
async def execute(self) -> Tuple[bool, Optional[str]]:
"""执行帮助命令
Returns:
Tuple[bool, Optional[str]]: (是否执行成功, 回复消息)
"""
try:
# 获取匹配到的命令名(如果有)
command_name = self.matched_groups.get("command")
# 如果指定了命令名,显示该命令的详细帮助
if command_name:
logger.info(f"{self.log_prefix} 查询命令帮助: {command_name}")
return self._show_command_help(command_name)
# 否则,显示所有命令的简要帮助
logger.info(f"{self.log_prefix} 查询所有命令帮助")
return self._show_all_commands()
except Exception as e:
logger.error(f"{self.log_prefix} 执行帮助命令时出错: {e}")
return False, f"执行命令时出错: {str(e)}"
def _show_command_help(self, command_name: str) -> Tuple[bool, str]:
"""显示特定命令的详细帮助信息
Args:
command_name: 命令名称
Returns:
Tuple[bool, str]: (是否执行成功, 回复消息)
"""
# 查找命令
command_cls = _COMMAND_REGISTRY.get(command_name)
if not command_cls:
return False, f"未找到命令: {command_name}"
# 获取命令信息
description = getattr(command_cls, "command_description", "无描述")
help_text = getattr(command_cls, "command_help", "无帮助信息")
examples = getattr(command_cls, "command_examples", [])
# 构建帮助信息
help_info = [
f"【命令】: {command_name}",
f"【描述】: {description}",
f"【用法】: {help_text}"
]
help_info = [f"【命令】: {command_name}", f"【描述】: {description}", f"【用法】: {help_text}"]
# 添加示例
if examples:
help_info.append("【示例】:")
for example in examples:
help_info.append(f" {example}")
return True, "\n".join(help_info)
def _show_all_commands(self) -> Tuple[bool, str]:
"""显示所有可用命令的简要帮助信息
Returns:
Tuple[bool, str]: (是否执行成功, 回复消息)
"""
# 获取所有已启用的命令
enabled_commands = {
name: cls for name, cls in _COMMAND_REGISTRY.items()
if getattr(cls, "enable_command", True)
name: cls for name, cls in _COMMAND_REGISTRY.items() if getattr(cls, "enable_command", True)
}
if not enabled_commands:
return True, "当前没有可用的命令"
# 构建命令列表
command_list = ["可用命令列表:"]
for name, cls in sorted(enabled_commands.items()):
@@ -107,9 +103,9 @@ class HelpCommand(BaseCommand):
else:
# 默认使用/name作为前缀
prefix = f"/{name}"
command_list.append(f"{prefix} - {description}")
command_list.append("\n使用 /help <命令名> 获取特定命令的详细帮助")
return True, "\n".join(command_list)
return True, "\n".join(command_list)

View File

@@ -1,43 +1,43 @@
from src.common.logger_manager import get_logger
from src.chat.command.command_handler import BaseCommand, register_command
from typing import Tuple, Optional
import json
logger = get_logger("message_info_command")
@register_command
class MessageInfoCommand(BaseCommand):
"""消息信息查看命令,展示发送命令的原始消息和相关信息"""
command_name = "msginfo"
command_description = "查看发送命令的原始消息信息"
command_pattern = r"^/msginfo(?:\s+(?P<detail>full|simple))?$"
command_help = "使用方法: /msginfo [full|simple] - 查看当前消息的详细信息"
command_examples = ["/msginfo", "/msginfo full", "/msginfo simple"]
enable_command = True
async def execute(self) -> Tuple[bool, Optional[str]]:
"""执行消息信息查看命令"""
try:
detail_level = self.matched_groups.get("detail", "simple")
logger.info(f"{self.log_prefix} 查看消息信息,详细级别: {detail_level}")
if detail_level == "full":
info_text = self._get_full_message_info()
else:
info_text = self._get_simple_message_info()
return True, info_text
except Exception as e:
logger.error(f"{self.log_prefix} 获取消息信息时出错: {e}")
return False, f"获取消息信息失败: {str(e)}"
def _get_simple_message_info(self) -> str:
"""获取简化的消息信息"""
message = self.message
# 基础信息
info_lines = [
"📨 消息信息概览",
@@ -45,157 +45,181 @@ class MessageInfoCommand(BaseCommand):
f"⏰ 时间: {message.message_info.time}",
f"🌐 平台: {message.message_info.platform}",
]
# 发送者信息
user = message.message_info.user_info
info_lines.extend([
"",
"👤 发送者信息:",
f" 用户ID: {user.user_id}",
f" 昵称: {user.user_nickname}",
f" 群名片: {user.user_cardname or ''}",
])
info_lines.extend(
[
"",
"👤 发送者信息:",
f" 用户ID: {user.user_id}",
f" 昵称: {user.user_nickname}",
f" 群名片: {user.user_cardname or ''}",
]
)
# 群聊信息(如果是群聊)
if message.message_info.group_info:
group = message.message_info.group_info
info_lines.extend([
"",
"👥 群聊信息:",
f" ID: {group.group_id}",
f": {group.group_name or '未知'}",
])
info_lines.extend(
[
"",
"👥聊信息:",
f"ID: {group.group_id}",
f" 群名: {group.group_name or '未知'}",
]
)
else:
info_lines.extend([
"",
"💬 消息类型: 私聊消息",
])
info_lines.extend(
[
"",
"💬 消息类型: 私聊消息",
]
)
# 消息内容
info_lines.extend([
"",
"📝 消息内容:",
f" 原始文本: {message.processed_plain_text}",
f" 是否表情: {'' if getattr(message, 'is_emoji', False) else ''}",
])
# 聊天流信息
if hasattr(message, 'chat_stream') and message.chat_stream:
chat_stream = message.chat_stream
info_lines.extend([
info_lines.extend(
[
"",
"🔄 聊天流信息:",
f" 流ID: {chat_stream.stream_id}",
f" 是否激活: {'' if chat_stream.is_active else ''}",
])
"📝 消息内容:",
f" 原始文本: {message.processed_plain_text}",
f" 是否表情: {'' if getattr(message, 'is_emoji', False) else ''}",
]
)
# 聊天流信息
if hasattr(message, "chat_stream") and message.chat_stream:
chat_stream = message.chat_stream
info_lines.extend(
[
"",
"🔄 聊天流信息:",
f" 流ID: {chat_stream.stream_id}",
f" 是否激活: {'' if chat_stream.is_active else ''}",
]
)
return "\n".join(info_lines)
def _get_full_message_info(self) -> str:
"""获取完整的消息信息(包含技术细节)"""
message = self.message
info_lines = [
"📨 完整消息信息",
"=" * 40,
]
# 消息基础信息
info_lines.extend([
"",
"🔍 基础消息信息:",
f" 消息ID: {message.message_info.message_id}",
f" 时间戳: {message.message_info.time}",
f" 平台: {message.message_info.platform}",
f" 处理后文本: {message.processed_plain_text}",
f" 详细文本: {message.detailed_plain_text[:100]}{'...' if len(message.detailed_plain_text) > 100 else ''}",
])
info_lines.extend(
[
"",
"🔍 基础消息信息:",
f" 消息ID: {message.message_info.message_id}",
f" 时间戳: {message.message_info.time}",
f" 平台: {message.message_info.platform}",
f" 处理后文本: {message.processed_plain_text}",
f" 详细文本: {message.detailed_plain_text[:100]}{'...' if len(message.detailed_plain_text) > 100 else ''}",
]
)
# 用户详细信息
user = message.message_info.user_info
info_lines.extend([
"",
"👤 发送者详细信息:",
f" 用户ID: {user.user_id}",
f" 昵称: {user.user_nickname}",
f" 群名片: {user.user_cardname or ''}",
f" 平台: {user.platform}",
])
info_lines.extend(
[
"",
"👤 发送者详细信息:",
f" 用户ID: {user.user_id}",
f" 昵称: {user.user_nickname}",
f" 群名片: {user.user_cardname or ''}",
f" 平台: {user.platform}",
]
)
# 群聊详细信息
if message.message_info.group_info:
group = message.message_info.group_info
info_lines.extend([
"",
"👥 群聊详细信息:",
f" ID: {group.group_id}",
f": {group.group_name or '未知'}",
f" 平台: {group.platform}",
])
info_lines.extend(
[
"",
"👥聊详细信息:",
f"ID: {group.group_id}",
f" 群名: {group.group_name or '未知'}",
f" 平台: {group.platform}",
]
)
else:
info_lines.append("\n💬 消息类型: 私聊消息")
# 消息段信息
if message.message_segment:
info_lines.extend([
"",
"📦 消息段信息:",
f" 类型: {message.message_segment.type}",
f" 数据类型: {type(message.message_segment.data).__name__}",
f" 数据预览: {str(message.message_segment.data)[:200]}{'...' if len(str(message.message_segment.data)) > 200 else ''}",
])
info_lines.extend(
[
"",
"📦 消息段信息:",
f" 类型: {message.message_segment.type}",
f" 数据类型: {type(message.message_segment.data).__name__}",
f" 数据预览: {str(message.message_segment.data)[:200]}{'...' if len(str(message.message_segment.data)) > 200 else ''}",
]
)
# 聊天流详细信息
if hasattr(message, 'chat_stream') and message.chat_stream:
if hasattr(message, "chat_stream") and message.chat_stream:
chat_stream = message.chat_stream
info_lines.extend([
"",
"🔄 聊天流详细信息:",
f" 流ID: {chat_stream.stream_id}",
f" 平台: {chat_stream.platform}",
f" 是否激活: {'' if chat_stream.is_active else ''}",
f" 用户信息: {chat_stream.user_info.user_nickname} ({chat_stream.user_info.user_id})",
f" 信息: {getattr(chat_stream.group_info, 'group_name', '私聊') if chat_stream.group_info else '私聊'}",
])
info_lines.extend(
[
"",
"🔄 聊天流详细信息:",
f" 流ID: {chat_stream.stream_id}",
f" 平台: {chat_stream.platform}",
f" 是否激活: {'' if chat_stream.is_active else ''}",
f" 用户信息: {chat_stream.user_info.user_nickname} ({chat_stream.user_info.user_id})",
f" 群信息: {getattr(chat_stream.group_info, 'group_name', '私聊') if chat_stream.group_info else '私聊'}",
]
)
# 回复信息
if hasattr(message, 'reply') and message.reply:
info_lines.extend([
"",
"↩️ 回复信息:",
f" 回复消息ID: {message.reply.message_info.message_id}",
f" 回复内容: {message.reply.processed_plain_text[:100]}{'...' if len(message.reply.processed_plain_text) > 100 else ''}",
])
if hasattr(message, "reply") and message.reply:
info_lines.extend(
[
"",
"↩️ 回复信息:",
f" 回复消息ID: {message.reply.message_info.message_id}",
f" 回复内容: {message.reply.processed_plain_text[:100]}{'...' if len(message.reply.processed_plain_text) > 100 else ''}",
]
)
# 原始消息数据(如果存在)
if hasattr(message, 'raw_message') and message.raw_message:
info_lines.extend([
"",
"🗂️ 原始消息数据:",
f" 数据类型: {type(message.raw_message).__name__}",
f" 数据大小: {len(str(message.raw_message))} 字符",
])
if hasattr(message, "raw_message") and message.raw_message:
info_lines.extend(
[
"",
"🗂️ 原始消息数据:",
f" 数据类型: {type(message.raw_message).__name__}",
f" 数据大小: {len(str(message.raw_message))} 字符",
]
)
return "\n".join(info_lines)
@register_command
class SenderInfoCommand(BaseCommand):
"""发送者信息命令,快速查看发送者信息"""
command_name = "whoami"
command_description = "查看发送命令的用户信息"
command_pattern = r"^/whoami$"
command_help = "使用方法: /whoami - 查看你的用户信息"
command_examples = ["/whoami"]
enable_command = True
async def execute(self) -> Tuple[bool, Optional[str]]:
"""执行发送者信息查看命令"""
try:
user = self.message.message_info.user_info
group = self.message.message_info.group_info
info_lines = [
"👤 你的身份信息",
f"🆔 用户ID: {user.user_id}",
@@ -203,19 +227,21 @@ class SenderInfoCommand(BaseCommand):
f"🏷️ 群名片: {user.user_cardname or ''}",
f"🌐 平台: {user.platform}",
]
if group:
info_lines.extend([
"",
"👥 当前群聊:",
f"🆔 群ID: {group.group_id}",
f"📝: {group.group_name or '未知'}",
])
info_lines.extend(
[
"",
"👥 当前群聊:",
f"🆔ID: {group.group_id}",
f"📝 群名: {group.group_name or '未知'}",
]
)
else:
info_lines.append("\n💬 当前在私聊中")
return True, "\n".join(info_lines)
except Exception as e:
logger.error(f"{self.log_prefix} 获取发送者信息时出错: {e}")
return False, f"获取发送者信息失败: {str(e)}"
@@ -224,59 +250,65 @@ class SenderInfoCommand(BaseCommand):
@register_command
class ChatStreamInfoCommand(BaseCommand):
"""聊天流信息命令"""
command_name = "streaminfo"
command_description = "查看当前聊天流的详细信息"
command_pattern = r"^/streaminfo$"
command_help = "使用方法: /streaminfo - 查看当前聊天流信息"
command_examples = ["/streaminfo"]
enable_command = True
async def execute(self) -> Tuple[bool, Optional[str]]:
"""执行聊天流信息查看命令"""
try:
if not hasattr(self.message, 'chat_stream') or not self.message.chat_stream:
if not hasattr(self.message, "chat_stream") or not self.message.chat_stream:
return False, "无法获取聊天流信息"
chat_stream = self.message.chat_stream
info_lines = [
"🔄 聊天流信息",
f"🆔 流ID: {chat_stream.stream_id}",
f"🌐 平台: {chat_stream.platform}",
f"⚡ 状态: {'激活' if chat_stream.is_active else '非激活'}",
]
# 用户信息
if chat_stream.user_info:
info_lines.extend([
"",
"👤 关联用户:",
f" ID: {chat_stream.user_info.user_id}",
f" 昵称: {chat_stream.user_info.user_nickname}",
])
info_lines.extend(
[
"",
"👤 关联用户:",
f" ID: {chat_stream.user_info.user_id}",
f" 昵称: {chat_stream.user_info.user_nickname}",
]
)
# 群信息
if chat_stream.group_info:
info_lines.extend([
"",
"👥 关联群聊:",
f" 群ID: {chat_stream.group_info.group_id}",
f": {chat_stream.group_info.group_name or '未知'}",
])
info_lines.extend(
[
"",
"👥 关联群聊:",
f"ID: {chat_stream.group_info.group_id}",
f" 群名: {chat_stream.group_info.group_name or '未知'}",
]
)
else:
info_lines.append("\n💬 类型: 私聊流")
# 最近消息统计
if hasattr(chat_stream, 'last_messages'):
if hasattr(chat_stream, "last_messages"):
msg_count = len(chat_stream.last_messages)
info_lines.extend([
"",
f"📈 消息统计: 记录了 {msg_count} 条最近消息",
])
info_lines.extend(
[
"",
f"📈 消息统计: 记录了 {msg_count} 条最近消息",
]
)
return True, "\n".join(info_lines)
except Exception as e:
logger.error(f"{self.log_prefix} 获取聊天流信息时出错: {e}")
return False, f"获取聊天流信息失败: {str(e)}"
return False, f"获取聊天流信息失败: {str(e)}"

View File

@@ -5,43 +5,41 @@ from typing import Tuple, Optional
logger = get_logger("send_msg_command")
@register_command
class SendMessageCommand(BaseCommand, MessageAPI):
"""发送消息命令,可以向指定群聊或私聊发送消息"""
command_name = "send"
command_description = "向指定群聊或私聊发送消息"
command_pattern = r"^/send\s+(?P<target_type>group|user)\s+(?P<target_id>\d+)\s+(?P<content>.+)$"
command_help = "使用方法: /send <group|user> <ID> <消息内容> - 发送消息到指定群聊或用户"
command_examples = [
"/send group 123456789 大家好!",
"/send user 987654321 私聊消息"
]
command_examples = ["/send group 123456789 大家好!", "/send user 987654321 私聊消息"]
enable_command = True
def __init__(self, message):
super().__init__(message)
# 初始化MessageAPI需要的服务虽然这里不会用到但保持一致性
self._services = {}
self.log_prefix = f"[Command:{self.command_name}]"
async def execute(self) -> Tuple[bool, Optional[str]]:
"""执行发送消息命令
Returns:
Tuple[bool, Optional[str]]: (是否执行成功, 回复消息)
"""
try:
# 获取匹配到的参数
target_type = self.matched_groups.get("target_type") # group 或 user
target_id = self.matched_groups.get("target_id") # 群ID或用户ID
content = self.matched_groups.get("content") # 消息内容
target_id = self.matched_groups.get("target_id") # 群ID或用户ID
content = self.matched_groups.get("content") # 消息内容
if not all([target_type, target_id, content]):
return False, "命令参数不完整,请检查格式"
logger.info(f"{self.log_prefix} 执行发送消息命令: {target_type}:{target_id} -> {content[:50]}...")
# 根据目标类型调用不同的发送方法
if target_type == "group":
success = await self._send_to_group(target_id, content)
@@ -51,24 +49,24 @@ class SendMessageCommand(BaseCommand, MessageAPI):
target_desc = f"用户 {target_id}"
else:
return False, f"不支持的目标类型: {target_type},只支持 group 或 user"
# 返回执行结果
if success:
return True, f"✅ 消息已成功发送到 {target_desc}"
else:
return False, f"❌ 消息发送失败,可能是目标 {target_desc} 不存在或没有权限"
except Exception as e:
logger.error(f"{self.log_prefix} 执行发送消息命令时出错: {e}")
return False, f"命令执行出错: {str(e)}"
async def _send_to_group(self, group_id: str, content: str) -> bool:
"""发送消息到群聊
Args:
group_id: 群聊ID
content: 消息内容
Returns:
bool: 是否发送成功
"""
@@ -76,27 +74,27 @@ class SendMessageCommand(BaseCommand, MessageAPI):
success = await self.send_text_to_group(
text=content,
group_id=group_id,
platform="qq" # 默认使用QQ平台
platform="qq", # 默认使用QQ平台
)
if success:
logger.info(f"{self.log_prefix} 成功发送消息到群聊 {group_id}")
else:
logger.warning(f"{self.log_prefix} 发送消息到群聊 {group_id} 失败")
return success
except Exception as e:
logger.error(f"{self.log_prefix} 发送群聊消息时出错: {e}")
return False
async def _send_to_user(self, user_id: str, content: str) -> bool:
"""发送消息到私聊
Args:
user_id: 用户ID
content: 消息内容
Returns:
bool: 是否发送成功
"""
@@ -104,16 +102,16 @@ class SendMessageCommand(BaseCommand, MessageAPI):
success = await self.send_text_to_user(
text=content,
user_id=user_id,
platform="qq" # 默认使用QQ平台
platform="qq", # 默认使用QQ平台
)
if success:
logger.info(f"{self.log_prefix} 成功发送消息到用户 {user_id}")
else:
logger.warning(f"{self.log_prefix} 发送消息到用户 {user_id} 失败")
return success
except Exception as e:
logger.error(f"{self.log_prefix} 发送私聊消息时出错: {e}")
return False
return False

View File

@@ -5,10 +5,11 @@ from typing import Tuple, Optional
logger = get_logger("send_msg_enhanced")
@register_command
class SendMessageEnhancedCommand(BaseCommand, MessageAPI):
"""增强版发送消息命令,支持多种消息类型和平台"""
command_name = "sendfull"
command_description = "增强版消息发送命令,支持多种类型和平台"
command_pattern = r"^/sendfull\s+(?P<msg_type>text|image|emoji)\s+(?P<target_type>group|user)\s+(?P<target_id>\d+)(?:\s+(?P<platform>\w+))?\s+(?P<content>.+)$"
@@ -17,108 +18,93 @@ class SendMessageEnhancedCommand(BaseCommand, MessageAPI):
"/sendfull text group 123456789 qq 大家好!这是文本消息",
"/sendfull image user 987654321 https://example.com/image.jpg",
"/sendfull emoji group 123456789 😄",
"/sendfull text user 987654321 qq 私聊消息"
"/sendfull text user 987654321 qq 私聊消息",
]
enable_command = True
def __init__(self, message):
super().__init__(message)
self._services = {}
self.log_prefix = f"[Command:{self.command_name}]"
async def execute(self) -> Tuple[bool, Optional[str]]:
"""执行增强版发送消息命令"""
try:
# 获取匹配参数
msg_type = self.matched_groups.get("msg_type") # 消息类型: text/image/emoji
target_type = self.matched_groups.get("target_type") # 目标类型: group/user
target_id = self.matched_groups.get("target_id") # 目标ID
platform = self.matched_groups.get("platform") or "qq" # 平台默认qq
content = self.matched_groups.get("content") # 内容
msg_type = self.matched_groups.get("msg_type") # 消息类型: text/image/emoji
target_type = self.matched_groups.get("target_type") # 目标类型: group/user
target_id = self.matched_groups.get("target_id") # 目标ID
platform = self.matched_groups.get("platform") or "qq" # 平台默认qq
content = self.matched_groups.get("content") # 内容
if not all([msg_type, target_type, target_id, content]):
return False, "命令参数不完整,请检查格式"
# 验证消息类型
valid_types = ["text", "image", "emoji"]
if msg_type not in valid_types:
return False, f"不支持的消息类型: {msg_type},支持的类型: {', '.join(valid_types)}"
# 验证目标类型
if target_type not in ["group", "user"]:
return False, "目标类型只能是 group 或 user"
logger.info(f"{self.log_prefix} 执行发送命令: {msg_type} -> {target_type}:{target_id} (平台:{platform})")
# 根据消息类型和目标类型发送消息
is_group = (target_type == "group")
is_group = target_type == "group"
success = await self.send_message_to_target(
message_type=msg_type,
content=content,
platform=platform,
target_id=target_id,
is_group=is_group
message_type=msg_type, content=content, platform=platform, target_id=target_id, is_group=is_group
)
# 构建结果消息
target_desc = f"{'群聊' if is_group else '用户'} {target_id} (平台: {platform})"
msg_type_desc = {
"text": "文本",
"image": "图片",
"emoji": "表情"
}.get(msg_type, msg_type)
msg_type_desc = {"text": "文本", "image": "图片", "emoji": "表情"}.get(msg_type, msg_type)
if success:
return True, f"{msg_type_desc}消息已成功发送到 {target_desc}"
else:
return False, f"{msg_type_desc}消息发送失败,可能是目标 {target_desc} 不存在或没有权限"
except Exception as e:
logger.error(f"{self.log_prefix} 执行增强发送命令时出错: {e}")
return False, f"命令执行出错: {str(e)}"
@register_command
@register_command
class SendQuickCommand(BaseCommand, MessageAPI):
"""快速发送文本消息命令"""
command_name = "msg"
command_description = "快速发送文本消息到群聊"
command_pattern = r"^/msg\s+(?P<group_id>\d+)\s+(?P<content>.+)$"
command_help = "使用方法: /msg <群ID> <消息内容> - 快速发送文本到指定群聊"
command_examples = [
"/msg 123456789 大家好!",
"/msg 987654321 这是一条快速消息"
]
command_examples = ["/msg 123456789 大家好!", "/msg 987654321 这是一条快速消息"]
enable_command = True
def __init__(self, message):
super().__init__(message)
self._services = {}
self.log_prefix = f"[Command:{self.command_name}]"
async def execute(self) -> Tuple[bool, Optional[str]]:
"""执行快速发送消息命令"""
try:
group_id = self.matched_groups.get("group_id")
content = self.matched_groups.get("content")
if not all([group_id, content]):
return False, "命令参数不完整"
logger.info(f"{self.log_prefix} 快速发送到群 {group_id}: {content[:50]}...")
success = await self.send_text_to_group(
text=content,
group_id=group_id,
platform="qq"
)
success = await self.send_text_to_group(text=content, group_id=group_id, platform="qq")
if success:
return True, f"✅ 消息已发送到群 {group_id}"
else:
return False, f"❌ 发送到群 {group_id} 失败"
except Exception as e:
logger.error(f"{self.log_prefix} 快速发送命令出错: {e}")
return False, f"发送失败: {str(e)}"
@@ -127,44 +113,37 @@ class SendQuickCommand(BaseCommand, MessageAPI):
@register_command
class SendPrivateCommand(BaseCommand, MessageAPI):
"""发送私聊消息命令"""
command_name = "pm"
command_description = "发送私聊消息到指定用户"
command_pattern = r"^/pm\s+(?P<user_id>\d+)\s+(?P<content>.+)$"
command_help = "使用方法: /pm <用户ID> <消息内容> - 发送私聊消息"
command_examples = [
"/pm 123456789 你好!",
"/pm 987654321 这是私聊消息"
]
command_examples = ["/pm 123456789 你好!", "/pm 987654321 这是私聊消息"]
enable_command = True
def __init__(self, message):
super().__init__(message)
self._services = {}
self.log_prefix = f"[Command:{self.command_name}]"
async def execute(self) -> Tuple[bool, Optional[str]]:
"""执行私聊发送命令"""
try:
user_id = self.matched_groups.get("user_id")
content = self.matched_groups.get("content")
if not all([user_id, content]):
return False, "命令参数不完整"
logger.info(f"{self.log_prefix} 发送私聊到用户 {user_id}: {content[:50]}...")
success = await self.send_text_to_user(
text=content,
user_id=user_id,
platform="qq"
)
success = await self.send_text_to_user(text=content, user_id=user_id, platform="qq")
if success:
return True, f"✅ 私聊消息已发送到用户 {user_id}"
else:
return False, f"❌ 发送私聊到用户 {user_id} 失败"
except Exception as e:
logger.error(f"{self.log_prefix} 私聊发送命令出错: {e}")
return False, f"私聊发送失败: {str(e)}"
return False, f"私聊发送失败: {str(e)}"

View File

@@ -6,173 +6,153 @@ import time
logger = get_logger("send_msg_with_context")
@register_command
class ContextAwareSendCommand(BaseCommand, MessageAPI):
"""上下文感知的发送消息命令,展示如何利用原始消息信息"""
command_name = "csend"
command_description = "带上下文感知的发送消息命令"
command_pattern = r"^/csend\s+(?P<target_type>group|user|here|reply)\s+(?P<target_id_or_content>.*?)(?:\s+(?P<content>.*))?$"
command_pattern = (
r"^/csend\s+(?P<target_type>group|user|here|reply)\s+(?P<target_id_or_content>.*?)(?:\s+(?P<content>.*))?$"
)
command_help = "使用方法: /csend <target_type> <参数> [内容]"
command_examples = [
"/csend group 123456789 大家好!",
"/csend user 987654321 私聊消息",
"/csend user 987654321 私聊消息",
"/csend here 在当前聊天发送",
"/csend reply 回复当前群/私聊"
"/csend reply 回复当前群/私聊",
]
enable_command = True
# 管理员用户ID列表示例
ADMIN_USERS = ["123456789", "987654321"] # 可以从配置文件读取
def __init__(self, message):
super().__init__(message)
self._services = {}
self.log_prefix = f"[Command:{self.command_name}]"
async def execute(self) -> Tuple[bool, Optional[str]]:
"""执行上下文感知的发送命令"""
try:
# 获取命令发送者信息
sender = self.message.message_info.user_info
current_group = self.message.message_info.group_info
# 权限检查
if not self._check_permission(sender.user_id):
return False, f"❌ 权限不足,只有管理员可以使用此命令\n你的ID: {sender.user_id}"
# 解析命令参数
target_type = self.matched_groups.get("target_type")
target_id_or_content = self.matched_groups.get("target_id_or_content", "")
content = self.matched_groups.get("content", "")
# 根据目标类型处理不同情况
if target_type == "here":
# 发送到当前聊天
return await self._send_to_current_chat(target_id_or_content, sender, current_group)
elif target_type == "reply":
# 回复到当前聊天,带发送者信息
return await self._send_reply_with_context(target_id_or_content, sender, current_group)
elif target_type in ["group", "user"]:
# 发送到指定目标
if not content:
return False, "指定群聊或用户时需要提供消息内容"
return await self._send_to_target(target_type, target_id_or_content, content, sender)
else:
return False, f"不支持的目标类型: {target_type}"
except Exception as e:
logger.error(f"{self.log_prefix} 执行上下文感知发送命令时出错: {e}")
return False, f"命令执行出错: {str(e)}"
def _check_permission(self, user_id: str) -> bool:
"""检查用户权限"""
return user_id in self.ADMIN_USERS
async def _send_to_current_chat(self, content: str, sender, current_group) -> Tuple[bool, str]:
"""发送到当前聊天"""
if not content:
return False, "消息内容不能为空"
# 构建带发送者信息的消息
timestamp = time.strftime("%H:%M:%S", time.localtime())
if current_group:
# 群聊
formatted_content = f"[管理员转发 {timestamp}] {sender.user_nickname}({sender.user_id}): {content}"
success = await self.send_text_to_group(
text=formatted_content,
group_id=current_group.group_id,
platform="qq"
text=formatted_content, group_id=current_group.group_id, platform="qq"
)
target_desc = f"当前群聊 {current_group.group_name}({current_group.group_id})"
else:
# 私聊
formatted_content = f"[管理员消息 {timestamp}]: {content}"
success = await self.send_text_to_user(
text=formatted_content,
user_id=sender.user_id,
platform="qq"
)
success = await self.send_text_to_user(text=formatted_content, user_id=sender.user_id, platform="qq")
target_desc = "当前私聊"
if success:
return True, f"✅ 消息已发送到{target_desc}"
else:
return False, f"❌ 发送到{target_desc}失败"
async def _send_reply_with_context(self, content: str, sender, current_group) -> Tuple[bool, str]:
"""发送回复,带完整上下文信息"""
if not content:
return False, "回复内容不能为空"
# 获取当前时间和环境信息
timestamp = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
# 构建上下文信息
context_info = [
f"📢 管理员回复 [{timestamp}]",
f"👤 发送者: {sender.user_nickname}({sender.user_id})",
]
if current_group:
context_info.append(f"👥 当前群聊: {current_group.group_name}({current_group.group_id})")
target_desc = f"群聊 {current_group.group_name}"
else:
context_info.append("💬 当前环境: 私聊")
target_desc = "私聊"
context_info.extend([
f"📝 回复内容: {content}",
"" * 30
])
context_info.extend([f"📝 回复内容: {content}", "" * 30])
formatted_content = "\n".join(context_info)
# 发送消息
if current_group:
success = await self.send_text_to_group(
text=formatted_content,
group_id=current_group.group_id,
platform="qq"
text=formatted_content, group_id=current_group.group_id, platform="qq"
)
else:
success = await self.send_text_to_user(
text=formatted_content,
user_id=sender.user_id,
platform="qq"
)
success = await self.send_text_to_user(text=formatted_content, user_id=sender.user_id, platform="qq")
if success:
return True, f"✅ 带上下文的回复已发送到{target_desc}"
else:
return False, f"❌ 发送上下文回复到{target_desc}失败"
async def _send_to_target(self, target_type: str, target_id: str, content: str, sender) -> Tuple[bool, str]:
"""发送到指定目标,带发送者追踪信息"""
timestamp = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
# 构建带追踪信息的消息
tracking_info = f"[管理转发 {timestamp}] 来自 {sender.user_nickname}({sender.user_id})"
formatted_content = f"{tracking_info}\n{content}"
if target_type == "group":
success = await self.send_text_to_group(
text=formatted_content,
group_id=target_id,
platform="qq"
)
success = await self.send_text_to_group(text=formatted_content, group_id=target_id, platform="qq")
target_desc = f"群聊 {target_id}"
else: # user
success = await self.send_text_to_user(
text=formatted_content,
user_id=target_id,
platform="qq"
)
success = await self.send_text_to_user(text=formatted_content, user_id=target_id, platform="qq")
target_desc = f"用户 {target_id}"
if success:
return True, f"✅ 带追踪信息的消息已发送到{target_desc}"
else:
@@ -182,21 +162,21 @@ class ContextAwareSendCommand(BaseCommand, MessageAPI):
@register_command
class MessageContextCommand(BaseCommand):
"""消息上下文命令,展示如何获取和利用上下文信息"""
command_name = "context"
command_description = "显示当前消息的完整上下文信息"
command_pattern = r"^/context$"
command_help = "使用方法: /context - 显示当前环境的上下文信息"
command_examples = ["/context"]
enable_command = True
async def execute(self) -> Tuple[bool, Optional[str]]:
"""显示上下文信息"""
try:
message = self.message
user = message.message_info.user_info
group = message.message_info.group_info
# 构建上下文信息
context_lines = [
"🌐 当前上下文信息",
@@ -212,42 +192,50 @@ class MessageContextCommand(BaseCommand):
f" 群名片: {user.user_cardname or ''}",
f" 平台: {user.platform}",
]
if group:
context_lines.extend([
"",
"👥 群聊环境:",
f" ID: {group.group_id}",
f": {group.group_name or '未知'}",
f" 平台: {group.platform}",
])
context_lines.extend(
[
"",
"👥聊环境:",
f"ID: {group.group_id}",
f" 群名: {group.group_name or '未知'}",
f" 平台: {group.platform}",
]
)
else:
context_lines.extend([
"",
"💬 私聊环境",
])
context_lines.extend(
[
"",
"💬 私聊环境",
]
)
# 添加聊天流信息
if hasattr(message, 'chat_stream') and message.chat_stream:
if hasattr(message, "chat_stream") and message.chat_stream:
chat_stream = message.chat_stream
context_lines.extend([
"",
"🔄 聊天流:",
f" 流ID: {chat_stream.stream_id}",
f" 激活状态: {'激活' if chat_stream.is_active else '非激活'}",
])
context_lines.extend(
[
"",
"🔄 聊天流:",
f" 流ID: {chat_stream.stream_id}",
f" 激活状态: {'激活' if chat_stream.is_active else '非激活'}",
]
)
# 添加消息内容信息
context_lines.extend([
"",
"📝 消息内容:",
f" 原始内容: {message.processed_plain_text}",
f" 消息长度: {len(message.processed_plain_text)} 字符",
f" 消息ID: {message.message_info.message_id}",
])
context_lines.extend(
[
"",
"📝 消息内容:",
f" 原始内容: {message.processed_plain_text}",
f" 消息长度: {len(message.processed_plain_text)} 字符",
f" 消息ID: {message.message_info.message_id}",
]
)
return True, "\n".join(context_lines)
except Exception as e:
logger.error(f"{self.log_prefix} 获取上下文信息时出错: {e}")
return False, f"获取上下文失败: {str(e)}"
return False, f"获取上下文失败: {str(e)}"

View File

@@ -1,2 +1,3 @@
"""测试插件动作模块"""
from . import mute_action # noqa

View File

@@ -22,21 +22,20 @@ class MuteAction(PluginAction):
"当有人刷屏时使用",
"当有人发了擦边,或者色情内容时使用",
"当有人要求禁言自己时使用",
"如果某人已经被禁言了,就不要再次禁言了,除非你想追加时间!!"
"如果某人已经被禁言了,就不要再次禁言了,除非你想追加时间!!",
]
enable_plugin = False # 启用插件
associated_types = ["command", "text"]
action_config_file_name = "mute_action_config.toml"
# 激活类型设置
focus_activation_type = ActionActivationType.LLM_JUDGE # Focus模式使用LLM判定确保谨慎
normal_activation_type = ActionActivationType.KEYWORD # Normal模式使用关键词激活快速响应
normal_activation_type = ActionActivationType.KEYWORD # Normal模式使用关键词激活快速响应
# 关键词设置用于Normal模式
activation_keywords = ["禁言", "mute", "ban", "silence"]
keyword_case_sensitive = False
# LLM判定提示词用于Focus模式
llm_judge_prompt = """
判定是否需要使用禁言动作的严格条件:
@@ -59,13 +58,13 @@ class MuteAction(PluginAction):
注意:禁言是严厉措施,只在明确违规或用户主动要求时使用。
宁可保守也不要误判,保护用户的发言权利。
"""
# Random激活概率备用
random_activation_probability = 0.05 # 设置很低的概率作为兜底
# 模式启用设置 - 禁言功能在所有模式下都可用
mode_enable = ChatMode.ALL
# 并行执行设置 - 禁言动作可以与回复并行执行,不覆盖回复内容
parallel_action = False
@@ -73,15 +72,15 @@ class MuteAction(PluginAction):
super().__init__(*args, **kwargs)
# 生成配置文件(如果不存在)
self._generate_config_if_needed()
def _generate_config_if_needed(self):
"""生成配置文件(如果不存在)"""
import os
# 获取动作文件所在目录
current_dir = os.path.dirname(os.path.abspath(__file__))
config_path = os.path.join(current_dir, "mute_action_config.toml")
if not os.path.exists(config_path):
config_content = """\
# 禁言动作配置文件
@@ -130,11 +129,10 @@ log_mute_history = true
def _get_template_message(self, target: str, duration_str: str, reason: str) -> str:
"""获取模板化的禁言消息"""
templates = self.config.get("templates", [
"好的,禁言 {target} {duration},理由:{reason}"
])
templates = self.config.get("templates", ["好的,禁言 {target} {duration},理由:{reason}"])
import random
template = random.choice(templates)
return template.format(target=target, duration=duration_str, reason=reason)
@@ -162,7 +160,7 @@ log_mute_history = true
# 获取时长限制配置
min_duration, max_duration, default_duration = self._get_duration_limits()
# 验证时长格式并转换
try:
duration_int = int(duration)
@@ -170,9 +168,11 @@ log_mute_history = true
error_msg = "禁言时长必须大于0"
logger.error(f"{self.log_prefix} {error_msg}")
error_templates = self.config.get("error_messages", ["禁言时长必须是正数哦~"])
await self.send_message_by_expressor(error_templates[2] if len(error_templates) > 2 else "禁言时长必须是正数哦~")
await self.send_message_by_expressor(
error_templates[2] if len(error_templates) > 2 else "禁言时长必须是正数哦~"
)
return False, error_msg
# 限制禁言时长范围
if duration_int < min_duration:
duration_int = min_duration
@@ -180,12 +180,14 @@ log_mute_history = true
elif duration_int > max_duration:
duration_int = max_duration
logger.info(f"{self.log_prefix} 禁言时长过长,调整为{max_duration}")
except (ValueError, TypeError) as e:
except (ValueError, TypeError):
error_msg = f"禁言时长格式无效: {duration}"
logger.error(f"{self.log_prefix} {error_msg}")
error_templates = self.config.get("error_messages", ["禁言时长必须是数字哦~"])
await self.send_message_by_expressor(error_templates[3] if len(error_templates) > 3 else "禁言时长必须是数字哦~")
await self.send_message_by_expressor(
error_templates[3] if len(error_templates) > 3 else "禁言时长必须是数字哦~"
)
return False, error_msg
# 获取用户ID
@@ -206,7 +208,7 @@ log_mute_history = true
# 发送表达情绪的消息
enable_formatting = self.config.get("enable_duration_formatting", True)
time_str = self._format_duration(duration_int) if enable_formatting else f"{duration_int}"
# 使用模板化消息
message = self._get_template_message(target, time_str, reason)
await self.send_message_by_expressor(message)

View File

@@ -1,7 +1,7 @@
import importlib
import pkgutil
import os
from typing import Dict, List, Tuple
from typing import Dict, Tuple
from src.common.logger_manager import get_logger
logger = get_logger("plugin_loader")
@@ -9,52 +9,53 @@ logger = get_logger("plugin_loader")
class PluginLoader:
"""统一的插件加载器负责加载插件的所有组件actions、commands等"""
def __init__(self):
self.loaded_actions = 0
self.loaded_commands = 0
self.plugin_stats: Dict[str, Dict[str, int]] = {} # 统计每个插件加载的组件数量
self.plugin_sources: Dict[str, str] = {} # 记录每个插件来自哪个路径
def load_all_plugins(self) -> Tuple[int, int]:
"""加载所有插件的所有组件
Returns:
Tuple[int, int]: (加载的动作数量, 加载的命令数量)
"""
# 定义插件搜索路径(优先级从高到低)
plugin_paths = [
("plugins", "plugins"), # 项目根目录的plugins文件夹
("src.plugins", os.path.join("src", "plugins")) # src下的plugins文件夹
("src.plugins", os.path.join("src", "plugins")), # src下的plugins文件夹
]
total_plugins_found = 0
for plugin_import_path, plugin_dir_path in plugin_paths:
try:
plugins_loaded = self._load_plugins_from_path(plugin_import_path, plugin_dir_path)
total_plugins_found += plugins_loaded
except Exception as e:
logger.error(f"从路径 {plugin_dir_path} 加载插件失败: {e}")
import traceback
logger.error(traceback.format_exc())
if total_plugins_found == 0:
logger.info("未找到任何插件目录或插件")
# 输出加载统计
self._log_loading_stats()
return self.loaded_actions, self.loaded_commands
def _load_plugins_from_path(self, plugin_import_path: str, plugin_dir_path: str) -> int:
"""从指定路径加载插件
Args:
plugin_import_path: 插件的导入路径 (如 "plugins""src.plugins")
plugin_dir_path: 插件目录的文件系统路径
Returns:
int: 找到的插件包数量
"""
@@ -62,9 +63,9 @@ class PluginLoader:
if not os.path.exists(plugin_dir_path):
logger.debug(f"插件目录 {plugin_dir_path} 不存在,跳过")
return 0
logger.info(f"正在从 {plugin_dir_path} 加载插件...")
# 导入插件包
try:
plugins_package = importlib.import_module(plugin_import_path)
@@ -72,122 +73,120 @@ class PluginLoader:
except ImportError as e:
logger.warning(f"导入插件包 {plugin_import_path} 失败: {e}")
return 0
# 遍历插件包中的所有子包
plugins_found = 0
for _, plugin_name, is_pkg in pkgutil.iter_modules(
plugins_package.__path__, plugins_package.__name__ + "."
):
for _, plugin_name, is_pkg in pkgutil.iter_modules(plugins_package.__path__, plugins_package.__name__ + "."):
if not is_pkg:
continue
logger.debug(f"检测到插件: {plugin_name}")
# 记录插件来源
self.plugin_sources[plugin_name] = plugin_dir_path
self._load_single_plugin(plugin_name)
plugins_found += 1
if plugins_found > 0:
logger.info(f"{plugin_dir_path} 找到 {plugins_found} 个插件包")
else:
logger.debug(f"{plugin_dir_path} 未找到任何插件包")
return plugins_found
def _load_single_plugin(self, plugin_name: str) -> None:
"""加载单个插件的所有组件
Args:
plugin_name: 插件名称
"""
plugin_stats = {"actions": 0, "commands": 0}
# 加载动作组件
actions_count = self._load_plugin_actions(plugin_name)
plugin_stats["actions"] = actions_count
self.loaded_actions += actions_count
# 加载命令组件
# 加载命令组件
commands_count = self._load_plugin_commands(plugin_name)
plugin_stats["commands"] = commands_count
self.loaded_commands += commands_count
# 记录插件统计信息
if actions_count > 0 or commands_count > 0:
self.plugin_stats[plugin_name] = plugin_stats
logger.info(f"插件 {plugin_name} 加载完成: {actions_count} 个动作, {commands_count} 个命令")
def _load_plugin_actions(self, plugin_name: str) -> int:
"""加载插件的动作组件
Args:
plugin_name: 插件名称
Returns:
int: 加载的动作数量
"""
loaded_count = 0
# 优先检查插件是否有actions子包
plugin_actions_path = f"{plugin_name}.actions"
plugin_actions_dir = plugin_name.replace(".", os.path.sep) + os.path.sep + "actions"
actions_loaded_from_subdir = False
# 首先尝试从actions子目录加载
if os.path.exists(plugin_actions_dir):
loaded_count += self._load_from_actions_subdir(plugin_name, plugin_actions_path, plugin_actions_dir)
if loaded_count > 0:
actions_loaded_from_subdir = True
# 如果actions子目录不存在或加载失败尝试从插件根目录加载
if not actions_loaded_from_subdir:
loaded_count += self._load_actions_from_root_dir(plugin_name)
return loaded_count
def _load_plugin_commands(self, plugin_name: str) -> int:
"""加载插件的命令组件
Args:
plugin_name: 插件名称
Returns:
int: 加载的命令数量
"""
loaded_count = 0
# 优先检查插件是否有commands子包
plugin_commands_path = f"{plugin_name}.commands"
plugin_commands_dir = plugin_name.replace(".", os.path.sep) + os.path.sep + "commands"
commands_loaded_from_subdir = False
# 首先尝试从commands子目录加载
if os.path.exists(plugin_commands_dir):
loaded_count += self._load_from_commands_subdir(plugin_name, plugin_commands_path, plugin_commands_dir)
if loaded_count > 0:
commands_loaded_from_subdir = True
# 如果commands子目录不存在或加载失败尝试从插件根目录加载
if not commands_loaded_from_subdir:
loaded_count += self._load_commands_from_root_dir(plugin_name)
return loaded_count
def _load_from_actions_subdir(self, plugin_name: str, plugin_actions_path: str, plugin_actions_dir: str) -> int:
"""从actions子目录加载动作"""
loaded_count = 0
try:
# 尝试导入插件的actions包
actions_module = importlib.import_module(plugin_actions_path)
logger.debug(f"成功加载插件动作模块: {plugin_actions_path}")
# 遍历actions目录中的所有Python文件
actions_dir = os.path.dirname(actions_module.__file__)
for file in os.listdir(actions_dir):
if file.endswith('.py') and file != '__init__.py':
if file.endswith(".py") and file != "__init__.py":
action_module_name = f"{plugin_actions_path}.{file[:-3]}"
try:
importlib.import_module(action_module_name)
@@ -195,25 +194,25 @@ class PluginLoader:
loaded_count += 1
except Exception as e:
logger.error(f"加载动作失败: {action_module_name}, 错误: {e}")
except ImportError as e:
logger.debug(f"插件 {plugin_name} 的actions子包导入失败: {e}")
return loaded_count
def _load_from_commands_subdir(self, plugin_name: str, plugin_commands_path: str, plugin_commands_dir: str) -> int:
"""从commands子目录加载命令"""
loaded_count = 0
try:
# 尝试导入插件的commands包
commands_module = importlib.import_module(plugin_commands_path)
logger.debug(f"成功加载插件命令模块: {plugin_commands_path}")
# 遍历commands目录中的所有Python文件
commands_dir = os.path.dirname(commands_module.__file__)
for file in os.listdir(commands_dir):
if file.endswith('.py') and file != '__init__.py':
if file.endswith(".py") and file != "__init__.py":
command_module_name = f"{plugin_commands_path}.{file[:-3]}"
try:
importlib.import_module(command_module_name)
@@ -221,29 +220,29 @@ class PluginLoader:
loaded_count += 1
except Exception as e:
logger.error(f"加载命令失败: {command_module_name}, 错误: {e}")
except ImportError as e:
logger.debug(f"插件 {plugin_name} 的commands子包导入失败: {e}")
return loaded_count
def _load_actions_from_root_dir(self, plugin_name: str) -> int:
"""从插件根目录加载动作文件"""
loaded_count = 0
try:
# 导入插件包本身
plugin_module = importlib.import_module(plugin_name)
logger.debug(f"尝试从插件根目录加载动作: {plugin_name}")
# 遍历插件根目录中的所有Python文件
plugin_dir = os.path.dirname(plugin_module.__file__)
for file in os.listdir(plugin_dir):
if file.endswith('.py') and file != '__init__.py':
if file.endswith(".py") and file != "__init__.py":
# 跳过非动作文件(根据命名约定)
if not (file.endswith('_action.py') or file.endswith('_actions.py') or 'action' in file):
if not (file.endswith("_action.py") or file.endswith("_actions.py") or "action" in file):
continue
action_module_name = f"{plugin_name}.{file[:-3]}"
try:
importlib.import_module(action_module_name)
@@ -251,29 +250,29 @@ class PluginLoader:
loaded_count += 1
except Exception as e:
logger.error(f"加载动作失败: {action_module_name}, 错误: {e}")
except ImportError as e:
logger.debug(f"插件 {plugin_name} 导入失败: {e}")
return loaded_count
def _load_commands_from_root_dir(self, plugin_name: str) -> int:
"""从插件根目录加载命令文件"""
loaded_count = 0
try:
# 导入插件包本身
plugin_module = importlib.import_module(plugin_name)
logger.debug(f"尝试从插件根目录加载命令: {plugin_name}")
# 遍历插件根目录中的所有Python文件
plugin_dir = os.path.dirname(plugin_module.__file__)
for file in os.listdir(plugin_dir):
if file.endswith('.py') and file != '__init__.py':
if file.endswith(".py") and file != "__init__.py":
# 跳过非命令文件(根据命名约定)
if not (file.endswith('_command.py') or file.endswith('_commands.py') or 'command' in file):
if not (file.endswith("_command.py") or file.endswith("_commands.py") or "command" in file):
continue
command_module_name = f"{plugin_name}.{file[:-3]}"
try:
importlib.import_module(command_module_name)
@@ -281,23 +280,25 @@ class PluginLoader:
loaded_count += 1
except Exception as e:
logger.error(f"加载命令失败: {command_module_name}, 错误: {e}")
except ImportError as e:
logger.debug(f"插件 {plugin_name} 导入失败: {e}")
return loaded_count
def _log_loading_stats(self) -> None:
"""输出加载统计信息"""
logger.success(f"插件加载完成: 总计 {self.loaded_actions} 个动作, {self.loaded_commands} 个命令")
if self.plugin_stats:
logger.info("插件加载详情:")
for plugin_name, stats in self.plugin_stats.items():
plugin_display_name = plugin_name.split('.')[-1] # 只显示插件名称,不显示完整路径
plugin_display_name = plugin_name.split(".")[-1] # 只显示插件名称,不显示完整路径
source_path = self.plugin_sources.get(plugin_name, "未知路径")
logger.info(f" {plugin_display_name} (来源: {source_path}): {stats['actions']} 动作, {stats['commands']} 命令")
logger.info(
f" {plugin_display_name} (来源: {source_path}): {stats['actions']} 动作, {stats['commands']} 命令"
)
# 创建全局插件加载器实例
plugin_loader = PluginLoader()
plugin_loader = PluginLoader()

View File

@@ -23,14 +23,14 @@ class TTSAction(PluginAction):
]
enable_plugin = True # 启用插件
associated_types = ["tts_text"]
focus_activation_type = ActionActivationType.LLM_JUDGE
normal_activation_type = ActionActivationType.KEYWORD
# 关键词配置 - Normal模式下使用关键词触发
activation_keywords = ["语音", "tts", "播报", "读出来", "语音播放", "", "朗读"]
keyword_case_sensitive = False
# 并行执行设置 - TTS可以与回复并行执行不覆盖回复内容
parallel_action = False

View File

@@ -22,11 +22,11 @@ class VTBAction(PluginAction):
]
enable_plugin = True # 启用插件
associated_types = ["vtb_text"]
# 激活类型设置
focus_activation_type = ActionActivationType.LLM_JUDGE # Focus模式使用LLM判定精确识别情感表达需求
normal_activation_type = ActionActivationType.RANDOM # Normal模式使用随机激活增加趣味性
normal_activation_type = ActionActivationType.RANDOM # Normal模式使用随机激活增加趣味性
# LLM判定提示词用于Focus模式
llm_judge_prompt = """
判定是否需要使用VTB虚拟主播动作的条件
@@ -41,7 +41,7 @@ class VTBAction(PluginAction):
3. 不涉及情感的日常对话
4. 已经有足够的情感表达
"""
# Random激活概率用于Normal模式
random_activation_probability = 0.08 # 较低概率,避免过度使用