ruff
This commit is contained in:
@@ -8,6 +8,7 @@ logger = get_logger("base_action")
|
||||
_ACTION_REGISTRY: Dict[str, Type["BaseAction"]] = {}
|
||||
_DEFAULT_ACTIONS: Dict[str, str] = {}
|
||||
|
||||
|
||||
# 动作激活类型枚举
|
||||
class ActionActivationType:
|
||||
ALWAYS = "always" # 默认参与到planner
|
||||
@@ -15,12 +16,14 @@ class ActionActivationType:
|
||||
RANDOM = "random" # 随机启用action到planner
|
||||
KEYWORD = "keyword" # 关键词触发启用action到planner
|
||||
|
||||
|
||||
# 聊天模式枚举
|
||||
class ChatMode:
|
||||
FOCUS = "focus" # Focus聊天模式
|
||||
NORMAL = "normal" # Normal聊天模式
|
||||
ALL = "all" # 所有聊天模式
|
||||
|
||||
|
||||
def register_action(cls):
|
||||
"""
|
||||
动作注册装饰器
|
||||
|
||||
@@ -22,9 +22,7 @@ class EmojiAction(BaseAction):
|
||||
action_parameters: dict[str:str] = {
|
||||
"description": "文字描述你想要发送的表情包内容",
|
||||
}
|
||||
action_require: list[str] = [
|
||||
"表达情绪时可以选择使用",
|
||||
"重点:不要连续发,如果你已经发过[表情包],就不要选择此动作"]
|
||||
action_require: list[str] = ["表达情绪时可以选择使用", "重点:不要连续发,如果你已经发过[表情包],就不要选择此动作"]
|
||||
|
||||
associated_types: list[str] = ["emoji"]
|
||||
|
||||
@@ -37,7 +35,6 @@ class EmojiAction(BaseAction):
|
||||
|
||||
parallel_action = True
|
||||
|
||||
|
||||
llm_judge_prompt = """
|
||||
判定是否需要使用表情动作的条件:
|
||||
1. 用户明确要求使用表情包
|
||||
|
||||
@@ -31,7 +31,7 @@ class ReplyAction(BaseAction):
|
||||
action_require: list[str] = [
|
||||
"你想要闲聊或者随便附和",
|
||||
"有人提到你",
|
||||
"如果你刚刚进行了回复,不要对同一个话题重复回应"
|
||||
"如果你刚刚进行了回复,不要对同一个话题重复回应",
|
||||
]
|
||||
|
||||
associated_types: list[str] = ["text"]
|
||||
@@ -122,7 +122,7 @@ class ReplyAction(BaseAction):
|
||||
target = ""
|
||||
if ":" in reply_to or ":" in reply_to:
|
||||
# 使用正则表达式匹配中文或英文冒号
|
||||
parts = re.split(pattern=r'[::]', string=reply_to, maxsplit=1)
|
||||
parts = re.split(pattern=r"[::]", string=reply_to, maxsplit=1)
|
||||
if len(parts) == 2:
|
||||
# sender = parts[0].strip()
|
||||
target = parts[1].strip()
|
||||
@@ -138,7 +138,6 @@ class ReplyAction(BaseAction):
|
||||
self.chat_stream.platform, self.chat_stream.group_info, self.chat_stream
|
||||
)
|
||||
|
||||
|
||||
success, reply_set = await self.replyer.deal_reply(
|
||||
cycle_timers=cycle_timers,
|
||||
action_data=reply_data,
|
||||
@@ -158,8 +157,9 @@ class ReplyAction(BaseAction):
|
||||
|
||||
return success, reply_text
|
||||
|
||||
|
||||
async def store_action_info(self, action_build_into_prompt: bool = False, action_prompt_display: str = "", action_done: bool = True) -> None:
|
||||
async def store_action_info(
|
||||
self, action_build_into_prompt: bool = False, action_prompt_display: str = "", action_done: bool = True
|
||||
) -> None:
|
||||
"""存储action执行信息到数据库
|
||||
|
||||
Args:
|
||||
@@ -188,7 +188,7 @@ class ReplyAction(BaseAction):
|
||||
chat_info_platform=chat_stream.platform,
|
||||
user_id=chat_stream.user_info.user_id if chat_stream.user_info else "",
|
||||
user_nickname=chat_stream.user_info.user_nickname if chat_stream.user_info else "",
|
||||
user_cardname=chat_stream.user_info.user_cardname if chat_stream.user_info else ""
|
||||
user_cardname=chat_stream.user_info.user_cardname if chat_stream.user_info else "",
|
||||
)
|
||||
logger.debug(f"{self.log_prefix} 已存储action信息: {action_prompt_display}")
|
||||
except Exception as e:
|
||||
|
||||
@@ -1,10 +1,6 @@
|
||||
import traceback
|
||||
from typing import Tuple, Dict, List, Any, Optional, Union, Type
|
||||
from typing import Tuple, Dict, Any, Optional
|
||||
from src.chat.actions.base_action import BaseAction, register_action, ActionActivationType, ChatMode # noqa F401
|
||||
from src.chat.heart_flow.observation.chatting_observation import ChattingObservation
|
||||
from src.chat.focus_chat.hfc_utils import create_empty_anchor_message
|
||||
from src.common.logger_manager import get_logger
|
||||
from src.config.config import global_config
|
||||
import os
|
||||
import inspect
|
||||
import toml # 导入 toml 库
|
||||
|
||||
@@ -7,11 +7,11 @@ from src.chat.actions.plugin_api.stream_api import StreamAPI
|
||||
from src.chat.actions.plugin_api.hearflow_api import HearflowAPI
|
||||
|
||||
__all__ = [
|
||||
'MessageAPI',
|
||||
'LLMAPI',
|
||||
'DatabaseAPI',
|
||||
'ConfigAPI',
|
||||
'UtilsAPI',
|
||||
'StreamAPI',
|
||||
'HearflowAPI',
|
||||
"MessageAPI",
|
||||
"LLMAPI",
|
||||
"DatabaseAPI",
|
||||
"ConfigAPI",
|
||||
"UtilsAPI",
|
||||
"StreamAPI",
|
||||
"HearflowAPI",
|
||||
]
|
||||
@@ -5,6 +5,7 @@ from src.person_info.person_info import person_info_manager
|
||||
|
||||
logger = get_logger("config_api")
|
||||
|
||||
|
||||
class ConfigAPI:
|
||||
"""配置API模块
|
||||
|
||||
|
||||
@@ -8,13 +8,16 @@ from peewee import Model, DoesNotExist
|
||||
|
||||
logger = get_logger("database_api")
|
||||
|
||||
|
||||
class DatabaseAPI:
|
||||
"""数据库API模块
|
||||
|
||||
提供了数据库操作相关的功能
|
||||
"""
|
||||
|
||||
async def store_action_info(self, action_build_into_prompt: bool = False, action_prompt_display: str = "", action_done: bool = True) -> None:
|
||||
async def store_action_info(
|
||||
self, action_build_into_prompt: bool = False, action_prompt_display: str = "", action_done: bool = True
|
||||
) -> None:
|
||||
"""存储action执行信息到数据库
|
||||
|
||||
Args:
|
||||
@@ -44,7 +47,7 @@ class DatabaseAPI:
|
||||
chat_info_platform=chat_stream.platform,
|
||||
user_id=chat_stream.user_info.user_id if chat_stream.user_info else "",
|
||||
user_nickname=chat_stream.user_info.user_nickname if chat_stream.user_info else "",
|
||||
user_cardname=chat_stream.user_info.user_cardname if chat_stream.user_info else ""
|
||||
user_cardname=chat_stream.user_info.user_cardname if chat_stream.user_info else "",
|
||||
)
|
||||
logger.debug(f"{self.log_prefix} 已存储action信息: {action_prompt_display}")
|
||||
except Exception as e:
|
||||
@@ -59,7 +62,7 @@ class DatabaseAPI:
|
||||
data: Dict[str, Any] = None,
|
||||
limit: int = None,
|
||||
order_by: List[str] = None,
|
||||
single_result: bool = False
|
||||
single_result: bool = False,
|
||||
) -> Union[List[Dict[str, Any]], Dict[str, Any], None]:
|
||||
"""执行数据库查询操作
|
||||
|
||||
@@ -198,10 +201,7 @@ class DatabaseAPI:
|
||||
raise "unknown query type"
|
||||
|
||||
async def db_raw_query(
|
||||
self,
|
||||
sql: str,
|
||||
params: List[Any] = None,
|
||||
fetch_results: bool = True
|
||||
self, sql: str, params: List[Any] = None, fetch_results: bool = True
|
||||
) -> Union[List[Dict[str, Any]], int, None]:
|
||||
"""执行原始SQL查询
|
||||
|
||||
@@ -241,11 +241,7 @@ class DatabaseAPI:
|
||||
return None
|
||||
|
||||
async def db_save(
|
||||
self,
|
||||
model_class: Type[Model],
|
||||
data: Dict[str, Any],
|
||||
key_field: str = None,
|
||||
key_value: Any = None
|
||||
self, model_class: Type[Model], data: Dict[str, Any], key_field: str = None, key_value: Any = None
|
||||
) -> Union[Dict[str, Any], None]:
|
||||
"""保存数据到数据库(创建或更新)
|
||||
|
||||
@@ -280,9 +276,9 @@ class DatabaseAPI:
|
||||
# 如果提供了key_field和key_value,尝试更新现有记录
|
||||
if key_field and key_value is not None:
|
||||
# 查找现有记录
|
||||
existing_records = list(model_class.select().where(
|
||||
getattr(model_class, key_field) == key_value
|
||||
).limit(1))
|
||||
existing_records = list(
|
||||
model_class.select().where(getattr(model_class, key_field) == key_value).limit(1)
|
||||
)
|
||||
|
||||
if existing_records:
|
||||
# 更新现有记录
|
||||
@@ -292,18 +288,14 @@ class DatabaseAPI:
|
||||
existing_record.save()
|
||||
|
||||
# 返回更新后的记录
|
||||
updated_record = model_class.select().where(
|
||||
model_class.id == existing_record.id
|
||||
).dicts().get()
|
||||
updated_record = model_class.select().where(model_class.id == existing_record.id).dicts().get()
|
||||
return updated_record
|
||||
|
||||
# 如果没有找到现有记录或未提供key_field和key_value,创建新记录
|
||||
new_record = model_class.create(**data)
|
||||
|
||||
# 返回创建的记录
|
||||
created_record = model_class.select().where(
|
||||
model_class.id == new_record.id
|
||||
).dicts().get()
|
||||
created_record = model_class.select().where(model_class.id == new_record.id).dicts().get()
|
||||
return created_record
|
||||
|
||||
except Exception as e:
|
||||
@@ -312,11 +304,7 @@ class DatabaseAPI:
|
||||
return None
|
||||
|
||||
async def db_get(
|
||||
self,
|
||||
model_class: Type[Model],
|
||||
filters: Dict[str, Any] = None,
|
||||
order_by: str = None,
|
||||
limit: int = None
|
||||
self, model_class: Type[Model], filters: Dict[str, Any] = None, order_by: str = None, limit: int = None
|
||||
) -> Union[List[Dict[str, Any]], Dict[str, Any], None]:
|
||||
"""从数据库获取记录
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Optional, List, Any, Tuple
|
||||
from typing import Optional, List, Any
|
||||
from src.common.logger_manager import get_logger
|
||||
from src.chat.heart_flow.heartflow import heartflow
|
||||
from src.chat.heart_flow.sub_heartflow import SubHeartflow, ChatState
|
||||
@@ -36,7 +36,6 @@ class HearflowAPI:
|
||||
logger.error(f"{self.log_prefix} 获取子心流实例时出错: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def get_all_sub_hearflow_ids(self) -> List[str]:
|
||||
"""获取所有子心流的ID列表
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ from src.config.config import global_config
|
||||
|
||||
logger = get_logger("llm_api")
|
||||
|
||||
|
||||
class LLMAPI:
|
||||
"""LLM API模块
|
||||
|
||||
@@ -26,11 +27,7 @@ class LLMAPI:
|
||||
return models
|
||||
|
||||
async def generate_with_model(
|
||||
self,
|
||||
prompt: str,
|
||||
model_config: Dict[str, Any],
|
||||
request_type: str = "plugin.generate",
|
||||
**kwargs
|
||||
self, prompt: str, model_config: Dict[str, Any], request_type: str = "plugin.generate", **kwargs
|
||||
) -> Tuple[bool, str, str, str]:
|
||||
"""使用指定模型生成内容
|
||||
|
||||
@@ -46,11 +43,7 @@ class LLMAPI:
|
||||
try:
|
||||
logger.info(f"{self.log_prefix} 使用模型生成内容,提示词: {prompt[:100]}...")
|
||||
|
||||
llm_request = LLMRequest(
|
||||
model=model_config,
|
||||
request_type=request_type,
|
||||
**kwargs
|
||||
)
|
||||
llm_request = LLMRequest(model=model_config, request_type=request_type, **kwargs)
|
||||
|
||||
response, (reasoning, model_name) = await llm_request.generate_response_async(prompt)
|
||||
return True, response, reasoning, model_name
|
||||
|
||||
@@ -14,11 +14,12 @@ from src.chat.focus_chat.info.obs_info import ObsInfo
|
||||
# 新增导入
|
||||
from src.chat.focus_chat.heartFC_sender import HeartFCSender
|
||||
from src.chat.message_receive.message import MessageSending
|
||||
from maim_message import Seg, UserInfo, GroupInfo
|
||||
from maim_message import Seg, UserInfo
|
||||
from src.config.config import global_config
|
||||
|
||||
logger = get_logger("message_api")
|
||||
|
||||
|
||||
class MessageAPI:
|
||||
"""消息API模块
|
||||
|
||||
@@ -53,9 +54,11 @@ class MessageAPI:
|
||||
# 群聊:从数据库查找对应的聊天流
|
||||
target_stream = None
|
||||
for stream_id, stream in chat_manager.streams.items():
|
||||
if (stream.group_info and
|
||||
str(stream.group_info.group_id) == str(target_id) and
|
||||
stream.platform == platform):
|
||||
if (
|
||||
stream.group_info
|
||||
and str(stream.group_info.group_id) == str(target_id)
|
||||
and stream.platform == platform
|
||||
):
|
||||
target_stream = stream
|
||||
break
|
||||
|
||||
@@ -66,9 +69,11 @@ class MessageAPI:
|
||||
# 私聊:从数据库查找对应的聊天流
|
||||
target_stream = None
|
||||
for stream_id, stream in chat_manager.streams.items():
|
||||
if (not stream.group_info and
|
||||
str(stream.user_info.user_id) == str(target_id) and
|
||||
stream.platform == platform):
|
||||
if (
|
||||
not stream.group_info
|
||||
and str(stream.user_info.user_id) == str(target_id)
|
||||
and stream.platform == platform
|
||||
):
|
||||
target_stream = stream
|
||||
break
|
||||
|
||||
@@ -95,9 +100,7 @@ class MessageAPI:
|
||||
message_segment = Seg(type=message_type, data=content)
|
||||
|
||||
# 创建空锚点消息(用于回复)
|
||||
anchor_message = await create_empty_anchor_message(
|
||||
platform, target_stream.group_info, target_stream
|
||||
)
|
||||
anchor_message = await create_empty_anchor_message(platform, target_stream.group_info, target_stream)
|
||||
|
||||
# 构建发送消息对象
|
||||
bot_message = MessageSending(
|
||||
@@ -114,12 +117,7 @@ class MessageAPI:
|
||||
)
|
||||
|
||||
# 发送消息
|
||||
sent_msg = await heart_fc_sender.send_message(
|
||||
bot_message,
|
||||
has_thinking=True,
|
||||
typing=False,
|
||||
set_reply=False
|
||||
)
|
||||
sent_msg = await heart_fc_sender.send_message(bot_message, has_thinking=True, typing=False, set_reply=False)
|
||||
|
||||
if sent_msg:
|
||||
logger.info(f"{getattr(self, 'log_prefix', '')} 成功发送消息到 {platform}:{target_id}")
|
||||
@@ -145,11 +143,7 @@ class MessageAPI:
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
return await self.send_message_to_target(
|
||||
message_type="text",
|
||||
content=text,
|
||||
platform=platform,
|
||||
target_id=group_id,
|
||||
is_group=True
|
||||
message_type="text", content=text, platform=platform, target_id=group_id, is_group=True
|
||||
)
|
||||
|
||||
async def send_text_to_user(self, text: str, user_id: str, platform: str = "qq") -> bool:
|
||||
@@ -164,11 +158,7 @@ class MessageAPI:
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
return await self.send_message_to_target(
|
||||
message_type="text",
|
||||
content=text,
|
||||
platform=platform,
|
||||
target_id=user_id,
|
||||
is_group=False
|
||||
message_type="text", content=text, platform=platform, target_id=user_id, is_group=False
|
||||
)
|
||||
|
||||
async def send_message(self, type: str, data: str, target: Optional[str] = "", display_message: str = "") -> bool:
|
||||
@@ -288,7 +278,9 @@ class MessageAPI:
|
||||
|
||||
return success
|
||||
|
||||
async def send_message_by_replyer(self, target: Optional[str] = None, extra_info_block: Optional[str] = None) -> bool:
|
||||
async def send_message_by_replyer(
|
||||
self, target: Optional[str] = None, extra_info_block: Optional[str] = None
|
||||
) -> bool:
|
||||
"""通过replyer发送消息的简化方法
|
||||
|
||||
Args:
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
import hashlib
|
||||
from typing import Optional, List, Dict, Any
|
||||
from src.common.logger_manager import get_logger
|
||||
from src.chat.message_receive.chat_stream import ChatManager, ChatStream
|
||||
from maim_message import GroupInfo, UserInfo
|
||||
|
||||
logger = get_logger("stream_api")
|
||||
|
||||
@@ -28,9 +26,11 @@ class StreamAPI:
|
||||
|
||||
# 遍历所有已加载的聊天流,查找匹配的群ID
|
||||
for stream_id, stream in chat_manager.streams.items():
|
||||
if (stream.group_info and
|
||||
str(stream.group_info.group_id) == str(group_id) and
|
||||
stream.platform == platform):
|
||||
if (
|
||||
stream.group_info
|
||||
and str(stream.group_info.group_id) == str(group_id)
|
||||
and stream.platform == platform
|
||||
):
|
||||
logger.info(f"{self.log_prefix} 通过群ID {group_id} 找到聊天流: {stream_id}")
|
||||
return stream
|
||||
|
||||
@@ -55,8 +55,7 @@ class StreamAPI:
|
||||
group_streams = []
|
||||
|
||||
for stream in chat_manager.streams.values():
|
||||
if (stream.group_info and
|
||||
stream.platform == platform):
|
||||
if stream.group_info and stream.platform == platform:
|
||||
group_streams.append(stream)
|
||||
|
||||
logger.info(f"{self.log_prefix} 找到 {len(group_streams)} 个群聊聊天流")
|
||||
@@ -81,10 +80,12 @@ class StreamAPI:
|
||||
|
||||
# 遍历所有已加载的聊天流,查找匹配的用户ID(私聊)
|
||||
for stream_id, stream in chat_manager.streams.items():
|
||||
if (not stream.group_info and # 私聊没有群信息
|
||||
stream.user_info and
|
||||
str(stream.user_info.user_id) == str(user_id) and
|
||||
stream.platform == platform):
|
||||
if (
|
||||
not stream.group_info # 私聊没有群信息
|
||||
and stream.user_info
|
||||
and str(stream.user_info.user_id) == str(user_id)
|
||||
and stream.platform == platform
|
||||
):
|
||||
logger.info(f"{self.log_prefix} 通过用户ID {user_id} 找到私聊聊天流: {stream_id}")
|
||||
return stream
|
||||
|
||||
@@ -111,20 +112,14 @@ class StreamAPI:
|
||||
"platform": stream.platform,
|
||||
"chat_type": "group" if stream.group_info else "private",
|
||||
"create_time": stream.create_time,
|
||||
"last_active_time": stream.last_active_time
|
||||
"last_active_time": stream.last_active_time,
|
||||
}
|
||||
|
||||
if stream.group_info:
|
||||
info.update({
|
||||
"group_id": stream.group_info.group_id,
|
||||
"group_name": stream.group_info.group_name
|
||||
})
|
||||
info.update({"group_id": stream.group_info.group_id, "group_name": stream.group_info.group_name})
|
||||
|
||||
if stream.user_info:
|
||||
info.update({
|
||||
"user_id": stream.user_info.user_id,
|
||||
"user_nickname": stream.user_info.user_nickname
|
||||
})
|
||||
info.update({"user_id": stream.user_info.user_id, "user_nickname": stream.user_info.user_nickname})
|
||||
|
||||
streams_info.append(info)
|
||||
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Optional
|
||||
from src.common.logger_manager import get_logger
|
||||
|
||||
logger = get_logger("utils_api")
|
||||
|
||||
|
||||
class UtilsAPI:
|
||||
"""工具类API模块
|
||||
|
||||
@@ -19,6 +20,7 @@ class UtilsAPI:
|
||||
str: 插件目录的绝对路径
|
||||
"""
|
||||
import inspect
|
||||
|
||||
plugin_module_path = inspect.getfile(self.__class__)
|
||||
plugin_dir = os.path.dirname(plugin_module_path)
|
||||
return plugin_dir
|
||||
@@ -42,7 +44,7 @@ class UtilsAPI:
|
||||
logger.warning(f"{self.log_prefix} 文件不存在: {file_path}")
|
||||
return default
|
||||
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 读取JSON文件出错: {e}")
|
||||
@@ -67,7 +69,7 @@ class UtilsAPI:
|
||||
# 确保目录存在
|
||||
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
||||
|
||||
with open(file_path, 'w', encoding='utf-8') as f:
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
json.dump(data, f, ensure_ascii=False, indent=indent)
|
||||
return True
|
||||
except Exception as e:
|
||||
@@ -93,6 +95,7 @@ class UtilsAPI:
|
||||
str: 格式化后的时间字符串
|
||||
"""
|
||||
import datetime
|
||||
|
||||
if timestamp is None:
|
||||
timestamp = time.time()
|
||||
return datetime.datetime.fromtimestamp(timestamp).strftime(format_str)
|
||||
@@ -108,6 +111,7 @@ class UtilsAPI:
|
||||
int: 时间戳(秒)
|
||||
"""
|
||||
import datetime
|
||||
|
||||
dt = datetime.datetime.strptime(time_str, format_str)
|
||||
return int(dt.timestamp())
|
||||
|
||||
@@ -118,4 +122,5 @@ class UtilsAPI:
|
||||
str: 唯一ID
|
||||
"""
|
||||
import uuid
|
||||
|
||||
return str(uuid.uuid4())
|
||||
@@ -3,7 +3,6 @@ from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Type, Optional, Tuple, Pattern
|
||||
from src.common.logger_manager import get_logger
|
||||
from src.chat.message_receive.message import MessageRecv
|
||||
from src.chat.actions.plugin_api.message_api import MessageAPI
|
||||
from src.chat.focus_chat.hfc_utils import create_empty_anchor_message
|
||||
from src.chat.focus_chat.expressors.default_expressor import DefaultExpressor
|
||||
|
||||
@@ -13,6 +12,7 @@ logger = get_logger("command_handler")
|
||||
_COMMAND_REGISTRY: Dict[str, Type["BaseCommand"]] = {}
|
||||
_COMMAND_PATTERNS: Dict[Pattern, Type["BaseCommand"]] = {}
|
||||
|
||||
|
||||
class BaseCommand(ABC):
|
||||
"""命令基类,所有自定义命令都应该继承这个类"""
|
||||
|
||||
@@ -72,9 +72,7 @@ class BaseCommand(ABC):
|
||||
|
||||
# 创建空的锚定消息
|
||||
anchor_message = await create_empty_anchor_message(
|
||||
chat_stream.platform,
|
||||
chat_stream.group_info,
|
||||
chat_stream
|
||||
chat_stream.platform, chat_stream.group_info, chat_stream
|
||||
)
|
||||
|
||||
# 创建表达器,传入chat_stream参数
|
||||
@@ -99,6 +97,7 @@ class BaseCommand(ABC):
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 发送命令回复消息失败: {e}")
|
||||
import traceback
|
||||
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
|
||||
@@ -115,7 +114,11 @@ def register_command(cls):
|
||||
...
|
||||
"""
|
||||
# 检查类是否有必要的属性
|
||||
if not hasattr(cls, "command_name") or not hasattr(cls, "command_description") or not hasattr(cls, "command_pattern"):
|
||||
if (
|
||||
not hasattr(cls, "command_name")
|
||||
or not hasattr(cls, "command_description")
|
||||
or not hasattr(cls, "command_pattern")
|
||||
):
|
||||
logger.error(f"命令类 {cls.__name__} 缺少必要的属性: command_name, command_description 或 command_pattern")
|
||||
return cls
|
||||
|
||||
@@ -196,6 +199,7 @@ class CommandManager:
|
||||
except Exception as e:
|
||||
logger.error(f"执行命令 {command_cls.command_name} 时出错: {e}")
|
||||
import traceback
|
||||
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
try:
|
||||
|
||||
@@ -227,8 +227,6 @@ class DefaultExpressor:
|
||||
logger.info(f"想要表达:{in_mind_reply}||理由:{reason}")
|
||||
logger.info(f"最终回复: {content}\n")
|
||||
|
||||
|
||||
|
||||
except Exception as llm_e:
|
||||
# 精简报错信息
|
||||
logger.error(f"{self.log_prefix}LLM 生成失败: {llm_e}")
|
||||
|
||||
@@ -178,7 +178,9 @@ class ExpressionLearner:
|
||||
decay = a * (time_diff_days - h) ** 2 + k
|
||||
return min(0.001, decay)
|
||||
|
||||
def apply_decay_to_expressions(self, expressions: List[Dict[str, Any]], current_time: float) -> List[Dict[str, Any]]:
|
||||
def apply_decay_to_expressions(
|
||||
self, expressions: List[Dict[str, Any]], current_time: float
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
对表达式列表应用衰减
|
||||
返回衰减后的表达式列表,移除count小于0的项
|
||||
|
||||
@@ -118,7 +118,6 @@ class CycleDetail:
|
||||
os.makedirs(dir_name, exist_ok=True)
|
||||
# 写入文件
|
||||
|
||||
|
||||
file_path = os.path.join(dir_name, os.path.basename(file_path))
|
||||
# print("file_path:", file_path)
|
||||
with open(file_path, "a", encoding="utf-8") as f:
|
||||
|
||||
@@ -113,8 +113,9 @@ class HeartFChatting:
|
||||
for proc_name, (_proc_class, config_key) in PROCESSOR_CLASSES.items():
|
||||
# 对于关系处理器,需要同时检查两个配置项
|
||||
if proc_name == "RelationshipProcessor":
|
||||
if (global_config.relationship.enable_relationship and
|
||||
getattr(config_processor_settings, config_key, True)):
|
||||
if global_config.relationship.enable_relationship and getattr(
|
||||
config_processor_settings, config_key, True
|
||||
):
|
||||
self.enabled_processor_names.append(proc_name)
|
||||
else:
|
||||
# 其他处理器的原有逻辑
|
||||
@@ -129,7 +130,6 @@ class HeartFChatting:
|
||||
self.expressor = DefaultExpressor(chat_stream=self.chat_stream)
|
||||
self.replyer = DefaultReplyer(chat_stream=self.chat_stream)
|
||||
|
||||
|
||||
self.action_manager = ActionManager()
|
||||
self.action_planner = PlannerFactory.create_planner(
|
||||
log_prefix=self.log_prefix, action_manager=self.action_manager
|
||||
@@ -138,7 +138,6 @@ class HeartFChatting:
|
||||
self.action_observation = ActionObservation(observe_id=self.stream_id)
|
||||
self.action_observation.set_action_manager(self.action_manager)
|
||||
|
||||
|
||||
self._processing_lock = asyncio.Lock()
|
||||
|
||||
# 循环控制内部状态
|
||||
@@ -182,7 +181,13 @@ class HeartFChatting:
|
||||
if processor_info:
|
||||
processor_actual_class = processor_info[0] # 获取实际的类定义
|
||||
# 根据处理器类名判断是否需要 subheartflow_id
|
||||
if name in ["MindProcessor", "ToolProcessor", "WorkingMemoryProcessor", "SelfProcessor", "RelationshipProcessor"]:
|
||||
if name in [
|
||||
"MindProcessor",
|
||||
"ToolProcessor",
|
||||
"WorkingMemoryProcessor",
|
||||
"SelfProcessor",
|
||||
"RelationshipProcessor",
|
||||
]:
|
||||
self.processors.append(processor_actual_class(subheartflow_id=self.stream_id))
|
||||
elif name == "ChattingInfoProcessor":
|
||||
self.processors.append(processor_actual_class())
|
||||
@@ -203,9 +208,7 @@ class HeartFChatting:
|
||||
)
|
||||
|
||||
if self.processors:
|
||||
logger.info(
|
||||
f"{self.log_prefix} 已注册处理器: {[p.__class__.__name__ for p in self.processors]}"
|
||||
)
|
||||
logger.info(f"{self.log_prefix} 已注册处理器: {[p.__class__.__name__ for p in self.processors]}")
|
||||
else:
|
||||
logger.warning(f"{self.log_prefix} 没有注册任何处理器。这可能是由于配置错误或所有处理器都被禁用了。")
|
||||
|
||||
@@ -292,7 +295,9 @@ class HeartFChatting:
|
||||
self._current_cycle_detail.set_loop_info(loop_info)
|
||||
|
||||
# 从observations列表中获取HFCloopObservation
|
||||
hfcloop_observation = next((obs for obs in self.observations if isinstance(obs, HFCloopObservation)), None)
|
||||
hfcloop_observation = next(
|
||||
(obs for obs in self.observations if isinstance(obs, HFCloopObservation)), None
|
||||
)
|
||||
if hfcloop_observation:
|
||||
hfcloop_observation.add_loop_info(self._current_cycle_detail)
|
||||
else:
|
||||
@@ -474,9 +479,6 @@ class HeartFChatting:
|
||||
action_modify_task, memory_task, processor_task
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
loop_processor_info = {
|
||||
"all_plan_info": all_plan_info,
|
||||
"processor_time_costs": processor_time_costs,
|
||||
@@ -594,9 +596,7 @@ class HeartFChatting:
|
||||
else:
|
||||
success, reply_text = result
|
||||
command = ""
|
||||
logger.debug(
|
||||
f"{self.log_prefix} 麦麦执行了'{action}', 返回结果'{success}', '{reply_text}', '{command}'"
|
||||
)
|
||||
logger.debug(f"{self.log_prefix} 麦麦执行了'{action}', 返回结果'{success}', '{reply_text}', '{command}'")
|
||||
|
||||
return success, reply_text, command
|
||||
|
||||
|
||||
@@ -181,7 +181,6 @@ class HeartFCMessageReceiver:
|
||||
userinfo = message.message_info.user_info
|
||||
messageinfo = message.message_info
|
||||
|
||||
|
||||
chat = await chat_manager.get_or_create_stream(
|
||||
platform=messageinfo.platform,
|
||||
user_info=userinfo,
|
||||
|
||||
@@ -11,7 +11,6 @@ from datetime import datetime
|
||||
from typing import Dict
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config
|
||||
import asyncio
|
||||
|
||||
logger = get_logger("processor")
|
||||
|
||||
|
||||
@@ -84,7 +84,6 @@ def init_prompt():
|
||||
Prompt(fetch_info_prompt, "fetch_info_prompt")
|
||||
|
||||
|
||||
|
||||
class RelationshipProcessor(BaseProcessor):
|
||||
log_prefix = "关系"
|
||||
|
||||
@@ -93,7 +92,9 @@ class RelationshipProcessor(BaseProcessor):
|
||||
|
||||
self.subheartflow_id = subheartflow_id
|
||||
self.info_fetching_cache: List[Dict[str, any]] = []
|
||||
self.info_fetched_cache: Dict[str, Dict[str, any]] = {} # {person_id: {"info": str, "ttl": int, "start_time": float}}
|
||||
self.info_fetched_cache: Dict[
|
||||
str, Dict[str, any]
|
||||
] = {} # {person_id: {"info": str, "ttl": int, "start_time": float}}
|
||||
self.person_engaged_cache: List[Dict[str, any]] = [] # [{person_id: str, start_time: float, rounds: int}]
|
||||
self.grace_period_rounds = 5
|
||||
|
||||
@@ -156,26 +157,27 @@ class RelationshipProcessor(BaseProcessor):
|
||||
for record in list(self.person_engaged_cache):
|
||||
record["rounds"] += 1
|
||||
time_elapsed = current_time - record["start_time"]
|
||||
message_count = len(get_raw_msg_by_timestamp_with_chat(self.subheartflow_id, record["start_time"], current_time))
|
||||
message_count = len(
|
||||
get_raw_msg_by_timestamp_with_chat(self.subheartflow_id, record["start_time"], current_time)
|
||||
)
|
||||
|
||||
print(record)
|
||||
|
||||
# 根据消息数量和时间设置不同的触发条件
|
||||
should_trigger = (
|
||||
message_count >= 50 or # 50条消息必定满足
|
||||
(message_count >= 35 and time_elapsed >= 300) or # 35条且10分钟
|
||||
(message_count >= 25 and time_elapsed >= 900) or # 25条且30分钟
|
||||
(message_count >= 10 and time_elapsed >= 2000) # 10条且1小时
|
||||
message_count >= 50 # 50条消息必定满足
|
||||
or (message_count >= 35 and time_elapsed >= 300) # 35条且10分钟
|
||||
or (message_count >= 25 and time_elapsed >= 900) # 25条且30分钟
|
||||
or (message_count >= 10 and time_elapsed >= 2000) # 10条且1小时
|
||||
)
|
||||
|
||||
if should_trigger:
|
||||
logger.info(f"{self.log_prefix} 用户 {record['person_id']} 满足关系构建条件,开始构建关系。消息数:{message_count},时长:{time_elapsed:.0f}秒")
|
||||
logger.info(
|
||||
f"{self.log_prefix} 用户 {record['person_id']} 满足关系构建条件,开始构建关系。消息数:{message_count},时长:{time_elapsed:.0f}秒"
|
||||
)
|
||||
asyncio.create_task(
|
||||
self.update_impression_on_cache_expiry(
|
||||
record["person_id"],
|
||||
self.subheartflow_id,
|
||||
record["start_time"],
|
||||
current_time
|
||||
record["person_id"], self.subheartflow_id, record["start_time"], current_time
|
||||
)
|
||||
)
|
||||
self.person_engaged_cache.remove(record)
|
||||
@@ -187,12 +189,16 @@ class RelationshipProcessor(BaseProcessor):
|
||||
if self.info_fetched_cache[person_id][info_type]["ttl"] <= 0:
|
||||
# 在删除前查找匹配的info_fetching_cache记录
|
||||
matched_record = None
|
||||
min_time_diff = float('inf')
|
||||
min_time_diff = float("inf")
|
||||
for record in self.info_fetching_cache:
|
||||
if (record["person_id"] == person_id and
|
||||
record["info_type"] == info_type and
|
||||
not record["forget"]):
|
||||
time_diff = abs(record["start_time"] - self.info_fetched_cache[person_id][info_type]["start_time"])
|
||||
if (
|
||||
record["person_id"] == person_id
|
||||
and record["info_type"] == info_type
|
||||
and not record["forget"]
|
||||
):
|
||||
time_diff = abs(
|
||||
record["start_time"] - self.info_fetched_cache[person_id][info_type]["start_time"]
|
||||
)
|
||||
if time_diff < min_time_diff:
|
||||
min_time_diff = time_diff
|
||||
matched_record = record
|
||||
@@ -238,13 +244,15 @@ class RelationshipProcessor(BaseProcessor):
|
||||
for person_name, info_type in content_json.items():
|
||||
person_id = person_info_manager.get_person_id_by_person_name(person_name)
|
||||
if person_id:
|
||||
self.info_fetching_cache.append({
|
||||
self.info_fetching_cache.append(
|
||||
{
|
||||
"person_id": person_id,
|
||||
"person_name": person_name,
|
||||
"info_type": info_type,
|
||||
"start_time": time.time(),
|
||||
"forget": False,
|
||||
})
|
||||
}
|
||||
)
|
||||
if len(self.info_fetching_cache) > 20:
|
||||
self.info_fetching_cache.pop(0)
|
||||
else:
|
||||
@@ -256,18 +264,18 @@ class RelationshipProcessor(BaseProcessor):
|
||||
# 检查person_engaged_cache中是否已存在该person_id
|
||||
person_exists = any(record["person_id"] == person_id for record in self.person_engaged_cache)
|
||||
if not person_exists:
|
||||
self.person_engaged_cache.append({
|
||||
"person_id": person_id,
|
||||
"start_time": time.time(),
|
||||
"rounds": 0
|
||||
})
|
||||
self.person_engaged_cache.append(
|
||||
{"person_id": person_id, "start_time": time.time(), "rounds": 0}
|
||||
)
|
||||
|
||||
if ENABLE_INSTANT_INFO_EXTRACTION:
|
||||
# 收集即时提取任务
|
||||
instant_tasks.append((person_id, info_type, time.time()))
|
||||
else:
|
||||
# 使用原来的异步模式
|
||||
async_tasks.append(asyncio.create_task(self.fetch_person_info(person_id, [info_type], start_time=time.time())))
|
||||
async_tasks.append(
|
||||
asyncio.create_task(self.fetch_person_info(person_id, [info_type], start_time=time.time()))
|
||||
)
|
||||
|
||||
# 执行即时提取任务
|
||||
if ENABLE_INSTANT_INFO_EXTRACTION and instant_tasks:
|
||||
@@ -314,8 +322,7 @@ class RelationshipProcessor(BaseProcessor):
|
||||
info_type = record["info_type"]
|
||||
|
||||
# 检查是否已经在info_fetched_cache中有结果
|
||||
if (person_id in self.info_fetched_cache and
|
||||
info_type in self.info_fetched_cache[person_id]):
|
||||
if person_id in self.info_fetched_cache and info_type in self.info_fetched_cache[person_id]:
|
||||
continue
|
||||
|
||||
# 按人物组织正在调取的信息
|
||||
@@ -343,8 +350,7 @@ class RelationshipProcessor(BaseProcessor):
|
||||
extraction_tasks = []
|
||||
for person_id, info_type, start_time in instant_tasks:
|
||||
# 检查缓存中是否已存在且未过期的信息
|
||||
if (person_id in self.info_fetched_cache and
|
||||
info_type in self.info_fetched_cache[person_id]):
|
||||
if person_id in self.info_fetched_cache and info_type in self.info_fetched_cache[person_id]:
|
||||
logger.info(f"{self.log_prefix} 用户 {person_id} 的 {info_type} 信息已存在且未过期,跳过调取。")
|
||||
continue
|
||||
|
||||
@@ -356,7 +362,6 @@ class RelationshipProcessor(BaseProcessor):
|
||||
await asyncio.gather(*extraction_tasks, return_exceptions=True)
|
||||
logger.info(f"{self.log_prefix} [即时提取] 批量提取完成")
|
||||
|
||||
|
||||
async def _fetch_single_info_instant(self, person_id: str, info_type: str, start_time: float):
|
||||
"""
|
||||
使用小模型提取单个信息类型
|
||||
@@ -374,10 +379,7 @@ class RelationshipProcessor(BaseProcessor):
|
||||
|
||||
points = await person_info_manager.get_value(person_id, "points")
|
||||
if points:
|
||||
points_text = "\n".join([
|
||||
f"{point[2]}:{point[0]}"
|
||||
for point in points
|
||||
])
|
||||
points_text = "\n".join([f"{point[2]}:{point[0]}" for point in points])
|
||||
else:
|
||||
points_text = "你不记得ta最近发生了什么"
|
||||
|
||||
@@ -410,7 +412,9 @@ class RelationshipProcessor(BaseProcessor):
|
||||
"person_name": person_name,
|
||||
"unknow": False,
|
||||
}
|
||||
logger.info(f"{self.log_prefix} [即时提取] 成功获取 {person_name} 的 {info_type}: {info_content}")
|
||||
logger.info(
|
||||
f"{self.log_prefix} [即时提取] 成功获取 {person_name} 的 {info_type}: {info_content}"
|
||||
)
|
||||
else:
|
||||
if person_id not in self.info_fetched_cache:
|
||||
self.info_fetched_cache[person_id] = {}
|
||||
@@ -423,7 +427,9 @@ class RelationshipProcessor(BaseProcessor):
|
||||
}
|
||||
logger.info(f"{self.log_prefix} [即时提取] {person_name} 的 {info_type} 信息不明确")
|
||||
else:
|
||||
logger.warning(f"{self.log_prefix} [即时提取] 小模型返回空结果,获取 {person_name} 的 {info_type} 信息失败。")
|
||||
logger.warning(
|
||||
f"{self.log_prefix} [即时提取] 小模型返回空结果,获取 {person_name} 的 {info_type} 信息失败。"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} [即时提取] 执行小模型请求获取用户信息时出错: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
@@ -436,8 +442,7 @@ class RelationshipProcessor(BaseProcessor):
|
||||
info_types_to_fetch = []
|
||||
|
||||
for info_type in info_types:
|
||||
if (person_id in self.info_fetched_cache and
|
||||
info_type in self.info_fetched_cache[person_id]):
|
||||
if person_id in self.info_fetched_cache and info_type in self.info_fetched_cache[person_id]:
|
||||
logger.info(f"{self.log_prefix} 用户 {person_id} 的 {info_type} 信息已存在且未过期,跳过调取。")
|
||||
continue
|
||||
info_types_to_fetch.append(info_type)
|
||||
@@ -454,7 +459,7 @@ class RelationshipProcessor(BaseProcessor):
|
||||
info_json_str = ""
|
||||
for info_type in info_types_to_fetch:
|
||||
info_type_str += f"{info_type},"
|
||||
info_json_str += f"\"{info_type}\": \"信息内容\","
|
||||
info_json_str += f'"{info_type}": "信息内容",'
|
||||
info_type_str = info_type_str[:-1]
|
||||
info_json_str = info_json_str[:-1]
|
||||
|
||||
@@ -464,18 +469,13 @@ class RelationshipProcessor(BaseProcessor):
|
||||
else:
|
||||
impression_block = f"{person_impression}"
|
||||
|
||||
|
||||
points = await person_info_manager.get_value(person_id, "points")
|
||||
|
||||
if points:
|
||||
points_text = "\n".join([
|
||||
f"{point[2]}:{point[0]}"
|
||||
for point in points
|
||||
])
|
||||
points_text = "\n".join([f"{point[2]}:{point[0]}" for point in points])
|
||||
else:
|
||||
points_text = "你不记得ta最近发生了什么"
|
||||
|
||||
|
||||
prompt = (await global_prompt_manager.get_prompt_async("fetch_info_prompt")).format(
|
||||
name_block=name_block,
|
||||
info_type=info_type_str,
|
||||
@@ -510,7 +510,7 @@ class RelationshipProcessor(BaseProcessor):
|
||||
self.info_fetched_cache[person_id] = {}
|
||||
|
||||
self.info_fetched_cache[person_id][info_type] = {
|
||||
"info":"unknow",
|
||||
"info": "unknow",
|
||||
"ttl": 10,
|
||||
"start_time": start_time,
|
||||
"person_name": person_name,
|
||||
@@ -525,16 +525,12 @@ class RelationshipProcessor(BaseProcessor):
|
||||
logger.error(f"{self.log_prefix} 执行LLM请求获取用户信息时出错: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
async def update_impression_on_cache_expiry(
|
||||
self, person_id: str, chat_id: str, start_time: float, end_time: float
|
||||
):
|
||||
async def update_impression_on_cache_expiry(self, person_id: str, chat_id: str, start_time: float, end_time: float):
|
||||
"""
|
||||
在缓存过期时,获取聊天记录并更新用户印象
|
||||
"""
|
||||
logger.info(f"缓存过期,开始为 {person_id} 更新印象。时间范围:{start_time} -> {end_time}")
|
||||
try:
|
||||
|
||||
|
||||
impression_messages = get_raw_msg_by_timestamp_with_chat(chat_id, start_time, end_time)
|
||||
if impression_messages:
|
||||
logger.info(f"为 {person_id} 获取到 {len(impression_messages)} 条消息用于印象更新。")
|
||||
|
||||
@@ -122,9 +122,7 @@ class SelfProcessor(BaseProcessor):
|
||||
)
|
||||
# 获取聊天内容
|
||||
chat_observe_info = observation.get_observe_info()
|
||||
person_list = observation.person_list
|
||||
if isinstance(observation, HFCloopObservation):
|
||||
# hfcloop_observe_info = observation.get_observe_info()
|
||||
pass
|
||||
|
||||
nickname_str = ""
|
||||
@@ -134,8 +132,6 @@ class SelfProcessor(BaseProcessor):
|
||||
|
||||
personality_block = individuality.get_personality_prompt(x_person=2, level=2)
|
||||
|
||||
|
||||
|
||||
identity_block = individuality.get_identity_prompt(x_person=2, level=2)
|
||||
|
||||
prompt = (await global_prompt_manager.get_prompt_async("indentify_prompt")).format(
|
||||
|
||||
@@ -118,7 +118,7 @@ class ToolProcessor(BaseProcessor):
|
||||
is_group_chat = observation.is_group_chat
|
||||
|
||||
chat_observe_info = observation.get_observe_info()
|
||||
person_list = observation.person_list
|
||||
# person_list = observation.person_list
|
||||
|
||||
memory_str = ""
|
||||
if running_memorys:
|
||||
@@ -141,9 +141,7 @@ class ToolProcessor(BaseProcessor):
|
||||
|
||||
# 调用LLM,专注于工具使用
|
||||
# logger.info(f"开始执行工具调用{prompt}")
|
||||
response, other_info = await self.llm_model.generate_response_async(
|
||||
prompt=prompt, tools=tools
|
||||
)
|
||||
response, other_info = await self.llm_model.generate_response_async(prompt=prompt, tools=tools)
|
||||
|
||||
if len(other_info) == 3:
|
||||
reasoning_content, model_name, tool_calls = other_info
|
||||
|
||||
@@ -118,10 +118,8 @@ class WorkingMemoryProcessor(BaseProcessor):
|
||||
memory_str=memory_choose_str,
|
||||
)
|
||||
|
||||
|
||||
# print(f"prompt: {prompt}")
|
||||
|
||||
|
||||
# 调用LLM处理记忆
|
||||
content = ""
|
||||
try:
|
||||
|
||||
@@ -5,9 +5,6 @@ from src.chat.focus_chat.replyer.default_replyer import DefaultReplyer
|
||||
from src.chat.focus_chat.expressors.default_expressor import DefaultExpressor
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
from src.common.logger_manager import get_logger
|
||||
import importlib
|
||||
import pkgutil
|
||||
import os
|
||||
|
||||
# 不再需要导入动作类,因为已经在main.py中导入
|
||||
# import src.chat.actions.default_actions # noqa
|
||||
@@ -120,7 +117,7 @@ class ActionManager:
|
||||
try:
|
||||
# 插件动作已在main.py中加载,这里只需要从_ACTION_REGISTRY获取
|
||||
self._load_registered_actions()
|
||||
logger.info(f"从注册表加载插件动作成功")
|
||||
logger.info("从注册表加载插件动作成功")
|
||||
except Exception as e:
|
||||
logger.error(f"加载插件动作失败: {e}")
|
||||
|
||||
|
||||
@@ -106,7 +106,9 @@ class ActionModifier:
|
||||
if not chat_context.check_types(data["associated_types"]):
|
||||
type_mismatched_actions.append(action_name)
|
||||
associated_types_str = ", ".join(data["associated_types"])
|
||||
logger.info(f"{self.log_prefix}移除动作: {action_name},原因: 关联类型不匹配(需要: {associated_types_str})")
|
||||
logger.info(
|
||||
f"{self.log_prefix}移除动作: {action_name},原因: 关联类型不匹配(需要: {associated_types_str})"
|
||||
)
|
||||
|
||||
if type_mismatched_actions:
|
||||
# 合并到移除列表中
|
||||
@@ -123,7 +125,9 @@ class ActionModifier:
|
||||
self.action_manager.remove_action_from_using(action_name)
|
||||
logger.debug(f"{self.log_prefix}应用移除动作: {action_name},原因集合: {reasons}")
|
||||
|
||||
logger.info(f"{self.log_prefix}传统动作修改完成,当前使用动作: {list(self.action_manager.get_using_actions().keys())}")
|
||||
logger.info(
|
||||
f"{self.log_prefix}传统动作修改完成,当前使用动作: {list(self.action_manager.get_using_actions().keys())}"
|
||||
)
|
||||
|
||||
# === 第二阶段:激活类型判定 ===
|
||||
# 如果提供了聊天上下文,则进行激活类型判定
|
||||
@@ -180,7 +184,9 @@ class ActionModifier:
|
||||
|
||||
logger.info(f"{self.log_prefix}激活类型判定完成,最终可用动作: {list(final_activated_actions.keys())}")
|
||||
|
||||
logger.info(f"{self.log_prefix}完整动作修改流程结束,最终动作集: {list(self.action_manager.get_using_actions().keys())}")
|
||||
logger.info(
|
||||
f"{self.log_prefix}完整动作修改流程结束,最终动作集: {list(self.action_manager.get_using_actions().keys())}"
|
||||
)
|
||||
|
||||
async def _apply_activation_type_filtering(
|
||||
self,
|
||||
@@ -271,10 +277,7 @@ class ActionModifier:
|
||||
return activated_actions
|
||||
|
||||
async def process_actions_for_planner(
|
||||
self,
|
||||
observed_messages_str: str = "",
|
||||
chat_context: Optional[str] = None,
|
||||
extra_context: Optional[str] = None
|
||||
self, observed_messages_str: str = "", chat_context: Optional[str] = None, extra_context: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
[已废弃] 此方法现在已被整合到 modify_actions() 中
|
||||
@@ -286,7 +289,9 @@ class ActionModifier:
|
||||
1. 主循环调用 modify_actions() 处理完整的动作管理流程
|
||||
2. 规划器直接使用 ActionManager.get_using_actions() 获取最终动作集
|
||||
"""
|
||||
logger.warning(f"{self.log_prefix}process_actions_for_planner() 已废弃,建议规划器直接使用 ActionManager.get_using_actions()")
|
||||
logger.warning(
|
||||
f"{self.log_prefix}process_actions_for_planner() 已废弃,建议规划器直接使用 ActionManager.get_using_actions()"
|
||||
)
|
||||
|
||||
# 为了向后兼容,仍然返回当前使用的动作集
|
||||
current_using_actions = self.action_manager.get_using_actions()
|
||||
@@ -303,9 +308,7 @@ class ActionModifier:
|
||||
def _generate_context_hash(self, chat_content: str) -> str:
|
||||
"""生成上下文的哈希值用于缓存"""
|
||||
context_content = f"{chat_content}"
|
||||
return hashlib.md5(context_content.encode('utf-8')).hexdigest()
|
||||
|
||||
|
||||
return hashlib.md5(context_content.encode("utf-8")).hexdigest()
|
||||
|
||||
async def _process_llm_judge_actions_parallel(
|
||||
self,
|
||||
@@ -337,11 +340,14 @@ class ActionModifier:
|
||||
cache_key = f"{action_name}_{current_context_hash}"
|
||||
|
||||
# 检查是否有有效的缓存
|
||||
if (cache_key in self._llm_judge_cache and
|
||||
current_time - self._llm_judge_cache[cache_key]["timestamp"] < self._cache_expiry_time):
|
||||
|
||||
if (
|
||||
cache_key in self._llm_judge_cache
|
||||
and current_time - self._llm_judge_cache[cache_key]["timestamp"] < self._cache_expiry_time
|
||||
):
|
||||
results[action_name] = self._llm_judge_cache[cache_key]["result"]
|
||||
logger.debug(f"{self.log_prefix}使用缓存结果 {action_name}: {'激活' if results[action_name] else '未激活'}")
|
||||
logger.debug(
|
||||
f"{self.log_prefix}使用缓存结果 {action_name}: {'激活' if results[action_name] else '未激活'}"
|
||||
)
|
||||
else:
|
||||
# 需要进行LLM判定
|
||||
tasks_to_run[action_name] = action_info
|
||||
@@ -368,7 +374,7 @@ class ActionModifier:
|
||||
task_results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# 处理结果并更新缓存
|
||||
for i, (action_name, result) in enumerate(zip(task_names, task_results)):
|
||||
for _, (action_name, result) in enumerate(zip(task_names, task_results)):
|
||||
if isinstance(result, Exception):
|
||||
logger.error(f"{self.log_prefix}LLM判定action {action_name} 时出错: {result}")
|
||||
results[action_name] = False
|
||||
@@ -377,10 +383,7 @@ class ActionModifier:
|
||||
|
||||
# 更新缓存
|
||||
cache_key = f"{action_name}_{current_context_hash}"
|
||||
self._llm_judge_cache[cache_key] = {
|
||||
"result": result,
|
||||
"timestamp": current_time
|
||||
}
|
||||
self._llm_judge_cache[cache_key] = {"result": result, "timestamp": current_time}
|
||||
|
||||
logger.debug(f"{self.log_prefix}并行LLM判定完成,耗时: {time.time() - current_time:.2f}s")
|
||||
|
||||
@@ -434,7 +437,6 @@ class ActionModifier:
|
||||
action_require = action_info.get("require", [])
|
||||
custom_prompt = action_info.get("llm_judge_prompt", "")
|
||||
|
||||
|
||||
# 构建基础判定提示词
|
||||
base_prompt = f"""
|
||||
你需要判断在当前聊天情况下,是否应该激活名为"{action_name}"的动作。
|
||||
@@ -452,7 +454,6 @@ class ActionModifier:
|
||||
if chat_content:
|
||||
base_prompt += f"\n当前聊天记录:\n{chat_content}\n"
|
||||
|
||||
|
||||
base_prompt += """
|
||||
请根据以上信息判断是否应该激活这个动作。
|
||||
只需要回答"是"或"否",不要有其他内容。
|
||||
@@ -467,10 +468,11 @@ class ActionModifier:
|
||||
# print(base_prompt)
|
||||
print(f"LLM判定动作 {action_name}:响应='{response}'")
|
||||
|
||||
|
||||
should_activate = "是" in response or "yes" in response or "true" in response
|
||||
|
||||
logger.debug(f"{self.log_prefix}LLM判定动作 {action_name}:响应='{response}',结果={'激活' if should_activate else '不激活'}")
|
||||
logger.debug(
|
||||
f"{self.log_prefix}LLM判定动作 {action_name}:响应='{response}',结果={'激活' if should_activate else '不激活'}"
|
||||
)
|
||||
return should_activate
|
||||
|
||||
except Exception as e:
|
||||
@@ -568,7 +570,9 @@ class ActionModifier:
|
||||
result["remove"].append("no_reply")
|
||||
result["remove"].append("reply")
|
||||
no_reply_ratio = no_reply_count / len(recent_cycles)
|
||||
logger.info(f"{self.log_prefix}检测到高no_reply比例: {no_reply_ratio:.2f},达到退出聊天阈值,将添加exit_focus_chat并移除no_reply/reply动作")
|
||||
logger.info(
|
||||
f"{self.log_prefix}检测到高no_reply比例: {no_reply_ratio:.2f},达到退出聊天阈值,将添加exit_focus_chat并移除no_reply/reply动作"
|
||||
)
|
||||
|
||||
# 计算连续回复的相关阈值
|
||||
|
||||
@@ -593,7 +597,7 @@ class ActionModifier:
|
||||
if len(last_max_reply_num) >= max_reply_num and all(last_max_reply_num):
|
||||
# 如果最近max_reply_num次都是reply,直接移除
|
||||
result["remove"].append("reply")
|
||||
reply_count = len(last_max_reply_num) - no_reply_count
|
||||
# reply_count = len(last_max_reply_num) - no_reply_count
|
||||
logger.info(
|
||||
f"{self.log_prefix}移除reply动作,原因: 连续回复过多(最近{len(last_max_reply_num)}次全是reply,超过阈值{max_reply_num})"
|
||||
)
|
||||
@@ -622,8 +626,6 @@ class ActionModifier:
|
||||
f"{self.log_prefix}连续回复检测:最近{one_thres_reply_num}次全是reply,{removal_probability:.2f}概率移除,未触发"
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
f"{self.log_prefix}连续回复检测:无需移除reply动作,最近回复模式正常"
|
||||
)
|
||||
logger.debug(f"{self.log_prefix}连续回复检测:无需移除reply动作,最近回复模式正常")
|
||||
|
||||
return result
|
||||
|
||||
@@ -197,7 +197,6 @@ class ActionPlanner(BasePlanner):
|
||||
logger.info(f"{self.log_prefix}规划器原始响应: {llm_content}")
|
||||
logger.info(f"{self.log_prefix}规划器推理: {reasoning_content}")
|
||||
|
||||
|
||||
except Exception as req_e:
|
||||
logger.error(f"{self.log_prefix}LLM 请求执行失败: {req_e}")
|
||||
reasoning = f"LLM 请求失败,你的模型出现问题: {req_e}"
|
||||
@@ -303,7 +302,6 @@ class ActionPlanner(BasePlanner):
|
||||
) -> str:
|
||||
"""构建 Planner LLM 的提示词 (获取模板并填充数据)"""
|
||||
try:
|
||||
|
||||
if relation_info_block:
|
||||
relation_info_block = f"以下是你和别人的关系描述:\n{relation_info_block}"
|
||||
else:
|
||||
@@ -351,16 +349,14 @@ class ActionPlanner(BasePlanner):
|
||||
param_text = "\n"
|
||||
for param_name, param_description in using_actions_info["parameters"].items():
|
||||
param_text += f' "{param_name}":"{param_description}"\n'
|
||||
param_text = param_text.rstrip('\n')
|
||||
param_text = param_text.rstrip("\n")
|
||||
else:
|
||||
param_text = ""
|
||||
|
||||
|
||||
require_text = ""
|
||||
for require_item in using_actions_info["require"]:
|
||||
require_text += f"- {require_item}\n"
|
||||
require_text = require_text.rstrip('\n')
|
||||
|
||||
require_text = require_text.rstrip("\n")
|
||||
|
||||
using_action_prompt = using_action_prompt.format(
|
||||
action_name=using_actions_name,
|
||||
|
||||
@@ -240,14 +240,13 @@ class DefaultReplyer:
|
||||
# current_temp = float(global_config.model.normal["temp"]) * arousal_multiplier
|
||||
# self.express_model.params["temperature"] = current_temp # 动态调整温度
|
||||
|
||||
|
||||
reply_to = action_data.get("reply_to", "none")
|
||||
|
||||
sender = ""
|
||||
targer = ""
|
||||
if ":" in reply_to or ":" in reply_to:
|
||||
# 使用正则表达式匹配中文或英文冒号
|
||||
parts = re.split(pattern=r'[::]', string=reply_to, maxsplit=1)
|
||||
parts = re.split(pattern=r"[::]", string=reply_to, maxsplit=1)
|
||||
if len(parts) == 2:
|
||||
sender = parts[0].strip()
|
||||
targer = parts[1].strip()
|
||||
@@ -375,8 +374,6 @@ class DefaultReplyer:
|
||||
style_habbits_str = "\n".join(style_habbits)
|
||||
grammar_habbits_str = "\n".join(grammar_habbits)
|
||||
|
||||
|
||||
|
||||
# 关键词检测与反应
|
||||
keywords_reaction_prompt = ""
|
||||
try:
|
||||
@@ -409,15 +406,14 @@ class DefaultReplyer:
|
||||
# logger.debug("开始构建 focus prompt")
|
||||
|
||||
if sender_name:
|
||||
reply_target_block = f"现在{sender_name}说的:{target_message}。引起了你的注意,你想要在群里发言或者回复这条消息。"
|
||||
reply_target_block = (
|
||||
f"现在{sender_name}说的:{target_message}。引起了你的注意,你想要在群里发言或者回复这条消息。"
|
||||
)
|
||||
elif target_message:
|
||||
reply_target_block = f"现在{target_message}引起了你的注意,你想要在群里发言或者回复这条消息。"
|
||||
else:
|
||||
reply_target_block = "现在,你想要在群里发言或者回复消息。"
|
||||
|
||||
|
||||
|
||||
|
||||
# --- Choose template based on chat type ---
|
||||
if is_group_chat:
|
||||
template_name = "default_replyer_prompt"
|
||||
@@ -667,7 +663,7 @@ def find_similar_expressions(input_text: str, expressions: List[Dict], top_k: in
|
||||
return []
|
||||
|
||||
# 准备文本数据
|
||||
texts = [expr['situation'] for expr in expressions]
|
||||
texts = [expr["situation"] for expr in expressions]
|
||||
texts.append(input_text) # 添加输入文本
|
||||
|
||||
# 使用TF-IDF向量化
|
||||
|
||||
@@ -68,7 +68,6 @@ class ChattingObservation(Observation):
|
||||
self.talking_message = initial_messages
|
||||
self.talking_message_str = build_readable_messages(self.talking_message, show_actions=True)
|
||||
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""将观察对象转换为可序列化的字典"""
|
||||
return {
|
||||
|
||||
@@ -42,9 +42,7 @@ class SubHeartflow:
|
||||
self.history_chat_state: List[Tuple[ChatState, float]] = []
|
||||
|
||||
self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_id)
|
||||
self.log_prefix = (
|
||||
chat_manager.get_stream_name(self.subheartflow_id) or self.subheartflow_id
|
||||
)
|
||||
self.log_prefix = chat_manager.get_stream_name(self.subheartflow_id) or self.subheartflow_id
|
||||
# 兴趣消息集合
|
||||
self.interest_dict: Dict[str, tuple[MessageRecv, float, bool]] = {}
|
||||
|
||||
@@ -199,7 +197,6 @@ class SubHeartflow:
|
||||
# 如果实例不存在,则创建并启动
|
||||
logger.info(f"{log_prefix} 麦麦准备开始专注聊天...")
|
||||
try:
|
||||
|
||||
self.heart_fc_instance = HeartFChatting(
|
||||
chat_id=self.subheartflow_id,
|
||||
# observations=self.observations,
|
||||
|
||||
@@ -3,7 +3,7 @@ import os
|
||||
|
||||
from .global_logger import logger
|
||||
from .lpmmconfig import global_config
|
||||
from src.chat.knowledge.utils import get_sha256
|
||||
from src.chat.knowledge.utils.hash import get_sha256
|
||||
|
||||
|
||||
def load_raw_data(path: str = None) -> tuple[list[str], list[str]]:
|
||||
|
||||
@@ -346,7 +346,9 @@ class Hippocampus:
|
||||
# 使用LLM提取关键词
|
||||
topic_num = min(5, max(1, int(len(text) * 0.1))) # 根据文本长度动态调整关键词数量
|
||||
# logger.info(f"提取关键词数量: {topic_num}")
|
||||
topics_response, (reasoning_content, model_name) = await self.model_summary.generate_response_async(self.find_topic_llm(text, topic_num))
|
||||
topics_response, (reasoning_content, model_name) = await self.model_summary.generate_response_async(
|
||||
self.find_topic_llm(text, topic_num)
|
||||
)
|
||||
|
||||
# 提取关键词
|
||||
keywords = re.findall(r"<([^>]+)>", topics_response)
|
||||
@@ -701,7 +703,9 @@ class Hippocampus:
|
||||
# 使用LLM提取关键词
|
||||
topic_num = min(5, max(1, int(len(text) * 0.1))) # 根据文本长度动态调整关键词数量
|
||||
# logger.info(f"提取关键词数量: {topic_num}")
|
||||
topics_response, (reasoning_content, model_name) = await self.model_summary.generate_response_async(self.find_topic_llm(text, topic_num))
|
||||
topics_response, (reasoning_content, model_name) = await self.model_summary.generate_response_async(
|
||||
self.find_topic_llm(text, topic_num)
|
||||
)
|
||||
|
||||
# 提取关键词
|
||||
keywords = re.findall(r"<([^>]+)>", topics_response)
|
||||
@@ -929,22 +933,26 @@ class EntorhinalCortex:
|
||||
continue
|
||||
|
||||
if concept not in db_nodes:
|
||||
nodes_to_create.append({
|
||||
nodes_to_create.append(
|
||||
{
|
||||
"concept": concept,
|
||||
"memory_items": memory_items_json,
|
||||
"hash": memory_hash,
|
||||
"created_time": created_time,
|
||||
"last_modified": last_modified,
|
||||
})
|
||||
}
|
||||
)
|
||||
else:
|
||||
db_node = db_nodes[concept]
|
||||
if db_node.hash != memory_hash:
|
||||
nodes_to_update.append({
|
||||
nodes_to_update.append(
|
||||
{
|
||||
"concept": concept,
|
||||
"memory_items": memory_items_json,
|
||||
"hash": memory_hash,
|
||||
"last_modified": last_modified,
|
||||
})
|
||||
}
|
||||
)
|
||||
|
||||
# 计算需要删除的节点
|
||||
memory_concepts = {concept for concept, _ in memory_nodes}
|
||||
@@ -954,13 +962,13 @@ class EntorhinalCortex:
|
||||
if nodes_to_create:
|
||||
batch_size = 100
|
||||
for i in range(0, len(nodes_to_create), batch_size):
|
||||
batch = nodes_to_create[i:i + batch_size]
|
||||
batch = nodes_to_create[i : i + batch_size]
|
||||
GraphNodes.insert_many(batch).execute()
|
||||
|
||||
if nodes_to_update:
|
||||
batch_size = 100
|
||||
for i in range(0, len(nodes_to_update), batch_size):
|
||||
batch = nodes_to_update[i:i + batch_size]
|
||||
batch = nodes_to_update[i : i + batch_size]
|
||||
for node_data in batch:
|
||||
GraphNodes.update(**{k: v for k, v in node_data.items() if k != "concept"}).where(
|
||||
GraphNodes.concept == node_data["concept"]
|
||||
@@ -992,22 +1000,26 @@ class EntorhinalCortex:
|
||||
last_modified = data.get("last_modified", current_time)
|
||||
|
||||
if edge_key not in db_edge_dict:
|
||||
edges_to_create.append({
|
||||
edges_to_create.append(
|
||||
{
|
||||
"source": source,
|
||||
"target": target,
|
||||
"strength": strength,
|
||||
"hash": edge_hash,
|
||||
"created_time": created_time,
|
||||
"last_modified": last_modified,
|
||||
})
|
||||
}
|
||||
)
|
||||
elif db_edge_dict[edge_key]["hash"] != edge_hash:
|
||||
edges_to_update.append({
|
||||
edges_to_update.append(
|
||||
{
|
||||
"source": source,
|
||||
"target": target,
|
||||
"strength": strength,
|
||||
"hash": edge_hash,
|
||||
"last_modified": last_modified,
|
||||
})
|
||||
}
|
||||
)
|
||||
|
||||
# 计算需要删除的边
|
||||
memory_edge_keys = {(source, target) for source, target, _ in memory_edges}
|
||||
@@ -1017,13 +1029,13 @@ class EntorhinalCortex:
|
||||
if edges_to_create:
|
||||
batch_size = 100
|
||||
for i in range(0, len(edges_to_create), batch_size):
|
||||
batch = edges_to_create[i:i + batch_size]
|
||||
batch = edges_to_create[i : i + batch_size]
|
||||
GraphEdges.insert_many(batch).execute()
|
||||
|
||||
if edges_to_update:
|
||||
batch_size = 100
|
||||
for i in range(0, len(edges_to_update), batch_size):
|
||||
batch = edges_to_update[i:i + batch_size]
|
||||
batch = edges_to_update[i : i + batch_size]
|
||||
for edge_data in batch:
|
||||
GraphEdges.update(**{k: v for k, v in edge_data.items() if k not in ["source", "target"]}).where(
|
||||
(GraphEdges.source == edge_data["source"]) & (GraphEdges.target == edge_data["target"])
|
||||
@@ -1031,9 +1043,7 @@ class EntorhinalCortex:
|
||||
|
||||
if edges_to_delete:
|
||||
for source, target in edges_to_delete:
|
||||
GraphEdges.delete().where(
|
||||
(GraphEdges.source == source) & (GraphEdges.target == target)
|
||||
).execute()
|
||||
GraphEdges.delete().where((GraphEdges.source == source) & (GraphEdges.target == target)).execute()
|
||||
|
||||
end_time = time.time()
|
||||
logger.success(f"[同步] 总耗时: {end_time - start_time:.2f}秒")
|
||||
@@ -1069,13 +1079,15 @@ class EntorhinalCortex:
|
||||
if not memory_items_json:
|
||||
continue
|
||||
|
||||
nodes_data.append({
|
||||
nodes_data.append(
|
||||
{
|
||||
"concept": concept,
|
||||
"memory_items": memory_items_json,
|
||||
"hash": self.hippocampus.calculate_node_hash(concept, memory_items),
|
||||
"created_time": data.get("created_time", current_time),
|
||||
"last_modified": data.get("last_modified", current_time),
|
||||
})
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"准备节点 {concept} 数据时发生错误: {e}")
|
||||
continue
|
||||
@@ -1084,14 +1096,16 @@ class EntorhinalCortex:
|
||||
edges_data = []
|
||||
for source, target, data in memory_edges:
|
||||
try:
|
||||
edges_data.append({
|
||||
edges_data.append(
|
||||
{
|
||||
"source": source,
|
||||
"target": target,
|
||||
"strength": data.get("strength", 1),
|
||||
"hash": self.hippocampus.calculate_edge_hash(source, target),
|
||||
"created_time": data.get("created_time", current_time),
|
||||
"last_modified": data.get("last_modified", current_time),
|
||||
})
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"准备边 {source}-{target} 数据时发生错误: {e}")
|
||||
continue
|
||||
@@ -1102,7 +1116,7 @@ class EntorhinalCortex:
|
||||
batch_size = 500 # 增加批量大小
|
||||
with GraphNodes._meta.database.atomic():
|
||||
for i in range(0, len(nodes_data), batch_size):
|
||||
batch = nodes_data[i:i + batch_size]
|
||||
batch = nodes_data[i : i + batch_size]
|
||||
GraphNodes.insert_many(batch).execute()
|
||||
node_end = time.time()
|
||||
logger.info(f"[数据库] 写入 {len(nodes_data)} 个节点耗时: {node_end - node_start:.2f}秒")
|
||||
@@ -1113,7 +1127,7 @@ class EntorhinalCortex:
|
||||
batch_size = 500 # 增加批量大小
|
||||
with GraphEdges._meta.database.atomic():
|
||||
for i in range(0, len(edges_data), batch_size):
|
||||
batch = edges_data[i:i + batch_size]
|
||||
batch = edges_data[i : i + batch_size]
|
||||
GraphEdges.insert_many(batch).execute()
|
||||
edge_end = time.time()
|
||||
logger.info(f"[数据库] 写入 {len(edges_data)} 条边耗时: {edge_end - edge_start:.2f}秒")
|
||||
|
||||
@@ -24,7 +24,6 @@ class MessageStorage:
|
||||
else:
|
||||
filtered_processed_plain_text = ""
|
||||
|
||||
|
||||
if isinstance(message, MessageSending):
|
||||
display_message = message.display_message
|
||||
if display_message:
|
||||
|
||||
@@ -13,8 +13,6 @@ from src.chat.utils.prompt_builder import global_prompt_manager
|
||||
from .normal_chat_generator import NormalChatGenerator
|
||||
from ..message_receive.message import MessageSending, MessageRecv, MessageThinking, MessageSet
|
||||
from src.chat.message_receive.message_sender import message_manager
|
||||
from src.chat.utils.utils_image import image_path_to_base64
|
||||
from src.chat.emoji_system.emoji_manager import emoji_manager
|
||||
from src.chat.normal_chat.willing.willing_manager import willing_manager
|
||||
from src.chat.normal_chat.normal_chat_utils import get_recent_message_stats
|
||||
from src.config.config import global_config
|
||||
@@ -193,7 +191,9 @@ class NormalChat:
|
||||
return
|
||||
|
||||
timing_results = {}
|
||||
reply_probability = 1.0 if is_mentioned and global_config.normal_chat.mentioned_bot_inevitable_reply else 0.0 # 如果被提及,且开启了提及必回复,则基础概率为1,否则需要意愿判断
|
||||
reply_probability = (
|
||||
1.0 if is_mentioned and global_config.normal_chat.mentioned_bot_inevitable_reply else 0.0
|
||||
) # 如果被提及,且开启了提及必回复,则基础概率为1,否则需要意愿判断
|
||||
|
||||
# 意愿管理器:设置当前message信息
|
||||
willing_manager.setup(message, self.chat_stream, is_mentioned, interested_rate)
|
||||
@@ -269,12 +269,16 @@ class NormalChat:
|
||||
sender_name = self._get_sender_name(message)
|
||||
|
||||
no_action = {
|
||||
"action_result": {"action_type": "no_action", "action_data": {}, "reasoning": "规划器初始化默认", "is_parallel": True},
|
||||
"action_result": {
|
||||
"action_type": "no_action",
|
||||
"action_data": {},
|
||||
"reasoning": "规划器初始化默认",
|
||||
"is_parallel": True,
|
||||
},
|
||||
"chat_context": "",
|
||||
"action_prompt": "",
|
||||
}
|
||||
|
||||
|
||||
# 检查是否应该跳过规划
|
||||
if self.action_modifier.should_skip_planning():
|
||||
logger.debug(f"[{self.stream_name}] 没有可用动作,跳过规划")
|
||||
@@ -288,7 +292,9 @@ class NormalChat:
|
||||
reasoning = plan_result["action_result"]["reasoning"]
|
||||
is_parallel = plan_result["action_result"].get("is_parallel", False)
|
||||
|
||||
logger.info(f"[{self.stream_name}] Planner决策: {action_type}, 理由: {reasoning}, 并行执行: {is_parallel}")
|
||||
logger.info(
|
||||
f"[{self.stream_name}] Planner决策: {action_type}, 理由: {reasoning}, 并行执行: {is_parallel}"
|
||||
)
|
||||
self.action_type = action_type # 更新实例属性
|
||||
self.is_parallel_action = is_parallel # 新增:保存并行执行标志
|
||||
|
||||
@@ -307,7 +313,12 @@ class NormalChat:
|
||||
else:
|
||||
logger.warning(f"[{self.stream_name}] 额外动作 {action_type} 执行失败")
|
||||
|
||||
return {"action_type": action_type, "action_data": action_data, "reasoning": reasoning, "is_parallel": is_parallel}
|
||||
return {
|
||||
"action_type": action_type,
|
||||
"action_data": action_data,
|
||||
"reasoning": reasoning,
|
||||
"is_parallel": is_parallel,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.stream_name}] Planner执行失败: {e}")
|
||||
@@ -333,11 +344,17 @@ class NormalChat:
|
||||
logger.debug(f"[{self.stream_name}] 额外动作处理完成: {self.action_type}")
|
||||
|
||||
if not response_set or (
|
||||
self.enable_planner and self.action_type not in ["no_action", "change_to_focus_chat"] and not self.is_parallel_action
|
||||
self.enable_planner
|
||||
and self.action_type not in ["no_action", "change_to_focus_chat"]
|
||||
and not self.is_parallel_action
|
||||
):
|
||||
if not response_set:
|
||||
logger.info(f"[{self.stream_name}] 模型未生成回复内容")
|
||||
elif self.enable_planner and self.action_type not in ["no_action", "change_to_focus_chat"] and not self.is_parallel_action:
|
||||
elif (
|
||||
self.enable_planner
|
||||
and self.action_type not in ["no_action", "change_to_focus_chat"]
|
||||
and not self.is_parallel_action
|
||||
):
|
||||
logger.info(f"[{self.stream_name}] 模型选择其他动作(非并行动作)")
|
||||
# 如果模型未生成回复,移除思考消息
|
||||
container = await message_manager.get_container(self.stream_id) # 使用 self.stream_id
|
||||
@@ -364,7 +381,6 @@ class NormalChat:
|
||||
|
||||
# 检查 first_bot_msg 是否为 None (例如思考消息已被移除的情况)
|
||||
if first_bot_msg:
|
||||
|
||||
# 记录回复信息到最近回复列表中
|
||||
reply_info = {
|
||||
"time": time.time(),
|
||||
@@ -396,7 +412,6 @@ class NormalChat:
|
||||
await self._check_switch_to_focus()
|
||||
pass
|
||||
|
||||
|
||||
# with Timer("关系更新", timing_results):
|
||||
# await self._update_relationship(message, response_set)
|
||||
|
||||
|
||||
@@ -74,7 +74,7 @@ class NormalChatActionModifier:
|
||||
# 第二阶段:应用激活类型判定
|
||||
# 构建聊天内容 - 使用与planner一致的方式
|
||||
chat_content = ""
|
||||
if chat_stream and hasattr(chat_stream, 'stream_id'):
|
||||
if chat_stream and hasattr(chat_stream, "stream_id"):
|
||||
try:
|
||||
# 获取消息历史,使用与normal_chat_planner相同的方法
|
||||
message_list_before_now = get_raw_msg_before_timestamp_with_chat(
|
||||
@@ -105,9 +105,7 @@ class NormalChatActionModifier:
|
||||
# print(f"current_actions: {current_actions}")
|
||||
# print(f"chat_content: {chat_content}")
|
||||
final_activated_actions = await self._apply_normal_activation_filtering(
|
||||
current_actions,
|
||||
chat_content,
|
||||
message_content
|
||||
current_actions, chat_content, message_content
|
||||
)
|
||||
# print(f"final_activated_actions: {final_activated_actions}")
|
||||
|
||||
@@ -124,7 +122,11 @@ class NormalChatActionModifier:
|
||||
all_actions_to_remove.add(action_name)
|
||||
|
||||
# 统计移除原因(避免重复)
|
||||
activation_failed_actions = [name for name in current_actions.keys() if name not in final_activated_actions and name not in type_mismatched_actions]
|
||||
activation_failed_actions = [
|
||||
name
|
||||
for name in current_actions.keys()
|
||||
if name not in final_activated_actions and name not in type_mismatched_actions
|
||||
]
|
||||
if activation_failed_actions:
|
||||
reasons.append(f"移除{activation_failed_actions}(激活类型判定未通过)")
|
||||
|
||||
@@ -210,12 +212,7 @@ class NormalChatActionModifier:
|
||||
|
||||
# 3. 处理KEYWORD类型(关键词匹配)
|
||||
for action_name, action_info in keyword_actions.items():
|
||||
should_activate = self._check_keyword_activation(
|
||||
action_name,
|
||||
action_info,
|
||||
chat_content,
|
||||
message_content
|
||||
)
|
||||
should_activate = self._check_keyword_activation(action_name, action_info, chat_content, message_content)
|
||||
if should_activate:
|
||||
activated_actions[action_name] = action_info
|
||||
keywords = action_info.get("activation_keywords", [])
|
||||
@@ -256,7 +253,7 @@ class NormalChatActionModifier:
|
||||
return False
|
||||
|
||||
# 使用构建好的聊天内容作为检索文本
|
||||
search_text = chat_content +message_content
|
||||
search_text = chat_content + message_content
|
||||
|
||||
# 如果不区分大小写,转换为小写
|
||||
if not case_sensitive:
|
||||
@@ -269,7 +266,6 @@ class NormalChatActionModifier:
|
||||
if check_keyword in search_text:
|
||||
matched_keywords.append(keyword)
|
||||
|
||||
|
||||
# print(f"search_text: {search_text}")
|
||||
# print(f"activation_keywords: {activation_keywords}")
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ import time
|
||||
from typing import List, Optional, Tuple, Dict, Any
|
||||
from src.chat.message_receive.message import MessageRecv, MessageSending, MessageThinking, Seg
|
||||
from src.chat.message_receive.message import UserInfo
|
||||
from src.chat.message_receive.chat_stream import ChatStream,chat_manager
|
||||
from src.chat.message_receive.chat_stream import ChatStream, chat_manager
|
||||
from src.chat.message_receive.message_sender import message_manager
|
||||
from src.config.config import global_config
|
||||
from src.common.logger_manager import get_logger
|
||||
|
||||
@@ -25,9 +25,7 @@ class NormalChatGenerator:
|
||||
request_type="normal.chat_2",
|
||||
)
|
||||
|
||||
self.model_sum = LLMRequest(
|
||||
model=global_config.model.memory_summary, temperature=0.7, request_type="relation"
|
||||
)
|
||||
self.model_sum = LLMRequest(model=global_config.model.memory_summary, temperature=0.7, request_type="relation")
|
||||
self.current_model_type = "r1" # 默认使用 R1
|
||||
self.current_model_name = "unknown model"
|
||||
|
||||
@@ -68,7 +66,6 @@ class NormalChatGenerator:
|
||||
enable_planner: bool = False,
|
||||
available_actions=None,
|
||||
):
|
||||
|
||||
person_id = person_info_manager.get_person_id(
|
||||
message.chat_stream.user_info.platform, message.chat_stream.user_info.user_id
|
||||
)
|
||||
@@ -103,7 +100,6 @@ class NormalChatGenerator:
|
||||
|
||||
logger.info(f"对 {message.processed_plain_text} 的回复:{content}")
|
||||
|
||||
|
||||
except Exception:
|
||||
logger.exception("生成回复时出错")
|
||||
return None
|
||||
|
||||
@@ -110,7 +110,12 @@ class NormalChatPlanner:
|
||||
if not current_available_actions:
|
||||
logger.debug(f"{self.log_prefix}规划器: 没有可用动作,返回no_action")
|
||||
return {
|
||||
"action_result": {"action_type": action, "action_data": action_data, "reasoning": reasoning, "is_parallel": True},
|
||||
"action_result": {
|
||||
"action_type": action,
|
||||
"action_data": action_data,
|
||||
"reasoning": reasoning,
|
||||
"is_parallel": True,
|
||||
},
|
||||
"chat_context": "",
|
||||
"action_prompt": "",
|
||||
}
|
||||
@@ -141,7 +146,12 @@ class NormalChatPlanner:
|
||||
if not prompt:
|
||||
logger.warning(f"{self.log_prefix}规划器: 构建提示词失败")
|
||||
return {
|
||||
"action_result": {"action_type": action, "action_data": action_data, "reasoning": reasoning, "is_parallel": False},
|
||||
"action_result": {
|
||||
"action_type": action,
|
||||
"action_data": action_data,
|
||||
"reasoning": reasoning,
|
||||
"is_parallel": False,
|
||||
},
|
||||
"chat_context": chat_context,
|
||||
"action_prompt": "",
|
||||
}
|
||||
@@ -202,7 +212,9 @@ class NormalChatPlanner:
|
||||
action_info = current_available_actions[action]
|
||||
is_parallel = action_info.get("parallel_action", False)
|
||||
|
||||
logger.debug(f"{self.log_prefix}规划器决策动作:{action}, 动作信息: '{action_data}', 理由: {reasoning}, 并行执行: {is_parallel}")
|
||||
logger.debug(
|
||||
f"{self.log_prefix}规划器决策动作:{action}, 动作信息: '{action_data}', 理由: {reasoning}, 并行执行: {is_parallel}"
|
||||
)
|
||||
|
||||
# 恢复到默认动作集
|
||||
self.action_manager.restore_actions()
|
||||
@@ -216,7 +228,7 @@ class NormalChatPlanner:
|
||||
"action_data": action_data,
|
||||
"reasoning": reasoning,
|
||||
"timestamp": time.time(),
|
||||
"model_name": model_name if 'model_name' in locals() else None
|
||||
"model_name": model_name if "model_name" in locals() else None,
|
||||
}
|
||||
|
||||
action_result = {
|
||||
@@ -224,7 +236,7 @@ class NormalChatPlanner:
|
||||
"action_data": action_data,
|
||||
"reasoning": reasoning,
|
||||
"is_parallel": is_parallel,
|
||||
"action_record": json.dumps(action_record, ensure_ascii=False)
|
||||
"action_record": json.dumps(action_record, ensure_ascii=False),
|
||||
}
|
||||
|
||||
plan_result = {
|
||||
@@ -248,9 +260,7 @@ class NormalChatPlanner:
|
||||
|
||||
# 添加特殊的change_to_focus_chat动作
|
||||
action_options_text += "动作:change_to_focus_chat\n"
|
||||
action_options_text += (
|
||||
"该动作的描述:当聊天变得热烈、自己回复条数很多或需要深入交流时使用,正常回复消息并切换到focus_chat模式\n"
|
||||
)
|
||||
action_options_text += "该动作的描述:当聊天变得热烈、自己回复条数很多或需要深入交流时使用,正常回复消息并切换到focus_chat模式\n"
|
||||
|
||||
action_options_text += "使用该动作的场景:\n"
|
||||
action_options_text += "- 聊天上下文中自己的回复条数较多(超过3-4条)\n"
|
||||
@@ -260,12 +270,9 @@ class NormalChatPlanner:
|
||||
|
||||
action_options_text += "输出要求:\n"
|
||||
action_options_text += "{{"
|
||||
action_options_text += " \"action\": \"change_to_focus_chat\""
|
||||
action_options_text += ' "action": "change_to_focus_chat"'
|
||||
action_options_text += "}}\n\n"
|
||||
|
||||
|
||||
|
||||
|
||||
for action_name, action_info in current_available_actions.items():
|
||||
action_description = action_info.get("description", "")
|
||||
action_parameters = action_info.get("parameters", {})
|
||||
@@ -276,15 +283,14 @@ class NormalChatPlanner:
|
||||
print(action_parameters)
|
||||
for param_name, param_description in action_parameters.items():
|
||||
param_text += f' "{param_name}":"{param_description}"\n'
|
||||
param_text = param_text.rstrip('\n')
|
||||
param_text = param_text.rstrip("\n")
|
||||
else:
|
||||
param_text = ""
|
||||
|
||||
|
||||
require_text = ""
|
||||
for require_item in action_require:
|
||||
require_text += f"- {require_item}\n"
|
||||
require_text = require_text.rstrip('\n')
|
||||
require_text = require_text.rstrip("\n")
|
||||
|
||||
# 构建单个动作的提示
|
||||
action_prompt = await global_prompt_manager.format_prompt(
|
||||
@@ -316,6 +322,4 @@ class NormalChatPlanner:
|
||||
return ""
|
||||
|
||||
|
||||
|
||||
|
||||
init_prompt()
|
||||
|
||||
@@ -214,7 +214,6 @@ class PromptBuilder:
|
||||
except Exception as e:
|
||||
logger.error(f"关键词检测与反应时发生异常: {str(e)}", exc_info=True)
|
||||
|
||||
|
||||
moderation_prompt_block = "请不要输出违法违规内容,不要输出色情,暴力,政治相关内容,如有敏感内容,请规避。"
|
||||
|
||||
# 构建action描述 (如果启用planner)
|
||||
|
||||
@@ -42,9 +42,7 @@ class ClassicalWillingManager(BaseWillingManager):
|
||||
|
||||
self.chat_reply_willing[chat_id] = min(current_willing, 3.0)
|
||||
|
||||
reply_probability = min(
|
||||
max((current_willing - 0.5), 0.01) * 2, 1
|
||||
)
|
||||
reply_probability = min(max((current_willing - 0.5), 0.01) * 2, 1)
|
||||
|
||||
# 检查群组权限(如果是群聊)
|
||||
if (
|
||||
|
||||
@@ -336,7 +336,7 @@ def _build_readable_messages_internal(
|
||||
"start_time": message_details[0][0],
|
||||
"end_time": message_details[0][0],
|
||||
"content": [message_details[0][2]],
|
||||
"is_action": message_details[0][3]
|
||||
"is_action": message_details[0][3],
|
||||
}
|
||||
|
||||
for i in range(1, len(message_details)):
|
||||
@@ -352,7 +352,7 @@ def _build_readable_messages_internal(
|
||||
"start_time": timestamp,
|
||||
"end_time": timestamp,
|
||||
"content": [content],
|
||||
"is_action": is_action
|
||||
"is_action": is_action,
|
||||
}
|
||||
continue
|
||||
|
||||
@@ -369,7 +369,7 @@ def _build_readable_messages_internal(
|
||||
"start_time": timestamp,
|
||||
"end_time": timestamp,
|
||||
"content": [content],
|
||||
"is_action": is_action
|
||||
"is_action": is_action,
|
||||
}
|
||||
# 添加最后一个合并块
|
||||
merged_messages.append(current_merge)
|
||||
@@ -381,11 +381,10 @@ def _build_readable_messages_internal(
|
||||
"start_time": timestamp, # 起始和结束时间相同
|
||||
"end_time": timestamp,
|
||||
"content": [content], # 内容只有一个元素
|
||||
"is_action": is_action
|
||||
"is_action": is_action,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# 4 & 5: 格式化为字符串
|
||||
output_lines = []
|
||||
for _i, merged in enumerate(merged_messages):
|
||||
@@ -473,11 +472,13 @@ def build_readable_messages(
|
||||
chat_id = copy_messages[0].get("chat_id") if copy_messages else None
|
||||
|
||||
# 获取这个时间范围内的动作记录,并匹配chat_id
|
||||
actions = ActionRecords.select().where(
|
||||
(ActionRecords.time >= min_time) &
|
||||
(ActionRecords.time <= max_time) &
|
||||
(ActionRecords.chat_id == chat_id)
|
||||
).order_by(ActionRecords.time)
|
||||
actions = (
|
||||
ActionRecords.select()
|
||||
.where(
|
||||
(ActionRecords.time >= min_time) & (ActionRecords.time <= max_time) & (ActionRecords.chat_id == chat_id)
|
||||
)
|
||||
.order_by(ActionRecords.time)
|
||||
)
|
||||
|
||||
# 将动作记录转换为消息格式
|
||||
for action in actions:
|
||||
@@ -505,15 +506,12 @@ def build_readable_messages(
|
||||
# for message in messages:
|
||||
# print(f"message:{message}")
|
||||
|
||||
|
||||
formatted_string, _ = _build_readable_messages_internal(
|
||||
copy_messages, replace_bot_name, merge_messages, timestamp_mode, truncate
|
||||
)
|
||||
|
||||
# print(f"formatted_string:{formatted_string}")
|
||||
|
||||
|
||||
|
||||
return formatted_string
|
||||
else:
|
||||
# 按 read_mark 分割消息
|
||||
|
||||
@@ -155,6 +155,7 @@ class Messages(BaseModel):
|
||||
# database = db # 继承自 BaseModel
|
||||
table_name = "messages"
|
||||
|
||||
|
||||
class ActionRecords(BaseModel):
|
||||
"""
|
||||
用于存储动作记录数据的模型。
|
||||
@@ -246,7 +247,6 @@ class PersonInfo(BaseModel):
|
||||
know_since = FloatField(null=True) # 首次印象总结时间
|
||||
last_know = FloatField(null=True) # 最后一次印象总结时间
|
||||
|
||||
|
||||
class Meta:
|
||||
# database = db # 继承自 BaseModel
|
||||
table_name = "person_info"
|
||||
@@ -403,20 +403,20 @@ def initialize_database():
|
||||
logger.info(f"表 '{table_name}' 缺失字段 '{field_name}',正在添加...")
|
||||
field_type = field_obj.__class__.__name__
|
||||
sql_type = {
|
||||
'TextField': 'TEXT',
|
||||
'IntegerField': 'INTEGER',
|
||||
'FloatField': 'FLOAT',
|
||||
'DoubleField': 'DOUBLE',
|
||||
'BooleanField': 'INTEGER',
|
||||
'DateTimeField': 'DATETIME'
|
||||
}.get(field_type, 'TEXT')
|
||||
alter_sql = f'ALTER TABLE {table_name} ADD COLUMN {field_name} {sql_type}'
|
||||
"TextField": "TEXT",
|
||||
"IntegerField": "INTEGER",
|
||||
"FloatField": "FLOAT",
|
||||
"DoubleField": "DOUBLE",
|
||||
"BooleanField": "INTEGER",
|
||||
"DateTimeField": "DATETIME",
|
||||
}.get(field_type, "TEXT")
|
||||
alter_sql = f"ALTER TABLE {table_name} ADD COLUMN {field_name} {sql_type}"
|
||||
if field_obj.null:
|
||||
alter_sql += ' NULL'
|
||||
alter_sql += " NULL"
|
||||
else:
|
||||
alter_sql += ' NOT NULL'
|
||||
if hasattr(field_obj, 'default') and field_obj.default is not None:
|
||||
alter_sql += f' DEFAULT {field_obj.default}'
|
||||
alter_sql += " NOT NULL"
|
||||
if hasattr(field_obj, "default") and field_obj.default is not None:
|
||||
alter_sql += f" DEFAULT {field_obj.default}"
|
||||
db.execute_sql(alter_sql)
|
||||
logger.info(f"字段 '{field_name}' 添加成功")
|
||||
|
||||
|
||||
@@ -58,6 +58,7 @@ class RelationshipConfig(ConfigBase):
|
||||
build_relationship_interval: int = 600
|
||||
"""构建关系间隔 单位秒,如果为0则不构建关系"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatConfig(ConfigBase):
|
||||
"""聊天配置类"""
|
||||
@@ -329,6 +330,7 @@ class KeywordReactionConfig(ConfigBase):
|
||||
if not isinstance(rule, KeywordRuleConfig):
|
||||
raise ValueError(f"规则必须是KeywordRuleConfig类型,而不是{type(rule).__name__}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResponsePostProcessConfig(ConfigBase):
|
||||
"""回复后处理配置类"""
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Dict, Any
|
||||
from typing import List, Dict, Any, Callable
|
||||
|
||||
from playhouse import shortcuts
|
||||
|
||||
# from src.common.database.database import db # Peewee db 导入
|
||||
from src.common.database.database_model import Messages # Peewee Messages 模型导入
|
||||
from playhouse.shortcuts import model_to_dict # 用于将模型实例转换为字典
|
||||
|
||||
model_to_dict: Callable[..., dict] = shortcuts.model_to_dict # Peewee 模型转换为字典的快捷函数
|
||||
|
||||
|
||||
class MessageStorage(ABC):
|
||||
|
||||
@@ -98,9 +98,10 @@ class PersonalityExpression:
|
||||
count = meta_data.get("count", 0)
|
||||
|
||||
# 检查是否有任何变化
|
||||
if (current_style_text != last_style_text or
|
||||
current_personality != last_personality):
|
||||
logger.info(f"检测到变化:\n风格: '{last_style_text}' -> '{current_style_text}'\n人格: '{last_personality}' -> '{current_personality}'")
|
||||
if current_style_text != last_style_text or current_personality != last_personality:
|
||||
logger.info(
|
||||
f"检测到变化:\n风格: '{last_style_text}' -> '{current_style_text}'\n人格: '{last_personality}' -> '{current_personality}'"
|
||||
)
|
||||
count = 0
|
||||
if os.path.exists(self.expressions_file_path):
|
||||
try:
|
||||
@@ -196,7 +197,7 @@ class PersonalityExpression:
|
||||
"last_style_text": current_style_text,
|
||||
"last_personality": current_personality,
|
||||
"count": count,
|
||||
"last_update_time": current_time
|
||||
"last_update_time": current_time,
|
||||
}
|
||||
)
|
||||
logger.info(f"成功处理。当前配置的计数现在是 {count},最后更新时间:{current_time}。")
|
||||
|
||||
15
src/main.py
15
src/main.py
@@ -19,6 +19,7 @@ from .common.server import global_server, Server
|
||||
from rich.traceback import install
|
||||
from .chat.focus_chat.expressors.exprssion_learner import expression_learner
|
||||
from .api.main import start_api_server
|
||||
|
||||
# 导入actions模块,确保装饰器被执行
|
||||
import src.chat.actions.default_actions # noqa
|
||||
|
||||
@@ -144,16 +145,18 @@ class MainSystem:
|
||||
# 加载命令处理系统
|
||||
try:
|
||||
# 导入命令处理系统
|
||||
from src.chat.command.command_handler import command_manager
|
||||
|
||||
logger.success("命令处理系统加载成功")
|
||||
except Exception as e:
|
||||
logger.error(f"加载命令处理系统失败: {e}")
|
||||
import traceback
|
||||
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"加载插件失败: {e}")
|
||||
import traceback
|
||||
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
async def schedule_tasks(self):
|
||||
@@ -168,11 +171,13 @@ class MainSystem:
|
||||
|
||||
# 根据配置条件性地添加记忆系统相关任务
|
||||
if global_config.memory.enable_memory and self.hippocampus_manager:
|
||||
tasks.extend([
|
||||
tasks.extend(
|
||||
[
|
||||
self.build_memory_task(),
|
||||
self.forget_memory_task(),
|
||||
self.consolidate_memory_task(),
|
||||
])
|
||||
]
|
||||
)
|
||||
|
||||
tasks.append(self.learn_and_store_expression_task())
|
||||
|
||||
@@ -190,9 +195,7 @@ class MainSystem:
|
||||
while True:
|
||||
await asyncio.sleep(global_config.memory.forget_memory_interval)
|
||||
logger.info("[记忆遗忘] 开始遗忘记忆...")
|
||||
await self.hippocampus_manager.forget_memory(
|
||||
percentage=global_config.memory.memory_forget_percentage
|
||||
)
|
||||
await self.hippocampus_manager.forget_memory(percentage=global_config.memory.memory_forget_percentage)
|
||||
logger.info("[记忆遗忘] 记忆遗忘完成")
|
||||
|
||||
async def consolidate_memory_task(self):
|
||||
|
||||
@@ -11,6 +11,7 @@ from collections import defaultdict
|
||||
|
||||
logger = get_logger("relation")
|
||||
|
||||
|
||||
# 暂时弃用,改为实时更新
|
||||
class ImpressionUpdateTask(AsyncTask):
|
||||
def __init__(self):
|
||||
@@ -85,7 +86,9 @@ class ImpressionUpdateTask(AsyncTask):
|
||||
# 计算权重
|
||||
for bot_msg in bot_messages:
|
||||
bot_time = bot_msg["time"]
|
||||
context_messages = [msg for msg in relevant_messages if abs(msg["time"] - bot_time) <= 600] # 前后10分钟
|
||||
context_messages = [
|
||||
msg for msg in relevant_messages if abs(msg["time"] - bot_time) <= 600
|
||||
] # 前后10分钟
|
||||
logger.debug(f"Bot消息 {bot_time} 的上下文消息数: {len(context_messages)}")
|
||||
|
||||
for msg in context_messages:
|
||||
@@ -121,7 +124,7 @@ class ImpressionUpdateTask(AsyncTask):
|
||||
weights = [user[1]["weight"] for user in sorted_users]
|
||||
total_weight = sum(weights)
|
||||
# 计算每个用户的概率
|
||||
probabilities = [w/total_weight for w in weights]
|
||||
probabilities = [w / total_weight for w in weights]
|
||||
# 使用累积概率进行选择
|
||||
selected_indices = []
|
||||
remaining_indices = list(range(len(sorted_users)))
|
||||
@@ -131,7 +134,7 @@ class ImpressionUpdateTask(AsyncTask):
|
||||
# 计算剩余索引的累积概率
|
||||
remaining_probs = [probabilities[i] for i in remaining_indices]
|
||||
# 归一化概率
|
||||
remaining_probs = [p/sum(remaining_probs) for p in remaining_probs]
|
||||
remaining_probs = [p / sum(remaining_probs) for p in remaining_probs]
|
||||
# 选择索引
|
||||
chosen_idx = random.choices(remaining_indices, weights=remaining_probs, k=1)[0]
|
||||
selected_indices.append(chosen_idx)
|
||||
@@ -160,12 +163,9 @@ class ImpressionUpdateTask(AsyncTask):
|
||||
logger.info(f"首次认识用户: {user_nickname}")
|
||||
await relationship_manager.first_knowing_some_one(platform, user_id, user_nickname, cardname)
|
||||
|
||||
|
||||
logger.info(f"开始更新用户 {user_nickname} 的印象")
|
||||
await relationship_manager.update_person_impression(
|
||||
person_id=person_id,
|
||||
timestamp=last_bot_msg["time"],
|
||||
bot_engaged_messages=relevant_messages
|
||||
person_id=person_id, timestamp=last_bot_msg["time"], bot_engaged_messages=relevant_messages
|
||||
)
|
||||
|
||||
logger.debug("印象更新任务执行完成")
|
||||
|
||||
@@ -46,7 +46,6 @@ person_info_default = {
|
||||
"info_list": None,
|
||||
"points": None,
|
||||
"forgotten_points": None,
|
||||
|
||||
}
|
||||
|
||||
|
||||
@@ -199,7 +198,6 @@ class PersonInfoManager:
|
||||
if data and "user_id" in data:
|
||||
creation_data["user_id"] = data["user_id"]
|
||||
|
||||
|
||||
await self.create_person_info(person_id, creation_data)
|
||||
|
||||
@staticmethod
|
||||
@@ -332,7 +330,9 @@ class PersonInfoManager:
|
||||
await self.update_one_field(person_id, "person_name", generated_nickname)
|
||||
await self.update_one_field(person_id, "name_reason", result.get("reason", "未提供理由"))
|
||||
|
||||
logger.info(f"成功给用户{user_nickname} {person_id} 取名 {generated_nickname},理由:{result.get('reason', '未提供理由')}")
|
||||
logger.info(
|
||||
f"成功给用户{user_nickname} {person_id} 取名 {generated_nickname},理由:{result.get('reason', '未提供理由')}"
|
||||
)
|
||||
|
||||
self.person_name_list[person_id] = generated_nickname
|
||||
return result
|
||||
@@ -416,7 +416,7 @@ class PersonInfoManager:
|
||||
|
||||
@staticmethod
|
||||
def get_value_sync(person_id: str, field_name: str):
|
||||
""" 同步获取指定用户指定字段的值 """
|
||||
"""同步获取指定用户指定字段的值"""
|
||||
default_value_for_field = person_info_default.get(field_name)
|
||||
if field_name in JSON_SERIALIZED_FIELDS and default_value_for_field is None:
|
||||
default_value_for_field = []
|
||||
@@ -534,7 +534,7 @@ class PersonInfoManager:
|
||||
"last_know": int(datetime.datetime.now().timestamp()),
|
||||
"impression": None,
|
||||
"points": [],
|
||||
"forgotten_points": []
|
||||
"forgotten_points": [],
|
||||
}
|
||||
model_fields = PersonInfo._meta.fields.keys()
|
||||
filtered_initial_data = {k: v for k, v in initial_data.items() if v is not None and k in model_fields}
|
||||
|
||||
@@ -7,7 +7,6 @@ from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages
|
||||
from src.manager.mood_manager import mood_manager
|
||||
from src.individuality.individuality import individuality
|
||||
import json
|
||||
from json_repair import repair_json
|
||||
from datetime import datetime
|
||||
@@ -90,9 +89,7 @@ class RelationshipManager:
|
||||
return is_known
|
||||
|
||||
@staticmethod
|
||||
async def first_knowing_some_one(
|
||||
platform: str, user_id: str, user_nickname: str, user_cardname: str
|
||||
):
|
||||
async def first_knowing_some_one(platform: str, user_id: str, user_nickname: str, user_cardname: str):
|
||||
"""判断是否认识某人"""
|
||||
person_id = person_info_manager.get_person_id(platform, user_id)
|
||||
# 生成唯一的 person_name
|
||||
@@ -124,7 +121,7 @@ class RelationshipManager:
|
||||
person_name = await person_info_manager.get_value(person_id, "person_name")
|
||||
if not person_name or person_name == "none":
|
||||
return ""
|
||||
impression = await person_info_manager.get_value(person_id, "impression")
|
||||
# impression = await person_info_manager.get_value(person_id, "impression")
|
||||
points = await person_info_manager.get_value(person_id, "points") or []
|
||||
|
||||
if isinstance(points, str):
|
||||
@@ -139,11 +136,9 @@ class RelationshipManager:
|
||||
platform = await person_info_manager.get_value(person_id, "platform")
|
||||
relation_prompt = f"'{person_name}' ,ta在{platform}上的昵称是{nickname_str}。"
|
||||
|
||||
|
||||
# if impression:
|
||||
# relation_prompt += f"你对ta的印象是:{impression}。"
|
||||
|
||||
|
||||
if random_points:
|
||||
for point in random_points:
|
||||
# print(f"point: {point}")
|
||||
@@ -152,7 +147,6 @@ class RelationshipManager:
|
||||
point_str = f"时间:{point[2]}。内容:{point[0]}"
|
||||
relation_prompt += f"你记得{person_name}最近的点是:{point_str}。"
|
||||
|
||||
|
||||
return relation_prompt
|
||||
|
||||
async def _update_list_field(self, person_id: str, field_name: str, new_items: list) -> None:
|
||||
@@ -181,8 +175,8 @@ class RelationshipManager:
|
||||
nickname = await person_info_manager.get_value(person_id, "nickname")
|
||||
|
||||
alias_str = ", ".join(global_config.bot.alias_names)
|
||||
personality_block = individuality.get_personality_prompt(x_person=2, level=2)
|
||||
identity_block = individuality.get_identity_prompt(x_person=2, level=2)
|
||||
# personality_block = individuality.get_personality_prompt(x_person=2, level=2)
|
||||
# identity_block = individuality.get_identity_prompt(x_person=2, level=2)
|
||||
|
||||
user_messages = bot_engaged_messages
|
||||
|
||||
@@ -219,19 +213,13 @@ class RelationshipManager:
|
||||
|
||||
# 其他用户映射
|
||||
if replace_person_name not in name_mapping:
|
||||
if current_user > 'Z':
|
||||
current_user = 'A'
|
||||
if current_user > "Z":
|
||||
current_user = "A"
|
||||
user_count += 1
|
||||
name_mapping[replace_person_name] = f"用户{current_user}{user_count if user_count > 1 else ''}"
|
||||
current_user = chr(ord(current_user) + 1)
|
||||
|
||||
|
||||
|
||||
|
||||
readable_messages = self.build_focus_readable_messages(
|
||||
messages=user_messages,
|
||||
target_person_id=person_id
|
||||
)
|
||||
readable_messages = self.build_focus_readable_messages(messages=user_messages, target_person_id=person_id)
|
||||
|
||||
for original_name, mapped_name in name_mapping.items():
|
||||
# print(f"original_name: {original_name}, mapped_name: {mapped_name}")
|
||||
@@ -318,7 +306,9 @@ class RelationshipManager:
|
||||
elif not isinstance(current_points, list):
|
||||
current_points = []
|
||||
current_points.extend(points_list)
|
||||
await person_info_manager.update_one_field(person_id, "points", json.dumps(current_points, ensure_ascii=False, indent=None))
|
||||
await person_info_manager.update_one_field(
|
||||
person_id, "points", json.dumps(current_points, ensure_ascii=False, indent=None)
|
||||
)
|
||||
|
||||
# 将新记录添加到现有记录中
|
||||
if isinstance(current_points, list):
|
||||
@@ -359,7 +349,7 @@ class RelationshipManager:
|
||||
else:
|
||||
current_points = points_list
|
||||
|
||||
# 如果points超过10条,按权重随机选择多余的条目移动到forgotten_points
|
||||
# 如果points超过10条,按权重随机选择多余的条目移动到forgotten_points
|
||||
if len(current_points) > 10:
|
||||
# 获取现有forgotten_points
|
||||
forgotten_points = await person_info_manager.get_value(person_id, "forgotten_points") or []
|
||||
@@ -421,11 +411,9 @@ class RelationshipManager:
|
||||
forgotten_points.sort(key=lambda x: x[2])
|
||||
|
||||
# 构建points文本
|
||||
points_text = "\n".join([
|
||||
f"时间:{point[2]}\n权重:{point[1]}\n内容:{point[0]}"
|
||||
for point in forgotten_points
|
||||
])
|
||||
|
||||
points_text = "\n".join(
|
||||
[f"时间:{point[2]}\n权重:{point[1]}\n内容:{point[0]}" for point in forgotten_points]
|
||||
)
|
||||
|
||||
impression = await person_info_manager.get_value(person_id, "impression") or ""
|
||||
|
||||
@@ -457,21 +445,21 @@ class RelationshipManager:
|
||||
|
||||
forgotten_points = []
|
||||
|
||||
|
||||
# 这句代码的作用是:将更新后的 forgotten_points(遗忘的记忆点)列表,序列化为 JSON 字符串后,写回到数据库中的 forgotten_points 字段
|
||||
await person_info_manager.update_one_field(person_id, "forgotten_points", json.dumps(forgotten_points, ensure_ascii=False, indent=None))
|
||||
await person_info_manager.update_one_field(
|
||||
person_id, "forgotten_points", json.dumps(forgotten_points, ensure_ascii=False, indent=None)
|
||||
)
|
||||
|
||||
# 更新数据库
|
||||
await person_info_manager.update_one_field(person_id, "points", json.dumps(current_points, ensure_ascii=False, indent=None))
|
||||
await person_info_manager.update_one_field(
|
||||
person_id, "points", json.dumps(current_points, ensure_ascii=False, indent=None)
|
||||
)
|
||||
know_times = await person_info_manager.get_value(person_id, "know_times") or 0
|
||||
await person_info_manager.update_one_field(person_id, "know_times", know_times + 1)
|
||||
await person_info_manager.update_one_field(person_id, "last_know", timestamp)
|
||||
|
||||
|
||||
logger.info(f"印象更新完成 for {person_name}")
|
||||
|
||||
|
||||
|
||||
def build_focus_readable_messages(self, messages: list, target_person_id: str = None) -> str:
|
||||
"""格式化消息,只保留目标用户和bot消息附近的内容"""
|
||||
# 找到目标用户和bot的消息索引
|
||||
@@ -522,10 +510,7 @@ class RelationshipManager:
|
||||
if i > 0:
|
||||
result.append("...")
|
||||
group_text = build_readable_messages(
|
||||
messages=group,
|
||||
replace_bot_name=True,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
truncate=False
|
||||
messages=group, replace_bot_name=True, timestamp_mode="normal_no_YMD", truncate=False
|
||||
)
|
||||
result.append(group_text)
|
||||
|
||||
|
||||
@@ -40,7 +40,7 @@ DEFAULT_CONFIG = {
|
||||
"default_guidance_scale": 2.5,
|
||||
"default_seed": 42,
|
||||
"cache_enabled": True,
|
||||
"cache_max_size": 10
|
||||
"cache_max_size": 10,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -91,7 +91,7 @@ class PicAction(PluginAction):
|
||||
"""清理缓存,保持大小在限制内"""
|
||||
if len(cls._request_cache) > cls._cache_max_size:
|
||||
# 简单的FIFO策略,移除最旧的条目
|
||||
keys_to_remove = list(cls._request_cache.keys())[:-cls._cache_max_size//2]
|
||||
keys_to_remove = list(cls._request_cache.keys())[: -cls._cache_max_size // 2]
|
||||
for key in keys_to_remove:
|
||||
del cls._request_cache[key]
|
||||
|
||||
@@ -370,7 +370,7 @@ class PicAction(PluginAction):
|
||||
def _validate_image_size(self, image_size: str) -> bool:
|
||||
"""验证图片尺寸格式"""
|
||||
try:
|
||||
width, height = map(int, image_size.split('x'))
|
||||
width, height = map(int, image_size.split("x"))
|
||||
return 100 <= width <= 10000 and 100 <= height <= 10000
|
||||
except (ValueError, TypeError):
|
||||
return False
|
||||
|
||||
@@ -5,6 +5,7 @@ import random
|
||||
|
||||
logger = get_logger("custom_prefix_command")
|
||||
|
||||
|
||||
@register_command
|
||||
class DiceCommand(BaseCommand):
|
||||
"""骰子命令,使用!前缀而不是/前缀"""
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Tuple, Optional
|
||||
|
||||
logger = get_logger("help_command")
|
||||
|
||||
|
||||
@register_command
|
||||
class HelpCommand(BaseCommand):
|
||||
"""帮助命令,显示所有可用命令的帮助信息"""
|
||||
@@ -59,11 +60,7 @@ class HelpCommand(BaseCommand):
|
||||
examples = getattr(command_cls, "command_examples", [])
|
||||
|
||||
# 构建帮助信息
|
||||
help_info = [
|
||||
f"【命令】: {command_name}",
|
||||
f"【描述】: {description}",
|
||||
f"【用法】: {help_text}"
|
||||
]
|
||||
help_info = [f"【命令】: {command_name}", f"【描述】: {description}", f"【用法】: {help_text}"]
|
||||
|
||||
# 添加示例
|
||||
if examples:
|
||||
@@ -81,8 +78,7 @@ class HelpCommand(BaseCommand):
|
||||
"""
|
||||
# 获取所有已启用的命令
|
||||
enabled_commands = {
|
||||
name: cls for name, cls in _COMMAND_REGISTRY.items()
|
||||
if getattr(cls, "enable_command", True)
|
||||
name: cls for name, cls in _COMMAND_REGISTRY.items() if getattr(cls, "enable_command", True)
|
||||
}
|
||||
|
||||
if not enabled_commands:
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
from src.common.logger_manager import get_logger
|
||||
from src.chat.command.command_handler import BaseCommand, register_command
|
||||
from typing import Tuple, Optional
|
||||
import json
|
||||
|
||||
logger = get_logger("message_info_command")
|
||||
|
||||
|
||||
@register_command
|
||||
class MessageInfoCommand(BaseCommand):
|
||||
"""消息信息查看命令,展示发送命令的原始消息和相关信息"""
|
||||
@@ -48,46 +48,56 @@ class MessageInfoCommand(BaseCommand):
|
||||
|
||||
# 发送者信息
|
||||
user = message.message_info.user_info
|
||||
info_lines.extend([
|
||||
info_lines.extend(
|
||||
[
|
||||
"",
|
||||
"👤 发送者信息:",
|
||||
f" 用户ID: {user.user_id}",
|
||||
f" 昵称: {user.user_nickname}",
|
||||
f" 群名片: {user.user_cardname or '无'}",
|
||||
])
|
||||
]
|
||||
)
|
||||
|
||||
# 群聊信息(如果是群聊)
|
||||
if message.message_info.group_info:
|
||||
group = message.message_info.group_info
|
||||
info_lines.extend([
|
||||
info_lines.extend(
|
||||
[
|
||||
"",
|
||||
"👥 群聊信息:",
|
||||
f" 群ID: {group.group_id}",
|
||||
f" 群名: {group.group_name or '未知'}",
|
||||
])
|
||||
]
|
||||
)
|
||||
else:
|
||||
info_lines.extend([
|
||||
info_lines.extend(
|
||||
[
|
||||
"",
|
||||
"💬 消息类型: 私聊消息",
|
||||
])
|
||||
]
|
||||
)
|
||||
|
||||
# 消息内容
|
||||
info_lines.extend([
|
||||
info_lines.extend(
|
||||
[
|
||||
"",
|
||||
"📝 消息内容:",
|
||||
f" 原始文本: {message.processed_plain_text}",
|
||||
f" 是否表情: {'是' if getattr(message, 'is_emoji', False) else '否'}",
|
||||
])
|
||||
]
|
||||
)
|
||||
|
||||
# 聊天流信息
|
||||
if hasattr(message, 'chat_stream') and message.chat_stream:
|
||||
if hasattr(message, "chat_stream") and message.chat_stream:
|
||||
chat_stream = message.chat_stream
|
||||
info_lines.extend([
|
||||
info_lines.extend(
|
||||
[
|
||||
"",
|
||||
"🔄 聊天流信息:",
|
||||
f" 流ID: {chat_stream.stream_id}",
|
||||
f" 是否激活: {'是' if chat_stream.is_active else '否'}",
|
||||
])
|
||||
]
|
||||
)
|
||||
|
||||
return "\n".join(info_lines)
|
||||
|
||||
@@ -101,7 +111,8 @@ class MessageInfoCommand(BaseCommand):
|
||||
]
|
||||
|
||||
# 消息基础信息
|
||||
info_lines.extend([
|
||||
info_lines.extend(
|
||||
[
|
||||
"",
|
||||
"🔍 基础消息信息:",
|
||||
f" 消息ID: {message.message_info.message_id}",
|
||||
@@ -109,46 +120,54 @@ class MessageInfoCommand(BaseCommand):
|
||||
f" 平台: {message.message_info.platform}",
|
||||
f" 处理后文本: {message.processed_plain_text}",
|
||||
f" 详细文本: {message.detailed_plain_text[:100]}{'...' if len(message.detailed_plain_text) > 100 else ''}",
|
||||
])
|
||||
]
|
||||
)
|
||||
|
||||
# 用户详细信息
|
||||
user = message.message_info.user_info
|
||||
info_lines.extend([
|
||||
info_lines.extend(
|
||||
[
|
||||
"",
|
||||
"👤 发送者详细信息:",
|
||||
f" 用户ID: {user.user_id}",
|
||||
f" 昵称: {user.user_nickname}",
|
||||
f" 群名片: {user.user_cardname or '无'}",
|
||||
f" 平台: {user.platform}",
|
||||
])
|
||||
]
|
||||
)
|
||||
|
||||
# 群聊详细信息
|
||||
if message.message_info.group_info:
|
||||
group = message.message_info.group_info
|
||||
info_lines.extend([
|
||||
info_lines.extend(
|
||||
[
|
||||
"",
|
||||
"👥 群聊详细信息:",
|
||||
f" 群ID: {group.group_id}",
|
||||
f" 群名: {group.group_name or '未知'}",
|
||||
f" 平台: {group.platform}",
|
||||
])
|
||||
]
|
||||
)
|
||||
else:
|
||||
info_lines.append("\n💬 消息类型: 私聊消息")
|
||||
|
||||
# 消息段信息
|
||||
if message.message_segment:
|
||||
info_lines.extend([
|
||||
info_lines.extend(
|
||||
[
|
||||
"",
|
||||
"📦 消息段信息:",
|
||||
f" 类型: {message.message_segment.type}",
|
||||
f" 数据类型: {type(message.message_segment.data).__name__}",
|
||||
f" 数据预览: {str(message.message_segment.data)[:200]}{'...' if len(str(message.message_segment.data)) > 200 else ''}",
|
||||
])
|
||||
]
|
||||
)
|
||||
|
||||
# 聊天流详细信息
|
||||
if hasattr(message, 'chat_stream') and message.chat_stream:
|
||||
if hasattr(message, "chat_stream") and message.chat_stream:
|
||||
chat_stream = message.chat_stream
|
||||
info_lines.extend([
|
||||
info_lines.extend(
|
||||
[
|
||||
"",
|
||||
"🔄 聊天流详细信息:",
|
||||
f" 流ID: {chat_stream.stream_id}",
|
||||
@@ -156,25 +175,30 @@ class MessageInfoCommand(BaseCommand):
|
||||
f" 是否激活: {'是' if chat_stream.is_active else '否'}",
|
||||
f" 用户信息: {chat_stream.user_info.user_nickname} ({chat_stream.user_info.user_id})",
|
||||
f" 群信息: {getattr(chat_stream.group_info, 'group_name', '私聊') if chat_stream.group_info else '私聊'}",
|
||||
])
|
||||
]
|
||||
)
|
||||
|
||||
# 回复信息
|
||||
if hasattr(message, 'reply') and message.reply:
|
||||
info_lines.extend([
|
||||
if hasattr(message, "reply") and message.reply:
|
||||
info_lines.extend(
|
||||
[
|
||||
"",
|
||||
"↩️ 回复信息:",
|
||||
f" 回复消息ID: {message.reply.message_info.message_id}",
|
||||
f" 回复内容: {message.reply.processed_plain_text[:100]}{'...' if len(message.reply.processed_plain_text) > 100 else ''}",
|
||||
])
|
||||
]
|
||||
)
|
||||
|
||||
# 原始消息数据(如果存在)
|
||||
if hasattr(message, 'raw_message') and message.raw_message:
|
||||
info_lines.extend([
|
||||
if hasattr(message, "raw_message") and message.raw_message:
|
||||
info_lines.extend(
|
||||
[
|
||||
"",
|
||||
"🗂️ 原始消息数据:",
|
||||
f" 数据类型: {type(message.raw_message).__name__}",
|
||||
f" 数据大小: {len(str(message.raw_message))} 字符",
|
||||
])
|
||||
]
|
||||
)
|
||||
|
||||
return "\n".join(info_lines)
|
||||
|
||||
@@ -205,12 +229,14 @@ class SenderInfoCommand(BaseCommand):
|
||||
]
|
||||
|
||||
if group:
|
||||
info_lines.extend([
|
||||
info_lines.extend(
|
||||
[
|
||||
"",
|
||||
"👥 当前群聊:",
|
||||
f"🆔 群ID: {group.group_id}",
|
||||
f"📝 群名: {group.group_name or '未知'}",
|
||||
])
|
||||
]
|
||||
)
|
||||
else:
|
||||
info_lines.append("\n💬 当前在私聊中")
|
||||
|
||||
@@ -235,7 +261,7 @@ class ChatStreamInfoCommand(BaseCommand):
|
||||
async def execute(self) -> Tuple[bool, Optional[str]]:
|
||||
"""执行聊天流信息查看命令"""
|
||||
try:
|
||||
if not hasattr(self.message, 'chat_stream') or not self.message.chat_stream:
|
||||
if not hasattr(self.message, "chat_stream") or not self.message.chat_stream:
|
||||
return False, "无法获取聊天流信息"
|
||||
|
||||
chat_stream = self.message.chat_stream
|
||||
@@ -249,31 +275,37 @@ class ChatStreamInfoCommand(BaseCommand):
|
||||
|
||||
# 用户信息
|
||||
if chat_stream.user_info:
|
||||
info_lines.extend([
|
||||
info_lines.extend(
|
||||
[
|
||||
"",
|
||||
"👤 关联用户:",
|
||||
f" ID: {chat_stream.user_info.user_id}",
|
||||
f" 昵称: {chat_stream.user_info.user_nickname}",
|
||||
])
|
||||
]
|
||||
)
|
||||
|
||||
# 群信息
|
||||
if chat_stream.group_info:
|
||||
info_lines.extend([
|
||||
info_lines.extend(
|
||||
[
|
||||
"",
|
||||
"👥 关联群聊:",
|
||||
f" 群ID: {chat_stream.group_info.group_id}",
|
||||
f" 群名: {chat_stream.group_info.group_name or '未知'}",
|
||||
])
|
||||
]
|
||||
)
|
||||
else:
|
||||
info_lines.append("\n💬 类型: 私聊流")
|
||||
|
||||
# 最近消息统计
|
||||
if hasattr(chat_stream, 'last_messages'):
|
||||
if hasattr(chat_stream, "last_messages"):
|
||||
msg_count = len(chat_stream.last_messages)
|
||||
info_lines.extend([
|
||||
info_lines.extend(
|
||||
[
|
||||
"",
|
||||
f"📈 消息统计: 记录了 {msg_count} 条最近消息",
|
||||
])
|
||||
]
|
||||
)
|
||||
|
||||
return True, "\n".join(info_lines)
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import Tuple, Optional
|
||||
|
||||
logger = get_logger("send_msg_command")
|
||||
|
||||
|
||||
@register_command
|
||||
class SendMessageCommand(BaseCommand, MessageAPI):
|
||||
"""发送消息命令,可以向指定群聊或私聊发送消息"""
|
||||
@@ -13,10 +14,7 @@ class SendMessageCommand(BaseCommand, MessageAPI):
|
||||
command_description = "向指定群聊或私聊发送消息"
|
||||
command_pattern = r"^/send\s+(?P<target_type>group|user)\s+(?P<target_id>\d+)\s+(?P<content>.+)$"
|
||||
command_help = "使用方法: /send <group|user> <ID> <消息内容> - 发送消息到指定群聊或用户"
|
||||
command_examples = [
|
||||
"/send group 123456789 大家好!",
|
||||
"/send user 987654321 私聊消息"
|
||||
]
|
||||
command_examples = ["/send group 123456789 大家好!", "/send user 987654321 私聊消息"]
|
||||
enable_command = True
|
||||
|
||||
def __init__(self, message):
|
||||
@@ -76,7 +74,7 @@ class SendMessageCommand(BaseCommand, MessageAPI):
|
||||
success = await self.send_text_to_group(
|
||||
text=content,
|
||||
group_id=group_id,
|
||||
platform="qq" # 默认使用QQ平台
|
||||
platform="qq", # 默认使用QQ平台
|
||||
)
|
||||
|
||||
if success:
|
||||
@@ -104,7 +102,7 @@ class SendMessageCommand(BaseCommand, MessageAPI):
|
||||
success = await self.send_text_to_user(
|
||||
text=content,
|
||||
user_id=user_id,
|
||||
platform="qq" # 默认使用QQ平台
|
||||
platform="qq", # 默认使用QQ平台
|
||||
)
|
||||
|
||||
if success:
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import Tuple, Optional
|
||||
|
||||
logger = get_logger("send_msg_enhanced")
|
||||
|
||||
|
||||
@register_command
|
||||
class SendMessageEnhancedCommand(BaseCommand, MessageAPI):
|
||||
"""增强版发送消息命令,支持多种消息类型和平台"""
|
||||
@@ -17,7 +18,7 @@ class SendMessageEnhancedCommand(BaseCommand, MessageAPI):
|
||||
"/sendfull text group 123456789 qq 大家好!这是文本消息",
|
||||
"/sendfull image user 987654321 https://example.com/image.jpg",
|
||||
"/sendfull emoji group 123456789 😄",
|
||||
"/sendfull text user 987654321 qq 私聊消息"
|
||||
"/sendfull text user 987654321 qq 私聊消息",
|
||||
]
|
||||
enable_command = True
|
||||
|
||||
@@ -51,22 +52,14 @@ class SendMessageEnhancedCommand(BaseCommand, MessageAPI):
|
||||
logger.info(f"{self.log_prefix} 执行发送命令: {msg_type} -> {target_type}:{target_id} (平台:{platform})")
|
||||
|
||||
# 根据消息类型和目标类型发送消息
|
||||
is_group = (target_type == "group")
|
||||
is_group = target_type == "group"
|
||||
success = await self.send_message_to_target(
|
||||
message_type=msg_type,
|
||||
content=content,
|
||||
platform=platform,
|
||||
target_id=target_id,
|
||||
is_group=is_group
|
||||
message_type=msg_type, content=content, platform=platform, target_id=target_id, is_group=is_group
|
||||
)
|
||||
|
||||
# 构建结果消息
|
||||
target_desc = f"{'群聊' if is_group else '用户'} {target_id} (平台: {platform})"
|
||||
msg_type_desc = {
|
||||
"text": "文本",
|
||||
"image": "图片",
|
||||
"emoji": "表情"
|
||||
}.get(msg_type, msg_type)
|
||||
msg_type_desc = {"text": "文本", "image": "图片", "emoji": "表情"}.get(msg_type, msg_type)
|
||||
|
||||
if success:
|
||||
return True, f"✅ {msg_type_desc}消息已成功发送到 {target_desc}"
|
||||
@@ -86,10 +79,7 @@ class SendQuickCommand(BaseCommand, MessageAPI):
|
||||
command_description = "快速发送文本消息到群聊"
|
||||
command_pattern = r"^/msg\s+(?P<group_id>\d+)\s+(?P<content>.+)$"
|
||||
command_help = "使用方法: /msg <群ID> <消息内容> - 快速发送文本到指定群聊"
|
||||
command_examples = [
|
||||
"/msg 123456789 大家好!",
|
||||
"/msg 987654321 这是一条快速消息"
|
||||
]
|
||||
command_examples = ["/msg 123456789 大家好!", "/msg 987654321 这是一条快速消息"]
|
||||
enable_command = True
|
||||
|
||||
def __init__(self, message):
|
||||
@@ -108,11 +98,7 @@ class SendQuickCommand(BaseCommand, MessageAPI):
|
||||
|
||||
logger.info(f"{self.log_prefix} 快速发送到群 {group_id}: {content[:50]}...")
|
||||
|
||||
success = await self.send_text_to_group(
|
||||
text=content,
|
||||
group_id=group_id,
|
||||
platform="qq"
|
||||
)
|
||||
success = await self.send_text_to_group(text=content, group_id=group_id, platform="qq")
|
||||
|
||||
if success:
|
||||
return True, f"✅ 消息已发送到群 {group_id}"
|
||||
@@ -132,10 +118,7 @@ class SendPrivateCommand(BaseCommand, MessageAPI):
|
||||
command_description = "发送私聊消息到指定用户"
|
||||
command_pattern = r"^/pm\s+(?P<user_id>\d+)\s+(?P<content>.+)$"
|
||||
command_help = "使用方法: /pm <用户ID> <消息内容> - 发送私聊消息"
|
||||
command_examples = [
|
||||
"/pm 123456789 你好!",
|
||||
"/pm 987654321 这是私聊消息"
|
||||
]
|
||||
command_examples = ["/pm 123456789 你好!", "/pm 987654321 这是私聊消息"]
|
||||
enable_command = True
|
||||
|
||||
def __init__(self, message):
|
||||
@@ -154,11 +137,7 @@ class SendPrivateCommand(BaseCommand, MessageAPI):
|
||||
|
||||
logger.info(f"{self.log_prefix} 发送私聊到用户 {user_id}: {content[:50]}...")
|
||||
|
||||
success = await self.send_text_to_user(
|
||||
text=content,
|
||||
user_id=user_id,
|
||||
platform="qq"
|
||||
)
|
||||
success = await self.send_text_to_user(text=content, user_id=user_id, platform="qq")
|
||||
|
||||
if success:
|
||||
return True, f"✅ 私聊消息已发送到用户 {user_id}"
|
||||
|
||||
@@ -6,19 +6,22 @@ import time
|
||||
|
||||
logger = get_logger("send_msg_with_context")
|
||||
|
||||
|
||||
@register_command
|
||||
class ContextAwareSendCommand(BaseCommand, MessageAPI):
|
||||
"""上下文感知的发送消息命令,展示如何利用原始消息信息"""
|
||||
|
||||
command_name = "csend"
|
||||
command_description = "带上下文感知的发送消息命令"
|
||||
command_pattern = r"^/csend\s+(?P<target_type>group|user|here|reply)\s+(?P<target_id_or_content>.*?)(?:\s+(?P<content>.*))?$"
|
||||
command_pattern = (
|
||||
r"^/csend\s+(?P<target_type>group|user|here|reply)\s+(?P<target_id_or_content>.*?)(?:\s+(?P<content>.*))?$"
|
||||
)
|
||||
command_help = "使用方法: /csend <target_type> <参数> [内容]"
|
||||
command_examples = [
|
||||
"/csend group 123456789 大家好!",
|
||||
"/csend user 987654321 私聊消息",
|
||||
"/csend here 在当前聊天发送",
|
||||
"/csend reply 回复当前群/私聊"
|
||||
"/csend reply 回复当前群/私聊",
|
||||
]
|
||||
enable_command = True
|
||||
|
||||
@@ -83,19 +86,13 @@ class ContextAwareSendCommand(BaseCommand, MessageAPI):
|
||||
# 群聊
|
||||
formatted_content = f"[管理员转发 {timestamp}] {sender.user_nickname}({sender.user_id}): {content}"
|
||||
success = await self.send_text_to_group(
|
||||
text=formatted_content,
|
||||
group_id=current_group.group_id,
|
||||
platform="qq"
|
||||
text=formatted_content, group_id=current_group.group_id, platform="qq"
|
||||
)
|
||||
target_desc = f"当前群聊 {current_group.group_name}({current_group.group_id})"
|
||||
else:
|
||||
# 私聊
|
||||
formatted_content = f"[管理员消息 {timestamp}]: {content}"
|
||||
success = await self.send_text_to_user(
|
||||
text=formatted_content,
|
||||
user_id=sender.user_id,
|
||||
platform="qq"
|
||||
)
|
||||
success = await self.send_text_to_user(text=formatted_content, user_id=sender.user_id, platform="qq")
|
||||
target_desc = "当前私聊"
|
||||
|
||||
if success:
|
||||
@@ -124,26 +121,17 @@ class ContextAwareSendCommand(BaseCommand, MessageAPI):
|
||||
context_info.append("💬 当前环境: 私聊")
|
||||
target_desc = "私聊"
|
||||
|
||||
context_info.extend([
|
||||
f"📝 回复内容: {content}",
|
||||
"─" * 30
|
||||
])
|
||||
context_info.extend([f"📝 回复内容: {content}", "─" * 30])
|
||||
|
||||
formatted_content = "\n".join(context_info)
|
||||
|
||||
# 发送消息
|
||||
if current_group:
|
||||
success = await self.send_text_to_group(
|
||||
text=formatted_content,
|
||||
group_id=current_group.group_id,
|
||||
platform="qq"
|
||||
text=formatted_content, group_id=current_group.group_id, platform="qq"
|
||||
)
|
||||
else:
|
||||
success = await self.send_text_to_user(
|
||||
text=formatted_content,
|
||||
user_id=sender.user_id,
|
||||
platform="qq"
|
||||
)
|
||||
success = await self.send_text_to_user(text=formatted_content, user_id=sender.user_id, platform="qq")
|
||||
|
||||
if success:
|
||||
return True, f"✅ 带上下文的回复已发送到{target_desc}"
|
||||
@@ -159,18 +147,10 @@ class ContextAwareSendCommand(BaseCommand, MessageAPI):
|
||||
formatted_content = f"{tracking_info}\n{content}"
|
||||
|
||||
if target_type == "group":
|
||||
success = await self.send_text_to_group(
|
||||
text=formatted_content,
|
||||
group_id=target_id,
|
||||
platform="qq"
|
||||
)
|
||||
success = await self.send_text_to_group(text=formatted_content, group_id=target_id, platform="qq")
|
||||
target_desc = f"群聊 {target_id}"
|
||||
else: # user
|
||||
success = await self.send_text_to_user(
|
||||
text=formatted_content,
|
||||
user_id=target_id,
|
||||
platform="qq"
|
||||
)
|
||||
success = await self.send_text_to_user(text=formatted_content, user_id=target_id, platform="qq")
|
||||
target_desc = f"用户 {target_id}"
|
||||
|
||||
if success:
|
||||
@@ -214,37 +194,45 @@ class MessageContextCommand(BaseCommand):
|
||||
]
|
||||
|
||||
if group:
|
||||
context_lines.extend([
|
||||
context_lines.extend(
|
||||
[
|
||||
"",
|
||||
"👥 群聊环境:",
|
||||
f" 群ID: {group.group_id}",
|
||||
f" 群名: {group.group_name or '未知'}",
|
||||
f" 平台: {group.platform}",
|
||||
])
|
||||
]
|
||||
)
|
||||
else:
|
||||
context_lines.extend([
|
||||
context_lines.extend(
|
||||
[
|
||||
"",
|
||||
"💬 私聊环境",
|
||||
])
|
||||
]
|
||||
)
|
||||
|
||||
# 添加聊天流信息
|
||||
if hasattr(message, 'chat_stream') and message.chat_stream:
|
||||
if hasattr(message, "chat_stream") and message.chat_stream:
|
||||
chat_stream = message.chat_stream
|
||||
context_lines.extend([
|
||||
context_lines.extend(
|
||||
[
|
||||
"",
|
||||
"🔄 聊天流:",
|
||||
f" 流ID: {chat_stream.stream_id}",
|
||||
f" 激活状态: {'激活' if chat_stream.is_active else '非激活'}",
|
||||
])
|
||||
]
|
||||
)
|
||||
|
||||
# 添加消息内容信息
|
||||
context_lines.extend([
|
||||
context_lines.extend(
|
||||
[
|
||||
"",
|
||||
"📝 消息内容:",
|
||||
f" 原始内容: {message.processed_plain_text}",
|
||||
f" 消息长度: {len(message.processed_plain_text)} 字符",
|
||||
f" 消息ID: {message.message_info.message_id}",
|
||||
])
|
||||
]
|
||||
)
|
||||
|
||||
return True, "\n".join(context_lines)
|
||||
|
||||
|
||||
@@ -1,2 +1,3 @@
|
||||
"""测试插件动作模块"""
|
||||
|
||||
from . import mute_action # noqa
|
||||
|
||||
@@ -22,7 +22,7 @@ class MuteAction(PluginAction):
|
||||
"当有人刷屏时使用",
|
||||
"当有人发了擦边,或者色情内容时使用",
|
||||
"当有人要求禁言自己时使用",
|
||||
"如果某人已经被禁言了,就不要再次禁言了,除非你想追加时间!!"
|
||||
"如果某人已经被禁言了,就不要再次禁言了,除非你想追加时间!!",
|
||||
]
|
||||
enable_plugin = False # 启用插件
|
||||
associated_types = ["command", "text"]
|
||||
@@ -32,7 +32,6 @@ class MuteAction(PluginAction):
|
||||
focus_activation_type = ActionActivationType.LLM_JUDGE # Focus模式使用LLM判定,确保谨慎
|
||||
normal_activation_type = ActionActivationType.KEYWORD # Normal模式使用关键词激活,快速响应
|
||||
|
||||
|
||||
# 关键词设置(用于Normal模式)
|
||||
activation_keywords = ["禁言", "mute", "ban", "silence"]
|
||||
keyword_case_sensitive = False
|
||||
@@ -130,11 +129,10 @@ log_mute_history = true
|
||||
|
||||
def _get_template_message(self, target: str, duration_str: str, reason: str) -> str:
|
||||
"""获取模板化的禁言消息"""
|
||||
templates = self.config.get("templates", [
|
||||
"好的,禁言 {target} {duration},理由:{reason}"
|
||||
])
|
||||
templates = self.config.get("templates", ["好的,禁言 {target} {duration},理由:{reason}"])
|
||||
|
||||
import random
|
||||
|
||||
template = random.choice(templates)
|
||||
return template.format(target=target, duration=duration_str, reason=reason)
|
||||
|
||||
@@ -170,7 +168,9 @@ log_mute_history = true
|
||||
error_msg = "禁言时长必须大于0"
|
||||
logger.error(f"{self.log_prefix} {error_msg}")
|
||||
error_templates = self.config.get("error_messages", ["禁言时长必须是正数哦~"])
|
||||
await self.send_message_by_expressor(error_templates[2] if len(error_templates) > 2 else "禁言时长必须是正数哦~")
|
||||
await self.send_message_by_expressor(
|
||||
error_templates[2] if len(error_templates) > 2 else "禁言时长必须是正数哦~"
|
||||
)
|
||||
return False, error_msg
|
||||
|
||||
# 限制禁言时长范围
|
||||
@@ -181,11 +181,13 @@ log_mute_history = true
|
||||
duration_int = max_duration
|
||||
logger.info(f"{self.log_prefix} 禁言时长过长,调整为{max_duration}秒")
|
||||
|
||||
except (ValueError, TypeError) as e:
|
||||
except (ValueError, TypeError):
|
||||
error_msg = f"禁言时长格式无效: {duration}"
|
||||
logger.error(f"{self.log_prefix} {error_msg}")
|
||||
error_templates = self.config.get("error_messages", ["禁言时长必须是数字哦~"])
|
||||
await self.send_message_by_expressor(error_templates[3] if len(error_templates) > 3 else "禁言时长必须是数字哦~")
|
||||
await self.send_message_by_expressor(
|
||||
error_templates[3] if len(error_templates) > 3 else "禁言时长必须是数字哦~"
|
||||
)
|
||||
return False, error_msg
|
||||
|
||||
# 获取用户ID
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import importlib
|
||||
import pkgutil
|
||||
import os
|
||||
from typing import Dict, List, Tuple
|
||||
from typing import Dict, Tuple
|
||||
from src.common.logger_manager import get_logger
|
||||
|
||||
logger = get_logger("plugin_loader")
|
||||
@@ -25,7 +25,7 @@ class PluginLoader:
|
||||
# 定义插件搜索路径(优先级从高到低)
|
||||
plugin_paths = [
|
||||
("plugins", "plugins"), # 项目根目录的plugins文件夹
|
||||
("src.plugins", os.path.join("src", "plugins")) # src下的plugins文件夹
|
||||
("src.plugins", os.path.join("src", "plugins")), # src下的plugins文件夹
|
||||
]
|
||||
|
||||
total_plugins_found = 0
|
||||
@@ -38,6 +38,7 @@ class PluginLoader:
|
||||
except Exception as e:
|
||||
logger.error(f"从路径 {plugin_dir_path} 加载插件失败: {e}")
|
||||
import traceback
|
||||
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
if total_plugins_found == 0:
|
||||
@@ -75,9 +76,7 @@ class PluginLoader:
|
||||
|
||||
# 遍历插件包中的所有子包
|
||||
plugins_found = 0
|
||||
for _, plugin_name, is_pkg in pkgutil.iter_modules(
|
||||
plugins_package.__path__, plugins_package.__name__ + "."
|
||||
):
|
||||
for _, plugin_name, is_pkg in pkgutil.iter_modules(plugins_package.__path__, plugins_package.__name__ + "."):
|
||||
if not is_pkg:
|
||||
continue
|
||||
|
||||
@@ -187,7 +186,7 @@ class PluginLoader:
|
||||
# 遍历actions目录中的所有Python文件
|
||||
actions_dir = os.path.dirname(actions_module.__file__)
|
||||
for file in os.listdir(actions_dir):
|
||||
if file.endswith('.py') and file != '__init__.py':
|
||||
if file.endswith(".py") and file != "__init__.py":
|
||||
action_module_name = f"{plugin_actions_path}.{file[:-3]}"
|
||||
try:
|
||||
importlib.import_module(action_module_name)
|
||||
@@ -213,7 +212,7 @@ class PluginLoader:
|
||||
# 遍历commands目录中的所有Python文件
|
||||
commands_dir = os.path.dirname(commands_module.__file__)
|
||||
for file in os.listdir(commands_dir):
|
||||
if file.endswith('.py') and file != '__init__.py':
|
||||
if file.endswith(".py") and file != "__init__.py":
|
||||
command_module_name = f"{plugin_commands_path}.{file[:-3]}"
|
||||
try:
|
||||
importlib.import_module(command_module_name)
|
||||
@@ -239,9 +238,9 @@ class PluginLoader:
|
||||
# 遍历插件根目录中的所有Python文件
|
||||
plugin_dir = os.path.dirname(plugin_module.__file__)
|
||||
for file in os.listdir(plugin_dir):
|
||||
if file.endswith('.py') and file != '__init__.py':
|
||||
if file.endswith(".py") and file != "__init__.py":
|
||||
# 跳过非动作文件(根据命名约定)
|
||||
if not (file.endswith('_action.py') or file.endswith('_actions.py') or 'action' in file):
|
||||
if not (file.endswith("_action.py") or file.endswith("_actions.py") or "action" in file):
|
||||
continue
|
||||
|
||||
action_module_name = f"{plugin_name}.{file[:-3]}"
|
||||
@@ -269,9 +268,9 @@ class PluginLoader:
|
||||
# 遍历插件根目录中的所有Python文件
|
||||
plugin_dir = os.path.dirname(plugin_module.__file__)
|
||||
for file in os.listdir(plugin_dir):
|
||||
if file.endswith('.py') and file != '__init__.py':
|
||||
if file.endswith(".py") and file != "__init__.py":
|
||||
# 跳过非命令文件(根据命名约定)
|
||||
if not (file.endswith('_command.py') or file.endswith('_commands.py') or 'command' in file):
|
||||
if not (file.endswith("_command.py") or file.endswith("_commands.py") or "command" in file):
|
||||
continue
|
||||
|
||||
command_module_name = f"{plugin_name}.{file[:-3]}"
|
||||
@@ -294,9 +293,11 @@ class PluginLoader:
|
||||
if self.plugin_stats:
|
||||
logger.info("插件加载详情:")
|
||||
for plugin_name, stats in self.plugin_stats.items():
|
||||
plugin_display_name = plugin_name.split('.')[-1] # 只显示插件名称,不显示完整路径
|
||||
plugin_display_name = plugin_name.split(".")[-1] # 只显示插件名称,不显示完整路径
|
||||
source_path = self.plugin_sources.get(plugin_name, "未知路径")
|
||||
logger.info(f" {plugin_display_name} (来源: {source_path}): {stats['actions']} 动作, {stats['commands']} 命令")
|
||||
logger.info(
|
||||
f" {plugin_display_name} (来源: {source_path}): {stats['actions']} 动作, {stats['commands']} 命令"
|
||||
)
|
||||
|
||||
|
||||
# 创建全局插件加载器实例
|
||||
|
||||
Reference in New Issue
Block a user