Merge branch 'dev' of https://github.com/MoFox-Studio/MoFox_Bot into dev
This commit is contained in:
@@ -30,8 +30,8 @@ from .utils.hash import get_sha256
|
||||
install(extra_lines=3)
|
||||
|
||||
# 多线程embedding配置常量
|
||||
DEFAULT_MAX_WORKERS = 10 # 默认最大线程数
|
||||
DEFAULT_CHUNK_SIZE = 10 # 默认每个线程处理的数据块大小
|
||||
DEFAULT_MAX_WORKERS = 3 # 默认最大线程数
|
||||
DEFAULT_CHUNK_SIZE = 5 # 默认每个线程处理的数据块大小
|
||||
MIN_CHUNK_SIZE = 1 # 最小分块大小
|
||||
MAX_CHUNK_SIZE = 50 # 最大分块大小
|
||||
MIN_WORKERS = 1 # 最小线程数
|
||||
@@ -124,60 +124,124 @@ class EmbeddingStore:
|
||||
self.faiss_index = None
|
||||
self.idx2hash = None
|
||||
|
||||
@staticmethod
|
||||
def _get_embedding(s: str) -> list[float]:
|
||||
"""获取字符串的嵌入向量,使用完全同步的方式避免事件循环问题"""
|
||||
# 创建新的事件循环并在完成后立即关闭
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
# 创建新的LLMRequest实例
|
||||
from src.config.config import model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type="embedding")
|
||||
|
||||
# 使用新的事件循环运行异步方法
|
||||
embedding, _ = loop.run_until_complete(llm.get_embedding(s))
|
||||
|
||||
if embedding and len(embedding) > 0:
|
||||
return embedding
|
||||
else:
|
||||
logger.error(f"获取嵌入失败: {s}")
|
||||
return []
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取嵌入时发生异常: {s}, 错误: {e}")
|
||||
return []
|
||||
finally:
|
||||
# 确保事件循环被正确关闭
|
||||
try:
|
||||
loop.close()
|
||||
except Exception:
|
||||
...
|
||||
|
||||
@staticmethod
|
||||
def _get_embeddings_batch_threaded(
|
||||
strs: list[str],
|
||||
main_loop: asyncio.AbstractEventLoop,
|
||||
chunk_size: int = 10,
|
||||
max_workers: int = 10,
|
||||
progress_callback=None,
|
||||
strs: list[str], chunk_size: int = 10, max_workers: int = 10, progress_callback=None
|
||||
) -> list[tuple[str, list[float]]]:
|
||||
"""使用多线程批量获取嵌入向量, 并通过 run_coroutine_threadsafe 在主事件循环中运行异步任务"""
|
||||
"""使用多线程批量获取嵌入向量
|
||||
|
||||
Args:
|
||||
strs: 要获取嵌入的字符串列表
|
||||
chunk_size: 每个线程处理的数据块大小
|
||||
max_workers: 最大线程数
|
||||
progress_callback: 进度回调函数,接收一个参数表示完成的数量
|
||||
|
||||
Returns:
|
||||
包含(原始字符串, 嵌入向量)的元组列表,保持与输入顺序一致
|
||||
"""
|
||||
if not strs:
|
||||
return []
|
||||
|
||||
# 导入必要的模块
|
||||
from src.config.config import model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
# 在主线程(即主事件循环所在的线程)中创建LLMRequest实例
|
||||
# 这样可以确保它绑定到正确的事件循环
|
||||
llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type="embedding")
|
||||
|
||||
# 分块
|
||||
chunks = [(i, strs[i : i + chunk_size]) for i in range(0, len(strs), chunk_size)]
|
||||
chunks = []
|
||||
for i in range(0, len(strs), chunk_size):
|
||||
chunk = strs[i : i + chunk_size]
|
||||
chunks.append((i, chunk)) # 保存起始索引以维持顺序
|
||||
|
||||
# 结果存储,使用字典按索引存储以保证顺序
|
||||
results = {}
|
||||
|
||||
def process_chunk(chunk_data):
|
||||
"""在工作线程中运行的函数"""
|
||||
"""处理单个数据块的函数"""
|
||||
start_idx, chunk_strs = chunk_data
|
||||
chunk_results = []
|
||||
|
||||
for i, s in enumerate(chunk_strs):
|
||||
embedding = []
|
||||
try:
|
||||
# 将异步的 get_embedding 调用提交到主事件循环
|
||||
future = asyncio.run_coroutine_threadsafe(llm.get_embedding(s), main_loop)
|
||||
# 同步等待结果,延长超时时间
|
||||
embedding_result, _ = future.result(timeout=60)
|
||||
# 为每个线程创建独立的LLMRequest实例
|
||||
from src.config.config import model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
if embedding_result and len(embedding_result) > 0:
|
||||
embedding = embedding_result
|
||||
else:
|
||||
logger.error(f"获取嵌入失败(返回为空): {s}")
|
||||
try:
|
||||
# 创建线程专用的LLM实例
|
||||
llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type="embedding")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"在线程中获取嵌入时发生异常: {s}, 错误: {type(e).__name__}: {e}")
|
||||
finally:
|
||||
chunk_results.append((start_idx + i, s, embedding))
|
||||
for i, s in enumerate(chunk_strs):
|
||||
try:
|
||||
# 在线程中创建独立的事件循环
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
embedding = loop.run_until_complete(llm.get_embedding(s))
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
if embedding and len(embedding) > 0:
|
||||
chunk_results.append((start_idx + i, s, embedding[0])) # embedding[0] 是实际的向量
|
||||
else:
|
||||
logger.error(f"获取嵌入失败: {s}")
|
||||
chunk_results.append((start_idx + i, s, []))
|
||||
|
||||
# 每完成一个嵌入立即更新进度
|
||||
if progress_callback:
|
||||
progress_callback(1)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取嵌入时发生异常: {s}, 错误: {e}")
|
||||
chunk_results.append((start_idx + i, s, []))
|
||||
|
||||
# 即使失败也要更新进度
|
||||
if progress_callback:
|
||||
progress_callback(1)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"创建LLM实例失败: {e}")
|
||||
# 如果创建LLM实例失败,返回空结果
|
||||
for i, s in enumerate(chunk_strs):
|
||||
chunk_results.append((start_idx + i, s, []))
|
||||
# 即使失败也要更新进度
|
||||
if progress_callback:
|
||||
progress_callback(1)
|
||||
|
||||
return chunk_results
|
||||
|
||||
# 使用线程池处理
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
# 提交所有任务
|
||||
future_to_chunk = {executor.submit(process_chunk, chunk): chunk for chunk in chunks}
|
||||
|
||||
# 收集结果(进度已在process_chunk中实时更新)
|
||||
for future in as_completed(future_to_chunk):
|
||||
try:
|
||||
chunk_results = future.result()
|
||||
@@ -185,14 +249,22 @@ class EmbeddingStore:
|
||||
results[idx] = (s, embedding)
|
||||
except Exception as e:
|
||||
chunk = future_to_chunk[future]
|
||||
logger.error(f"处理数据块时发生严重异常: {chunk}, 错误: {e}")
|
||||
logger.error(f"处理数据块时发生异常: {chunk}, 错误: {e}")
|
||||
# 为失败的块添加空结果
|
||||
start_idx, chunk_strs = chunk
|
||||
for i, s_item in enumerate(chunk_strs):
|
||||
if (start_idx + i) not in results:
|
||||
results[start_idx + i] = (s_item, [])
|
||||
for i, s in enumerate(chunk_strs):
|
||||
results[start_idx + i] = (s, [])
|
||||
|
||||
# 按原始顺序返回结果
|
||||
return [results.get(i, (strs[i], [])) for i in range(len(strs))]
|
||||
ordered_results = []
|
||||
for i in range(len(strs)):
|
||||
if i in results:
|
||||
ordered_results.append(results[i])
|
||||
else:
|
||||
# 防止遗漏
|
||||
ordered_results.append((strs[i], []))
|
||||
|
||||
return ordered_results
|
||||
|
||||
@staticmethod
|
||||
def get_test_file_path():
|
||||
@@ -202,17 +274,9 @@ class EmbeddingStore:
|
||||
"""保存测试字符串的嵌入到本地(使用多线程优化)"""
|
||||
logger.info("开始保存测试字符串的嵌入向量...")
|
||||
|
||||
# 获取当前正在运行的事件循环
|
||||
try:
|
||||
main_loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
logger.error("无法获取正在运行的事件循环。请确保在异步上下文中调用此方法。")
|
||||
return
|
||||
|
||||
# 使用多线程批量获取测试字符串的嵌入
|
||||
embedding_results = self._get_embeddings_batch_threaded(
|
||||
EMBEDDING_TEST_STRINGS,
|
||||
main_loop,
|
||||
chunk_size=min(self.chunk_size, len(EMBEDDING_TEST_STRINGS)),
|
||||
max_workers=min(self.max_workers, len(EMBEDDING_TEST_STRINGS)),
|
||||
)
|
||||
@@ -224,6 +288,8 @@ class EmbeddingStore:
|
||||
test_vectors[str(idx)] = embedding
|
||||
else:
|
||||
logger.error(f"获取测试字符串嵌入失败: {s}")
|
||||
# 使用原始单线程方法作为后备
|
||||
test_vectors[str(idx)] = self._get_embedding(s)
|
||||
|
||||
with open(self.get_test_file_path(), "w", encoding="utf-8") as f:
|
||||
f.write(orjson.dumps(test_vectors, option=orjson.OPT_INDENT_2).decode("utf-8"))
|
||||
@@ -255,17 +321,9 @@ class EmbeddingStore:
|
||||
|
||||
logger.info("开始检验嵌入模型一致性...")
|
||||
|
||||
# 获取当前正在运行的事件循环
|
||||
try:
|
||||
main_loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
logger.error("无法获取正在运行的事件循环。请确保在异步上下文中调用此方法。")
|
||||
return False
|
||||
|
||||
# 使用多线程批量获取当前模型的嵌入
|
||||
embedding_results = self._get_embeddings_batch_threaded(
|
||||
EMBEDDING_TEST_STRINGS,
|
||||
main_loop,
|
||||
chunk_size=min(self.chunk_size, len(EMBEDDING_TEST_STRINGS)),
|
||||
max_workers=min(self.max_workers, len(EMBEDDING_TEST_STRINGS)),
|
||||
)
|
||||
@@ -325,20 +383,11 @@ class EmbeddingStore:
|
||||
progress.update(task, advance=already_processed)
|
||||
|
||||
if new_strs:
|
||||
try:
|
||||
main_loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
logger.error("无法获取正在运行的事件循环。请确保在异步上下文中调用此方法。")
|
||||
# 更新进度条以反映未处理的项目
|
||||
progress.update(task, advance=len(new_strs))
|
||||
return
|
||||
|
||||
# 使用实例配置的参数,智能调整分块和线程数
|
||||
optimal_chunk_size = max(
|
||||
MIN_CHUNK_SIZE,
|
||||
min(
|
||||
self.chunk_size,
|
||||
len(new_strs) // self.max_workers if self.max_workers > 0 else self.chunk_size,
|
||||
self.chunk_size, len(new_strs) // self.max_workers if self.max_workers > 0 else self.chunk_size
|
||||
),
|
||||
)
|
||||
optimal_max_workers = min(
|
||||
@@ -355,13 +404,12 @@ class EmbeddingStore:
|
||||
# 批量获取嵌入,并实时更新进度
|
||||
embedding_results = self._get_embeddings_batch_threaded(
|
||||
new_strs,
|
||||
main_loop,
|
||||
chunk_size=optimal_chunk_size,
|
||||
max_workers=optimal_max_workers,
|
||||
progress_callback=update_progress,
|
||||
)
|
||||
|
||||
# 存入结果
|
||||
# 存入结果(不再需要在这里更新进度,因为已经在回调中更新了)
|
||||
for s, embedding in embedding_results:
|
||||
item_hash = self.namespace + "-" + get_sha256(s)
|
||||
if embedding: # 只有成功获取到嵌入才存入
|
||||
|
||||
@@ -88,6 +88,8 @@ class QAManager:
|
||||
else:
|
||||
logger.info("未找到相关关系,将使用文段检索结果")
|
||||
result = paragraph_search_res
|
||||
if result and result[0][1] < global_config.lpmm_knowledge.qa_paragraph_threshold:
|
||||
result = []
|
||||
ppr_node_weights = None
|
||||
|
||||
# 过滤阈值
|
||||
|
||||
@@ -45,8 +45,8 @@ class MessageManager:
|
||||
self.chatter_manager = ChatterManager(self.action_manager)
|
||||
|
||||
# 消息缓存系统 - 直接集成到消息管理器
|
||||
self.message_caches: Dict[str, deque] = defaultdict(deque) # 每个流的消息缓存
|
||||
self.stream_processing_status: Dict[str, bool] = defaultdict(bool) # 流的处理状态
|
||||
self.message_caches: dict[str, deque] = defaultdict(deque) # 每个流的消息缓存
|
||||
self.stream_processing_status: dict[str, bool] = defaultdict(bool) # 流的处理状态
|
||||
self.cache_stats = {
|
||||
"total_cached_messages": 0,
|
||||
"total_flushed_messages": 0,
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
from datetime import datetime, time, timedelta
|
||||
import random
|
||||
from typing import Optional, Tuple
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.schedule.schedule_manager import schedule_manager
|
||||
|
||||
from .state_manager import SleepState, sleep_state_manager
|
||||
|
||||
logger = get_logger("sleep_logic")
|
||||
@@ -77,7 +77,7 @@ class SleepLogic:
|
||||
logger.info(f"当前时间 {now.strftime('%H:%M')} 已到达或超过预定起床时间 {wake_up_time.strftime('%H:%M')}。")
|
||||
sleep_state_manager.set_state(SleepState.AWAKE)
|
||||
|
||||
def _should_be_sleeping(self, now: datetime) -> Tuple[bool, Optional[datetime]]:
|
||||
def _should_be_sleeping(self, now: datetime) -> tuple[bool, datetime | None]:
|
||||
"""
|
||||
判断在当前时刻,是否应该处于睡眠时间。
|
||||
|
||||
@@ -108,10 +108,10 @@ class SleepLogic:
|
||||
return True, wake_up_time
|
||||
# 如果当前时间大于入睡时间,说明已经进入睡眠窗口
|
||||
return True, wake_up_time
|
||||
|
||||
|
||||
return False, None
|
||||
|
||||
def _get_fixed_sleep_times(self, now: datetime) -> Tuple[Optional[datetime], Optional[datetime]]:
|
||||
def _get_fixed_sleep_times(self, now: datetime) -> tuple[datetime | None, datetime | None]:
|
||||
"""
|
||||
当使用“固定时间”模式时,从此方法计算睡眠和起床时间。
|
||||
会加入配置中的随机偏移量,让作息更自然。
|
||||
@@ -129,7 +129,7 @@ class SleepLogic:
|
||||
wake_up_t = datetime.strptime(sleep_config.fixed_wake_up_time, "%H:%M").time()
|
||||
|
||||
sleep_time = datetime.combine(now.date(), sleep_t) + timedelta(minutes=sleep_offset)
|
||||
|
||||
|
||||
# 如果起床时间比睡觉时间早,说明是第二天
|
||||
wake_up_day = now.date() + timedelta(days=1) if wake_up_t < sleep_t else now.date()
|
||||
wake_up_time = datetime.combine(wake_up_day, wake_up_t) + timedelta(minutes=wake_up_offset)
|
||||
@@ -139,7 +139,7 @@ class SleepLogic:
|
||||
logger.error(f"解析固定睡眠时间失败: {e}")
|
||||
return None, None
|
||||
|
||||
def _get_sleep_times_from_schedule(self, now: datetime) -> Tuple[Optional[datetime], Optional[datetime]]:
|
||||
def _get_sleep_times_from_schedule(self, now: datetime) -> tuple[datetime | None, datetime | None]:
|
||||
"""
|
||||
当使用“日程表”模式时,从此方法获取睡眠时间。
|
||||
实现了核心逻辑:
|
||||
@@ -164,8 +164,8 @@ class SleepLogic:
|
||||
wake_up_time = None
|
||||
|
||||
return sleep_time, wake_up_time
|
||||
|
||||
def _get_wakeup_times_from_schedule(self, now: datetime) -> Tuple[Optional[datetime], Optional[datetime]]:
|
||||
|
||||
def _get_wakeup_times_from_schedule(self, now: datetime) -> tuple[datetime | None, datetime | None]:
|
||||
"""
|
||||
当使用“日程表”模式时,从此方法获取睡眠时间。
|
||||
实现了核心逻辑:
|
||||
@@ -192,4 +192,4 @@ class SleepLogic:
|
||||
|
||||
|
||||
# 全局单例
|
||||
sleep_logic = SleepLogic()
|
||||
sleep_logic = SleepLogic()
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import enum
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.manager.local_store_manager import local_storage
|
||||
@@ -43,7 +43,7 @@ class SleepStateManager:
|
||||
"""
|
||||
初始化状态管理器,定义状态数据结构并从本地加载历史状态。
|
||||
"""
|
||||
self.state: Dict[str, Any] = {}
|
||||
self.state: dict[str, Any] = {}
|
||||
self._default_state()
|
||||
self.load_state()
|
||||
|
||||
@@ -115,9 +115,9 @@ class SleepStateManager:
|
||||
def set_state(
|
||||
self,
|
||||
new_state: SleepState,
|
||||
duration_seconds: Optional[float] = None,
|
||||
sleep_start: Optional[datetime] = None,
|
||||
wake_up: Optional[datetime] = None,
|
||||
duration_seconds: float | None = None,
|
||||
sleep_start: datetime | None = None,
|
||||
wake_up: datetime | None = None,
|
||||
):
|
||||
"""
|
||||
核心函数:切换到新的睡眠状态,并更新相关的状态数据。
|
||||
@@ -132,7 +132,7 @@ class SleepStateManager:
|
||||
if new_state == SleepState.AWAKE:
|
||||
self._default_state() # 醒来时重置所有状态
|
||||
self.state["state"] = SleepState.AWAKE # 确保状态正确
|
||||
|
||||
|
||||
elif new_state == SleepState.SLEEPING:
|
||||
self.state["sleep_start_time"] = (sleep_start or datetime.now()).isoformat()
|
||||
self.state["wake_up_time"] = wake_up.isoformat() if wake_up else None
|
||||
@@ -153,7 +153,7 @@ class SleepStateManager:
|
||||
self.state["last_checked"] = datetime.now().isoformat()
|
||||
self.save_state()
|
||||
|
||||
def get_wake_up_time(self) -> Optional[datetime]:
|
||||
def get_wake_up_time(self) -> datetime | None:
|
||||
"""获取预定的起床时间,如果已设置的话。"""
|
||||
wake_up_str = self.state.get("wake_up_time")
|
||||
if wake_up_str:
|
||||
@@ -163,7 +163,7 @@ class SleepStateManager:
|
||||
return None
|
||||
return None
|
||||
|
||||
def get_sleep_start_time(self) -> Optional[datetime]:
|
||||
def get_sleep_start_time(self) -> datetime | None:
|
||||
"""获取本次睡眠的开始时间,如果已设置的话。"""
|
||||
sleep_start_str = self.state.get("sleep_start_time")
|
||||
if sleep_start_str:
|
||||
@@ -187,4 +187,4 @@ class SleepStateManager:
|
||||
|
||||
|
||||
# 全局单例
|
||||
sleep_state_manager = SleepStateManager()
|
||||
sleep_state_manager = SleepStateManager()
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from src.common.logger import get_logger
|
||||
from src.manager.async_task_manager import AsyncTask, async_task_manager
|
||||
|
||||
from .sleep_logic import sleep_logic
|
||||
|
||||
logger = get_logger("sleep_tasks")
|
||||
|
||||
@@ -402,19 +402,31 @@ class ChatBot:
|
||||
# 确保所有任务已启动
|
||||
await self._ensure_started()
|
||||
|
||||
platform = message_data["message_info"].get("platform")
|
||||
# 控制握手等消息可能缺少 message_info,这里直接跳过避免 KeyError
|
||||
if not isinstance(message_data, dict):
|
||||
logger.warning(f"收到无法解析的消息类型: {type(message_data)},已跳过")
|
||||
return
|
||||
message_info = message_data.get("message_info")
|
||||
if not isinstance(message_info, dict):
|
||||
logger.debug(
|
||||
"收到缺少 message_info 的消息,已跳过。可用字段: %s",
|
||||
", ".join(message_data.keys()),
|
||||
)
|
||||
return
|
||||
|
||||
platform = message_info.get("platform")
|
||||
|
||||
if platform == "amaidesu_default":
|
||||
await self.do_s4u(message_data)
|
||||
return
|
||||
|
||||
if message_data["message_info"].get("group_info") is not None:
|
||||
message_data["message_info"]["group_info"]["group_id"] = str(
|
||||
message_data["message_info"]["group_info"]["group_id"]
|
||||
if message_info.get("group_info") is not None:
|
||||
message_info["group_info"]["group_id"] = str(
|
||||
message_info["group_info"]["group_id"]
|
||||
)
|
||||
if message_data["message_info"].get("user_info") is not None:
|
||||
message_data["message_info"]["user_info"]["user_id"] = str(
|
||||
message_data["message_info"]["user_info"]["user_id"]
|
||||
if message_info.get("user_info") is not None:
|
||||
message_info["user_info"]["user_id"] = str(
|
||||
message_info["user_info"]["user_id"]
|
||||
)
|
||||
# print(message_data)
|
||||
# logger.debug(str(message_data))
|
||||
|
||||
@@ -11,7 +11,7 @@ from src.common.data_models.message_manager_data_model import StreamContext
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.plugin_system.base.component_types import ActionActivationType, ActionInfo
|
||||
from src.plugin_system.base.component_types import ActionInfo
|
||||
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -207,18 +207,18 @@ class ActionModifier:
|
||||
List[Tuple[str, str]]: 需要停用的 (action_name, reason) 元组列表
|
||||
"""
|
||||
deactivated_actions = []
|
||||
|
||||
|
||||
# 获取 Action 类注册表
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
from src.plugin_system.base.component_types import ComponentType
|
||||
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
|
||||
actions_to_check = list(actions_with_info.items())
|
||||
random.shuffle(actions_to_check)
|
||||
|
||||
|
||||
# 创建并行任务列表
|
||||
activation_tasks = []
|
||||
task_action_names = []
|
||||
|
||||
|
||||
for action_name, action_info in actions_to_check:
|
||||
# 获取 Action 类
|
||||
action_class = component_registry.get_component_class(action_name, ComponentType.ACTION)
|
||||
@@ -226,7 +226,7 @@ class ActionModifier:
|
||||
logger.warning(f"{self.log_prefix}未找到 Action 类: {action_name},默认不激活")
|
||||
deactivated_actions.append((action_name, "未找到 Action 类"))
|
||||
continue
|
||||
|
||||
|
||||
# 创建一个临时实例来调用 go_activate 方法
|
||||
# 注意:这里只是为了调用 go_activate,不需要完整的初始化
|
||||
try:
|
||||
@@ -237,24 +237,24 @@ class ActionModifier:
|
||||
action_instance.log_prefix = self.log_prefix
|
||||
# 设置聊天内容,用于激活判断
|
||||
action_instance._activation_chat_content = chat_content
|
||||
|
||||
|
||||
# 调用 go_activate 方法(不再需要传入 chat_content)
|
||||
task = action_instance.go_activate(
|
||||
llm_judge_model=self.llm_judge,
|
||||
)
|
||||
activation_tasks.append(task)
|
||||
task_action_names.append(action_name)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix}创建 Action 实例 {action_name} 失败: {e}")
|
||||
deactivated_actions.append((action_name, f"创建实例失败: {e}"))
|
||||
|
||||
|
||||
# 并行执行所有激活判断
|
||||
if activation_tasks:
|
||||
logger.debug(f"{self.log_prefix}并行执行激活判断,任务数: {len(activation_tasks)}")
|
||||
try:
|
||||
task_results = await asyncio.gather(*activation_tasks, return_exceptions=True)
|
||||
|
||||
|
||||
# 处理结果
|
||||
for action_name, result in zip(task_action_names, task_results, strict=False):
|
||||
if isinstance(result, Exception):
|
||||
@@ -267,7 +267,7 @@ class ActionModifier:
|
||||
else:
|
||||
# go_activate 返回 True,激活
|
||||
logger.debug(f"{self.log_prefix}激活动作: {action_name}")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix}并行激活判断失败: {e}")
|
||||
# 如果并行执行失败,为所有任务默认不激活
|
||||
|
||||
@@ -23,7 +23,8 @@ from src.chat.utils.chat_message_builder import (
|
||||
from src.chat.utils.memory_mappings import get_memory_type_chinese_label
|
||||
|
||||
# 导入新的统一Prompt系统
|
||||
from src.chat.utils.prompt import Prompt, PromptParameters, global_prompt_manager
|
||||
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||
from src.chat.utils.prompt_params import PromptParameters
|
||||
from src.chat.utils.timer_calculator import Timer
|
||||
from src.chat.utils.utils import get_chat_type_and_target_info
|
||||
from src.common.logger import get_logger
|
||||
@@ -1312,7 +1313,7 @@ class DefaultReplyer:
|
||||
}
|
||||
|
||||
# 设置超时
|
||||
timeout = 15.0 # 秒
|
||||
timeout = 45.0 # 秒
|
||||
|
||||
async def get_task_result(task_name, task):
|
||||
try:
|
||||
|
||||
@@ -8,13 +8,14 @@ import contextvars
|
||||
import re
|
||||
import time
|
||||
from contextlib import asynccontextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Literal, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from rich.traceback import install
|
||||
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages
|
||||
from src.chat.utils.prompt_component_manager import prompt_component_manager
|
||||
from src.chat.utils.prompt_params import PromptParameters
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.person_info.person_info import get_person_info_manager
|
||||
@@ -23,81 +24,6 @@ install(extra_lines=3)
|
||||
logger = get_logger("unified_prompt")
|
||||
|
||||
|
||||
@dataclass
|
||||
class PromptParameters:
|
||||
"""统一提示词参数系统"""
|
||||
|
||||
# 基础参数
|
||||
chat_id: str = ""
|
||||
is_group_chat: bool = False
|
||||
sender: str = ""
|
||||
target: str = ""
|
||||
reply_to: str = ""
|
||||
extra_info: str = ""
|
||||
prompt_mode: Literal["s4u", "normal", "minimal"] = "s4u"
|
||||
bot_name: str = ""
|
||||
bot_nickname: str = ""
|
||||
|
||||
# 功能开关
|
||||
enable_tool: bool = True
|
||||
enable_memory: bool = True
|
||||
enable_expression: bool = True
|
||||
enable_relation: bool = True
|
||||
enable_cross_context: bool = True
|
||||
enable_knowledge: bool = True
|
||||
|
||||
# 性能控制
|
||||
max_context_messages: int = 50
|
||||
|
||||
# 调试选项
|
||||
debug_mode: bool = False
|
||||
|
||||
# 聊天历史和上下文
|
||||
chat_target_info: dict[str, Any] | None = None
|
||||
message_list_before_now_long: list[dict[str, Any]] = field(default_factory=list)
|
||||
message_list_before_short: list[dict[str, Any]] = field(default_factory=list)
|
||||
chat_talking_prompt_short: str = ""
|
||||
target_user_info: dict[str, Any] | None = None
|
||||
|
||||
# 已构建的内容块
|
||||
expression_habits_block: str = ""
|
||||
relation_info_block: str = ""
|
||||
memory_block: str = ""
|
||||
tool_info_block: str = ""
|
||||
knowledge_prompt: str = ""
|
||||
cross_context_block: str = ""
|
||||
notice_block: str = ""
|
||||
|
||||
# 其他内容块
|
||||
keywords_reaction_prompt: str = ""
|
||||
extra_info_block: str = ""
|
||||
time_block: str = ""
|
||||
identity_block: str = ""
|
||||
schedule_block: str = ""
|
||||
moderation_prompt_block: str = ""
|
||||
safety_guidelines_block: str = ""
|
||||
reply_target_block: str = ""
|
||||
mood_prompt: str = ""
|
||||
action_descriptions: str = ""
|
||||
|
||||
# 可用动作信息
|
||||
available_actions: dict[str, Any] | None = None
|
||||
|
||||
# 动态生成的聊天场景提示
|
||||
chat_scene: str = ""
|
||||
|
||||
def validate(self) -> list[str]:
|
||||
"""参数验证"""
|
||||
errors = []
|
||||
if not self.chat_id:
|
||||
errors.append("chat_id不能为空")
|
||||
if self.prompt_mode not in ["s4u", "normal", "minimal"]:
|
||||
errors.append("prompt_mode必须是's4u'、'normal'或'minimal'")
|
||||
if self.max_context_messages <= 0:
|
||||
errors.append("max_context_messages必须大于0")
|
||||
return errors
|
||||
|
||||
|
||||
class PromptContext:
|
||||
"""提示词上下文管理器"""
|
||||
|
||||
@@ -132,7 +58,7 @@ class PromptContext:
|
||||
context_id = None
|
||||
|
||||
previous_context = self._current_context
|
||||
token = self._current_context_var.set(context_id) if context_id else None
|
||||
token = self._current_context_var.set(context_id) if context_id else None # type: ignore
|
||||
else:
|
||||
previous_context = self._current_context
|
||||
token = None
|
||||
@@ -185,16 +111,42 @@ class PromptManager:
|
||||
async with self._context.async_scope(message_id):
|
||||
yield self
|
||||
|
||||
async def get_prompt_async(self, name: str) -> "Prompt":
|
||||
"""异步获取提示模板"""
|
||||
async def get_prompt_async(self, name: str, parameters: PromptParameters | None = None) -> "Prompt":
|
||||
"""
|
||||
异步获取提示模板,并动态注入插件内容
|
||||
"""
|
||||
original_prompt = None
|
||||
context_prompt = await self._context.get_prompt_async(name)
|
||||
if context_prompt is not None:
|
||||
logger.debug(f"从上下文中获取提示词: {name} {context_prompt}")
|
||||
return context_prompt
|
||||
|
||||
if name not in self._prompts:
|
||||
original_prompt = context_prompt
|
||||
elif name in self._prompts:
|
||||
original_prompt = self._prompts[name]
|
||||
else:
|
||||
raise KeyError(f"Prompt '{name}' not found")
|
||||
return self._prompts[name]
|
||||
|
||||
# 动态注入插件内容
|
||||
if original_prompt.name:
|
||||
# 确保我们有有效的parameters实例
|
||||
params_for_injection = parameters or original_prompt.parameters
|
||||
|
||||
components_prefix = await prompt_component_manager.execute_components_for(
|
||||
injection_point=original_prompt.name, params=params_for_injection
|
||||
)
|
||||
logger.info(components_prefix)
|
||||
if components_prefix:
|
||||
logger.info(f"为'{name}'注入插件内容: \n{components_prefix}")
|
||||
# 创建一个新的临时Prompt实例,不进行注册
|
||||
new_template = f"{components_prefix}\n\n{original_prompt.template}"
|
||||
temp_prompt = Prompt(
|
||||
template=new_template,
|
||||
name=original_prompt.name,
|
||||
parameters=original_prompt.parameters,
|
||||
should_register=False, # 确保不重新注册
|
||||
)
|
||||
return temp_prompt
|
||||
|
||||
return original_prompt
|
||||
|
||||
def generate_name(self, template: str) -> str:
|
||||
"""为未命名的prompt生成名称"""
|
||||
@@ -216,7 +168,9 @@ class PromptManager:
|
||||
|
||||
async def format_prompt(self, name: str, **kwargs) -> str:
|
||||
"""格式化提示模板"""
|
||||
prompt = await self.get_prompt_async(name)
|
||||
# 提取parameters用于注入
|
||||
parameters = kwargs.get("parameters")
|
||||
prompt = await self.get_prompt_async(name, parameters=parameters)
|
||||
result = prompt.format(**kwargs)
|
||||
return result
|
||||
|
||||
@@ -304,11 +258,14 @@ class Prompt:
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
# 构建上下文数据
|
||||
# 1. 构建核心上下文数据
|
||||
context_data = await self._build_context_data()
|
||||
|
||||
# 格式化模板
|
||||
result = await self._format_with_context(context_data)
|
||||
# 2. 格式化主模板
|
||||
main_formatted_prompt = await self._format_with_context(context_data)
|
||||
|
||||
# 3. 拼接组件内容和主模板内容 (逻辑已前置到 get_prompt_async)
|
||||
result = main_formatted_prompt
|
||||
|
||||
total_time = time.time() - start_time
|
||||
logger.debug(f"Prompt构建完成,模式: {self.parameters.prompt_mode}, 耗时: {total_time:.2f}s")
|
||||
@@ -470,9 +427,13 @@ class Prompt:
|
||||
if not self.parameters.message_list_before_now_long:
|
||||
return
|
||||
|
||||
target_user_id = ""
|
||||
if self.parameters.target_user_info:
|
||||
target_user_id = self.parameters.target_user_info.get("user_id") or ""
|
||||
|
||||
read_history_prompt, unread_history_prompt = await self._build_s4u_chat_history_prompts(
|
||||
self.parameters.message_list_before_now_long,
|
||||
self.parameters.target_user_info.get("user_id") if self.parameters.target_user_info else "",
|
||||
target_user_id,
|
||||
self.parameters.sender,
|
||||
self.parameters.chat_id,
|
||||
)
|
||||
@@ -498,11 +459,14 @@ class Prompt:
|
||||
|
||||
# 创建临时生成器实例来使用其方法
|
||||
temp_generator = await get_replyer(None, chat_id, request_type="prompt_building")
|
||||
return await temp_generator.build_s4u_chat_history_prompts(
|
||||
message_list_before_now, target_user_id, sender, chat_id
|
||||
)
|
||||
if temp_generator:
|
||||
return await temp_generator.build_s4u_chat_history_prompts(
|
||||
message_list_before_now, target_user_id, sender, chat_id
|
||||
)
|
||||
return "", ""
|
||||
except Exception as e:
|
||||
logger.error(f"构建S4U历史消息prompt失败: {e}")
|
||||
return "", ""
|
||||
|
||||
async def _build_expression_habits(self) -> dict[str, Any]:
|
||||
"""构建表达习惯"""
|
||||
@@ -589,10 +553,10 @@ class Prompt:
|
||||
running_memories, instant_memory = await asyncio.gather(*memory_tasks, return_exceptions=True)
|
||||
|
||||
# 处理可能的异常结果
|
||||
if isinstance(running_memories, Exception):
|
||||
if isinstance(running_memories, BaseException):
|
||||
logger.warning(f"长期记忆查询失败: {running_memories}")
|
||||
running_memories = []
|
||||
if isinstance(instant_memory, Exception):
|
||||
if isinstance(instant_memory, BaseException):
|
||||
logger.warning(f"即时记忆查询失败: {instant_memory}")
|
||||
instant_memory = None
|
||||
|
||||
@@ -763,20 +727,15 @@ class Prompt:
|
||||
return {"knowledge_prompt": ""}
|
||||
|
||||
try:
|
||||
from src.chat.knowledge.knowledge_lib import QAManager
|
||||
from src.chat.knowledge.knowledge_lib import qa_manager
|
||||
|
||||
# 获取问题文本(当前消息)
|
||||
question = self.parameters.target or ""
|
||||
if not question:
|
||||
if not question or not qa_manager:
|
||||
return {"knowledge_prompt": ""}
|
||||
|
||||
# 创建QA管理器
|
||||
qa_manager = QAManager()
|
||||
|
||||
# 搜索相关知识
|
||||
knowledge_results = await qa_manager.get_knowledge(
|
||||
question=question, chat_id=self.parameters.chat_id, max_results=5, min_similarity=0.5
|
||||
)
|
||||
knowledge_results = await qa_manager.get_knowledge(question=question)
|
||||
|
||||
# 构建知识块
|
||||
if knowledge_results and knowledge_results.get("knowledge_items"):
|
||||
@@ -786,12 +745,17 @@ class Prompt:
|
||||
content = item.get("content", "")
|
||||
source = item.get("source", "")
|
||||
relevance = item.get("relevance", 0.0)
|
||||
|
||||
if content:
|
||||
try:
|
||||
relevance_float = float(relevance)
|
||||
relevance_str = f"{relevance_float:.2f}"
|
||||
except (ValueError, TypeError):
|
||||
relevance_str = str(relevance)
|
||||
|
||||
if source:
|
||||
knowledge_parts.append(f"- [{relevance:.2f}] {content} (来源: {source})")
|
||||
knowledge_parts.append(f"- [{relevance_str}] {content} (来源: {source})")
|
||||
else:
|
||||
knowledge_parts.append(f"- [{relevance:.2f}] {content}")
|
||||
knowledge_parts.append(f"- [{relevance_str}] {content}")
|
||||
|
||||
if knowledge_results.get("summary"):
|
||||
knowledge_parts.append(f"\n知识总结: {knowledge_results['summary']}")
|
||||
@@ -1108,8 +1072,24 @@ def create_prompt(
|
||||
async def create_prompt_async(
|
||||
template: str, name: str | None = None, parameters: PromptParameters | None = None, **kwargs
|
||||
) -> Prompt:
|
||||
"""异步创建Prompt实例"""
|
||||
prompt = create_prompt(template, name, parameters, **kwargs)
|
||||
"""异步创建Prompt实例,并动态注入插件内容"""
|
||||
# 确保有可用的parameters实例
|
||||
final_params = parameters or PromptParameters(**kwargs)
|
||||
|
||||
# 动态注入插件内容
|
||||
if name:
|
||||
components_prefix = await prompt_component_manager.execute_components_for(
|
||||
injection_point=name, params=final_params
|
||||
)
|
||||
if components_prefix:
|
||||
logger.debug(f"为'{name}'注入插件内容: \n{components_prefix}")
|
||||
template = f"{components_prefix}\n\n{template}"
|
||||
|
||||
# 使用可能已修改的模板创建实例
|
||||
prompt = create_prompt(template, name, final_params)
|
||||
|
||||
# 如果在特定上下文中,则异步注册
|
||||
if global_prompt_manager._context._current_context:
|
||||
await global_prompt_manager._context.register_async(prompt)
|
||||
|
||||
return prompt
|
||||
|
||||
109
src/chat/utils/prompt_component_manager.py
Normal file
109
src/chat/utils/prompt_component_manager.py
Normal file
@@ -0,0 +1,109 @@
|
||||
import asyncio
|
||||
from typing import Type
|
||||
|
||||
from src.chat.utils.prompt_params import PromptParameters
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.base.base_prompt import BasePrompt
|
||||
from src.plugin_system.base.component_types import ComponentType, PromptInfo
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
|
||||
logger = get_logger("prompt_component_manager")
|
||||
|
||||
|
||||
class PromptComponentManager:
|
||||
"""
|
||||
管理所有 `BasePrompt` 组件的单例类。
|
||||
|
||||
该管理器负责:
|
||||
1. 从 `component_registry` 中查询 `BasePrompt` 子类。
|
||||
2. 根据注入点(目标Prompt名称)对它们进行筛选。
|
||||
3. 提供一个接口,以便在构建核心Prompt时,能够获取并执行所有相关的组件。
|
||||
"""
|
||||
|
||||
def get_components_for(self, injection_point: str) -> list[Type[BasePrompt]]:
|
||||
"""
|
||||
获取指定注入点的所有已注册组件类。
|
||||
|
||||
Args:
|
||||
injection_point: 目标Prompt的名称。
|
||||
|
||||
Returns:
|
||||
list[Type[BasePrompt]]: 与该注入点关联的组件类列表。
|
||||
"""
|
||||
# 从组件注册中心获取所有启用的Prompt组件
|
||||
enabled_prompts = component_registry.get_enabled_components_by_type(ComponentType.PROMPT)
|
||||
|
||||
matching_components: list[Type[BasePrompt]] = []
|
||||
|
||||
for prompt_name, prompt_info in enabled_prompts.items():
|
||||
# 确保 prompt_info 是 PromptInfo 类型
|
||||
if not isinstance(prompt_info, PromptInfo):
|
||||
continue
|
||||
|
||||
# 获取注入点信息
|
||||
injection_points = prompt_info.injection_point
|
||||
if isinstance(injection_points, str):
|
||||
injection_points = [injection_points]
|
||||
|
||||
# 检查当前注入点是否匹配
|
||||
if injection_point in injection_points:
|
||||
# 获取组件类
|
||||
component_class = component_registry.get_component_class(prompt_name, ComponentType.PROMPT)
|
||||
if component_class and issubclass(component_class, BasePrompt):
|
||||
matching_components.append(component_class)
|
||||
|
||||
return matching_components
|
||||
|
||||
async def execute_components_for(self, injection_point: str, params: PromptParameters) -> str:
|
||||
"""
|
||||
实例化并执行指定注入点的所有组件,然后将它们的输出拼接成一个字符串。
|
||||
|
||||
Args:
|
||||
injection_point: 目标Prompt的名称。
|
||||
params: 用于初始化组件的 PromptParameters 对象。
|
||||
|
||||
Returns:
|
||||
str: 所有相关组件生成的、用换行符连接的文本内容。
|
||||
"""
|
||||
component_classes = self.get_components_for(injection_point)
|
||||
if not component_classes:
|
||||
return ""
|
||||
|
||||
tasks = []
|
||||
for component_class in component_classes:
|
||||
try:
|
||||
# 从注册中心获取组件信息
|
||||
prompt_info = component_registry.get_component_info(
|
||||
component_class.prompt_name, ComponentType.PROMPT
|
||||
)
|
||||
if not isinstance(prompt_info, PromptInfo):
|
||||
logger.warning(f"找不到 Prompt 组件 '{component_class.prompt_name}' 的信息,无法获取插件配置")
|
||||
plugin_config = {}
|
||||
else:
|
||||
plugin_config = component_registry.get_plugin_config(prompt_info.plugin_name)
|
||||
|
||||
instance = component_class(params=params, plugin_config=plugin_config)
|
||||
tasks.append(instance.execute())
|
||||
except Exception as e:
|
||||
logger.error(f"实例化 Prompt 组件 '{component_class.prompt_name}' 失败: {e}")
|
||||
|
||||
if not tasks:
|
||||
return ""
|
||||
|
||||
# 并行执行所有组件
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# 过滤掉执行失败的结果和空字符串
|
||||
valid_results = []
|
||||
for i, result in enumerate(results):
|
||||
if isinstance(result, Exception):
|
||||
logger.error(f"执行 Prompt 组件 '{component_classes[i].prompt_name}' 失败: {result}")
|
||||
elif result and isinstance(result, str) and result.strip():
|
||||
valid_results.append(result.strip())
|
||||
|
||||
# 使用换行符拼接所有有效结果
|
||||
return "\n".join(valid_results)
|
||||
|
||||
|
||||
# 创建全局单例
|
||||
prompt_component_manager = PromptComponentManager()
|
||||
79
src/chat/utils/prompt_params.py
Normal file
79
src/chat/utils/prompt_params.py
Normal file
@@ -0,0 +1,79 @@
|
||||
"""
|
||||
This module contains the PromptParameters class, which is used to define the parameters for a prompt.
|
||||
"""
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Literal
|
||||
|
||||
|
||||
@dataclass
|
||||
class PromptParameters:
|
||||
"""统一提示词参数系统"""
|
||||
|
||||
# 基础参数
|
||||
chat_id: str = ""
|
||||
is_group_chat: bool = False
|
||||
sender: str = ""
|
||||
target: str = ""
|
||||
reply_to: str = ""
|
||||
extra_info: str = ""
|
||||
prompt_mode: Literal["s4u", "normal", "minimal"] = "s4u"
|
||||
bot_name: str = ""
|
||||
bot_nickname: str = ""
|
||||
|
||||
# 功能开关
|
||||
enable_tool: bool = True
|
||||
enable_memory: bool = True
|
||||
enable_expression: bool = True
|
||||
enable_relation: bool = True
|
||||
enable_cross_context: bool = True
|
||||
enable_knowledge: bool = True
|
||||
|
||||
# 性能控制
|
||||
max_context_messages: int = 50
|
||||
|
||||
# 调试选项
|
||||
debug_mode: bool = False
|
||||
|
||||
# 聊天历史和上下文
|
||||
chat_target_info: dict[str, Any] | None = None
|
||||
message_list_before_now_long: list[dict[str, Any]] = field(default_factory=list)
|
||||
message_list_before_short: list[dict[str, Any]] = field(default_factory=list)
|
||||
chat_talking_prompt_short: str = ""
|
||||
target_user_info: dict[str, Any] | None = None
|
||||
|
||||
# 已构建的内容块
|
||||
expression_habits_block: str = ""
|
||||
relation_info_block: str = ""
|
||||
memory_block: str = ""
|
||||
tool_info_block: str = ""
|
||||
knowledge_prompt: str = ""
|
||||
cross_context_block: str = ""
|
||||
|
||||
# 其他内容块
|
||||
keywords_reaction_prompt: str = ""
|
||||
extra_info_block: str = ""
|
||||
time_block: str = ""
|
||||
identity_block: str = ""
|
||||
schedule_block: str = ""
|
||||
moderation_prompt_block: str = ""
|
||||
safety_guidelines_block: str = ""
|
||||
reply_target_block: str = ""
|
||||
mood_prompt: str = ""
|
||||
action_descriptions: str = ""
|
||||
|
||||
# 可用动作信息
|
||||
available_actions: dict[str, Any] | None = None
|
||||
|
||||
# 动态生成的聊天场景提示
|
||||
chat_scene: str = ""
|
||||
|
||||
def validate(self) -> list[str]:
|
||||
"""参数验证"""
|
||||
errors = []
|
||||
if not self.chat_id:
|
||||
errors.append("chat_id不能为空")
|
||||
if self.prompt_mode not in ["s4u", "normal", "minimal"]:
|
||||
errors.append("prompt_mode必须是's4u'、'normal'或'minimal'")
|
||||
if self.max_context_messages <= 0:
|
||||
errors.append("max_context_messages必须大于0")
|
||||
return errors
|
||||
@@ -298,14 +298,14 @@ def random_remove_punctuation(text: str) -> str:
|
||||
def protect_special_blocks(text: str) -> tuple[str, dict[str, str]]:
|
||||
"""识别并保护数学公式和代码块,返回处理后的文本和映射"""
|
||||
placeholder_map = {}
|
||||
|
||||
|
||||
# 第一层防护:优先保护标准Markdown格式
|
||||
# 使用 re.S 来让 . 匹配换行符
|
||||
markdown_patterns = {
|
||||
'code': r"```.*?```",
|
||||
'math': r"\$\$.*?\$\$",
|
||||
"code": r"```.*?```",
|
||||
"math": r"\$\$.*?\$\$",
|
||||
}
|
||||
|
||||
|
||||
placeholder_idx = 0
|
||||
for block_type, pattern in markdown_patterns.items():
|
||||
matches = re.findall(pattern, text, re.S)
|
||||
@@ -318,7 +318,7 @@ def protect_special_blocks(text: str) -> tuple[str, dict[str, str]]:
|
||||
# 第二层防护:保护非标准的、可能是公式或代码的片段
|
||||
# 这个正则表达式寻找连续5个以上的、主要由非中文字符组成的片段
|
||||
general_pattern = r"(?:[a-zA-Z0-9\s.,;:(){}\[\]_+\-*/=<>^|&%?!'\"√²³ⁿ∑∫≠≥≤]){5,}"
|
||||
|
||||
|
||||
# 为了避免与已保护的占位符冲突,我们在剩余的文本上进行查找
|
||||
# 这是一个简化的处理,更稳妥的方式是分段查找,但目前这样足以应对多数情况
|
||||
try:
|
||||
@@ -327,7 +327,7 @@ def protect_special_blocks(text: str) -> tuple[str, dict[str, str]]:
|
||||
# 避免将包含占位符的片段再次保护
|
||||
if "__SPECIAL_" in match:
|
||||
continue
|
||||
|
||||
|
||||
placeholder = f"__SPECIAL_GENERAL_{placeholder_idx}__"
|
||||
text = text.replace(match, placeholder, 1)
|
||||
placeholder_map[placeholder] = match
|
||||
@@ -352,23 +352,23 @@ def protect_quoted_content(text: str) -> tuple[str, dict[str, str]]:
|
||||
placeholder_map = {}
|
||||
# 匹配中英文单双引号,使用非贪婪模式
|
||||
quote_pattern = re.compile(r'(".*?")|(\'.*?\')|(“.*?”)|(‘.*?’)')
|
||||
|
||||
|
||||
matches = quote_pattern.finditer(text)
|
||||
|
||||
|
||||
# 为了避免替换时索引错乱,我们从后往前替换
|
||||
# finditer 找到的是 match 对象,我们需要转换为 list 来反转
|
||||
match_list = list(matches)
|
||||
|
||||
|
||||
for idx, match in enumerate(reversed(match_list)):
|
||||
original_quoted_text = match.group(0)
|
||||
placeholder = f"__QUOTE_{len(match_list) - 1 - idx}__"
|
||||
|
||||
|
||||
# 直接在原始文本上操作,替换 match 对象的 span
|
||||
start, end = match.span()
|
||||
text = text[:start] + placeholder + text[end:]
|
||||
|
||||
|
||||
placeholder_map[placeholder] = original_quoted_text
|
||||
|
||||
|
||||
return text, placeholder_map
|
||||
|
||||
|
||||
@@ -389,13 +389,13 @@ def process_llm_response(text: str, enable_splitter: bool = True, enable_chinese
|
||||
# --- 三层防护系统 ---
|
||||
# 第一层:保护颜文字
|
||||
protected_text, kaomoji_mapping = protect_kaomoji(text) if global_config.response_splitter.enable_kaomoji_protection else (text, {})
|
||||
|
||||
|
||||
# 第二层:保护引号内容
|
||||
protected_text, quote_mapping = protect_quoted_content(protected_text)
|
||||
|
||||
# 第三层:保护数学公式和代码块
|
||||
protected_text, special_blocks_mapping = protect_special_blocks(protected_text)
|
||||
|
||||
|
||||
# 提取被 () 或 [] 或 ()包裹且包含中文的内容
|
||||
pattern = re.compile(r"[(\[(](?=.*[一-鿿]).*?[)\])]")
|
||||
_extracted_contents = pattern.findall(protected_text)
|
||||
@@ -412,7 +412,7 @@ def process_llm_response(text: str, enable_splitter: bool = True, enable_chinese
|
||||
|
||||
# 对清理后的文本进行进一步处理
|
||||
max_sentence_num = global_config.response_splitter.max_sentence_num
|
||||
|
||||
|
||||
# --- 移除总长度检查 ---
|
||||
# 原有的总长度检查会导致长回复被直接丢弃,现已移除,由后续的智能合并逻辑处理。
|
||||
# max_length = global_config.response_splitter.max_length * 2
|
||||
@@ -472,7 +472,7 @@ def process_llm_response(text: str, enable_splitter: bool = True, enable_chinese
|
||||
break
|
||||
|
||||
# 寻找最短的相邻句子对
|
||||
min_len = float('inf')
|
||||
min_len = float("inf")
|
||||
merge_idx = -1
|
||||
for i in range(len(sentences) - 1):
|
||||
combined_len = len(sentences[i]) + len(sentences[i+1])
|
||||
@@ -488,7 +488,7 @@ def process_llm_response(text: str, enable_splitter: bool = True, enable_chinese
|
||||
sentences[merge_idx] = merged_sentence
|
||||
# 删除后一个句子
|
||||
del sentences[merge_idx + 1]
|
||||
|
||||
|
||||
logger.info(f"智能合并完成,最终消息数量: {len(sentences)}")
|
||||
|
||||
# if extracted_contents:
|
||||
|
||||
@@ -79,7 +79,7 @@ class Server:
|
||||
logger.warning(f"端口 {self.port} 已被占用,正在尝试下一个端口...")
|
||||
self.port += 1
|
||||
|
||||
logger.info(f"将在 http://{self.host}:{self.port} 上启动服务器")
|
||||
logger.info(f"将在 {self.host}:{self.port} 上启动服务器")
|
||||
# 禁用 uvicorn 默认日志和访问日志
|
||||
config = Config(app=self.app, host=self.host, port=self.port, log_config=None, access_log=False)
|
||||
self._server = UvicornServer(config=config)
|
||||
|
||||
@@ -7,7 +7,7 @@ from src.config.config_base import ValidatedConfigBase
|
||||
"""
|
||||
须知:
|
||||
1. 本文件中记录了所有的配置项
|
||||
2. 重要的配置类继承自ValidatedConfigBase进行Pydantic验证
|
||||
2. 所有配置类必须继承自ValidatedConfigBase进行Pydantic验证
|
||||
3. 所有新增的class都应在config.py中的Config类中添加字段
|
||||
4. 对于新增的字段,若为可选项,则应在其后添加field()并设置default_factory或default
|
||||
"""
|
||||
@@ -492,6 +492,7 @@ class LPMMKnowledgeConfig(ValidatedConfigBase):
|
||||
info_extraction_workers: int = Field(default=3, description="信息提取工作线程数")
|
||||
qa_relation_search_top_k: int = Field(default=10, description="QA关系搜索Top K")
|
||||
qa_relation_threshold: float = Field(default=0.75, description="QA关系阈值")
|
||||
qa_paragraph_threshold: float = Field(default=0.3, description="QA段落阈值")
|
||||
qa_paragraph_search_top_k: int = Field(default=1000, description="QA段落搜索Top K")
|
||||
qa_paragraph_node_weight: float = Field(default=0.05, description="QA段落节点权重")
|
||||
qa_ent_filter_top_k: int = Field(default=10, description="QA实体过滤Top K")
|
||||
|
||||
@@ -13,6 +13,7 @@ from rich.traceback import install
|
||||
|
||||
from src.chat.emoji_system.emoji_manager import get_emoji_manager
|
||||
from src.chat.memory_system.memory_manager import memory_manager
|
||||
from src.chat.message_manager.sleep_system.tasks import start_sleep_system_tasks
|
||||
from src.chat.message_receive.bot import chat_bot
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.utils.statistic import OnlineTimeRecordTask, StatisticOutputTask
|
||||
@@ -29,7 +30,6 @@ from src.plugin_system.core.event_manager import event_manager
|
||||
from src.plugin_system.core.plugin_manager import plugin_manager
|
||||
from src.schedule.monthly_plan_manager import monthly_plan_manager
|
||||
from src.schedule.schedule_manager import schedule_manager
|
||||
from src.chat.message_manager.sleep_system.tasks import start_sleep_system_tasks
|
||||
|
||||
# 插件系统现在使用统一的插件加载器
|
||||
install(extra_lines=3)
|
||||
|
||||
@@ -26,6 +26,7 @@ from .base import (
|
||||
ActionInfo,
|
||||
BaseAction,
|
||||
BaseCommand,
|
||||
BasePrompt,
|
||||
BaseEventHandler,
|
||||
BasePlugin,
|
||||
BaseTool,
|
||||
@@ -64,6 +65,7 @@ __all__ = [
|
||||
"BaseEventHandler",
|
||||
# 基础类
|
||||
"BasePlugin",
|
||||
"BasePrompt",
|
||||
"BaseTool",
|
||||
"ChatMode",
|
||||
"ChatType",
|
||||
|
||||
@@ -8,6 +8,7 @@ from .base_action import BaseAction
|
||||
from .base_command import BaseCommand
|
||||
from .base_events_handler import BaseEventHandler
|
||||
from .base_plugin import BasePlugin
|
||||
from .base_prompt import BasePrompt
|
||||
from .base_tool import BaseTool
|
||||
from .command_args import CommandArgs
|
||||
from .component_types import (
|
||||
@@ -37,6 +38,7 @@ __all__ = [
|
||||
"BaseCommand",
|
||||
"BaseEventHandler",
|
||||
"BasePlugin",
|
||||
"BasePrompt",
|
||||
"BaseTool",
|
||||
"ChatMode",
|
||||
"ChatType",
|
||||
|
||||
@@ -615,15 +615,15 @@ class BaseAction(ABC):
|
||||
"""
|
||||
# 尝试从不同的实例属性中获取聊天内容
|
||||
# 优先级:_activation_chat_content > action_data['chat_content'] > ""
|
||||
|
||||
|
||||
# 1. 如果有专门设置的激活用聊天内容(由 ActionModifier 设置)
|
||||
if hasattr(self, '_activation_chat_content'):
|
||||
return getattr(self, '_activation_chat_content', "")
|
||||
|
||||
if hasattr(self, "_activation_chat_content"):
|
||||
return getattr(self, "_activation_chat_content", "")
|
||||
|
||||
# 2. 尝试从 action_data 中获取
|
||||
if hasattr(self, 'action_data') and isinstance(self.action_data, dict):
|
||||
return self.action_data.get('chat_content', "")
|
||||
|
||||
if hasattr(self, "action_data") and isinstance(self.action_data, dict):
|
||||
return self.action_data.get("chat_content", "")
|
||||
|
||||
# 3. 默认返回空字符串
|
||||
return ""
|
||||
|
||||
@@ -729,7 +729,7 @@ class BaseAction(ABC):
|
||||
|
||||
# 自动获取聊天内容
|
||||
chat_content = self._get_chat_content()
|
||||
|
||||
|
||||
search_text = chat_content
|
||||
if not case_sensitive:
|
||||
search_text = search_text.lower()
|
||||
@@ -786,7 +786,7 @@ class BaseAction(ABC):
|
||||
try:
|
||||
# 自动获取聊天内容
|
||||
chat_content = self._get_chat_content()
|
||||
|
||||
|
||||
# 如果没有提供 LLM 模型,创建一个默认的
|
||||
if llm_judge_model is None:
|
||||
from src.config.config import model_config
|
||||
|
||||
@@ -8,6 +8,7 @@ from src.plugin_system.base.component_types import (
|
||||
EventHandlerInfo,
|
||||
InterestCalculatorInfo,
|
||||
PlusCommandInfo,
|
||||
PromptInfo,
|
||||
ToolInfo,
|
||||
)
|
||||
|
||||
@@ -15,6 +16,7 @@ from .base_action import BaseAction
|
||||
from .base_command import BaseCommand
|
||||
from .base_events_handler import BaseEventHandler
|
||||
from .base_interest_calculator import BaseInterestCalculator
|
||||
from .base_prompt import BasePrompt
|
||||
from .base_tool import BaseTool
|
||||
from .plugin_base import PluginBase
|
||||
from .plus_command import PlusCommand
|
||||
@@ -80,6 +82,13 @@ class BasePlugin(PluginBase):
|
||||
logger.warning("EventHandler的get_info逻辑尚未实现")
|
||||
return None
|
||||
|
||||
elif component_type == ComponentType.PROMPT:
|
||||
if hasattr(component_class, "get_prompt_info"):
|
||||
return component_class.get_prompt_info()
|
||||
else:
|
||||
logger.warning(f"Prompt类 {component_class.__name__} 缺少 get_prompt_info 方法")
|
||||
return None
|
||||
|
||||
else:
|
||||
logger.error(f"不支持的组件类型: {component_type}")
|
||||
return None
|
||||
@@ -109,6 +118,7 @@ class BasePlugin(PluginBase):
|
||||
| tuple[EventHandlerInfo, type[BaseEventHandler]]
|
||||
| tuple[ToolInfo, type[BaseTool]]
|
||||
| tuple[InterestCalculatorInfo, type[BaseInterestCalculator]]
|
||||
| tuple[PromptInfo, type[BasePrompt]]
|
||||
]:
|
||||
"""获取插件包含的组件列表
|
||||
|
||||
|
||||
95
src/plugin_system/base/base_prompt.py
Normal file
95
src/plugin_system/base/base_prompt.py
Normal file
@@ -0,0 +1,95 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
from src.chat.utils.prompt_params import PromptParameters
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.base.component_types import ComponentType, PromptInfo
|
||||
|
||||
logger = get_logger("base_prompt")
|
||||
|
||||
|
||||
class BasePrompt(ABC):
|
||||
"""Prompt组件基类
|
||||
|
||||
Prompt是插件的一种组件类型,用于动态地向现有的核心Prompt模板中注入额外的上下文信息。
|
||||
它的主要作用是在不修改核心代码的情况下,扩展和定制模型的行为。
|
||||
|
||||
子类可以通过类属性定义其行为:
|
||||
- prompt_name: Prompt组件的唯一名称。
|
||||
- injection_point: 指定要注入的目标Prompt名称(或名称列表)。
|
||||
"""
|
||||
|
||||
prompt_name: str = ""
|
||||
"""Prompt组件的名称"""
|
||||
prompt_description: str = ""
|
||||
"""Prompt组件的描述"""
|
||||
|
||||
# 定义此组件希望注入到哪个或哪些核心Prompt中
|
||||
# 可以是一个字符串(单个目标)或字符串列表(多个目标)
|
||||
# 例如: "planner_prompt" 或 ["s4u_style_prompt", "normal_style_prompt"]
|
||||
injection_point: str | list[str] = ""
|
||||
"""要注入的目标Prompt名称或列表"""
|
||||
|
||||
def __init__(self, params: PromptParameters, plugin_config: dict | None = None):
|
||||
"""初始化Prompt组件
|
||||
|
||||
Args:
|
||||
params: 统一提示词参数,包含所有构建提示词所需的上下文信息。
|
||||
plugin_config: 插件配置字典。
|
||||
"""
|
||||
self.params = params
|
||||
self.plugin_config = plugin_config or {}
|
||||
self.log_prefix = "[PromptComponent]"
|
||||
|
||||
logger.debug(f"{self.log_prefix} Prompt组件 '{self.prompt_name}' 初始化完成")
|
||||
|
||||
@abstractmethod
|
||||
async def execute(self) -> str:
|
||||
"""执行Prompt生成的抽象方法,子类必须实现。
|
||||
|
||||
此方法应根据初始化时传入的 `self.params` 来构建并返回一个字符串。
|
||||
返回的字符串将被拼接到目标Prompt的最前面。
|
||||
|
||||
Returns:
|
||||
str: 生成的文本内容。
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_config(self, key: str, default: Any = None) -> Any:
|
||||
"""获取插件配置值,支持嵌套键访问。
|
||||
|
||||
Args:
|
||||
key: 配置键名,使用点号进行嵌套访问,如 "section.subsection.key"。
|
||||
default: 未找到键时返回的默认值。
|
||||
|
||||
Returns:
|
||||
Any: 配置值或默认值。
|
||||
"""
|
||||
if not self.plugin_config:
|
||||
return default
|
||||
|
||||
keys = key.split(".")
|
||||
current = self.plugin_config
|
||||
for k in keys:
|
||||
if isinstance(current, dict) and k in current:
|
||||
current = current[k]
|
||||
else:
|
||||
return default
|
||||
return current
|
||||
|
||||
@classmethod
|
||||
def get_prompt_info(cls) -> "PromptInfo":
|
||||
"""从类属性生成PromptInfo,用于组件注册和管理。
|
||||
|
||||
Returns:
|
||||
PromptInfo: 生成的Prompt信息对象。
|
||||
"""
|
||||
if not cls.prompt_name:
|
||||
raise ValueError("Prompt组件必须定义 'prompt_name' 类属性。")
|
||||
|
||||
return PromptInfo(
|
||||
name=cls.prompt_name,
|
||||
component_type=ComponentType.PROMPT,
|
||||
description=cls.prompt_description,
|
||||
injection_point=cls.injection_point,
|
||||
)
|
||||
@@ -20,6 +20,7 @@ class ComponentType(Enum):
|
||||
EVENT_HANDLER = "event_handler" # 事件处理组件
|
||||
CHATTER = "chatter" # 聊天处理器组件
|
||||
INTEREST_CALCULATOR = "interest_calculator" # 兴趣度计算组件
|
||||
PROMPT = "prompt" # Prompt组件
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.value
|
||||
@@ -143,7 +144,7 @@ class ActionInfo(ComponentInfo):
|
||||
) # 动作参数与描述,例如 {"param1": "描述1", "param2": "描述2"}
|
||||
action_require: list[str] = field(default_factory=list) # 动作需求说明
|
||||
associated_types: list[str] = field(default_factory=list) # 关联的消息类型
|
||||
|
||||
|
||||
# ==================================================================================
|
||||
# 激活类型相关字段(已废弃,建议使用 go_activate() 方法)
|
||||
# 保留这些字段是为了向后兼容,BaseAction.go_activate() 的默认实现会使用这些字段
|
||||
@@ -155,7 +156,7 @@ class ActionInfo(ComponentInfo):
|
||||
llm_judge_prompt: str = "" # 已废弃,建议在 go_activate() 中使用 _llm_judge_activation()
|
||||
activation_keywords: list[str] = field(default_factory=list) # 已废弃,建议在 go_activate() 中使用 _keyword_match()
|
||||
keyword_case_sensitive: bool = False # 已废弃
|
||||
|
||||
|
||||
# 模式和并行设置
|
||||
mode_enable: ChatMode = ChatMode.ALL
|
||||
parallel_action: bool = False
|
||||
@@ -266,6 +267,18 @@ class EventInfo(ComponentInfo):
|
||||
self.component_type = ComponentType.EVENT_HANDLER
|
||||
|
||||
|
||||
@dataclass
|
||||
class PromptInfo(ComponentInfo):
|
||||
"""Prompt组件信息"""
|
||||
|
||||
injection_point: str | list[str] = ""
|
||||
"""要注入的目标Prompt名称或列表"""
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
self.component_type = ComponentType.PROMPT
|
||||
|
||||
|
||||
@dataclass
|
||||
class PluginInfo:
|
||||
"""插件信息"""
|
||||
|
||||
@@ -11,6 +11,7 @@ from src.plugin_system.base.base_chatter import BaseChatter
|
||||
from src.plugin_system.base.base_command import BaseCommand
|
||||
from src.plugin_system.base.base_events_handler import BaseEventHandler
|
||||
from src.plugin_system.base.base_interest_calculator import BaseInterestCalculator
|
||||
from src.plugin_system.base.base_prompt import BasePrompt
|
||||
from src.plugin_system.base.base_tool import BaseTool
|
||||
from src.plugin_system.base.component_types import (
|
||||
ActionInfo,
|
||||
@@ -22,6 +23,7 @@ from src.plugin_system.base.component_types import (
|
||||
InterestCalculatorInfo,
|
||||
PluginInfo,
|
||||
PlusCommandInfo,
|
||||
PromptInfo,
|
||||
ToolInfo,
|
||||
)
|
||||
from src.plugin_system.base.plus_command import PlusCommand
|
||||
@@ -37,6 +39,7 @@ ComponentClassType = (
|
||||
| type[PlusCommand]
|
||||
| type[BaseChatter]
|
||||
| type[BaseInterestCalculator]
|
||||
| type[BasePrompt]
|
||||
)
|
||||
|
||||
|
||||
@@ -183,6 +186,10 @@ class ComponentRegistry:
|
||||
assert isinstance(component_info, InterestCalculatorInfo)
|
||||
assert issubclass(component_class, BaseInterestCalculator)
|
||||
ret = self._register_interest_calculator_component(component_info, component_class)
|
||||
case ComponentType.PROMPT:
|
||||
assert isinstance(component_info, PromptInfo)
|
||||
assert issubclass(component_class, BasePrompt)
|
||||
ret = self._register_prompt_component(component_info, component_class)
|
||||
case _:
|
||||
logger.warning(f"未知组件类型: {component_type}")
|
||||
ret = False
|
||||
@@ -346,6 +353,31 @@ class ComponentRegistry:
|
||||
logger.debug(f"已注册InterestCalculator组件: {calculator_name}")
|
||||
return True
|
||||
|
||||
def _register_prompt_component(
|
||||
self, prompt_info: PromptInfo, prompt_class: "ComponentClassType"
|
||||
) -> bool:
|
||||
"""注册Prompt组件到Prompt特定注册表"""
|
||||
prompt_name = prompt_info.name
|
||||
if not prompt_name:
|
||||
logger.error(f"Prompt组件 {prompt_class.__name__} 必须指定名称")
|
||||
return False
|
||||
|
||||
if not hasattr(self, "_prompt_registry"):
|
||||
self._prompt_registry: dict[str, type[BasePrompt]] = {}
|
||||
if not hasattr(self, "_enabled_prompt_registry"):
|
||||
self._enabled_prompt_registry: dict[str, type[BasePrompt]] = {}
|
||||
|
||||
_assign_plugin_attrs(
|
||||
prompt_class, prompt_info.plugin_name, self.get_plugin_config(prompt_info.plugin_name) or {}
|
||||
)
|
||||
self._prompt_registry[prompt_name] = prompt_class # type: ignore
|
||||
|
||||
if prompt_info.enabled:
|
||||
self._enabled_prompt_registry[prompt_name] = prompt_class # type: ignore
|
||||
|
||||
logger.debug(f"已注册Prompt组件: {prompt_name}")
|
||||
return True
|
||||
|
||||
# === 组件移除相关 ===
|
||||
|
||||
async def remove_component(self, component_name: str, component_type: ComponentType, plugin_name: str) -> bool:
|
||||
@@ -580,7 +612,17 @@ class ComponentRegistry:
|
||||
component_name: str,
|
||||
component_type: ComponentType | None = None,
|
||||
) -> (
|
||||
type[BaseCommand | BaseAction | BaseEventHandler | BaseTool | PlusCommand | BaseChatter | BaseInterestCalculator] | None
|
||||
type[
|
||||
BaseCommand
|
||||
| BaseAction
|
||||
| BaseEventHandler
|
||||
| BaseTool
|
||||
| PlusCommand
|
||||
| BaseChatter
|
||||
| BaseInterestCalculator
|
||||
| BasePrompt
|
||||
]
|
||||
| None
|
||||
):
|
||||
"""获取组件类,支持自动命名空间解析
|
||||
|
||||
@@ -829,6 +871,7 @@ class ComponentRegistry:
|
||||
events_handlers: int = 0
|
||||
plus_command_components: int = 0
|
||||
chatter_components: int = 0
|
||||
prompt_components: int = 0
|
||||
for component in self._components.values():
|
||||
if component.component_type == ComponentType.ACTION:
|
||||
action_components += 1
|
||||
@@ -842,6 +885,8 @@ class ComponentRegistry:
|
||||
plus_command_components += 1
|
||||
elif component.component_type == ComponentType.CHATTER:
|
||||
chatter_components += 1
|
||||
elif component.component_type == ComponentType.PROMPT:
|
||||
prompt_components += 1
|
||||
return {
|
||||
"action_components": action_components,
|
||||
"command_components": command_components,
|
||||
@@ -849,6 +894,7 @@ class ComponentRegistry:
|
||||
"event_handlers": events_handlers,
|
||||
"plus_command_components": plus_command_components,
|
||||
"chatter_components": chatter_components,
|
||||
"prompt_components": prompt_components,
|
||||
"total_components": len(self._components),
|
||||
"total_plugins": len(self._plugins),
|
||||
"components_by_type": {
|
||||
|
||||
@@ -358,13 +358,14 @@ class PluginManager:
|
||||
event_handler_count = stats.get("event_handlers", 0)
|
||||
plus_command_count = stats.get("plus_command_components", 0)
|
||||
chatter_count = stats.get("chatter_components", 0)
|
||||
prompt_count = stats.get("prompt_components", 0)
|
||||
total_components = stats.get("total_components", 0)
|
||||
|
||||
# 📋 显示插件加载总览
|
||||
if total_registered > 0:
|
||||
logger.info("🎉 插件系统加载完成!")
|
||||
logger.info(
|
||||
f"📊 总览: {total_registered}个插件, {total_components}个组件 (Action: {action_count}, Command: {command_count}, Tool: {tool_count}, PlusCommand: {plus_command_count}, EventHandler: {event_handler_count}, Chatter: {chatter_count})"
|
||||
f"📊 总览: {total_registered}个插件, {total_components}个组件 (Action: {action_count}, Command: {command_count}, Tool: {tool_count}, PlusCommand: {plus_command_count}, EventHandler: {event_handler_count}, Chatter: {chatter_count}, Prompt: {prompt_count})"
|
||||
)
|
||||
|
||||
# 显示详细的插件列表
|
||||
@@ -382,6 +383,13 @@ class PluginManager:
|
||||
|
||||
# 组件列表
|
||||
if plugin_info.components:
|
||||
|
||||
def format_component(c):
|
||||
desc = c.description
|
||||
if len(desc) > 15:
|
||||
desc = desc[:15] + "..."
|
||||
return f"{c.name} ({desc})" if desc else c.name
|
||||
|
||||
action_components = [
|
||||
c for c in plugin_info.components if c.component_type == ComponentType.ACTION
|
||||
]
|
||||
@@ -395,29 +403,35 @@ class PluginManager:
|
||||
plus_command_components = [
|
||||
c for c in plugin_info.components if c.component_type == ComponentType.PLUS_COMMAND
|
||||
]
|
||||
prompt_components = [
|
||||
c for c in plugin_info.components if c.component_type == ComponentType.PROMPT
|
||||
]
|
||||
|
||||
if action_components:
|
||||
action_names = [c.name for c in action_components]
|
||||
logger.info(f" 🎯 Action组件: {', '.join(action_names)}")
|
||||
action_details = [format_component(c) for c in action_components]
|
||||
logger.info(f" 🎯 Action组件: {', '.join(action_details)}")
|
||||
|
||||
if command_components:
|
||||
command_names = [c.name for c in command_components]
|
||||
logger.info(f" ⚡ Command组件: {', '.join(command_names)}")
|
||||
command_details = [format_component(c) for c in command_components]
|
||||
logger.info(f" ⚡ Command组件: {', '.join(command_details)}")
|
||||
if tool_components:
|
||||
tool_names = [c.name for c in tool_components]
|
||||
logger.info(f" 🛠️ Tool组件: {', '.join(tool_names)}")
|
||||
tool_details = [format_component(c) for c in tool_components]
|
||||
logger.info(f" 🛠️ Tool组件: {', '.join(tool_details)}")
|
||||
if plus_command_components:
|
||||
plus_command_names = [c.name for c in plus_command_components]
|
||||
logger.info(f" ⚡ PlusCommand组件: {', '.join(plus_command_names)}")
|
||||
plus_command_details = [format_component(c) for c in plus_command_components]
|
||||
logger.info(f" ⚡ PlusCommand组件: {', '.join(plus_command_details)}")
|
||||
chatter_components = [
|
||||
c for c in plugin_info.components if c.component_type == ComponentType.CHATTER
|
||||
]
|
||||
if chatter_components:
|
||||
chatter_names = [c.name for c in chatter_components]
|
||||
logger.info(f" 🗣️ Chatter组件: {', '.join(chatter_names)}")
|
||||
chatter_details = [format_component(c) for c in chatter_components]
|
||||
logger.info(f" 🗣️ Chatter组件: {', '.join(chatter_details)}")
|
||||
if event_handler_components:
|
||||
event_handler_names = [c.name for c in event_handler_components]
|
||||
logger.info(f" 📢 EventHandler组件: {', '.join(event_handler_names)}")
|
||||
event_handler_details = [format_component(c) for c in event_handler_components]
|
||||
logger.info(f" 📢 EventHandler组件: {', '.join(event_handler_details)}")
|
||||
if prompt_components:
|
||||
prompt_details = [format_component(c) for c in prompt_components]
|
||||
logger.info(f" 📝 Prompt组件: {', '.join(prompt_details)}")
|
||||
|
||||
# 权限节点信息
|
||||
if plugin_instance := self.loaded_plugins.get(plugin_name):
|
||||
|
||||
@@ -155,88 +155,22 @@ class ChatterPlanFilter:
|
||||
identity_block = f"你的名字是{bot_name}{bot_nickname},你{bot_core_personality}:"
|
||||
|
||||
schedule_block = ""
|
||||
# 优先检查是否被吵醒
|
||||
from src.chat.message_manager.message_manager import message_manager
|
||||
|
||||
angry_prompt_addition = ""
|
||||
try:
|
||||
from src.plugins.built_in.sleep_system.api import get_wakeup_manager
|
||||
wakeup_mgr = get_wakeup_manager()
|
||||
except ImportError:
|
||||
logger.debug("无法导入睡眠系统API,将跳过相关检查。")
|
||||
wakeup_mgr = None
|
||||
|
||||
if wakeup_mgr:
|
||||
|
||||
# 双重检查确保愤怒状态不会丢失
|
||||
# 检查1: 直接从 wakeup_manager 获取
|
||||
if wakeup_mgr.is_in_angry_state():
|
||||
angry_prompt_addition = wakeup_mgr.get_angry_prompt_addition()
|
||||
|
||||
# 检查2: 如果上面没获取到,再从 mood_manager 确认
|
||||
if not angry_prompt_addition:
|
||||
chat_mood_for_check = mood_manager.get_mood_by_chat_id(plan.chat_id)
|
||||
if chat_mood_for_check.is_angry_from_wakeup:
|
||||
angry_prompt_addition = global_config.sleep_system.angry_prompt
|
||||
|
||||
if angry_prompt_addition:
|
||||
schedule_block = angry_prompt_addition
|
||||
elif global_config.planning_system.schedule_enable:
|
||||
if global_config.planning_system.schedule_enable:
|
||||
if activity_info := schedule_manager.get_current_activity():
|
||||
activity = activity_info.get("activity", "未知活动")
|
||||
schedule_block = f"你当前正在:{activity},但注意它与群聊的聊天无关。"
|
||||
|
||||
mood_block = ""
|
||||
# 如果被吵醒,则心情也是愤怒的,不需要另外的情绪模块
|
||||
if not angry_prompt_addition and global_config.mood.enable_mood:
|
||||
# 需要情绪模块打开才能获得情绪,否则会引发报错
|
||||
if global_config.mood.enable_mood:
|
||||
chat_mood = mood_manager.get_mood_by_chat_id(plan.chat_id)
|
||||
mood_block = f"你现在的心情是:{chat_mood.mood_state}"
|
||||
|
||||
if plan.mode == ChatMode.PROACTIVE:
|
||||
long_term_memory_block = await self._get_long_term_memory_context()
|
||||
|
||||
chat_content_block, message_id_list = await build_readable_messages_with_id(
|
||||
messages=[msg.flatten() for msg in plan.chat_history],
|
||||
timestamp_mode="normal",
|
||||
truncate=False,
|
||||
show_actions=False,
|
||||
)
|
||||
|
||||
prompt_template = await global_prompt_manager.get_prompt_async("proactive_planner_prompt")
|
||||
actions_before_now = await get_actions_by_timestamp_with_chat(
|
||||
chat_id=plan.chat_id,
|
||||
timestamp_start=time.time() - 3600,
|
||||
timestamp_end=time.time(),
|
||||
limit=5,
|
||||
)
|
||||
actions_before_now_block = build_readable_actions(actions=actions_before_now)
|
||||
actions_before_now_block = f"你刚刚选择并执行过的action是:\n{actions_before_now_block}"
|
||||
|
||||
prompt = prompt_template.format(
|
||||
time_block=time_block,
|
||||
identity_block=identity_block,
|
||||
schedule_block=schedule_block,
|
||||
mood_block=mood_block,
|
||||
long_term_memory_block=long_term_memory_block,
|
||||
chat_content_block=chat_content_block or "最近没有聊天内容。",
|
||||
actions_before_now_block=actions_before_now_block,
|
||||
)
|
||||
return prompt, message_id_list
|
||||
|
||||
# 构建已读/未读历史消息
|
||||
read_history_block, unread_history_block, message_id_list = await self._build_read_unread_history_blocks(
|
||||
plan
|
||||
)
|
||||
|
||||
# 为了兼容性,保留原有的chat_content_block
|
||||
chat_content_block, _ = await build_readable_messages_with_id(
|
||||
messages=[msg.flatten() for msg in plan.chat_history],
|
||||
timestamp_mode="normal",
|
||||
read_mark=self.last_obs_time_mark,
|
||||
truncate=True,
|
||||
show_actions=True,
|
||||
)
|
||||
|
||||
actions_before_now = await get_actions_by_timestamp_with_chat(
|
||||
chat_id=plan.chat_id,
|
||||
timestamp_start=time.time() - 3600,
|
||||
@@ -286,7 +220,7 @@ class ChatterPlanFilter:
|
||||
is_group_chat = plan.chat_type == ChatType.GROUP
|
||||
chat_context_description = "你现在正在一个群聊中"
|
||||
if not is_group_chat and plan.target_info:
|
||||
chat_target_name = plan.target_info.get("person_name") or plan.target_info.get("user_nickname") or "对方"
|
||||
chat_target_name = plan.target_info.person_name or plan.target_info.user_nickname or "对方"
|
||||
chat_context_description = f"你正在和 {chat_target_name} 私聊"
|
||||
|
||||
action_options_block = await self._build_action_options(plan.available_actions)
|
||||
|
||||
@@ -9,7 +9,7 @@ from src.chat.utils.utils import get_chat_type_and_target_info
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.common.data_models.info_data_model import Plan, TargetPersonInfo
|
||||
from src.config.config import global_config
|
||||
from src.plugin_system.base.component_types import ActionInfo, ChatMode, ChatType
|
||||
from src.plugin_system.base.component_types import ActionInfo, ChatMode, ChatType, ComponentType
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
|
||||
|
||||
@@ -55,6 +55,11 @@ class ChatterPlanGenerator:
|
||||
try:
|
||||
# 获取聊天类型和目标信息
|
||||
chat_type, target_info = await get_chat_type_and_target_info(self.chat_id)
|
||||
if chat_type:
|
||||
chat_type = ChatType.GROUP
|
||||
else:
|
||||
#遇到未知类型也当私聊处理
|
||||
chat_type = ChatType.PRIVATE
|
||||
|
||||
# 获取可用动作列表
|
||||
available_actions = await self._get_available_actions(chat_type, mode)
|
||||
@@ -62,12 +67,16 @@ class ChatterPlanGenerator:
|
||||
# 获取聊天历史记录
|
||||
recent_messages = await self._get_recent_messages()
|
||||
|
||||
# 构建计划对象
|
||||
# 使用 target_info 字典创建 TargetPersonInfo 实例
|
||||
target_person_info = TargetPersonInfo(**target_info) if target_info else TargetPersonInfo()
|
||||
|
||||
# 构建计划对象
|
||||
plan = Plan(
|
||||
chat_id=self.chat_id,
|
||||
chat_type=chat_type,
|
||||
mode=mode,
|
||||
target_info=target_info,
|
||||
target_info=target_person_info,
|
||||
available_actions=available_actions,
|
||||
chat_history=recent_messages,
|
||||
)
|
||||
@@ -77,6 +86,7 @@ class ChatterPlanGenerator:
|
||||
except Exception:
|
||||
# 如果生成失败,返回一个基本的空计划
|
||||
return Plan(
|
||||
chat_type = ChatType.PRIVATE,#空计划默认当成私聊
|
||||
chat_id=self.chat_id,
|
||||
mode=mode,
|
||||
target_info=TargetPersonInfo(),
|
||||
@@ -124,7 +134,7 @@ class ChatterPlanGenerator:
|
||||
try:
|
||||
# 获取最近的消息记录
|
||||
raw_messages = await get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=self.chat_id, timestamp=time.time(), limit=global_config.memory.short_memory_length
|
||||
chat_id=self.chat_id, timestamp=time.time(), limit=global_config.chat.max_context_size
|
||||
)
|
||||
|
||||
# 转换为 DatabaseMessages 对象
|
||||
|
||||
@@ -70,6 +70,7 @@ class ChatterActionPlanner:
|
||||
"replies_generated": 0,
|
||||
"other_actions_executed": 0,
|
||||
}
|
||||
self._background_tasks: set[asyncio.Task] = set()
|
||||
|
||||
async def plan(self, context: "StreamContext | None" = None) -> tuple[list[dict[str, Any]], Any | None]:
|
||||
"""
|
||||
@@ -157,7 +158,9 @@ class ChatterActionPlanner:
|
||||
)
|
||||
|
||||
if interest_updates:
|
||||
asyncio.create_task(self._commit_interest_updates(interest_updates))
|
||||
task = asyncio.create_task(self._commit_interest_updates(interest_updates))
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._handle_task_result)
|
||||
|
||||
# 检查兴趣度是否达到非回复动作阈值
|
||||
non_reply_action_interest_threshold = global_config.affinity_flow.non_reply_action_interest_threshold
|
||||
@@ -266,6 +269,17 @@ class ChatterActionPlanner:
|
||||
|
||||
return final_actions_dict, final_target_message_dict
|
||||
|
||||
def _handle_task_result(self, task: asyncio.Task) -> None:
|
||||
"""处理后台任务的结果,记录异常。"""
|
||||
try:
|
||||
task.result()
|
||||
except asyncio.CancelledError:
|
||||
pass # 任务被取消是正常现象
|
||||
except Exception as e:
|
||||
logger.error(f"后台任务执行失败: {e}", exc_info=True)
|
||||
finally:
|
||||
self._background_tasks.discard(task)
|
||||
|
||||
def get_planner_stats(self) -> dict[str, Any]:
|
||||
"""获取规划器统计"""
|
||||
return self.planner_stats.copy()
|
||||
|
||||
@@ -15,7 +15,7 @@ logger = get_logger(__name__)
|
||||
|
||||
@register_plugin
|
||||
class ProactiveThinkerPlugin(BasePlugin):
|
||||
"""一个主动思考的插件,但现在还只是个空壳子"""
|
||||
"""一个主动思考的插件"""
|
||||
|
||||
plugin_name: str = "proactive_thinker"
|
||||
enable_plugin: bool = True
|
||||
|
||||
@@ -6,6 +6,7 @@ from datetime import datetime
|
||||
|
||||
from maim_message import UserInfo
|
||||
|
||||
from src.chat.message_manager.sleep_system.state_manager import SleepState, sleep_state_manager
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
@@ -13,7 +14,6 @@ from src.manager.async_task_manager import AsyncTask, async_task_manager
|
||||
from src.plugin_system import BaseEventHandler, EventType
|
||||
from src.plugin_system.apis import chat_api, message_api, person_api
|
||||
from src.plugin_system.base.base_event import HandlerResult
|
||||
from src.chat.message_manager.sleep_system.state_manager import SleepState, sleep_state_manager
|
||||
|
||||
from .proactive_thinker_executor import ProactiveThinkerExecutor
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ Base search engine interface
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
|
||||
|
||||
class BaseSearchEngine(ABC):
|
||||
@@ -24,6 +24,12 @@ class BaseSearchEngine(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
async def read_url(self, url: str) -> Optional[str]:
|
||||
"""
|
||||
读取URL内容,如果引擎不支持则返回None
|
||||
"""
|
||||
return None
|
||||
|
||||
@abstractmethod
|
||||
def is_available(self) -> bool:
|
||||
"""
|
||||
|
||||
107
src/plugins/built_in/web_search_tool/engines/metaso_engine.py
Normal file
107
src/plugins/built_in/web_search_tool/engines/metaso_engine.py
Normal file
@@ -0,0 +1,107 @@
|
||||
"""
|
||||
Metaso Search Engine (Chat Completions Mode)
|
||||
"""
|
||||
import json
|
||||
from typing import Any, List
|
||||
|
||||
import httpx
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.apis import config_api
|
||||
|
||||
from ..utils.api_key_manager import create_api_key_manager_from_config
|
||||
from .base import BaseSearchEngine
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class MetasoClient:
|
||||
"""A client to interact with the Metaso API."""
|
||||
|
||||
def __init__(self, api_key: str):
|
||||
self.api_key = api_key
|
||||
self.base_url = "https://metaso.cn/api/v1"
|
||||
self.headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Accept": "application/json",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
async def search(self, query: str, **kwargs) -> List[dict[str, Any]]:
|
||||
"""Perform a search using the Metaso Chat Completions API."""
|
||||
payload = {"model": "fast", "stream": True, "messages": [{"role": "user", "content": query}]}
|
||||
search_url = f"{self.base_url}/chat/completions"
|
||||
full_response_content = ""
|
||||
|
||||
async with httpx.AsyncClient(timeout=90.0) as client:
|
||||
try:
|
||||
async with client.stream("POST", search_url, headers=self.headers, json=payload) as response:
|
||||
response.raise_for_status()
|
||||
async for line in response.aiter_lines():
|
||||
if line.startswith("data:"):
|
||||
data_str = line[len("data:") :].strip()
|
||||
if data_str == "[DONE]":
|
||||
break
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
delta = data.get("choices", [{}])[0].get("delta", {})
|
||||
content_chunk = delta.get("content")
|
||||
if content_chunk:
|
||||
full_response_content += content_chunk
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Metaso stream: could not decode JSON line: {data_str}")
|
||||
continue
|
||||
|
||||
if not full_response_content:
|
||||
logger.warning("Metaso search returned an empty stream.")
|
||||
return []
|
||||
|
||||
return [
|
||||
{
|
||||
"title": query,
|
||||
"url": "https://metaso.cn/",
|
||||
"snippet": full_response_content,
|
||||
"provider": "Metaso (Chat)",
|
||||
}
|
||||
]
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"HTTP error occurred while searching with Metaso Chat: {e.response.text}")
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error(f"An error occurred while searching with Metaso Chat: {e}", exc_info=True)
|
||||
return []
|
||||
|
||||
|
||||
class MetasoSearchEngine(BaseSearchEngine):
|
||||
"""Metaso Search Engine implementation."""
|
||||
|
||||
def __init__(self):
|
||||
self._initialize_clients()
|
||||
|
||||
def _initialize_clients(self):
|
||||
"""Initialize Metaso clients."""
|
||||
metaso_api_keys = config_api.get_global_config("web_search.metaso_api_keys", None)
|
||||
self.api_manager = create_api_key_manager_from_config(
|
||||
metaso_api_keys, lambda key: MetasoClient(api_key=key), "Metaso"
|
||||
)
|
||||
|
||||
def is_available(self) -> bool:
|
||||
"""Check if the Metaso search engine is available."""
|
||||
return self.api_manager.is_available()
|
||||
|
||||
async def search(self, args: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
"""Execute a Metaso search."""
|
||||
if not self.is_available():
|
||||
return []
|
||||
|
||||
query = args["query"]
|
||||
try:
|
||||
metaso_client = self.api_manager.get_next_client()
|
||||
if not metaso_client:
|
||||
logger.error("Could not get Metaso client.")
|
||||
return []
|
||||
|
||||
return await metaso_client.search(query)
|
||||
except Exception as e:
|
||||
logger.error(f"Metaso search failed: {e}", exc_info=True)
|
||||
return []
|
||||
@@ -22,6 +22,7 @@ class WEBSEARCHPLUGIN(BasePlugin):
|
||||
提供网络搜索和URL解析功能,支持多种搜索引擎:
|
||||
- Exa (需要API密钥)
|
||||
- Tavily (需要API密钥)
|
||||
- Metaso (需要API密钥)
|
||||
- DuckDuckGo (免费)
|
||||
- Bing (免费)
|
||||
"""
|
||||
@@ -43,6 +44,7 @@ class WEBSEARCHPLUGIN(BasePlugin):
|
||||
from .engines.exa_engine import ExaSearchEngine
|
||||
from .engines.searxng_engine import SearXNGSearchEngine
|
||||
from .engines.tavily_engine import TavilySearchEngine
|
||||
from .engines.metaso_engine import MetasoSearchEngine
|
||||
|
||||
# 实例化所有搜索引擎,这会触发API密钥管理器的初始化
|
||||
exa_engine = ExaSearchEngine()
|
||||
@@ -50,14 +52,16 @@ class WEBSEARCHPLUGIN(BasePlugin):
|
||||
ddg_engine = DDGSearchEngine()
|
||||
bing_engine = BingSearchEngine()
|
||||
searxng_engine = SearXNGSearchEngine()
|
||||
|
||||
# 报告每个引擎的状态
|
||||
metaso_engine = MetasoSearchEngine()
|
||||
|
||||
# 报告每个引擎的状态
|
||||
engines_status = {
|
||||
"Exa": exa_engine.is_available(),
|
||||
"Tavily": tavily_engine.is_available(),
|
||||
"DuckDuckGo": ddg_engine.is_available(),
|
||||
"Bing": bing_engine.is_available(),
|
||||
"SearXNG": searxng_engine.is_available(),
|
||||
"Metaso": metaso_engine.is_available(),
|
||||
}
|
||||
|
||||
available_engines = [name for name, available in engines_status.items() if available]
|
||||
|
||||
@@ -15,6 +15,7 @@ from ..engines.ddg_engine import DDGSearchEngine
|
||||
from ..engines.exa_engine import ExaSearchEngine
|
||||
from ..engines.searxng_engine import SearXNGSearchEngine
|
||||
from ..engines.tavily_engine import TavilySearchEngine
|
||||
from ..engines.metaso_engine import MetasoSearchEngine
|
||||
from ..utils.formatters import deduplicate_results, format_search_results
|
||||
|
||||
logger = get_logger("web_search_tool")
|
||||
@@ -51,6 +52,7 @@ class WebSurfingTool(BaseTool):
|
||||
"ddg": DDGSearchEngine(),
|
||||
"bing": BingSearchEngine(),
|
||||
"searxng": SearXNGSearchEngine(),
|
||||
"metaso": MetasoSearchEngine(),
|
||||
}
|
||||
|
||||
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
|
||||
|
||||
Reference in New Issue
Block a user