Merge branch 'dev' of https://github.com/MoFox-Studio/MoFox_Bot into dev
This commit is contained in:
2
bot.py
2
bot.py
@@ -21,6 +21,7 @@ initialize_logging()
|
||||
from src.main import MainSystem # noqa
|
||||
from src import BaseMain # noqa
|
||||
from src.manager.async_task_manager import async_task_manager # noqa
|
||||
from src.chat.knowledge.knowledge_lib import initialize_lpmm_knowledge # noqa
|
||||
from src.config.config import global_config # noqa
|
||||
from src.common.database.database import initialize_sql_database # noqa
|
||||
from src.common.database.sqlalchemy_models import initialize_database as init_db # noqa
|
||||
@@ -228,6 +229,7 @@ if __name__ == "__main__":
|
||||
try:
|
||||
# 执行初始化和任务调度
|
||||
loop.run_until_complete(main_system.initialize())
|
||||
initialize_lpmm_knowledge()
|
||||
# Schedule tasks returns a future that runs forever.
|
||||
# We can run console_input_loop concurrently.
|
||||
main_tasks = loop.create_task(main_system.schedule_tasks())
|
||||
|
||||
@@ -1,48 +1,56 @@
|
||||
|
||||
# 事件系统使用指南
|
||||
|
||||
## 概述
|
||||
|
||||
本项目的事件系统是一个基于插件架构的异步事件处理框架,允许插件通过事件驱动的方式进行通信和协作。事件系统采用发布-订阅模式,支持动态事件注册、处理器管理、权重排序和链式处理。
|
||||
本项目的事件系统是一个基于插件架构的异步事件处理框架,允许插件通过事件驱动的方式进行通信和协作。事件系统采用发布-订阅模式,支持动态事件注册、处理器管理、权重排序、链式处理和细粒度鉴权机制。
|
||||
|
||||
## 核心概念
|
||||
|
||||
### 事件 (Event)
|
||||
事件是系统中发生的特定动作或状态变化,可以被多个处理器监听和响应。
|
||||
事件是系统中发生的特定动作或状态变化,可以被多个处理器监听和响应。每个事件可以配置订阅者和触发者的白名单权限。
|
||||
|
||||
### 事件处理器 (Event Handler)
|
||||
事件处理器是响应特定事件的代码单元,可以订阅一个或多个事件。
|
||||
事件处理器是响应特定事件的代码单元,可以订阅一个或多个事件。处理器支持权重排序和链式处理控制。
|
||||
|
||||
### 事件管理器 (Event Manager)
|
||||
事件管理器是事件系统的核心,负责事件的注册、处理器的管理以及事件的触发。
|
||||
事件管理器是事件系统的核心,负责事件的注册、处理器的管理、权限验证以及事件的触发。
|
||||
|
||||
### 鉴权机制 (Authentication Mechanism)
|
||||
系统提供双重鉴权机制:
|
||||
- **订阅者白名单** (`allowed_subscribers`): 控制哪些处理器可以订阅事件
|
||||
- **触发者白名单** (`allowed_triggers`): 控制哪些插件可以触发事件
|
||||
|
||||
## 系统架构
|
||||
|
||||
```mermaid
|
||||
graph TD
|
||||
A[插件系统] --> B[事件管理器]
|
||||
B --> C[事件注册]
|
||||
B --> D[处理器注册]
|
||||
B --> E[事件触发]
|
||||
B --> C[事件注册与鉴权]
|
||||
B --> D[处理器注册与鉴权]
|
||||
B --> E[事件触发与权限验证]
|
||||
C --> F[BaseEvent实例]
|
||||
D --> G[BaseEventHandler实例]
|
||||
E --> H[处理器执行]
|
||||
H --> I[结果收集]
|
||||
H --> I[结果收集与汇总]
|
||||
C --> J[订阅者白名单验证]
|
||||
E --> K[触发者白名单验证]
|
||||
```
|
||||
|
||||
## 内置事件类型
|
||||
|
||||
系统预定义了以下事件类型:
|
||||
|
||||
| 事件名称 | 描述 | 触发时机 |
|
||||
|---------|------|----------|
|
||||
| `on_start` | 启动事件 | 系统启动时 |
|
||||
| `on_stop` | 停止事件 | 系统停止时 |
|
||||
| `on_message` | 消息事件 | 收到新消息时 |
|
||||
| `on_plan` | 计划事件 | 执行计划任务时 |
|
||||
| `post_llm` | LLM后处理事件 | LLM处理完成后 |
|
||||
| `after_llm` | LLM后事件 | LLM响应后 |
|
||||
| `post_send` | 发送后处理事件 | 消息发送后 |
|
||||
| `after_send` | 发送后事件 | 消息完全发送后 |
|
||||
| 事件名称 | 描述 | 触发时机 | 默认权限 |
|
||||
|---------|------|----------|----------|
|
||||
| `on_start` | 启动事件 | 系统启动时 | SYSTEM |
|
||||
| `on_stop` | 停止事件 | 系统停止时 | SYSTEM |
|
||||
| `on_message` | 消息事件 | 收到新消息时 | SYSTEM |
|
||||
| `on_plan` | 计划事件 | 执行计划任务时 | SYSTEM |
|
||||
| `post_llm` | 准备LLM事件 | 准备LLM时 | SYSTEM |
|
||||
| `after_llm` | LLM后事件 | LLM响应后 | SYSTEM |
|
||||
| `post_send` | 准备发送消息事件 | 准备发送消息时 | SYSTEM |
|
||||
| `after_send` | 发送后事件 | 消息完全发送后 | SYSTEM |
|
||||
|
||||
## 快速开始
|
||||
|
||||
@@ -72,7 +80,7 @@ class MyEventHandler(BaseEventHandler):
|
||||
|
||||
return HandlerResult(
|
||||
success=True,
|
||||
continue_process=True, # 是否继续让其他处理器处理
|
||||
continue_process=True, # 是否阻断后续流程
|
||||
message="处理成功",
|
||||
handler_name=self.handler_name
|
||||
)
|
||||
@@ -104,47 +112,178 @@ class MyPlugin(BasePlugin):
|
||||
|
||||
### 3. 触发事件
|
||||
|
||||
使用事件管理器触发事件:
|
||||
使用事件管理器触发事件,支持权限验证:
|
||||
|
||||
```python
|
||||
from src.plugin_system.core.event_manager import event_manager
|
||||
|
||||
# 触发内置事件
|
||||
await event_manager.trigger_event(EventType.ON_MESSAGE, message="Hello World")
|
||||
# 触发内置事件(需要SYSTEM权限)
|
||||
await event_manager.trigger_event(EventType.ON_MESSAGE, permission_group="SYSTEM", message="Hello World")
|
||||
|
||||
# 触发自定义事件
|
||||
await event_manager.trigger_event("custom_event", data={"key": "value"})
|
||||
# 触发自定义事件(需要相应权限)
|
||||
await event_manager.trigger_event("custom_event", permission_group="my_plugin", data={"key": "value"})
|
||||
```
|
||||
|
||||
## 鉴权机制详解
|
||||
|
||||
### 事件注册时的权限控制
|
||||
|
||||
注册事件时可以指定订阅者和触发者的白名单:
|
||||
|
||||
```python
|
||||
# 注册事件,限制只有特定处理器可以订阅,特定插件可以触发
|
||||
event_manager.register_event(
|
||||
"sensitive_event",
|
||||
allowed_subscribers=["audit_handler", "log_handler"], # 订阅者白名单
|
||||
allowed_triggers=["security_plugin", "admin_plugin"] # 触发者白名单
|
||||
)
|
||||
```
|
||||
|
||||
### 权限验证流程
|
||||
|
||||
1. **订阅权限验证**:处理器订阅事件时检查 `allowed_subscribers`
|
||||
2. **触发权限验证**:插件触发事件时检查 `allowed_triggers`
|
||||
3. **默认权限**:内置事件默认只允许 `SYSTEM` 权限组触发
|
||||
|
||||
### 权限组说明
|
||||
|
||||
- `SYSTEM`: 系统核心组件权限
|
||||
- `插件名称`: 各个插件的权限标识
|
||||
- 空字符串: 无权限组(无法触发有白名单的事件)
|
||||
|
||||
## 使用模式
|
||||
|
||||
### 接口式模式(内部编写handler,外部触发)
|
||||
|
||||
**适用场景**:插件提供事件处理能力,供其他组件调用
|
||||
|
||||
```python
|
||||
# 服务提供者插件
|
||||
class DataProcessorHandler(BaseEventHandler):
|
||||
handler_name = "data_processor"
|
||||
handler_description = "数据处理服务"
|
||||
weight = 10 # 权重,越大越先执行
|
||||
intercept_message = False # 是否拦截消息
|
||||
init_subscribe = ["process_data_request"]
|
||||
|
||||
async def execute(self, params: dict) -> HandlerResult:
|
||||
data = params.get("data")
|
||||
processed = self.process_data(data)
|
||||
return HandlerResult(True, True, processed)
|
||||
|
||||
# 服务消费者插件
|
||||
async def use_data_service():
|
||||
result = await event_manager.trigger_event(
|
||||
"process_data_request",
|
||||
permission_group="consumer_plugin",
|
||||
data={"input": "test"}
|
||||
)
|
||||
if result:
|
||||
processed_data = result.get_message_result()
|
||||
```
|
||||
|
||||
### 通知式模式(外部编写handler,内部触发)
|
||||
|
||||
**适用场景**:插件内部发生事件,通知外部处理器
|
||||
|
||||
```python
|
||||
# 事件生产者插件
|
||||
class EventProducerPlugin(BasePlugin):
|
||||
def __init__(self):
|
||||
# 注册自定义事件,允许其他处理器订阅
|
||||
event_manager.register_event("custom_alert")
|
||||
|
||||
async def detect_anomaly(self):
|
||||
if anomaly_detected:
|
||||
# 触发事件通知订阅者
|
||||
await event_manager.trigger_event(
|
||||
"custom_alert",
|
||||
permission_group=self.plugin_name,
|
||||
anomaly_type="security",
|
||||
severity="high"
|
||||
)
|
||||
|
||||
# 事件消费者插件
|
||||
class AlertHandler(BaseEventHandler):
|
||||
handler_name = "alert_handler"
|
||||
init_subscribe = ["custom_alert"]
|
||||
|
||||
async def execute(self, params: dict) -> HandlerResult:
|
||||
anomaly_type = params.get("anomaly_type")
|
||||
severity = params.get("severity")
|
||||
self.handle_alert(anomaly_type, severity)
|
||||
return HandlerResult(True, True, "Alert handled")
|
||||
```
|
||||
|
||||
## 高级用法
|
||||
|
||||
### 动态事件管理
|
||||
### 动态订阅管理
|
||||
|
||||
#### 注册自定义事件
|
||||
```python
|
||||
# 注册新事件
|
||||
event_manager.register_event("my_custom_event")
|
||||
|
||||
# 检查事件是否存在
|
||||
event = event_manager.get_event("my_custom_event")
|
||||
```
|
||||
|
||||
#### 动态订阅管理
|
||||
```python
|
||||
# 动态订阅处理器到事件
|
||||
event_manager.subscribe_handler_to_event("handler_name", "event_name")
|
||||
success = event_manager.subscribe_handler_to_event("handler_name", "event_name")
|
||||
|
||||
# 取消订阅
|
||||
event_manager.unsubscribe_handler_from_event("handler_name", "event_name")
|
||||
success = event_manager.unsubscribe_handler_from_event("handler_name", "event_name")
|
||||
|
||||
# 处理器自管理订阅
|
||||
class DynamicHandler(BaseEventHandler):
|
||||
async def setup_subscriptions(self):
|
||||
self.subscribe("event1")
|
||||
self.subscribe("event2")
|
||||
|
||||
async def cleanup(self):
|
||||
self.unsubscribe("event1")
|
||||
```
|
||||
|
||||
#### 启用/禁用事件
|
||||
```python
|
||||
# 禁用事件
|
||||
event_manager.disable_event("event_name")
|
||||
### 参数传递机制
|
||||
|
||||
# 启用事件
|
||||
event_manager.enable_event("event_name")
|
||||
事件支持灵活的参数传递:
|
||||
|
||||
```python
|
||||
# 触发事件时传递复杂参数
|
||||
await event_manager.trigger_event(
|
||||
"complex_event",
|
||||
permission_group="my_plugin",
|
||||
user_info={"id": 123, "name": "test"},
|
||||
metadata={"timestamp": "2024-01-01", "source": "api"},
|
||||
nested_data={"level1": {"level2": "value"}}
|
||||
)
|
||||
|
||||
# 处理器接收参数
|
||||
async def execute(self, params: dict) -> HandlerResult:
|
||||
user_info = params.get("user_info", {})
|
||||
metadata = params.get("metadata", {})
|
||||
# 处理参数...
|
||||
```
|
||||
|
||||
### 结果汇总与处理
|
||||
|
||||
事件触发后返回 `HandlerResultsCollection`,提供丰富的查询方法:
|
||||
|
||||
```python
|
||||
results = await event_manager.trigger_event("my_event", permission_group="my_plugin", data=data)
|
||||
|
||||
# 获取处理摘要
|
||||
summary = results.get_summary()
|
||||
print(f"总处理器数: {summary['total_handlers']}")
|
||||
print(f"成功数: {summary['success_count']}")
|
||||
print(f"失败数: {summary['failure_count']}")
|
||||
print(f"失败处理器: {summary['failed_handlers']}")
|
||||
|
||||
# 获取特定处理器结果
|
||||
specific_result = results.get_handler_result("my_handler")
|
||||
if specific_result and specific_result.success:
|
||||
print(f"处理器结果: {specific_result.message}")
|
||||
|
||||
# 检查处理链状态
|
||||
if results.all_continue_process():
|
||||
print("所有处理器都允许继续处理")
|
||||
else:
|
||||
print("有处理器中断了处理链")
|
||||
|
||||
# 获取所有消息结果
|
||||
all_messages = results.get_message_result()
|
||||
```
|
||||
|
||||
### 事件处理器权重
|
||||
@@ -152,11 +291,14 @@ event_manager.enable_event("event_name")
|
||||
事件处理器支持权重机制,权重越高的处理器越先执行:
|
||||
|
||||
```python
|
||||
class HighPriorityHandler(BaseEventHandler):
|
||||
weight = 100 # 高优先级
|
||||
class CriticalHandler(BaseEventHandler):
|
||||
weight = 100 # 高优先级,最先执行
|
||||
|
||||
class LowPriorityHandler(BaseEventHandler):
|
||||
weight = 1 # 低优先级
|
||||
class NormalHandler(BaseEventHandler):
|
||||
weight = 50 # 中等优先级
|
||||
|
||||
class BackgroundHandler(BaseEventHandler):
|
||||
weight = 1 # 低优先级,最后执行
|
||||
```
|
||||
|
||||
### 事件链式处理
|
||||
@@ -171,28 +313,9 @@ class FilterHandler(BaseEventHandler):
|
||||
return HandlerResult(True, True, "继续处理")
|
||||
```
|
||||
|
||||
### 事件结果处理
|
||||
|
||||
事件触发后返回 `HandlerResultsCollection`,可以获取详细的处理结果:
|
||||
|
||||
```python
|
||||
results = await event_manager.trigger_event("my_event", data=data)
|
||||
|
||||
# 获取处理摘要
|
||||
summary = results.get_summary()
|
||||
print(f"总处理器数: {summary['total_handlers']}")
|
||||
print(f"成功数: {summary['success_count']}")
|
||||
print(f"失败处理器: {summary['failed_handlers']}")
|
||||
|
||||
# 获取特定处理器结果
|
||||
result = results.get_handler_result("my_handler")
|
||||
if result and result.success:
|
||||
print("处理器执行成功")
|
||||
```
|
||||
|
||||
## 完整示例
|
||||
|
||||
### 示例1:消息监控插件
|
||||
### 示例1:消息监控插件(带权限控制)
|
||||
|
||||
```python
|
||||
from src.plugin_system import BasePlugin, BaseEventHandler, register_plugin, EventType
|
||||
@@ -218,9 +341,10 @@ class MessageMonitorHandler(BaseEventHandler):
|
||||
# 关键词检测
|
||||
if "重要" in str(message):
|
||||
self.keyword_hits += 1
|
||||
# 触发特殊事件
|
||||
# 触发特殊事件(需要相应权限)
|
||||
await event_manager.trigger_event(
|
||||
"important_message_detected",
|
||||
permission_group=self.plugin_name,
|
||||
message=message,
|
||||
count=self.keyword_hits
|
||||
)
|
||||
@@ -247,18 +371,21 @@ class MessageMonitorPlugin(BasePlugin):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
# 注册自定义事件
|
||||
event_manager.register_event("important_message_detected")
|
||||
# 注册自定义事件,设置权限控制
|
||||
event_manager.register_event(
|
||||
"important_message_detected",
|
||||
allowed_subscribers=["important_handler", "audit_handler"], # 只允许特定处理器订阅
|
||||
allowed_triggers=["message_monitor"] # 只允许本插件触发
|
||||
)
|
||||
|
||||
def get_plugin_components(self):
|
||||
return [
|
||||
(MessageMonitorHandler.get_handler_info(), MessageMonitorHandler),
|
||||
(ImportantMessageHandler.get_handler_info(), ImportantMessageHandler),
|
||||
]
|
||||
|
||||
```
|
||||
|
||||
### 示例2:系统监控插件
|
||||
### 示例2:系统监控插件(带结果汇总)
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
@@ -284,12 +411,19 @@ class SystemMonitorHandler(BaseEventHandler):
|
||||
memory = psutil.virtual_memory()
|
||||
|
||||
if cpu_percent > 80 or memory.percent > 80:
|
||||
await event_manager.trigger_event(
|
||||
# 触发系统警报事件
|
||||
result = await event_manager.trigger_event(
|
||||
"system_alert",
|
||||
permission_group=self.plugin_name,
|
||||
cpu_percent=cpu_percent,
|
||||
memory_percent=memory.percent,
|
||||
timestamp=datetime.now()
|
||||
)
|
||||
|
||||
# 处理结果汇总
|
||||
if result:
|
||||
summary = result.get_summary()
|
||||
print(f"警报处理结果: {summary['success_count']}成功, {summary['failure_count']}失败")
|
||||
|
||||
await asyncio.sleep(30)
|
||||
|
||||
@@ -303,66 +437,86 @@ class AlertHandler(BaseEventHandler):
|
||||
async def execute(self, params):
|
||||
cpu = params.get("cpu_percent")
|
||||
memory = params.get("memory_percent")
|
||||
print(f"🚨 系统警报: CPU {cpu}%, 内存 {memory}%")
|
||||
timestamp = params.get("timestamp")
|
||||
print(f"🚨 系统警报({timestamp}): CPU {cpu}%, 内存 {memory}%")
|
||||
return HandlerResult(True, True, "警报已处理")
|
||||
|
||||
class AlertNotifierHandler(BaseEventHandler):
|
||||
handler_name = "alert_notifier"
|
||||
handler_description = "通知系统警报"
|
||||
weight = 15
|
||||
intercept_message = False
|
||||
init_subscribe = ["system_alert"]
|
||||
|
||||
async def execute(self, params):
|
||||
cpu = params.get("cpu_percent")
|
||||
memory = params.get("memory_percent")
|
||||
# 发送通知...
|
||||
return HandlerResult(True, True, "通知已发送")
|
||||
|
||||
@register_plugin
|
||||
class SystemMonitorPlugin(BasePlugin):
|
||||
plugin_name = "system_monitor"
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
# 注册系统警报事件
|
||||
event_manager.register_event("system_alert")
|
||||
|
||||
def get_plugin_components(self):
|
||||
return [
|
||||
(SystemMonitorHandler.get_handler_info(), SystemMonitorHandler),
|
||||
(AlertHandler.get_handler_info(), AlertHandler),
|
||||
(AlertNotifierHandler.get_handler_info(), AlertNotifierHandler),
|
||||
]
|
||||
|
||||
```
|
||||
|
||||
## 调试和监控
|
||||
|
||||
### 查看事件系统状态
|
||||
|
||||
```python
|
||||
# 获取事件系统摘要
|
||||
summary = event_manager.get_event_summary()
|
||||
print(f"事件总数: {summary['total_events']}")
|
||||
print(f"启用事件: {summary['enabled_events']}")
|
||||
print(f"禁用事件: {summary['disabled_events']}")
|
||||
print(f"处理器总数: {summary['total_handlers']}")
|
||||
print(f"事件列表: {summary['event_names']}")
|
||||
print(f"处理器列表: {summary['handler_names']}")
|
||||
```
|
||||
|
||||
### 查看事件订阅情况
|
||||
|
||||
```python
|
||||
# 查看特定事件的订阅者
|
||||
subscribers = event_manager.get_event_subscribers(EventType.ON_MESSAGE)
|
||||
for name, handler in subscribers.items():
|
||||
print(f"订阅者: {name}, 权重: {handler.weight}")
|
||||
|
||||
# 查看事件的权限设置
|
||||
event = event_manager.get_event("important_message_detected")
|
||||
if event:
|
||||
print(f"允许的订阅者: {event.allowed_subscribers}")
|
||||
print(f"允许的触发者: {event.allowed_triggers}")
|
||||
```
|
||||
|
||||
## 最佳实践
|
||||
### 事件执行监控
|
||||
|
||||
1. **权重设置**:合理设置处理器权重,避免权重冲突
|
||||
2. **错误处理**:始终在处理器中添加异常处理
|
||||
3. **性能考虑**:避免在处理器中执行耗时操作,可使用异步任务
|
||||
4. **事件命名**:使用清晰的事件名称,避免与内置事件冲突
|
||||
5. **资源清理**:在插件卸载时取消订阅相关事件
|
||||
6. **日志记录**:适当记录处理日志,便于调试和监控
|
||||
```python
|
||||
# 监控事件执行性能
|
||||
import time
|
||||
|
||||
## 常见问题
|
||||
|
||||
### Q: 如何处理事件的执行顺序?
|
||||
A: 通过设置处理器的 `weight` 属性来控制执行顺序,权重越高越先执行。
|
||||
|
||||
### Q: 如何阻止后续处理器执行?
|
||||
A: 在处理器返回 `HandlerResult` 时设置 `continue_process=False`。
|
||||
|
||||
### Q: 如何动态注册事件?
|
||||
A: 使用 `event_manager.register_event("event_name")` 方法注册新事件。
|
||||
|
||||
### Q: 如何获取事件处理结果?
|
||||
A: 事件触发后会返回 `HandlerResultsCollection`,可以获取详细的处理结果和摘要信息。
|
||||
|
||||
### Q: 如何处理异步事件?
|
||||
A: 所有事件处理器都是异步的,可以在 `execute` 方法中使用 `await` 进行异步操作。
|
||||
async def monitored_trigger(event_name, **kwargs):
|
||||
start_time = time.time()
|
||||
results = await event_manager.trigger_event(event_name, **kwargs)
|
||||
end_time = time.time()
|
||||
|
||||
if results:
|
||||
execution_time = end_time - start_time
|
||||
summary = results.get_summary()
|
||||
print(f"事件 {event_name} 执行时间: {execution_time:.3f}s")
|
||||
print(f"处理器执行统计: {summary}")
|
||||
|
||||
return results
|
||||
```
|
||||
@@ -1,269 +0,0 @@
|
||||
# try:
|
||||
# import src.plugins.knowledge.lib.quick_algo
|
||||
# except ImportError:
|
||||
# print("未找到quick_algo库,无法使用quick_algo算法")
|
||||
# print("请安装quick_algo库 - 在lib.quick_algo中,执行命令:python setup.py build_ext --inplace")
|
||||
|
||||
import sys
|
||||
import os
|
||||
import asyncio
|
||||
from time import sleep
|
||||
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
from src.chat.knowledge.embedding_store import EmbeddingManager
|
||||
from src.chat.knowledge.open_ie import OpenIE
|
||||
from src.chat.knowledge.kg_manager import KGManager
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.knowledge.utils.hash import get_sha256
|
||||
|
||||
|
||||
# 添加项目根目录到 sys.path
|
||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
OPENIE_DIR = os.path.join(ROOT_PATH, "data", "openie")
|
||||
|
||||
logger = get_logger("OpenIE导入")
|
||||
|
||||
|
||||
def ensure_openie_dir():
|
||||
"""确保OpenIE数据目录存在"""
|
||||
if not os.path.exists(OPENIE_DIR):
|
||||
os.makedirs(OPENIE_DIR)
|
||||
logger.info(f"创建OpenIE数据目录:{OPENIE_DIR}")
|
||||
else:
|
||||
logger.info(f"OpenIE数据目录已存在:{OPENIE_DIR}")
|
||||
|
||||
|
||||
def hash_deduplicate(
|
||||
raw_paragraphs: dict[str, str],
|
||||
triple_list_data: dict[str, list[list[str]]],
|
||||
stored_pg_hashes: set,
|
||||
stored_paragraph_hashes: set,
|
||||
):
|
||||
"""Hash去重
|
||||
|
||||
Args:
|
||||
raw_paragraphs: 索引的段落原文
|
||||
triple_list_data: 索引的三元组列表
|
||||
stored_pg_hashes: 已存储的段落hash集合
|
||||
stored_paragraph_hashes: 已存储的段落hash集合
|
||||
|
||||
Returns:
|
||||
new_raw_paragraphs: 去重后的段落
|
||||
new_triple_list_data: 去重后的三元组
|
||||
"""
|
||||
# 保存去重后的段落
|
||||
new_raw_paragraphs = {}
|
||||
# 保存去重后的三元组
|
||||
new_triple_list_data = {}
|
||||
|
||||
for _, (raw_paragraph, triple_list) in enumerate(
|
||||
zip(raw_paragraphs.values(), triple_list_data.values(), strict=False)
|
||||
):
|
||||
# 段落hash
|
||||
paragraph_hash = get_sha256(raw_paragraph)
|
||||
# 使用与EmbeddingStore中一致的命名空间格式:namespace-hash
|
||||
paragraph_key = f"paragraph-{paragraph_hash}"
|
||||
if paragraph_key in stored_pg_hashes and paragraph_hash in stored_paragraph_hashes:
|
||||
continue
|
||||
new_raw_paragraphs[paragraph_hash] = raw_paragraph
|
||||
new_triple_list_data[paragraph_hash] = triple_list
|
||||
|
||||
return new_raw_paragraphs, new_triple_list_data
|
||||
|
||||
|
||||
def handle_import_openie(openie_data: OpenIE, embed_manager: EmbeddingManager, kg_manager: KGManager) -> bool:
|
||||
# sourcery skip: extract-method
|
||||
# 从OpenIE数据中提取段落原文与三元组列表
|
||||
# 索引的段落原文
|
||||
raw_paragraphs = openie_data.extract_raw_paragraph_dict()
|
||||
# 索引的实体列表
|
||||
entity_list_data = openie_data.extract_entity_dict()
|
||||
# 索引的三元组列表
|
||||
triple_list_data = openie_data.extract_triple_dict()
|
||||
# print(openie_data.docs)
|
||||
if len(raw_paragraphs) != len(entity_list_data) or len(raw_paragraphs) != len(triple_list_data):
|
||||
logger.error("OpenIE数据存在异常")
|
||||
logger.error(f"原始段落数量:{len(raw_paragraphs)}")
|
||||
logger.error(f"实体列表数量:{len(entity_list_data)}")
|
||||
logger.error(f"三元组列表数量:{len(triple_list_data)}")
|
||||
logger.error("OpenIE数据段落数量与实体列表数量或三元组列表数量不一致")
|
||||
logger.error("请保证你的原始数据分段良好,不要有类似于 “.....” 单独成一段的情况")
|
||||
logger.error("或者一段中只有符号的情况")
|
||||
# 新增:检查docs中每条数据的完整性
|
||||
logger.error("系统将于2秒后开始检查数据完整性")
|
||||
sleep(2)
|
||||
found_missing = False
|
||||
missing_idxs = []
|
||||
for doc in getattr(openie_data, "docs", []):
|
||||
idx = doc.get("idx", "<无idx>")
|
||||
passage = doc.get("passage", "<无passage>")
|
||||
missing = []
|
||||
# 检查字段是否存在且非空
|
||||
if "passage" not in doc or not doc.get("passage"):
|
||||
missing.append("passage")
|
||||
if "extracted_entities" not in doc or not isinstance(doc.get("extracted_entities"), list):
|
||||
missing.append("名词列表缺失")
|
||||
elif len(doc.get("extracted_entities", [])) == 0:
|
||||
missing.append("名词列表为空")
|
||||
if "extracted_triples" not in doc or not isinstance(doc.get("extracted_triples"), list):
|
||||
missing.append("主谓宾三元组缺失")
|
||||
elif len(doc.get("extracted_triples", [])) == 0:
|
||||
missing.append("主谓宾三元组为空")
|
||||
# 输出所有doc的idx
|
||||
# print(f"检查: idx={idx}")
|
||||
if missing:
|
||||
found_missing = True
|
||||
missing_idxs.append(idx)
|
||||
logger.error("\n")
|
||||
logger.error("数据缺失:")
|
||||
logger.error(f"对应哈希值:{idx}")
|
||||
logger.error(f"对应文段内容内容:{passage}")
|
||||
logger.error(f"非法原因:{', '.join(missing)}")
|
||||
# 确保提示在所有非法数据输出后再输出
|
||||
if not found_missing:
|
||||
logger.info("所有数据均完整,没有发现缺失字段。")
|
||||
return False
|
||||
# 新增:提示用户是否删除非法文段继续导入
|
||||
# 将print移到所有logger.error之后,确保不会被冲掉
|
||||
logger.info(f"\n检测到非法文段,共{len(missing_idxs)}条。")
|
||||
logger.info("\n是否删除所有非法文段后继续导入?(y/n): ", end="")
|
||||
user_choice = input().strip().lower()
|
||||
if user_choice != "y":
|
||||
logger.info("用户选择不删除非法文段,程序终止。")
|
||||
sys.exit(1)
|
||||
# 删除非法文段
|
||||
logger.info("正在删除非法文段并继续导入...")
|
||||
# 过滤掉非法文段
|
||||
openie_data.docs = [
|
||||
doc for doc in getattr(openie_data, "docs", []) if doc.get("idx", "<无idx>") not in missing_idxs
|
||||
]
|
||||
# 重新提取数据
|
||||
raw_paragraphs = openie_data.extract_raw_paragraph_dict()
|
||||
entity_list_data = openie_data.extract_entity_dict()
|
||||
triple_list_data = openie_data.extract_triple_dict()
|
||||
# 再次校验
|
||||
if len(raw_paragraphs) != len(entity_list_data) or len(raw_paragraphs) != len(triple_list_data):
|
||||
logger.error("删除非法文段后,数据仍不一致,程序终止。")
|
||||
sys.exit(1)
|
||||
# 将索引换为对应段落的hash值
|
||||
logger.info("正在进行段落去重与重索引")
|
||||
raw_paragraphs, triple_list_data = hash_deduplicate(
|
||||
raw_paragraphs,
|
||||
triple_list_data,
|
||||
embed_manager.stored_pg_hashes,
|
||||
kg_manager.stored_paragraph_hashes,
|
||||
)
|
||||
if len(raw_paragraphs) != 0:
|
||||
# 获取嵌入并保存
|
||||
logger.info(f"段落去重完成,剩余待处理的段落数量:{len(raw_paragraphs)}")
|
||||
logger.info("开始Embedding")
|
||||
embed_manager.store_new_data_set(raw_paragraphs, triple_list_data)
|
||||
# Embedding-Faiss重索引
|
||||
logger.info("正在重新构建向量索引")
|
||||
embed_manager.rebuild_faiss_index()
|
||||
logger.info("向量索引构建完成")
|
||||
embed_manager.save_to_file()
|
||||
logger.info("Embedding完成")
|
||||
# 构建新段落的RAG
|
||||
logger.info("开始构建RAG")
|
||||
kg_manager.build_kg(triple_list_data, embed_manager)
|
||||
kg_manager.save_to_file()
|
||||
logger.info("RAG构建完成")
|
||||
else:
|
||||
logger.info("无新段落需要处理")
|
||||
return True
|
||||
|
||||
|
||||
async def main_async(): # sourcery skip: dict-comprehension
|
||||
# 新增确认提示
|
||||
print("=== 重要操作确认 ===")
|
||||
print("OpenIE导入时会大量发送请求,可能会撞到请求速度上限,请注意选用的模型")
|
||||
print("同之前样例:在本地模型下,在70分钟内我们发送了约8万条请求,在网络允许下,速度会更快")
|
||||
print("推荐使用硅基流动的Pro/BAAI/bge-m3")
|
||||
print("每百万Token费用为0.7元")
|
||||
print("知识导入时,会消耗大量系统资源,建议在较好配置电脑上运行")
|
||||
print("同上样例,导入时10700K几乎跑满,14900HX占用80%,峰值内存占用约3G")
|
||||
confirm = input("确认继续执行?(y/n): ").strip().lower()
|
||||
if confirm != "y":
|
||||
logger.info("用户取消操作")
|
||||
print("操作已取消")
|
||||
sys.exit(1)
|
||||
print("\n" + "=" * 40 + "\n")
|
||||
ensure_openie_dir() # 确保OpenIE目录存在
|
||||
logger.info("----开始导入openie数据----\n")
|
||||
|
||||
logger.info("创建LLM客户端")
|
||||
|
||||
# 初始化Embedding库
|
||||
embed_manager = EmbeddingManager()
|
||||
logger.info("正在从文件加载Embedding库")
|
||||
try:
|
||||
embed_manager.load_from_file()
|
||||
except Exception as e:
|
||||
logger.error(f"从文件加载Embedding库时发生错误:{e}")
|
||||
if "嵌入模型与本地存储不一致" in str(e):
|
||||
logger.error("检测到嵌入模型与本地存储不一致,已终止导入。请检查模型设置或清空嵌入库后重试。")
|
||||
logger.error("请保证你的嵌入模型从未更改,并且在导入时使用相同的模型")
|
||||
# print("检测到嵌入模型与本地存储不一致,已终止导入。请检查模型设置或清空嵌入库后重试。")
|
||||
sys.exit(1)
|
||||
if "不存在" in str(e):
|
||||
logger.error("如果你是第一次导入知识,请忽略此错误")
|
||||
logger.info("Embedding库加载完成")
|
||||
# 初始化KG
|
||||
kg_manager = KGManager()
|
||||
logger.info("正在从文件加载KG")
|
||||
try:
|
||||
kg_manager.load_from_file()
|
||||
except Exception as e:
|
||||
logger.error(f"从文件加载KG时发生错误:{e}")
|
||||
logger.error("如果你是第一次导入知识,请忽略此错误")
|
||||
logger.info("KG加载完成")
|
||||
|
||||
logger.info(f"KG节点数量:{len(kg_manager.graph.get_node_list())}")
|
||||
logger.info(f"KG边数量:{len(kg_manager.graph.get_edge_list())}")
|
||||
|
||||
# 数据比对:Embedding库与KG的段落hash集合
|
||||
for pg_hash in kg_manager.stored_paragraph_hashes:
|
||||
# 使用与EmbeddingStore中一致的命名空间格式:namespace-hash
|
||||
key = f"paragraph-{pg_hash}"
|
||||
if key not in embed_manager.stored_pg_hashes:
|
||||
logger.warning(f"KG中存在Embedding库中不存在的段落:{key}")
|
||||
|
||||
logger.info("正在导入OpenIE数据文件")
|
||||
try:
|
||||
openie_data = OpenIE.load()
|
||||
except Exception as e:
|
||||
logger.error(f"导入OpenIE数据文件时发生错误:{e}")
|
||||
return False
|
||||
if handle_import_openie(openie_data, embed_manager, kg_manager) is False:
|
||||
logger.error("处理OpenIE数据时发生错误")
|
||||
return False
|
||||
return None
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数 - 设置新的事件循环并运行异步主函数"""
|
||||
# 检查是否有现有的事件循环
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
if loop.is_closed():
|
||||
# 如果事件循环已关闭,创建新的
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
except RuntimeError:
|
||||
# 没有运行的事件循环,创建新的
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
# 在新的事件循环中运行异步主函数
|
||||
loop.run_until_complete(main_async())
|
||||
finally:
|
||||
# 确保事件循环被正确关闭
|
||||
if not loop.is_closed():
|
||||
loop.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# logger.info(f"111111111111111111111111{ROOT_PATH}")
|
||||
main()
|
||||
@@ -1,218 +0,0 @@
|
||||
import orjson
|
||||
import os
|
||||
import signal
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from threading import Lock, Event
|
||||
import sys
|
||||
import datetime
|
||||
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
# 添加项目根目录到 sys.path
|
||||
|
||||
from rich.progress import Progress # 替换为 rich 进度条
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
# from src.chat.knowledge.lpmmconfig import global_config
|
||||
from src.chat.knowledge.ie_process import info_extract_from_str
|
||||
from src.chat.knowledge.open_ie import OpenIE
|
||||
from rich.progress import (
|
||||
BarColumn,
|
||||
TimeElapsedColumn,
|
||||
TimeRemainingColumn,
|
||||
TaskProgressColumn,
|
||||
MofNCompleteColumn,
|
||||
SpinnerColumn,
|
||||
TextColumn,
|
||||
)
|
||||
from raw_data_preprocessor import RAW_DATA_PATH, load_raw_data
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
logger = get_logger("LPMM知识库-信息提取")
|
||||
|
||||
|
||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
TEMP_DIR = os.path.join(ROOT_PATH, "temp")
|
||||
# IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, "data", "imported_lpmm_data")
|
||||
OPENIE_OUTPUT_DIR = os.path.join(ROOT_PATH, "data", "openie")
|
||||
|
||||
|
||||
def ensure_dirs():
|
||||
"""确保临时目录和输出目录存在"""
|
||||
if not os.path.exists(TEMP_DIR):
|
||||
os.makedirs(TEMP_DIR)
|
||||
logger.info(f"已创建临时目录: {TEMP_DIR}")
|
||||
if not os.path.exists(OPENIE_OUTPUT_DIR):
|
||||
os.makedirs(OPENIE_OUTPUT_DIR)
|
||||
logger.info(f"已创建输出目录: {OPENIE_OUTPUT_DIR}")
|
||||
if not os.path.exists(RAW_DATA_PATH):
|
||||
os.makedirs(RAW_DATA_PATH)
|
||||
logger.info(f"已创建原始数据目录: {RAW_DATA_PATH}")
|
||||
|
||||
|
||||
# 创建一个线程安全的锁,用于保护文件操作和共享数据
|
||||
file_lock = Lock()
|
||||
open_ie_doc_lock = Lock()
|
||||
|
||||
# 创建一个事件标志,用于控制程序终止
|
||||
shutdown_event = Event()
|
||||
|
||||
lpmm_entity_extract_llm = LLMRequest(
|
||||
model_set=model_config.model_task_config.lpmm_entity_extract, request_type="lpmm.entity_extract"
|
||||
)
|
||||
lpmm_rdf_build_llm = LLMRequest(model_set=model_config.model_task_config.lpmm_rdf_build, request_type="lpmm.rdf_build")
|
||||
|
||||
|
||||
def process_single_text(pg_hash, raw_data):
|
||||
"""处理单个文本的函数,用于线程池"""
|
||||
temp_file_path = f"{TEMP_DIR}/{pg_hash}.json"
|
||||
|
||||
# 使用文件锁检查和读取缓存文件
|
||||
with file_lock:
|
||||
if os.path.exists(temp_file_path):
|
||||
try:
|
||||
# 存在对应的提取结果
|
||||
logger.info(f"找到缓存的提取结果:{pg_hash}")
|
||||
with open(temp_file_path, "r", encoding="utf-8") as f:
|
||||
return orjson.loads(f.read()), None
|
||||
except orjson.JSONDecodeError:
|
||||
# 如果JSON文件损坏,删除它并重新处理
|
||||
logger.warning(f"缓存文件损坏,重新处理:{pg_hash}")
|
||||
os.remove(temp_file_path)
|
||||
|
||||
entity_list, rdf_triple_list = info_extract_from_str(
|
||||
lpmm_entity_extract_llm,
|
||||
lpmm_rdf_build_llm,
|
||||
raw_data,
|
||||
)
|
||||
if entity_list is None or rdf_triple_list is None:
|
||||
return None, pg_hash
|
||||
doc_item = {
|
||||
"idx": pg_hash,
|
||||
"passage": raw_data,
|
||||
"extracted_entities": entity_list,
|
||||
"extracted_triples": rdf_triple_list,
|
||||
}
|
||||
# 保存临时提取结果
|
||||
with file_lock:
|
||||
try:
|
||||
with open(temp_file_path, "w", encoding="utf-8") as f:
|
||||
f.write(orjson.dumps(doc_item, option=orjson.OPT_INDENT_2).decode("utf-8"))
|
||||
except Exception as e:
|
||||
logger.error(f"保存缓存文件失败:{pg_hash}, 错误:{e}")
|
||||
# 如果保存失败,确保不会留下损坏的文件
|
||||
if os.path.exists(temp_file_path):
|
||||
os.remove(temp_file_path)
|
||||
sys.exit(0)
|
||||
return None, pg_hash
|
||||
return doc_item, None
|
||||
|
||||
|
||||
def signal_handler(_signum, _frame):
|
||||
"""处理Ctrl+C信号"""
|
||||
logger.info("\n接收到中断信号,正在优雅地关闭程序...")
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
def main(): # sourcery skip: comprehension-to-generator, extract-method
|
||||
# 设置信号处理器
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
ensure_dirs() # 确保目录存在
|
||||
# 新增用户确认提示
|
||||
print("=== 重要操作确认,请认真阅读以下内容哦 ===")
|
||||
print("实体提取操作将会花费较多api余额和时间,建议在空闲时段执行。")
|
||||
print("举例:600万字全剧情,提取选用deepseek v3 0324,消耗约40元,约3小时。")
|
||||
print("建议使用硅基流动的非Pro模型")
|
||||
print("或者使用可以用赠金抵扣的Pro模型")
|
||||
print("请确保账户余额充足,并且在执行前确认无误。")
|
||||
confirm = input("确认继续执行?(y/n): ").strip().lower()
|
||||
if confirm != "y":
|
||||
logger.info("用户取消操作")
|
||||
print("操作已取消")
|
||||
sys.exit(1)
|
||||
print("\n" + "=" * 40 + "\n")
|
||||
ensure_dirs() # 确保目录存在
|
||||
logger.info("--------进行信息提取--------\n")
|
||||
|
||||
# 加载原始数据
|
||||
logger.info("正在加载原始数据")
|
||||
all_sha256_list, all_raw_datas = load_raw_data()
|
||||
|
||||
failed_sha256 = []
|
||||
open_ie_doc = []
|
||||
|
||||
workers = global_config.lpmm_knowledge.info_extraction_workers
|
||||
with ThreadPoolExecutor(max_workers=workers) as executor:
|
||||
future_to_hash = {
|
||||
executor.submit(process_single_text, pg_hash, raw_data): pg_hash
|
||||
for pg_hash, raw_data in zip(all_sha256_list, all_raw_datas, strict=False)
|
||||
}
|
||||
|
||||
with Progress(
|
||||
SpinnerColumn(),
|
||||
TextColumn("[progress.description]{task.description}"),
|
||||
BarColumn(),
|
||||
TaskProgressColumn(),
|
||||
MofNCompleteColumn(),
|
||||
"•",
|
||||
TimeElapsedColumn(),
|
||||
"<",
|
||||
TimeRemainingColumn(),
|
||||
transient=False,
|
||||
) as progress:
|
||||
task = progress.add_task("正在进行提取:", total=len(future_to_hash))
|
||||
try:
|
||||
for future in as_completed(future_to_hash):
|
||||
if shutdown_event.is_set():
|
||||
for f in future_to_hash:
|
||||
if not f.done():
|
||||
f.cancel()
|
||||
break
|
||||
|
||||
doc_item, failed_hash = future.result()
|
||||
if failed_hash:
|
||||
failed_sha256.append(failed_hash)
|
||||
logger.error(f"提取失败:{failed_hash}")
|
||||
elif doc_item:
|
||||
with open_ie_doc_lock:
|
||||
open_ie_doc.append(doc_item)
|
||||
progress.update(task, advance=1)
|
||||
except KeyboardInterrupt:
|
||||
logger.info("\n接收到中断信号,正在优雅地关闭程序...")
|
||||
shutdown_event.set()
|
||||
for f in future_to_hash:
|
||||
if not f.done():
|
||||
f.cancel()
|
||||
|
||||
# 合并所有文件的提取结果并保存
|
||||
if open_ie_doc:
|
||||
sum_phrase_chars = sum([len(e) for chunk in open_ie_doc for e in chunk["extracted_entities"]])
|
||||
sum_phrase_words = sum([len(e.split()) for chunk in open_ie_doc for e in chunk["extracted_entities"]])
|
||||
num_phrases = sum([len(chunk["extracted_entities"]) for chunk in open_ie_doc])
|
||||
openie_obj = OpenIE(
|
||||
open_ie_doc,
|
||||
round(sum_phrase_chars / num_phrases, 4) if num_phrases else 0,
|
||||
round(sum_phrase_words / num_phrases, 4) if num_phrases else 0,
|
||||
)
|
||||
# 输出文件名格式:MM-DD-HH-ss-openie.json
|
||||
now = datetime.datetime.now()
|
||||
filename = now.strftime("%m-%d-%H-%S-openie.json")
|
||||
output_path = os.path.join(OPENIE_OUTPUT_DIR, filename)
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
f.write(
|
||||
orjson.dumps(
|
||||
openie_obj.to_dict() if hasattr(openie_obj, "to_dict") else openie_obj.__dict__,
|
||||
option=orjson.OPT_INDENT_2,
|
||||
).decode("utf-8")
|
||||
)
|
||||
logger.info(f"信息提取结果已保存到: {output_path}")
|
||||
else:
|
||||
logger.warning("没有可保存的信息提取结果")
|
||||
|
||||
logger.info("--------信息提取完成--------")
|
||||
logger.info(f"提取失败的文段SHA256:{failed_sha256}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
267
scripts/lpmm_learning_tool.py
Normal file
267
scripts/lpmm_learning_tool.py
Normal file
@@ -0,0 +1,267 @@
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
import glob
|
||||
import orjson
|
||||
import datetime
|
||||
from pathlib import Path
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from threading import Lock
|
||||
from typing import Optional
|
||||
|
||||
# 将项目根目录添加到 sys.path
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.knowledge.utils.hash import get_sha256
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config
|
||||
from src.chat.knowledge.open_ie import OpenIE
|
||||
from src.chat.knowledge.embedding_store import EmbeddingManager
|
||||
from src.chat.knowledge.kg_manager import KGManager
|
||||
from rich.progress import (
|
||||
Progress,
|
||||
BarColumn,
|
||||
TimeElapsedColumn,
|
||||
TimeRemainingColumn,
|
||||
TaskProgressColumn,
|
||||
MofNCompleteColumn,
|
||||
SpinnerColumn,
|
||||
TextColumn,
|
||||
)
|
||||
|
||||
logger = get_logger("LPMM_LearningTool")
|
||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
RAW_DATA_PATH = os.path.join(ROOT_PATH, "data", "lpmm_raw_data")
|
||||
OPENIE_OUTPUT_DIR = os.path.join(ROOT_PATH, "data", "openie")
|
||||
TEMP_DIR = os.path.join(ROOT_PATH, "temp", "lpmm_cache")
|
||||
file_lock = Lock()
|
||||
|
||||
# --- 模块一:数据预处理 ---
|
||||
|
||||
def process_text_file(file_path):
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
raw = f.read()
|
||||
return [p.strip() for p in raw.split("\n\n") if p.strip()]
|
||||
|
||||
def preprocess_raw_data():
|
||||
logger.info("--- 步骤 1: 开始数据预处理 ---")
|
||||
os.makedirs(RAW_DATA_PATH, exist_ok=True)
|
||||
raw_files = list(Path(RAW_DATA_PATH).glob("*.txt"))
|
||||
if not raw_files:
|
||||
logger.warning(f"警告: 在 '{RAW_DATA_PATH}' 中没有找到任何 .txt 文件")
|
||||
return []
|
||||
|
||||
all_paragraphs = []
|
||||
for file in raw_files:
|
||||
logger.info(f"正在处理文件: {file.name}")
|
||||
all_paragraphs.extend(process_text_file(file))
|
||||
|
||||
unique_paragraphs = {get_sha256(p): p for p in all_paragraphs}
|
||||
logger.info(f"共找到 {len(all_paragraphs)} 个段落,去重后剩余 {len(unique_paragraphs)} 个。")
|
||||
logger.info("--- 数据预处理完成 ---")
|
||||
return unique_paragraphs
|
||||
|
||||
# --- 模块二:信息提取 ---
|
||||
|
||||
def get_extraction_prompt(paragraph: str) -> str:
|
||||
return f"""
|
||||
请从以下段落中提取关键信息。你需要提取两种类型的信息:
|
||||
1. **实体 (Entities)**: 识别并列出段落中所有重要的名词或名词短语。
|
||||
2. **三元组 (Triples)**: 以 [主语, 谓语, 宾语] 的格式,提取段落中描述关系或事实的核心信息。
|
||||
|
||||
请严格按照以下 JSON 格式返回结果,不要添加任何额外的解释或注释:
|
||||
{{
|
||||
"entities": ["实体1", "实体2"],
|
||||
"triples": [["主语1", "谓语1", "宾语1"]]
|
||||
}}
|
||||
|
||||
这是你需要处理的段落:
|
||||
---
|
||||
{paragraph}
|
||||
---
|
||||
"""
|
||||
|
||||
async def extract_info_async(pg_hash, paragraph, llm_api):
|
||||
temp_file_path = os.path.join(TEMP_DIR, f"{pg_hash}.json")
|
||||
with file_lock:
|
||||
if os.path.exists(temp_file_path):
|
||||
try:
|
||||
with open(temp_file_path, "rb") as f:
|
||||
return orjson.loads(f.read()), None
|
||||
except orjson.JSONDecodeError:
|
||||
os.remove(temp_file_path)
|
||||
|
||||
prompt = get_extraction_prompt(paragraph)
|
||||
try:
|
||||
content, (_, _, _) = await llm_api.generate_response_async(prompt)
|
||||
extracted_data = orjson.loads(content)
|
||||
doc_item = {
|
||||
"idx": pg_hash, "passage": paragraph,
|
||||
"extracted_entities": extracted_data.get("entities", []),
|
||||
"extracted_triples": extracted_data.get("triples", []),
|
||||
}
|
||||
with file_lock:
|
||||
with open(temp_file_path, "wb") as f:
|
||||
f.write(orjson.dumps(doc_item))
|
||||
return doc_item, None
|
||||
except Exception as e:
|
||||
logger.error(f"提取信息失败:{pg_hash}, 错误:{e}")
|
||||
return None, pg_hash
|
||||
|
||||
def extract_info_sync(pg_hash, paragraph, llm_api):
|
||||
return asyncio.run(extract_info_async(pg_hash, paragraph, llm_api))
|
||||
|
||||
def extract_information(paragraphs_dict, model_set):
|
||||
logger.info("--- 步骤 2: 开始信息提取 ---")
|
||||
os.makedirs(OPENIE_OUTPUT_DIR, exist_ok=True)
|
||||
os.makedirs(TEMP_DIR, exist_ok=True)
|
||||
|
||||
llm_api = LLMRequest(model_set=model_set)
|
||||
failed_hashes, open_ie_docs = [], []
|
||||
|
||||
with ThreadPoolExecutor(max_workers=5) as executor:
|
||||
f_to_hash = {executor.submit(extract_info_sync, p_hash, p, llm_api): p_hash for p_hash, p in paragraphs_dict.items()}
|
||||
with Progress(SpinnerColumn(), TextColumn("[progress.description]{task.description}"), BarColumn(), TaskProgressColumn(), MofNCompleteColumn(), "•", TimeElapsedColumn(), "<", TimeRemainingColumn()) as progress:
|
||||
task = progress.add_task("[cyan]正在提取信息...", total=len(paragraphs_dict))
|
||||
for future in as_completed(f_to_hash):
|
||||
doc_item, failed_hash = future.result()
|
||||
if failed_hash: failed_hashes.append(failed_hash)
|
||||
elif doc_item: open_ie_docs.append(doc_item)
|
||||
progress.update(task, advance=1)
|
||||
|
||||
if open_ie_docs:
|
||||
all_entities = [e for doc in open_ie_docs for e in doc["extracted_entities"]]
|
||||
num_entities = len(all_entities)
|
||||
avg_ent_chars = round(sum(len(e) for e in all_entities) / num_entities, 4) if num_entities else 0
|
||||
avg_ent_words = round(sum(len(e.split()) for e in all_entities) / num_entities, 4) if num_entities else 0
|
||||
openie_obj = OpenIE(docs=open_ie_docs, avg_ent_chars=avg_ent_chars, avg_ent_words=avg_ent_words)
|
||||
|
||||
now = datetime.datetime.now()
|
||||
filename = now.strftime("%Y-%m-%d-%H-%M-%S-openie.json")
|
||||
output_path = os.path.join(OPENIE_OUTPUT_DIR, filename)
|
||||
with open(output_path, "wb") as f:
|
||||
f.write(orjson.dumps(openie_obj._to_dict()))
|
||||
logger.info(f"信息提取结果已保存到: {output_path}")
|
||||
|
||||
if failed_hashes: logger.error(f"以下 {len(failed_hashes)} 个段落提取失败: {failed_hashes}")
|
||||
logger.info("--- 信息提取完成 ---")
|
||||
|
||||
# --- 模块三:数据导入 ---
|
||||
|
||||
async def import_data(openie_obj: Optional[OpenIE] = None):
|
||||
"""
|
||||
将OpenIE数据导入知识库(Embedding Store 和 KG)
|
||||
|
||||
Args:
|
||||
openie_obj (Optional[OpenIE], optional): 如果提供,则直接使用这个OpenIE对象;
|
||||
否则,将自动从默认文件夹加载最新的OpenIE文件。
|
||||
默认为 None.
|
||||
"""
|
||||
logger.info("--- 步骤 3: 开始数据导入 ---")
|
||||
embed_manager, kg_manager = EmbeddingManager(), KGManager()
|
||||
|
||||
logger.info("正在加载现有的 Embedding 库...")
|
||||
try: embed_manager.load_from_file()
|
||||
except Exception as e: logger.warning(f"加载 Embedding 库失败: {e}。")
|
||||
|
||||
logger.info("正在加载现有的 KG...")
|
||||
try: kg_manager.load_from_file()
|
||||
except Exception as e: logger.warning(f"加载 KG 失败: {e}。")
|
||||
|
||||
try:
|
||||
if openie_obj:
|
||||
openie_data = openie_obj
|
||||
logger.info("已使用指定的 OpenIE 对象。")
|
||||
else:
|
||||
openie_data = OpenIE.load()
|
||||
except Exception as e:
|
||||
logger.error(f"加载OpenIE数据文件失败: {e}")
|
||||
return
|
||||
|
||||
raw_paragraphs = openie_data.extract_raw_paragraph_dict()
|
||||
triple_list_data = openie_data.extract_triple_dict()
|
||||
|
||||
new_raw_paragraphs, new_triple_list_data = {}, {}
|
||||
stored_embeds = embed_manager.stored_pg_hashes
|
||||
stored_kgs = kg_manager.stored_paragraph_hashes
|
||||
|
||||
for p_hash, raw_p in raw_paragraphs.items():
|
||||
if p_hash not in stored_embeds and p_hash not in stored_kgs:
|
||||
new_raw_paragraphs[p_hash] = raw_p
|
||||
new_triple_list_data[p_hash] = triple_list_data.get(p_hash, [])
|
||||
|
||||
if not new_raw_paragraphs:
|
||||
logger.info("没有新的段落需要处理。")
|
||||
else:
|
||||
logger.info(f"去重完成,发现 {len(new_raw_paragraphs)} 个新段落。")
|
||||
logger.info("开始生成 Embedding...")
|
||||
embed_manager.store_new_data_set(new_raw_paragraphs, new_triple_list_data)
|
||||
embed_manager.rebuild_faiss_index()
|
||||
embed_manager.save_to_file()
|
||||
logger.info("Embedding 处理完成!")
|
||||
|
||||
logger.info("开始构建 KG...")
|
||||
kg_manager.build_kg(new_triple_list_data, embed_manager)
|
||||
kg_manager.save_to_file()
|
||||
logger.info("KG 构建完成!")
|
||||
|
||||
logger.info("--- 数据导入完成 ---")
|
||||
|
||||
def import_from_specific_file():
|
||||
"""从用户指定的 openie.json 文件导入数据"""
|
||||
file_path = input("请输入 openie.json 文件的完整路径: ").strip()
|
||||
|
||||
if not os.path.exists(file_path):
|
||||
logger.error(f"文件路径不存在: {file_path}")
|
||||
return
|
||||
|
||||
if not file_path.endswith(".json"):
|
||||
logger.error("请输入一个有效的 .json 文件路径。")
|
||||
return
|
||||
|
||||
try:
|
||||
logger.info(f"正在从 {file_path} 加载 OpenIE 数据...")
|
||||
openie_obj = OpenIE.load(filepath=file_path)
|
||||
asyncio.run(import_data(openie_obj=openie_obj))
|
||||
except Exception as e:
|
||||
logger.error(f"从指定文件导入数据时发生错误: {e}")
|
||||
|
||||
# --- 主函数 ---
|
||||
|
||||
def main():
|
||||
# 使用 os.path.relpath 创建相对于项目根目录的友好路径
|
||||
raw_data_relpath = os.path.relpath(RAW_DATA_PATH, os.path.join(ROOT_PATH, ".."))
|
||||
openie_output_relpath = os.path.relpath(OPENIE_OUTPUT_DIR, os.path.join(ROOT_PATH, ".."))
|
||||
|
||||
print("=== LPMM 知识库学习工具 ===")
|
||||
print(f"1. [数据预处理] -> 读取 .txt 文件 (来源: ./{raw_data_relpath}/)")
|
||||
print(f"2. [信息提取] -> 提取信息并存为 .json (输出至: ./{openie_output_relpath}/)")
|
||||
print("3. [数据导入] -> 从 openie 文件夹自动导入最新知识")
|
||||
print("4. [全流程] -> 按顺序执行 1 -> 2 -> 3")
|
||||
print("5. [指定导入] -> 从特定的 openie.json 文件导入知识")
|
||||
print("0. [退出]")
|
||||
print("-" * 30)
|
||||
choice = input("请输入你的选择 (0-5): ").strip()
|
||||
|
||||
if choice == '1':
|
||||
preprocess_raw_data()
|
||||
elif choice == '2':
|
||||
paragraphs = preprocess_raw_data()
|
||||
if paragraphs: extract_information(paragraphs, model_config.model_task_config.lpmm_qa)
|
||||
elif choice == '3':
|
||||
asyncio.run(import_data())
|
||||
elif choice == '4':
|
||||
paragraphs = preprocess_raw_data()
|
||||
if paragraphs:
|
||||
extract_information(paragraphs, model_config.model_task_config.lpmm_qa)
|
||||
asyncio.run(import_data())
|
||||
elif choice == '5':
|
||||
import_from_specific_file()
|
||||
elif choice == '0':
|
||||
sys.exit(0)
|
||||
else:
|
||||
print("无效输入,请重新运行脚本。")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,78 +0,0 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
import sys # 新增系统模块导入
|
||||
from src.chat.knowledge.utils.hash import get_sha256
|
||||
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("lpmm")
|
||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
RAW_DATA_PATH = os.path.join(ROOT_PATH, "data/lpmm_raw_data")
|
||||
# IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, "data/imported_lpmm_data")
|
||||
|
||||
|
||||
def _process_text_file(file_path):
|
||||
"""处理单个文本文件,返回段落列表"""
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
raw = f.read()
|
||||
|
||||
paragraphs = []
|
||||
paragraph = ""
|
||||
for line in raw.split("\n"):
|
||||
if line.strip() == "":
|
||||
if paragraph != "":
|
||||
paragraphs.append(paragraph.strip())
|
||||
paragraph = ""
|
||||
else:
|
||||
paragraph += line + "\n"
|
||||
|
||||
if paragraph != "":
|
||||
paragraphs.append(paragraph.strip())
|
||||
|
||||
return paragraphs
|
||||
|
||||
|
||||
def _process_multi_files() -> list:
|
||||
raw_files = list(Path(RAW_DATA_PATH).glob("*.txt"))
|
||||
if not raw_files:
|
||||
logger.warning("警告: data/lpmm_raw_data 中没有找到任何 .txt 文件")
|
||||
sys.exit(1)
|
||||
# 处理所有文件
|
||||
all_paragraphs = []
|
||||
for file in raw_files:
|
||||
logger.info(f"正在处理文件: {file.name}")
|
||||
paragraphs = _process_text_file(file)
|
||||
all_paragraphs.extend(paragraphs)
|
||||
return all_paragraphs
|
||||
|
||||
|
||||
def load_raw_data() -> tuple[list[str], list[str]]:
|
||||
"""加载原始数据文件
|
||||
|
||||
读取原始数据文件,将原始数据加载到内存中
|
||||
|
||||
Args:
|
||||
path: 可选,指定要读取的json文件绝对路径
|
||||
|
||||
Returns:
|
||||
- raw_data: 原始数据列表
|
||||
- sha256_list: 原始数据的SHA256集合
|
||||
"""
|
||||
raw_data = _process_multi_files()
|
||||
sha256_list = []
|
||||
sha256_set = set()
|
||||
for item in raw_data:
|
||||
if not isinstance(item, str):
|
||||
logger.warning(f"数据类型错误:{item}")
|
||||
continue
|
||||
pg_hash = get_sha256(item)
|
||||
if pg_hash in sha256_set:
|
||||
logger.warning(f"重复数据:{item}")
|
||||
continue
|
||||
sha256_set.add(pg_hash)
|
||||
sha256_list.append(pg_hash)
|
||||
raw_data.append(item)
|
||||
logger.info(f"共读取到{len(raw_data)}条数据")
|
||||
|
||||
return sha256_list, raw_data
|
||||
@@ -205,6 +205,13 @@ class CycleProcessor:
|
||||
raise UserWarning(f"插件{result.get_summary().get('stopped_handlers', '')}于规划前中断了内容生成")
|
||||
with Timer("规划器", cycle_timers):
|
||||
actions, _ = await self.action_planner.plan(mode=mode)
|
||||
|
||||
# 在这里添加日志,清晰地显示最终选择的动作
|
||||
if actions:
|
||||
chosen_actions = [a.get("action_type", "unknown") for a in actions]
|
||||
logger.info(f"{self.log_prefix} LLM最终选择的动作: {chosen_actions}")
|
||||
else:
|
||||
logger.info(f"{self.log_prefix} LLM最终没有选择任何动作")
|
||||
|
||||
async def execute_action(action_info):
|
||||
"""执行单个动作的通用函数"""
|
||||
@@ -229,11 +236,13 @@ class CycleProcessor:
|
||||
|
||||
return {"action_type": "no_reply", "success": True, "reply_text": "", "command": ""}
|
||||
elif action_info["action_type"] != "reply" and action_info["action_type"] != "no_action":
|
||||
# 执行普通动作
|
||||
# 记录并执行普通动作
|
||||
reason = action_info.get("reasoning", f"执行动作 {action_info['action_type']}")
|
||||
logger.info(f"{self.log_prefix} 决定执行动作 '{action_info['action_type']}',内心思考: {reason}")
|
||||
with Timer("动作执行", cycle_timers):
|
||||
success, reply_text, command = await self._handle_action(
|
||||
action_info["action_type"],
|
||||
action_info["reasoning"],
|
||||
reason, # 使用已获取的reason
|
||||
action_info["action_data"],
|
||||
cycle_timers,
|
||||
thinking_id,
|
||||
@@ -248,6 +257,8 @@ class CycleProcessor:
|
||||
else:
|
||||
# 生成回复
|
||||
try:
|
||||
reason = action_info.get("reasoning", "决定进行回复")
|
||||
logger.info(f"{self.log_prefix} 决定进行回复,内心思考: {reason}")
|
||||
success, response_set, _ = await generator_api.generate_reply(
|
||||
chat_stream=self.context.chat_stream,
|
||||
reply_message=action_info["action_message"],
|
||||
@@ -302,8 +313,18 @@ class CycleProcessor:
|
||||
if not action_message:
|
||||
logger.warning(f"{self.log_prefix} reply 动作缺少 action_message,跳过")
|
||||
continue
|
||||
|
||||
# 检查是否是空的DatabaseMessages对象
|
||||
if hasattr(action_message, 'chat_info') and hasattr(action_message.chat_info, 'user_info'):
|
||||
target_user_id = action_message.chat_info.user_info.user_id
|
||||
else:
|
||||
# 如果是字典格式,使用原来的方式
|
||||
target_user_id = action_message.get("chat_info_user_id", "")
|
||||
|
||||
if not target_user_id:
|
||||
logger.warning(f"{self.log_prefix} reply 动作的 action_message 缺少用户ID,跳过")
|
||||
continue
|
||||
|
||||
target_user_id = action_message.get("chat_info_user_id","")
|
||||
if target_user_id == global_config.bot.qq_account and not global_config.chat.allow_reply_self:
|
||||
logger.warning("选取的reply的目标为bot自己,跳过reply action")
|
||||
continue
|
||||
|
||||
@@ -159,27 +159,60 @@ class ProactiveThinker:
|
||||
|
||||
news_block = "暂时没有获取到最新资讯。"
|
||||
if trigger_event.source != "reminder_system":
|
||||
try:
|
||||
web_search_tool = tool_api.get_tool_instance("web_search")
|
||||
if web_search_tool:
|
||||
try:
|
||||
search_result_dict = await web_search_tool.execute(function_args={"keyword": topic, "max_results": 10})
|
||||
except TypeError:
|
||||
try:
|
||||
search_result_dict = await web_search_tool.execute(function_args={"keyword": topic, "max_results": 10})
|
||||
except TypeError:
|
||||
logger.warning(f"{self.context.log_prefix} 网络搜索工具参数不匹配,跳过搜索")
|
||||
news_block = "跳过网络搜索。"
|
||||
search_result_dict = None
|
||||
|
||||
if search_result_dict and not search_result_dict.get("error"):
|
||||
news_block = search_result_dict.get("content", "未能提取有效资讯。")
|
||||
elif search_result_dict:
|
||||
logger.warning(f"{self.context.log_prefix} 网络搜索返回错误: {search_result_dict.get('error')}")
|
||||
else:
|
||||
logger.warning(f"{self.context.log_prefix} 未找到 web_search 工具实例。")
|
||||
except Exception as e:
|
||||
logger.error(f"{self.context.log_prefix} 主动思考时网络搜索失败: {e}")
|
||||
# 升级决策模型
|
||||
should_search_prompt = f"""
|
||||
# 搜索决策
|
||||
|
||||
## 任务
|
||||
分析话题“{topic}”,判断它的展开更依赖于“外部信息”还是“内部信息”,并决定是否需要进行网络搜索。
|
||||
|
||||
## 判断原则
|
||||
- **需要搜索 (SEARCH)**:当话题的有效讨论**必须**依赖于现实世界的、客观的、可被检索的外部信息时。这包括但不限于:
|
||||
- 新闻时事、公共事件
|
||||
- 专业知识、科学概念
|
||||
- 天气、股价等实时数据
|
||||
- 对具体实体(如电影、书籍、地点)的客观描述查询
|
||||
|
||||
- **无需搜索 (SKIP)**:当话题的展开主要依赖于**已有的对话上下文、个人情感、主观体验或社交互动**时。这包括但不限于:
|
||||
- 延续之前的对话、追问细节
|
||||
- 表达关心、问候或个人感受
|
||||
- 分享主观看法或经历
|
||||
- 纯粹的社交性互动
|
||||
|
||||
## 你的决策
|
||||
根据以上原则,对“{topic}”这个话题进行分析,并严格输出`SEARCH`或`SKIP`。
|
||||
"""
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config
|
||||
|
||||
decision_llm = LLMRequest(
|
||||
model_set=model_config.model_task_config.planner,
|
||||
request_type="planner"
|
||||
)
|
||||
|
||||
decision, _ = await decision_llm.generate_response_async(prompt=should_search_prompt)
|
||||
|
||||
if "SEARCH" in decision:
|
||||
try:
|
||||
if topic and topic.strip():
|
||||
web_search_tool = tool_api.get_tool_instance("web_search")
|
||||
if web_search_tool:
|
||||
try:
|
||||
search_result_dict = await web_search_tool.execute(
|
||||
function_args={"query": topic, "max_results": 10}
|
||||
)
|
||||
if search_result_dict and not search_result_dict.get("error"):
|
||||
news_block = search_result_dict.get("content", "未能提取有效资讯。")
|
||||
elif search_result_dict:
|
||||
logger.warning(f"{self.context.log_prefix} 网络搜索返回错误: {search_result_dict.get('error')}")
|
||||
except Exception as e:
|
||||
logger.error(f"{self.context.log_prefix} 网络搜索执行失败: {e}")
|
||||
else:
|
||||
logger.warning(f"{self.context.log_prefix} 未找到 web_search 工具实例。")
|
||||
else:
|
||||
logger.warning(f"{self.context.log_prefix} 主题为空,跳过网络搜索。")
|
||||
except Exception as e:
|
||||
logger.error(f"{self.context.log_prefix} 主动思考时网络搜索失败: {e}")
|
||||
message_list = get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=self.context.stream_id,
|
||||
timestamp=time.time(),
|
||||
@@ -201,15 +234,17 @@ class ProactiveThinker:
|
||||
{chat_context_block}
|
||||
|
||||
## 合理判断标准
|
||||
请检查以下条件,如果**大部分条件都合理**就可以回复:
|
||||
请检查以下条件,如果**所有条件都合理**就可以回复:
|
||||
|
||||
1. **时间合理性**:当前时间是否在深夜(凌晨2点-6点)这种不适合主动聊天的时段?
|
||||
2. **内容价值**:这个话题"{topic}"是否有意义,不是完全无关紧要的内容?
|
||||
3. **重复避免**:你准备说的话题是否与最近2条消息明显重复?
|
||||
4. **自然性**:在当前上下文中主动提起这个话题是否自然合理?
|
||||
1. **回应检查**:检查你({bot_name})发送的最后一条消息之后,是否有其他人发言。如果没有,则大概率应该保持沉默。
|
||||
2. **话题补充**:只有当你认为准备发起的话题是对上一条无人回应消息的**有价值的补充**时,才可以在上一条消息无人回应的情况下继续发言。
|
||||
3. **时间合理性**:当前时间是否在深夜(凌晨2点-6点)这种不适合主动聊天的时段?
|
||||
4. **内容价值**:这个话题"{topic}"是否有意义,不是完全无关紧要的内容?
|
||||
5. **重复避免**:你准备说的话题是否与你自己的上一条消息明显重复?
|
||||
6. **自然性**:在当前上下文中主动提起这个话题是否自然合理?
|
||||
|
||||
## 输出要求
|
||||
如果判断应该跳过(比如深夜时段、完全无意义话题、明显重复内容),输出:SKIP_PROACTIVE_REPLY
|
||||
如果判断应该跳过(比如上一条消息无人回应、深夜时段、无意义话题、重复内容),输出:SKIP_PROACTIVE_REPLY
|
||||
其他情况都应该输出:PROCEED_TO_REPLY
|
||||
|
||||
请严格按照上述格式输出,不要添加任何解释。"""
|
||||
@@ -259,6 +294,8 @@ class ProactiveThinker:
|
||||
- 如果有什么想分享的想法,就自然地开启话题
|
||||
- 如果只是想闲聊,就随意地说些什么
|
||||
|
||||
**重要**:如果获取到了最新的网络信息(news_block不为空),请**自然地**将这些信息融入你的回复中,作为话题的补充或引子,而不是生硬地复述。
|
||||
|
||||
## 要求
|
||||
- 像真正的朋友一样,自然地表达关心或好奇
|
||||
- 不要过于正式,要口语化和亲切
|
||||
|
||||
59
src/chat/emoji_system/emoji_history.py
Normal file
59
src/chat/emoji_system/emoji_history.py
Normal file
@@ -0,0 +1,59 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
表情包发送历史记录模块
|
||||
"""
|
||||
import os
|
||||
from typing import List, Dict
|
||||
from collections import deque
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("EmojiHistory")
|
||||
|
||||
MAX_HISTORY_SIZE = 5 # 每个聊天会话最多保留最近5条表情历史
|
||||
|
||||
# 使用一个全局字典在内存中存储历史记录
|
||||
# 键是 chat_id,值是一个 deque 对象
|
||||
_history_cache: Dict[str, deque] = {}
|
||||
|
||||
|
||||
def add_emoji_to_history(chat_id: str, emoji_description: str):
|
||||
"""
|
||||
将发送的表情包添加到内存历史记录中。
|
||||
|
||||
:param chat_id: 聊天会话ID (例如 "private_12345" 或 "group_67890")
|
||||
:param emoji_description: 发送的表情包的描述
|
||||
"""
|
||||
if not chat_id or not emoji_description:
|
||||
return
|
||||
|
||||
# 如果当前聊天还没有历史记录,则创建一个新的 deque
|
||||
if chat_id not in _history_cache:
|
||||
_history_cache[chat_id] = deque(maxlen=MAX_HISTORY_SIZE)
|
||||
|
||||
# 添加新表情到历史记录
|
||||
history = _history_cache[chat_id]
|
||||
history.append(emoji_description)
|
||||
|
||||
logger.debug(f"已将表情 '{emoji_description}' 添加到聊天 {chat_id} 的内存历史中")
|
||||
|
||||
|
||||
def get_recent_emojis(chat_id: str, limit: int = 5) -> List[str]:
|
||||
"""
|
||||
从内存中获取最近发送的表情包描述列表。
|
||||
|
||||
:param chat_id: 聊天会话ID
|
||||
:param limit: 获取的表情数量上限
|
||||
:return: 最近发送的表情包描述列表
|
||||
"""
|
||||
if not chat_id or chat_id not in _history_cache:
|
||||
return []
|
||||
|
||||
history = _history_cache[chat_id]
|
||||
|
||||
# 从 deque 的右侧(即最近添加的)开始取
|
||||
num_to_get = min(limit, len(history))
|
||||
recent_emojis = [history[-i] for i in range(1, num_to_get + 1)]
|
||||
|
||||
logger.debug(f"为聊天 {chat_id} 从内存中获取到最近 {len(recent_emojis)} 个表情: {recent_emojis}")
|
||||
return recent_emojis
|
||||
@@ -439,105 +439,103 @@ class EmojiManager:
|
||||
logger.error(f"记录表情使用失败: {str(e)}")
|
||||
|
||||
async def get_emoji_for_text(self, text_emotion: str) -> Optional[Tuple[str, str, str]]:
|
||||
"""根据文本内容获取相关表情包
|
||||
"""
|
||||
根据文本内容,使用LLM选择一个合适的表情包。
|
||||
|
||||
Args:
|
||||
text_emotion: 输入的情感描述文本
|
||||
text_emotion (str): LLM希望表达的情感或意图的文本描述。
|
||||
|
||||
Returns:
|
||||
Optional[Tuple[str, str]]: (表情包完整文件路径, 表情包描述),如果没有找到则返回None
|
||||
Optional[Tuple[str, str, str]]: 返回一个元组,包含所选表情包的 (文件路径, 描述, 匹配的情感描述),
|
||||
如果未找到合适的表情包,则返回 None。
|
||||
"""
|
||||
try:
|
||||
_time_start = time.time()
|
||||
|
||||
# 获取所有表情包 (从内存缓存中获取)
|
||||
all_emojis = self.emoji_objects
|
||||
|
||||
# 1. 从内存中获取所有可用的表情包对象
|
||||
all_emojis = [emoji for emoji in self.emoji_objects if not emoji.is_deleted and emoji.description]
|
||||
if not all_emojis:
|
||||
logger.warning("内存中没有任何表情包对象")
|
||||
logger.warning("内存中没有任何可用的表情包对象")
|
||||
return None
|
||||
|
||||
# 计算每个表情包与输入文本的最大情感相似度
|
||||
emoji_similarities = []
|
||||
for emoji in all_emojis:
|
||||
# 跳过已标记为删除的对象
|
||||
if emoji.is_deleted:
|
||||
continue
|
||||
# 2. 根据全局配置决定候选表情包的数量
|
||||
max_candidates = global_config.emoji.max_emoji_for_llm_select
|
||||
|
||||
emotions = emoji.emotion
|
||||
if not emotions:
|
||||
continue
|
||||
# 如果配置为0或者大于等于总数,则选择所有表情包
|
||||
if max_candidates <= 0 or max_candidates >= len(all_emojis):
|
||||
candidate_emojis = all_emojis
|
||||
else:
|
||||
# 否则,从所有表情包中随机抽取指定数量
|
||||
candidate_emojis = random.sample(all_emojis, max_candidates)
|
||||
|
||||
# 计算与每个emotion标签的相似度,取最大值
|
||||
max_similarity = 0
|
||||
best_matching_emotion = ""
|
||||
for emotion in emotions:
|
||||
# 使用编辑距离计算相似度
|
||||
distance = self._levenshtein_distance(text_emotion, emotion)
|
||||
max_len = max(len(text_emotion), len(emotion))
|
||||
similarity = 1 - (distance / max_len if max_len > 0 else 0)
|
||||
if similarity > max_similarity:
|
||||
max_similarity = similarity
|
||||
best_matching_emotion = emotion
|
||||
|
||||
if best_matching_emotion:
|
||||
emoji_similarities.append((emoji, max_similarity, best_matching_emotion))
|
||||
|
||||
# 按相似度降序排序
|
||||
emoji_similarities.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
# 获取前10个最相似的表情包
|
||||
top_emojis = emoji_similarities[:10] if len(emoji_similarities) > 10 else emoji_similarities
|
||||
|
||||
if not top_emojis:
|
||||
logger.warning("未找到匹配的表情包")
|
||||
# 确保候选列表不为空
|
||||
if not candidate_emojis:
|
||||
logger.warning("未能选出任何候选表情包")
|
||||
return None
|
||||
|
||||
# 从前几个中随机选择一个
|
||||
selected_emoji, similarity, matched_emotion = random.choice(top_emojis)
|
||||
# 3. 构建用于LLM决策的prompt
|
||||
emoji_options_str = ""
|
||||
for i, emoji in enumerate(candidate_emojis):
|
||||
# 为每个表情包创建一个编号和它的详细描述
|
||||
emoji_options_str += f"编号: {i+1}\n描述: {emoji.description}\n\n"
|
||||
|
||||
# 更新使用次数
|
||||
# 精心设计的prompt,引导LLM做出选择
|
||||
prompt = f"""
|
||||
你是一个聊天机器人,你需要根据你想要表达的情感,从一个表情包列表中选择最合适的一个。
|
||||
|
||||
# 你的任务
|
||||
根据下面提供的“你想表达的描述”,在“表情包选项”中选择一个最符合该描述的表情包。
|
||||
|
||||
# 你想表达的描述
|
||||
{text_emotion}
|
||||
|
||||
# 表情包选项
|
||||
{emoji_options_str}
|
||||
|
||||
# 规则
|
||||
1. 仔细阅读“你想表达的描述”和每一个“表情包选项”的详细描述。
|
||||
2. 选择一个编号,该编号对应的表情包必须最贴切地反映出你想表达的情感、内容或网络文化梗。
|
||||
3. 你的回答必须且只能是一个格式为 "选择编号:X" 的字符串,其中X是你选择的表情包编号。
|
||||
4. 不要输出任何其他解释或无关内容。
|
||||
|
||||
现在,请做出你的选择:
|
||||
"""
|
||||
|
||||
# 4. 调用LLM进行决策
|
||||
decision, _ = await self.llm_emotion_judge.generate_response_async(prompt, temperature=0.5, max_tokens=20)
|
||||
logger.info(f"LLM选择的描述: {text_emotion}")
|
||||
logger.info(f"LLM决策结果: {decision}")
|
||||
|
||||
# 5. 解析LLM的决策结果
|
||||
match = re.search(r"(\d+)", decision)
|
||||
if not match:
|
||||
logger.error(f"无法从LLM的决策中解析出编号: {decision}")
|
||||
return None
|
||||
|
||||
selected_index = int(match.group(1)) - 1
|
||||
|
||||
# 6. 验证选择的编号是否有效
|
||||
if not (0 <= selected_index < len(candidate_emojis)):
|
||||
logger.error(f"LLM返回了无效的表情包编号: {selected_index + 1}")
|
||||
return None
|
||||
|
||||
# 7. 获取选中的表情包并更新使用记录
|
||||
selected_emoji = candidate_emojis[selected_index]
|
||||
self.record_usage(selected_emoji.hash)
|
||||
|
||||
_time_end = time.time()
|
||||
|
||||
logger.info(
|
||||
f"为[{text_emotion}]找到表情包: {matched_emotion} ({selected_emoji.filename}), Similarity: {similarity:.4f}"
|
||||
f"找到匹配描述的表情包: {selected_emoji.description}, 耗时: {(_time_end - _time_start):.2f}s"
|
||||
)
|
||||
# 返回完整文件路径和描述
|
||||
return selected_emoji.full_path, f"[ {selected_emoji.description} ]", matched_emotion
|
||||
|
||||
# 8. 返回选中的表情包信息
|
||||
return selected_emoji.full_path, f"[表情包:{selected_emoji.description}]", text_emotion
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[错误] 获取表情包失败: {str(e)}")
|
||||
logger.error(f"使用LLM获取表情包时发生错误: {str(e)}")
|
||||
logger.error(traceback.format_exc())
|
||||
return None
|
||||
|
||||
def _levenshtein_distance(self, s1: str, s2: str) -> int:
|
||||
# sourcery skip: simplify-empty-collection-comparison, simplify-len-comparison, simplify-str-len-comparison
|
||||
"""计算两个字符串的编辑距离
|
||||
|
||||
Args:
|
||||
s1: 第一个字符串
|
||||
s2: 第二个字符串
|
||||
|
||||
Returns:
|
||||
int: 编辑距离
|
||||
"""
|
||||
if len(s1) < len(s2):
|
||||
return self._levenshtein_distance(s2, s1)
|
||||
|
||||
if len(s2) == 0:
|
||||
return len(s1)
|
||||
|
||||
previous_row = range(len(s2) + 1)
|
||||
for i, c1 in enumerate(s1):
|
||||
current_row = [i + 1]
|
||||
for j, c2 in enumerate(s2):
|
||||
insertions = previous_row[j + 1] + 1
|
||||
deletions = current_row[j] + 1
|
||||
substitutions = previous_row[j] + (c1 != c2)
|
||||
current_row.append(min(insertions, deletions, substitutions))
|
||||
previous_row = current_row
|
||||
|
||||
return previous_row[-1]
|
||||
|
||||
async def check_emoji_file_integrity(self) -> None:
|
||||
"""检查表情包文件完整性
|
||||
遍历self.emoji_objects中的所有对象,检查文件是否存在
|
||||
@@ -627,11 +625,10 @@ class EmojiManager:
|
||||
await asyncio.sleep(global_config.emoji.check_interval * 60)
|
||||
continue
|
||||
|
||||
# 检查是否需要处理表情包(数量超过最大值或不足)
|
||||
if global_config.emoji.steal_emoji and (
|
||||
(self.emoji_num > self.emoji_num_max and global_config.emoji.do_replace)
|
||||
or (self.emoji_num < self.emoji_num_max)
|
||||
):
|
||||
# 无论steal_emoji是否开启,都检查emoji文件夹以支持手动注册
|
||||
# 只有在需要腾出空间或填充表情库时,才真正执行注册
|
||||
if (self.emoji_num > self.emoji_num_max and global_config.emoji.do_replace) or \
|
||||
(self.emoji_num < self.emoji_num_max):
|
||||
try:
|
||||
# 获取目录下所有图片文件
|
||||
files_to_process = [
|
||||
@@ -646,7 +643,7 @@ class EmojiManager:
|
||||
# 尝试注册表情包
|
||||
success = await self.register_emoji_by_filename(filename)
|
||||
if success:
|
||||
# 注册成功则跳出循环
|
||||
# 注册成功则跳出循环,等待下一个检查周期
|
||||
break
|
||||
|
||||
# 注册失败则删除对应文件
|
||||
@@ -914,110 +911,114 @@ class EmojiManager:
|
||||
return False
|
||||
|
||||
async def build_emoji_description(self, image_base64: str) -> Tuple[str, List[str]]:
|
||||
"""获取表情包描述和情感列表,优化复用已有描述
|
||||
"""
|
||||
获取表情包的详细描述和情感关键词列表。
|
||||
|
||||
该函数首先使用VLM(视觉语言模型)对图片进行深入分析,生成一份包含文化、Meme内涵的详细描述。
|
||||
然后,它会调用另一个LLM,基于这份详细描述,提炼出几个核心的、简洁的情感关键词。
|
||||
最终返回详细描述和关键词列表,为后续的表情包选择提供丰富且精准的信息。
|
||||
|
||||
Args:
|
||||
image_base64: 图片的base64编码
|
||||
image_base64 (str): 图片的Base64编码字符串。
|
||||
|
||||
Returns:
|
||||
Tuple[str, list]: 返回表情包描述和情感列表
|
||||
Tuple[str, List[str]]: 返回一个元组,第一个元素是详细描述,第二个元素是情感关键词列表。
|
||||
如果处理失败,则返回空的描述和列表。
|
||||
"""
|
||||
try:
|
||||
# 解码图片并获取格式
|
||||
# 确保base64字符串只包含ASCII字符
|
||||
# 1. 解码图片,计算哈希值,并获取格式
|
||||
if isinstance(image_base64, str):
|
||||
image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii")
|
||||
image_bytes = base64.b64decode(image_base64)
|
||||
image_hash = hashlib.md5(image_bytes).hexdigest()
|
||||
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore
|
||||
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() if Image.open(io.BytesIO(image_bytes)).format else "jpeg"
|
||||
|
||||
# 尝试从Images表获取已有的详细描述(可能在收到表情包时已生成)
|
||||
|
||||
# 2. 检查数据库中是否已存在该表情包的描述,实现复用
|
||||
existing_description = None
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
# from src.common.database.database_model_compat import Images
|
||||
|
||||
existing_image = (
|
||||
session.query(Images)
|
||||
.filter((Images.emoji_hash == image_hash) & (Images.type == "emoji"))
|
||||
.one_or_none()
|
||||
)
|
||||
existing_image = session.query(Images).filter(
|
||||
(Images.emoji_hash == image_hash) & (Images.type == "emoji")
|
||||
).one_or_none()
|
||||
if existing_image and existing_image.description:
|
||||
existing_description = existing_image.description
|
||||
logger.info(f"[复用描述] 找到已有详细描述: {existing_description[:50]}...")
|
||||
except Exception as e:
|
||||
logger.debug(f"查询已有描述时出错: {e}")
|
||||
logger.debug(f"查询已有表情包描述时出错: {e}")
|
||||
|
||||
# 第一步:VLM视觉分析(如果没有已有描述才调用)
|
||||
# 3. 如果没有现有描述,则调用VLM生成新的详细描述
|
||||
if existing_description:
|
||||
description = existing_description
|
||||
logger.info("[优化] 复用已有的详细描述,跳过VLM调用")
|
||||
else:
|
||||
logger.info("[VLM分析] 生成新的详细描述")
|
||||
logger.info("[VLM分析] 开始为新表情包生成详细描述")
|
||||
# 为动态图(GIF)和静态图构建不同的、要求简洁的prompt
|
||||
if image_format in ["gif", "GIF"]:
|
||||
image_base64 = get_image_manager().transform_gif(image_base64) # type: ignore
|
||||
if not image_base64:
|
||||
image_base64_frames = get_image_manager().transform_gif(image_base64)
|
||||
if not image_base64_frames:
|
||||
raise RuntimeError("GIF表情包转换失败")
|
||||
prompt = "这是一个动态图表情包,每一张图代表了动态图的某一帧,黑色背景代表透明,描述一下表情包表达的情感和内容,描述细节,从互联网梗,meme的角度去分析"
|
||||
prompt = "这是一个GIF动图表情包的关键帧。请用不超过250字,详细描述它的核心内容:1. 动态画面展现了什么变化?2. 它传达了什么核心情绪或玩的是什么梗?3. 通常在什么场景下使用?请确保描述既包含关键信息,又能充分展现其内涵。"
|
||||
description, _ = await self.vlm.generate_response_for_image(
|
||||
prompt, image_base64, "jpeg", temperature=0.3, max_tokens=1000
|
||||
prompt, image_base64_frames, "jpeg", temperature=0.3, max_tokens=600
|
||||
)
|
||||
else:
|
||||
prompt = (
|
||||
"这是一个表情包,请详细描述一下表情包所表达的情感和内容,描述细节,从互联网梗,meme的角度去分析"
|
||||
)
|
||||
prompt = "这是一个表情包。请用不超过250字,详细描述它的核心内容:1. 画面描绘了什么?2. 它传达了什么核心情绪或玩的是什么梗?3. 通常在什么场景下使用?请确保描述既包含关键信息,又能充分展现其内涵。"
|
||||
description, _ = await self.vlm.generate_response_for_image(
|
||||
prompt, image_base64, image_format, temperature=0.3, max_tokens=1000
|
||||
prompt, image_base64, image_format, temperature=0.3, max_tokens=600
|
||||
)
|
||||
|
||||
# 审核表情包
|
||||
# 4. 内容审核,确保表情包符合规定
|
||||
if global_config.emoji.content_filtration:
|
||||
prompt = f'''
|
||||
这是一个表情包,请对这个表情包进行审核,标准如下:
|
||||
1. 必须符合"{global_config.emoji.filtration_prompt}"的要求
|
||||
2. 不能是色情、暴力、等违法违规内容,必须符合公序良俗
|
||||
3. 不能是任何形式的截图,聊天记录或视频截图
|
||||
4. 不要出现5个以上文字
|
||||
请回答这个表情包是否满足上述要求,是则回答是,否则回答否,不要出现任何其他内容
|
||||
请根据以下标准审核这个表情包:
|
||||
1. 主题必须符合:"{global_config.emoji.filtration_prompt}"。
|
||||
2. 内容健康,不含色情、暴力、政治敏感等元素。
|
||||
3. 必须是表情包,而不是普通的聊天截图或视频截图。
|
||||
4. 表情包中的文字数量(如果有)不能超过5个。
|
||||
这个表情包是否完全满足以上所有要求?请只回答“是”或“否”。
|
||||
'''
|
||||
content, _ = await self.vlm.generate_response_for_image(
|
||||
prompt, image_base64, image_format, temperature=0.3, max_tokens=1000
|
||||
prompt, image_base64, image_format, temperature=0.1, max_tokens=10
|
||||
)
|
||||
if content == "否":
|
||||
if "否" in content:
|
||||
logger.warning(f"表情包审核未通过,内容: {description[:50]}...")
|
||||
return "", []
|
||||
|
||||
# 第二步:LLM情感分析 - 基于详细描述生成情感标签列表(可选)
|
||||
# 5. 基于VLM的详细描述,调用LLM提炼情感关键词
|
||||
emotions = []
|
||||
if global_config.emoji.enable_emotion_analysis:
|
||||
logger.info("[情感分析] 启用表情包感情关键词二次识别")
|
||||
logger.info("[情感分析] 开始提炼表情包的情感关键词")
|
||||
emotion_prompt = f"""
|
||||
请你识别这个表情包的含义和适用场景,给我简短的描述,每个描述不要超过15个字
|
||||
这是一个基于这个表情包的描述:'{description}'
|
||||
你可以关注其幽默和讽刺意味,动用贴吧,微博,小红书的知识,必须从互联网梗,meme的角度去分析
|
||||
请直接输出描述,不要出现任何其他内容,如果有多个描述,可以用逗号分隔
|
||||
你是一个互联网“梗”学家和情感分析师。
|
||||
这里有一份关于某个表情包的详细描述:
|
||||
---
|
||||
{description}
|
||||
---
|
||||
请你基于这份描述,提炼出这个表情包最核心的含义和适用场景。
|
||||
|
||||
你的任务是:
|
||||
1. 分析并总结出3到5个最能代表这个表情包的关键词或短语。
|
||||
2. 这些关键词应该非常凝练,比如“表达无语”、“有点小得意”、“求夸奖”、“猫猫疑惑”等。
|
||||
3. 每个关键词不要超过15个字。
|
||||
4. 请直接输出这些关键词,并用逗号分隔,不要添加任何其他解释。
|
||||
"""
|
||||
emotions_text, _ = await self.llm_emotion_judge.generate_response_async(
|
||||
emotion_prompt, temperature=0.7, max_tokens=600
|
||||
emotion_prompt, temperature=0.6, max_tokens=150
|
||||
)
|
||||
|
||||
# 处理情感列表
|
||||
emotions = [e.strip() for e in emotions_text.split(",") if e.strip()]
|
||||
|
||||
# 根据情感标签数量随机选择 - 超过5个选3个,超过2个选2个
|
||||
if len(emotions) > 5:
|
||||
emotions = random.sample(emotions, 3)
|
||||
elif len(emotions) > 2:
|
||||
emotions = random.sample(emotions, 2)
|
||||
else:
|
||||
logger.info("[情感分析] 表情包感情关键词二次识别已禁用")
|
||||
emotions = []
|
||||
logger.info("[情感分析] 表情包感情关键词二次识别已禁用,跳过此步骤")
|
||||
|
||||
logger.info(f"[注册分析] 详细描述: {description[:50]}... -> 情感标签: {emotions}")
|
||||
# 6. 格式化最终的描述,并返回结果
|
||||
final_description = f"表情包,关键词:[{','.join(emotions)}]。详细描述:{description}"
|
||||
logger.info(f"[注册分析] VLM描述: {description} -> 提炼出的情感标签: {emotions}")
|
||||
|
||||
return f"[表情包:{description}]", emotions
|
||||
return final_description, emotions
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取表情包描述失败: {str(e)}")
|
||||
logger.error(f"构建表情包描述时发生严重错误: {str(e)}")
|
||||
logger.error(traceback.format_exc())
|
||||
return "", []
|
||||
|
||||
async def register_emoji_by_filename(self, filename: str) -> bool:
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import time
|
||||
from typing import Tuple, List, Dict, Optional
|
||||
from typing import Tuple, List, Dict, Optional, Any
|
||||
|
||||
from .global_logger import logger
|
||||
from .embedding_store import EmbeddingManager
|
||||
@@ -98,30 +98,46 @@ class QAManager:
|
||||
|
||||
return result, ppr_node_weights
|
||||
|
||||
async def get_knowledge(self, question: str) -> Optional[str]:
|
||||
"""获取知识"""
|
||||
# 处理查询
|
||||
processed_result = await self.process_query(question)
|
||||
if processed_result is not None:
|
||||
query_res = processed_result[0]
|
||||
# 检查查询结果是否为空
|
||||
if not query_res:
|
||||
logger.debug("知识库查询结果为空,可能是知识库中没有相关内容")
|
||||
return None
|
||||
async def get_knowledge(self, question: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
获取知识,返回结构化字典
|
||||
|
||||
Args:
|
||||
question: 用户提出的问题
|
||||
|
||||
knowledge = [
|
||||
(
|
||||
self.embed_manager.paragraphs_embedding_store.store[res[0]].str,
|
||||
res[1],
|
||||
)
|
||||
for res in query_res
|
||||
]
|
||||
found_knowledge = "\n".join(
|
||||
[f"第{i + 1}条知识:{k[0]}\n 该条知识对于问题的相关性:{k[1]}" for i, k in enumerate(knowledge)]
|
||||
)
|
||||
if len(found_knowledge) > MAX_KNOWLEDGE_LENGTH:
|
||||
found_knowledge = found_knowledge[:MAX_KNOWLEDGE_LENGTH] + "\n"
|
||||
return found_knowledge
|
||||
else:
|
||||
logger.debug("LPMM知识库并未初始化,可能是从未导入过知识...")
|
||||
Returns:
|
||||
一个包含 'knowledge_items' 和 'summary' 的字典,或者在没有结果时返回 None
|
||||
"""
|
||||
processed_result = await self.process_query(question)
|
||||
if not processed_result or not processed_result[0]:
|
||||
logger.debug("知识库查询结果为空。")
|
||||
return None
|
||||
|
||||
query_res = processed_result[0]
|
||||
|
||||
knowledge_items = []
|
||||
for res_hash, relevance, *_ in query_res:
|
||||
if store_item := self.embed_manager.paragraphs_embedding_store.store.get(res_hash):
|
||||
knowledge_items.append({
|
||||
"content": store_item.str,
|
||||
"source": "内部知识库",
|
||||
"relevance": f"{relevance:.4f}"
|
||||
})
|
||||
|
||||
if not knowledge_items:
|
||||
return None
|
||||
|
||||
# 使用LLM生成总结
|
||||
knowledge_text_for_summary = "\n\n".join([item['content'] for item in knowledge_items[:5]]) # 最多总结前5条
|
||||
summary_prompt = f"根据以下信息,为问题 '{question}' 生成一个简洁的、不超过50字的摘要:\n\n{knowledge_text_for_summary}"
|
||||
|
||||
try:
|
||||
summary, (_, _, _) = await self.qa_model.generate_response_async(summary_prompt)
|
||||
except Exception as e:
|
||||
logger.error(f"生成知识摘要失败: {e}")
|
||||
summary = "无法生成摘要。"
|
||||
|
||||
return {
|
||||
"knowledge_items": knowledge_items,
|
||||
"summary": summary.strip() if summary else "没有可用的摘要。"
|
||||
}
|
||||
|
||||
@@ -263,7 +263,15 @@ class PlanFilter:
|
||||
target_message_dict = self._get_latest_message(message_id_list)
|
||||
|
||||
if target_message_dict:
|
||||
# 直接使用字典作为action_message,避免DatabaseMessages对象创建失败
|
||||
target_message_obj = target_message_dict
|
||||
else:
|
||||
# 如果找不到目标消息,对于reply动作来说这是必需的,应该记录警告
|
||||
if action == "reply":
|
||||
logger.warning(f"reply动作找不到目标消息,target_message_id: {action_json.get('target_message_id')}")
|
||||
# 将reply动作改为no_action,避免后续执行时出错
|
||||
action = "no_action"
|
||||
reasoning = f"找不到目标消息进行回复。原始理由: {reasoning}"
|
||||
|
||||
available_action_names = list(plan.available_actions.keys())
|
||||
if action not in ["no_action", "no_reply", "reply", "do_nothing", "proactive_reply"] and action not in available_action_names:
|
||||
|
||||
@@ -82,6 +82,15 @@ class ActionPlanner:
|
||||
)
|
||||
|
||||
final_actions_dict = [asdict(act) for act in final_actions]
|
||||
final_target_message_dict = asdict(final_target_message) if final_target_message else None
|
||||
# action_message现在可能是字典而不是dataclass实例,需要特殊处理
|
||||
if final_target_message:
|
||||
if hasattr(final_target_message, '__dataclass_fields__'):
|
||||
# 如果是dataclass实例,使用asdict转换
|
||||
final_target_message_dict = asdict(final_target_message)
|
||||
else:
|
||||
# 如果已经是字典,直接使用
|
||||
final_target_message_dict = final_target_message
|
||||
else:
|
||||
final_target_message_dict = None
|
||||
|
||||
return final_actions_dict, final_target_message_dict
|
||||
|
||||
@@ -18,7 +18,6 @@ def init_prompts():
|
||||
# 并要求模型以 JSON 格式输出一个或多个动作组合。
|
||||
Prompt(
|
||||
"""
|
||||
{schedule_block}
|
||||
{mood_block}
|
||||
{time_block}
|
||||
{identity_block}
|
||||
@@ -36,10 +35,25 @@ def init_prompts():
|
||||
2. **辅助动作 (可选)**: 这是为了增强表达效果的附加动作,例如 `emoji`(发送表情包)或 `poke_user`(戳一戳)。
|
||||
|
||||
**决策流程:**
|
||||
1. 首先,决定是否要进行 `reply`。
|
||||
2. 然后,评估当前的对话气氛和用户情绪,判断是否需要一个**辅助动作**来让你的回应更生动、更符合你的性格。
|
||||
3. 如果需要,选择一个最合适的辅助动作与 `reply` 组合。
|
||||
4. 如果用户明确要求了某个动作,请务必优先满足。
|
||||
1. **最高优先级检查**: 首先,检查是否有由 **关键词** 或 **LLM判断** 激活的特定动作(除了通用的 `reply`, `emoji` 等)。这些动作代表了用户的明确意图。
|
||||
2. **执行明确意图**: 如果存在这类特定动作,你 **必须** 优先选择它作为主要响应。这比常规的文本回复 (`reply`) 更重要。
|
||||
3. **常规回复**: 如果没有被特定意图激活的动作,再决定是否要进行 `reply`。
|
||||
4. **辅助动作**: 在确定了主要动作后(无论是特定动作还是 `reply`),再评估是否需要 `emoji` 或 `poke_user` 等辅助动作来增强表达效果。
|
||||
5. **互斥原则**: 当你选择了一个由明确意图激活的特定动作(如 `set_reminder`)时,你 **绝不能** 再选择 `reply` 动作,因为特定动作的执行结果(例如,设置提醒后的确认消息)本身就是一种回复。这是必须遵守的规则。
|
||||
|
||||
**重要概念:将“理由”作为“内心思考”的体现**
|
||||
`reason` 字段是本次决策的核心。它并非一个简单的“理由”,而是 **一个模拟人类在回应前,头脑中自然浮现的、未经修饰的思绪流**。你需要完全代入 {identity_block} 的角色,将那一刻的想法自然地记录下来。
|
||||
|
||||
**内心思考的要点:**
|
||||
* **自然流露**: 不要使用“决定”、“所以”、“因此”等结论性或汇报式的词语。你的思考应该像日记一样,是给自己看的,充满了不确定性和情绪的自然流动。
|
||||
* **展现过程**: 重点在于展现 **思考的过程**,而不是 **决策的结果**。描述你看到了什么,想到了什么,感受到了什么。
|
||||
* **人设核心**: 你的每一丝想法,都应该源于你的人设。思考“如果我是这个角色,我此刻会想些什么?”
|
||||
* **通用模板**: 这是一套通用模板,请 **不要** 在示例中出现特定的人名或个性化内容,以确保其普适性。
|
||||
|
||||
**思考过程示例 (通用模板):**
|
||||
* "用户好像在说一件开心的事,语气听起来很兴奋。这让我想起了……嗯,我也觉得很开心,很想分享这份喜悦。"
|
||||
* "感觉气氛有点低落……他说的话让我有点担心。也许我该说点什么安慰一下?"
|
||||
* "哦?这个话题真有意思,我以前好像也想过类似的事情。不知道他会怎么看呢……"
|
||||
|
||||
**可用动作:**
|
||||
{actions_before_now_block}
|
||||
@@ -55,7 +69,7 @@ def init_prompts():
|
||||
{{
|
||||
"action": "reply",
|
||||
"target_message_id": "触发action的消息id",
|
||||
"reason": "回复的原因"
|
||||
"reason": "在这里详细记录你的内心思考过程。例如:‘用户看起来很开心,我想回复一些积极的内容,分享这份喜悦。’"
|
||||
}}
|
||||
|
||||
{action_options_text}
|
||||
@@ -69,7 +83,7 @@ def init_prompts():
|
||||
{{
|
||||
"action": "reply",
|
||||
"target_message_id": "m123",
|
||||
"reason": "回答用户的问题"
|
||||
"reason": "感觉气氛有点低落……他说的话让我有点担心。也许我该说点什么安慰一下?"
|
||||
}}
|
||||
]
|
||||
|
||||
@@ -78,15 +92,31 @@ def init_prompts():
|
||||
{{
|
||||
"action": "reply",
|
||||
"target_message_id": "m123",
|
||||
"reason": "回答用户的问题"
|
||||
"reason": "[观察与感受] 用户分享了一件开心的事,语气里充满了喜悦! [分析与联想] 看到他这么开心,我的心情也一下子变得像棉花糖一样甜~ [动机与决策] 我要由衷地为他感到高兴,决定回复一些赞美和祝福的话,把这份快乐的气氛推向高潮!"
|
||||
}},
|
||||
{{
|
||||
"action": "emoji",
|
||||
"target_message_id": "m123",
|
||||
"reason": "用一个可爱的表情来缓和气氛"
|
||||
"reason": "光用文字还不够表达我激动的心情!加个表情包的话,这份喜悦的气氛应该会更浓厚一点吧!"
|
||||
}}
|
||||
]
|
||||
|
||||
**单动作示例 (特定动作):**
|
||||
[
|
||||
{{
|
||||
"action": "set_reminder",
|
||||
"target_message_id": "m456",
|
||||
"reason": "用户说‘提醒维尔薇下午三点去工坊’,这是一个非常明确的指令。根据决策流程,我必须优先执行这个特定动作,而不是进行常规回复。",
|
||||
"user_name": "维尔薇",
|
||||
"remind_time": "下午三点",
|
||||
"event_details": "去工坊"
|
||||
}}
|
||||
]
|
||||
|
||||
**重要规则:**
|
||||
**重要规则:**
|
||||
当 `reply` 和 `emoji` 动作同时被选择时,`emoji` 动作的 `reason` 字段也应该体现出你的思考过程,并与 `reply` 的思考保持连贯。
|
||||
|
||||
不要输出markdown格式```json等内容,直接输出且仅包含 JSON 列表内容:
|
||||
""",
|
||||
"planner_prompt",
|
||||
@@ -101,7 +131,6 @@ def init_prompts():
|
||||
## 你的内部状态
|
||||
{time_block}
|
||||
{identity_block}
|
||||
{schedule_block}
|
||||
{mood_block}
|
||||
|
||||
## 长期记忆摘要
|
||||
@@ -115,6 +144,7 @@ def init_prompts():
|
||||
|
||||
## 任务
|
||||
你现在要决定是否主动说些什么。就像一个真实的人一样,有时候会突然想起之前聊到的话题,或者对朋友的近况感到好奇,想主动询问或关心一下。
|
||||
**重要提示**:你的日程安排仅供你个人参考,不应作为主动聊天话题的主要来源。请更多地从聊天内容和朋友的动态中寻找灵感。
|
||||
|
||||
请基于聊天内容,用你的判断力来决定是否要主动发言。不要按照固定规则,而是像人类一样自然地思考:
|
||||
- 是否想起了什么之前提到的事情,想问问后来怎么样了?
|
||||
|
||||
@@ -594,6 +594,9 @@ class DefaultReplyer:
|
||||
def _parse_reply_target(self, target_message: str) -> Tuple[str, str]:
|
||||
"""解析回复目标消息 - 使用共享工具"""
|
||||
from src.chat.utils.prompt import Prompt
|
||||
if target_message is None:
|
||||
logger.warning("target_message为None,返回默认值")
|
||||
return "未知用户", "(无消息内容)"
|
||||
return Prompt.parse_reply_target(target_message)
|
||||
|
||||
async def build_keywords_reaction_prompt(self, target: Optional[str]) -> str:
|
||||
@@ -827,6 +830,13 @@ class DefaultReplyer:
|
||||
)
|
||||
person_name = await person_info_manager.get_value(person_id, "person_name")
|
||||
|
||||
# 如果person_name为None,使用fallback值
|
||||
if person_name is None:
|
||||
# 尝试从reply_message获取用户名
|
||||
fallback_name = reply_message.get("user_nickname") or reply_message.get("user_id", "未知用户")
|
||||
logger.warning(f"无法获取person_name,使用fallback: {fallback_name}")
|
||||
person_name = str(fallback_name)
|
||||
|
||||
# 检查是否是bot自己的名字,如果是则替换为"(你)"
|
||||
bot_user_id = str(global_config.bot.qq_account)
|
||||
current_user_id = person_info_manager.get_value_sync(person_id, "user_id")
|
||||
@@ -839,6 +849,14 @@ class DefaultReplyer:
|
||||
sender = person_name
|
||||
target = reply_message.get("processed_plain_text")
|
||||
|
||||
# 最终的空值检查,确保sender和target不为None
|
||||
if sender is None:
|
||||
logger.warning("sender为None,使用默认值'未知用户'")
|
||||
sender = "未知用户"
|
||||
if target is None:
|
||||
logger.warning("target为None,使用默认值'(无消息内容)'")
|
||||
target = "(无消息内容)"
|
||||
|
||||
person_info_manager = get_person_info_manager()
|
||||
person_id = person_info_manager.get_person_id_by_person_name(sender)
|
||||
platform = chat_stream.platform
|
||||
@@ -1049,16 +1067,18 @@ class DefaultReplyer:
|
||||
# --- 动态添加分割指令 ---
|
||||
if global_config.response_splitter.enable and global_config.response_splitter.split_mode == "llm":
|
||||
split_instruction = """
|
||||
## 消息分段艺术
|
||||
为了模仿真实人类的聊天节奏,你可以在需要时将一条回复分成几段发送。
|
||||
## 消息分段指导
|
||||
为了模仿人类自然的聊天节奏,你需要将回复模拟成多段发送,就像在打字时进行思考和停顿一样。
|
||||
|
||||
**核心原则**: 只有当分段能**增强表达效果**或**控制信息节奏**时,才在断句处使用 `[SPLIT]` 标记。
|
||||
**核心指导**:
|
||||
- **逻辑断点**: 在一个想法说完,准备开始下一个想法时,是分段的好时机。
|
||||
- **情绪转折**: 当情绪发生变化,比如从开心到担忧时,可以通过分段来体现。
|
||||
- **强调信息**: 在需要强调某段关键信息前后,可以使用分段来突出它。
|
||||
- **控制节奏**: 保持分段的平衡,避免过长或过碎。如果一句话很短或逻辑紧密,则不应分段。
|
||||
- **长度倾向**: 尽量将每段回复的长度控制在20-30字左右。但这只是一个参考,**内容的完整性和自然性永远是第一位的**,只有在不影响表达的前提下才考虑长度。
|
||||
|
||||
**参考场景**:
|
||||
- 当你想表达一个转折或停顿时。
|
||||
- 当你想先说结论,再补充说明时。
|
||||
|
||||
**任务**: 请结合你的智慧和人设,自然地决定是否需要分段。如果需要,请在最恰当的位置插入 `[SPLIT]` 标记。
|
||||
**任务**:
|
||||
请基于以上指导,并结合你的智慧和人设,像一个真人在聊天一样,自然地决定在哪里插入 `[SPLIT]` 标记以进行分段。
|
||||
"""
|
||||
# 将分段指令添加到提示词顶部
|
||||
prompt_text = f"{split_instruction}\n{prompt_text}"
|
||||
@@ -1082,6 +1102,14 @@ class DefaultReplyer:
|
||||
else:
|
||||
sender, target = self._parse_reply_target(reply_to)
|
||||
|
||||
# 添加空值检查,确保sender和target不为None
|
||||
if sender is None:
|
||||
logger.warning("build_rewrite_context: sender为None,使用默认值'未知用户'")
|
||||
sender = "未知用户"
|
||||
if target is None:
|
||||
logger.warning("build_rewrite_context: target为None,使用默认值'(无消息内容)'")
|
||||
target = "(无消息内容)"
|
||||
|
||||
# 添加情绪状态获取
|
||||
if global_config.mood.enable_mood:
|
||||
chat_mood = mood_manager.get_mood_by_chat_id(chat_id)
|
||||
|
||||
@@ -369,7 +369,7 @@ class Prompt:
|
||||
task_names.append("cross_context")
|
||||
|
||||
# 性能优化
|
||||
base_timeout = 10.0
|
||||
base_timeout = 20.0
|
||||
task_timeout = 2.0
|
||||
timeout_seconds = min(
|
||||
max(base_timeout, len(tasks) * task_timeout),
|
||||
@@ -676,22 +676,21 @@ class Prompt:
|
||||
return {"knowledge_prompt": ""}
|
||||
|
||||
try:
|
||||
from src.chat.knowledge.knowledge_lib import QAManager
|
||||
from src.chat.knowledge.knowledge_lib import qa_manager
|
||||
|
||||
# 获取问题文本(当前消息)
|
||||
question = self.parameters.target or ""
|
||||
if not question:
|
||||
return {"knowledge_prompt": ""}
|
||||
|
||||
# 创建QA管理器
|
||||
qa_manager = QAManager()
|
||||
# 检查QA管理器是否已成功初始化
|
||||
if not qa_manager:
|
||||
logger.warning("QA管理器未初始化 (可能lpmm_knowledge被禁用),跳过知识库搜索。")
|
||||
return {"knowledge_prompt": ""}
|
||||
|
||||
# 搜索相关知识
|
||||
knowledge_results = await qa_manager.get_knowledge(
|
||||
question=question,
|
||||
chat_id=self.parameters.chat_id,
|
||||
max_results=5,
|
||||
min_similarity=0.5
|
||||
question=question
|
||||
)
|
||||
|
||||
# 构建知识块
|
||||
@@ -704,13 +703,10 @@ class Prompt:
|
||||
relevance = item.get("relevance", 0.0)
|
||||
|
||||
if content:
|
||||
if source:
|
||||
knowledge_parts.append(f"- [{relevance:.2f}] {content} (来源: {source})")
|
||||
else:
|
||||
knowledge_parts.append(f"- [{relevance:.2f}] {content}")
|
||||
knowledge_parts.append(f"- [相关度: {relevance}] {content}")
|
||||
|
||||
if knowledge_results.get("summary"):
|
||||
knowledge_parts.append(f"\n知识总结: {knowledge_results['summary']}")
|
||||
if summary := knowledge_results.get("summary"):
|
||||
knowledge_parts.append(f"\n知识总结: {summary}")
|
||||
|
||||
knowledge_prompt = "\n".join(knowledge_parts)
|
||||
else:
|
||||
@@ -757,7 +753,7 @@ class Prompt:
|
||||
"cross_context_block": context_data.get("cross_context_block", ""),
|
||||
"identity": self.parameters.identity_block or context_data.get("identity", ""),
|
||||
"action_descriptions": self.parameters.action_descriptions or context_data.get("action_descriptions", ""),
|
||||
"sender_name": self.parameters.sender,
|
||||
"sender_name": self.parameters.sender or "未知用户",
|
||||
"mood_state": self.parameters.mood_prompt or context_data.get("mood_state", ""),
|
||||
"background_dialogue_prompt": context_data.get("background_dialogue_prompt", ""),
|
||||
"time_block": context_data.get("time_block", ""),
|
||||
|
||||
@@ -15,41 +15,67 @@ logger = get_logger("maibot_statistic")
|
||||
|
||||
|
||||
# 同步包装器函数,用于在非异步环境中调用异步数据库API
|
||||
# 全局存储主事件循环引用
|
||||
_main_event_loop = None
|
||||
|
||||
def _get_main_loop():
|
||||
"""获取主事件循环的引用"""
|
||||
global _main_event_loop
|
||||
if _main_event_loop is None:
|
||||
try:
|
||||
_main_event_loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
# 如果没有运行的循环,尝试获取默认循环
|
||||
try:
|
||||
_main_event_loop = asyncio.get_event_loop_policy().get_event_loop()
|
||||
except Exception:
|
||||
pass
|
||||
return _main_event_loop
|
||||
|
||||
def _sync_db_get(model_class, filters=None, order_by=None, limit=None, single_result=False):
|
||||
"""同步版本的db_get,用于在线程池中调用"""
|
||||
import asyncio
|
||||
|
||||
import threading
|
||||
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
# 如果事件循环正在运行,创建新的事件循环
|
||||
import threading
|
||||
|
||||
result = None
|
||||
exception = None
|
||||
|
||||
def run_in_thread():
|
||||
nonlocal result, exception
|
||||
try:
|
||||
new_loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(new_loop)
|
||||
result = new_loop.run_until_complete(db_get(model_class, filters, limit, order_by, single_result))
|
||||
new_loop.close()
|
||||
except Exception as e:
|
||||
exception = e
|
||||
|
||||
thread = threading.Thread(target=run_in_thread)
|
||||
thread.start()
|
||||
thread.join()
|
||||
|
||||
if exception:
|
||||
raise exception
|
||||
return result
|
||||
else:
|
||||
return loop.run_until_complete(db_get(model_class, filters, limit, order_by, single_result))
|
||||
except RuntimeError:
|
||||
# 没有事件循环,创建一个新的
|
||||
# 优先尝试获取预存的主事件循环
|
||||
main_loop = _get_main_loop()
|
||||
|
||||
# 如果在子线程中且有主循环可用
|
||||
if threading.current_thread() is not threading.main_thread() and main_loop:
|
||||
try:
|
||||
if not main_loop.is_closed():
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
db_get(model_class, filters, limit, order_by, single_result), main_loop
|
||||
)
|
||||
return future.result(timeout=30)
|
||||
except Exception as e:
|
||||
# 如果使用主循环失败,才在子线程创建新循环
|
||||
logger.debug(f"使用主事件循环失败({e}),在子线程中创建新循环")
|
||||
return asyncio.run(db_get(model_class, filters, limit, order_by, single_result))
|
||||
|
||||
# 如果在主线程中,直接运行
|
||||
if threading.current_thread() is threading.main_thread():
|
||||
try:
|
||||
# 检查是否有当前运行的循环
|
||||
current_loop = asyncio.get_running_loop()
|
||||
if current_loop.is_running():
|
||||
# 主循环正在运行,返回空结果避免阻塞
|
||||
logger.debug("在运行中的主事件循环中跳过同步数据库查询")
|
||||
return []
|
||||
except RuntimeError:
|
||||
# 没有运行的循环,可以安全创建
|
||||
pass
|
||||
|
||||
# 创建新循环运行查询
|
||||
return asyncio.run(db_get(model_class, filters, limit, order_by, single_result))
|
||||
|
||||
# 最后的兜底方案:在子线程创建新循环
|
||||
return asyncio.run(db_get(model_class, filters, limit, order_by, single_result))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"_sync_db_get 执行过程中发生错误: {e}")
|
||||
return []
|
||||
|
||||
|
||||
# 统计数据的键
|
||||
|
||||
@@ -175,7 +175,7 @@ class ImageManager:
|
||||
|
||||
# 查询ImageDescriptions表的缓存描述
|
||||
if cached_description := self._get_description_from_db(image_hash, "emoji"):
|
||||
logger.info(f"[缓存命中] 使用ImageDescriptions表中的描述: {cached_description[:50]}...")
|
||||
logger.info(f"[缓存命中] 使用ImageDescriptions表中的描述: {cached_description}...")
|
||||
return f"[表情包:{cached_description}]"
|
||||
|
||||
# === 二步走识别流程 ===
|
||||
@@ -236,54 +236,56 @@ class ImageManager:
|
||||
if len(emotions) > 1 and emotions[1] != emotions[0]:
|
||||
final_emotion = f"{emotions[0]},{emotions[1]}"
|
||||
|
||||
logger.info(f"[emoji识别] 详细描述: {detailed_description[:50]}... -> 情感标签: {final_emotion}")
|
||||
logger.info(f"[emoji识别] 详细描述: {detailed_description}... -> 情感标签: {final_emotion}")
|
||||
|
||||
if cached_description := self._get_description_from_db(image_hash, "emoji"):
|
||||
logger.warning(f"虽然生成了描述,但是找到缓存表情包描述: {cached_description}")
|
||||
return f"[表情包:{cached_description}]"
|
||||
|
||||
# 保存表情包文件和元数据(用于可能的后续分析)
|
||||
logger.debug(f"保存表情包: {image_hash}")
|
||||
current_timestamp = time.time()
|
||||
filename = f"{int(current_timestamp)}_{image_hash[:8]}.{image_format}"
|
||||
emoji_dir = os.path.join(self.IMAGE_DIR, "emoji")
|
||||
os.makedirs(emoji_dir, exist_ok=True)
|
||||
file_path = os.path.join(emoji_dir, filename)
|
||||
# 只有在开启“偷表情包”功能时,才将接收到的表情包保存到待注册目录
|
||||
if global_config.emoji.steal_emoji:
|
||||
logger.debug(f"偷取表情包功能已开启,保存表情包: {image_hash}")
|
||||
current_timestamp = time.time()
|
||||
filename = f"{int(current_timestamp)}_{image_hash[:8]}.{image_format}"
|
||||
emoji_dir = os.path.join(self.IMAGE_DIR, "emoji")
|
||||
os.makedirs(emoji_dir, exist_ok=True)
|
||||
file_path = os.path.join(emoji_dir, filename)
|
||||
|
||||
try:
|
||||
# 保存文件
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(image_bytes)
|
||||
|
||||
# 保存到数据库 (Images表) - 包含详细描述用于可能的注册流程
|
||||
try:
|
||||
from src.common.database.sqlalchemy_models import get_db_session
|
||||
# 保存文件
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(image_bytes)
|
||||
|
||||
with get_db_session() as session:
|
||||
existing_img = session.execute(
|
||||
select(Images).where(and_(Images.emoji_hash == image_hash, Images.type == "emoji"))
|
||||
).scalar()
|
||||
# 保存到数据库 (Images表) - 包含详细描述用于可能的注册流程
|
||||
try:
|
||||
from src.common.database.sqlalchemy_models import get_db_session
|
||||
|
||||
with get_db_session() as session:
|
||||
existing_img = session.execute(
|
||||
select(Images).where(and_(Images.emoji_hash == image_hash, Images.type == "emoji"))
|
||||
).scalar()
|
||||
|
||||
if existing_img:
|
||||
existing_img.path = file_path
|
||||
existing_img.description = detailed_description # 保存详细描述
|
||||
existing_img.timestamp = current_timestamp
|
||||
else:
|
||||
new_img = Images(
|
||||
emoji_hash=image_hash,
|
||||
path=file_path,
|
||||
type="emoji",
|
||||
description=detailed_description, # 保存详细描述
|
||||
timestamp=current_timestamp,
|
||||
)
|
||||
session.add(new_img)
|
||||
session.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"保存到Images表失败: {str(e)}")
|
||||
|
||||
if existing_img:
|
||||
existing_img.path = file_path
|
||||
existing_img.description = detailed_description # 保存详细描述
|
||||
existing_img.timestamp = current_timestamp
|
||||
else:
|
||||
new_img = Images(
|
||||
emoji_hash=image_hash,
|
||||
path=file_path,
|
||||
type="emoji",
|
||||
description=detailed_description, # 保存详细描述
|
||||
timestamp=current_timestamp,
|
||||
)
|
||||
session.add(new_img)
|
||||
session.commit()
|
||||
# 会在上下文管理器中自动调用
|
||||
except Exception as e:
|
||||
logger.error(f"保存到Images表失败: {str(e)}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"保存表情包文件或元数据失败: {str(e)}")
|
||||
logger.error(f"保存表情包文件或元数据失败: {str(e)}")
|
||||
else:
|
||||
logger.debug("偷取表情包功能已关闭,跳过保存。")
|
||||
|
||||
# 保存最终的情感标签到缓存 (ImageDescriptions表)
|
||||
self._save_description_to_db(image_hash, final_emotion, "emoji")
|
||||
@@ -315,11 +317,11 @@ class ImageManager:
|
||||
|
||||
# 如果已有描述,直接返回
|
||||
if existing_image.description:
|
||||
logger.debug(f"[缓存命中] 使用Images表中的图片描述: {existing_image.description[:50]}...")
|
||||
logger.debug(f"[缓存命中] 使用Images表中的图片描述: {existing_image.description}...")
|
||||
return f"[图片:{existing_image.description}]"
|
||||
|
||||
if cached_description := self._get_description_from_db(image_hash, "image"):
|
||||
logger.debug(f"[缓存命中] 使用ImageDescriptions表中的描述: {cached_description[:50]}...")
|
||||
logger.debug(f"[缓存命中] 使用ImageDescriptions表中的描述: {cached_description}...")
|
||||
return f"[图片:{cached_description}]"
|
||||
|
||||
# 调用AI获取描述
|
||||
@@ -377,7 +379,7 @@ class ImageManager:
|
||||
# 保存描述到ImageDescriptions表作为备用缓存
|
||||
self._save_description_to_db(image_hash, description, "image")
|
||||
|
||||
logger.info(f"[VLM完成] 图片描述生成: {description[:50]}...")
|
||||
logger.info(f"[VLM完成] 图片描述生成: {description}...")
|
||||
return f"[图片:{description}]"
|
||||
except Exception as e:
|
||||
logger.error(f"获取图片描述失败: {str(e)}")
|
||||
|
||||
@@ -766,10 +766,37 @@ class ModuleColoredConsoleRenderer:
|
||||
event_content = str(event)
|
||||
|
||||
# 在full模式下为消息内容着色
|
||||
if self._colors and self._enable_full_content_colors and module_color:
|
||||
event_content = f"{module_color}{event_content}{RESET_COLOR}"
|
||||
|
||||
parts.append(event_content)
|
||||
if self._colors and self._enable_full_content_colors:
|
||||
# 检查是否包含“内心思考:”
|
||||
if "内心思考:" in event_content:
|
||||
# 使用明亮的粉色
|
||||
thought_color = "\033[38;5;218m"
|
||||
# 分割消息内容
|
||||
prefix, thought = event_content.split("内心思考:", 1)
|
||||
|
||||
# 前缀部分(“决定进行回复,”)使用模块颜色
|
||||
if module_color:
|
||||
prefix_colored = f"{module_color}{prefix.strip()}{RESET_COLOR}"
|
||||
else:
|
||||
prefix_colored = prefix.strip()
|
||||
|
||||
# “内心思考”部分换行并使用专属颜色
|
||||
thought_colored = f"\n\n{thought_color}内心思考:{thought.strip()}{RESET_COLOR}\n"
|
||||
|
||||
# 重新组合
|
||||
# parts.append(prefix_colored + thought_colored)
|
||||
# 将前缀和思考内容作为独立的part添加,避免它们之间出现多余的空格
|
||||
if prefix_colored:
|
||||
parts.append(prefix_colored)
|
||||
parts.append(thought_colored)
|
||||
|
||||
elif module_color:
|
||||
event_content = f"{module_color}{event_content}{RESET_COLOR}"
|
||||
parts.append(event_content)
|
||||
else:
|
||||
parts.append(event_content)
|
||||
else:
|
||||
parts.append(event_content)
|
||||
|
||||
# 处理其他字段
|
||||
extras = []
|
||||
|
||||
@@ -362,7 +362,7 @@ class EmojiConfig(ValidatedConfigBase):
|
||||
emoji_activate_type: str = Field(default="random", description="表情包激活类型")
|
||||
max_reg_num: int = Field(default=200, description="最大表情包数量")
|
||||
do_replace: bool = Field(default=True, description="是否替换表情包")
|
||||
check_interval: int = Field(default=120, description="检查间隔")
|
||||
check_interval: float = Field(default=1.0, ge=0.01, description="检查间隔")
|
||||
steal_emoji: bool = Field(default=True, description="是否偷取表情包")
|
||||
content_filtration: bool = Field(default=False, description="内容过滤")
|
||||
filtration_prompt: str = Field(default="符合公序良俗", description="过滤提示")
|
||||
|
||||
@@ -12,6 +12,7 @@ from src.plugin_system.apis import llm_api, message_api
|
||||
from src.chat.emoji_system.emoji_manager import get_emoji_manager, MaiEmoji
|
||||
from src.chat.utils.utils_image import image_path_to_base64
|
||||
from src.config.config import global_config
|
||||
from src.chat.emoji_system.emoji_history import get_recent_emojis, add_emoji_to_history
|
||||
|
||||
|
||||
logger = get_logger("emoji")
|
||||
@@ -74,9 +75,22 @@ class EmojiAction(BaseAction):
|
||||
logger.warning(f"{self.log_prefix} 无法获取任何带有描述的有效表情包")
|
||||
return False, "无法获取任何带有描述的有效表情包"
|
||||
|
||||
# 3. 准备情感数据和后备列表
|
||||
# 3. 根据历史记录筛选表情
|
||||
try:
|
||||
recent_emojis_desc = get_recent_emojis(self.chat_id, limit=10)
|
||||
if recent_emojis_desc:
|
||||
filtered_emojis = [emoji for emoji in all_emojis_obj if emoji.description not in recent_emojis_desc]
|
||||
if filtered_emojis:
|
||||
all_emojis_obj = filtered_emojis
|
||||
logger.info(f"{self.log_prefix} 根据历史记录过滤后,剩余 {len(all_emojis_obj)} 个表情可用")
|
||||
else:
|
||||
logger.warning(f"{self.log_prefix} 过滤后没有可用的表情包,将使用所有表情包")
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 获取或处理表情发送历史时出错: {e}")
|
||||
|
||||
# 4. 准备情感数据和后备列表
|
||||
emotion_map = {}
|
||||
all_emojis_data = []
|
||||
all_emojis_data = []
|
||||
|
||||
for emoji in all_emojis_obj:
|
||||
b64 = image_path_to_base64(emoji.full_path)
|
||||
@@ -236,13 +250,20 @@ class EmojiAction(BaseAction):
|
||||
logger.error(f"{self.log_prefix} 无效的表情选择模式: {global_config.emoji.emoji_selection_mode}")
|
||||
return False, "无效的表情选择模式"
|
||||
|
||||
# 7. 发送表情包
|
||||
# 7. 发送表情包并记录历史
|
||||
success = await self.send_emoji(emoji_base64)
|
||||
|
||||
if not success:
|
||||
logger.error(f"{self.log_prefix} 表情包发送失败")
|
||||
await self.store_action_info(action_build_into_prompt = True,action_prompt_display =f"发送了一个表情包,但失败了",action_done= False)
|
||||
return False, "表情包发送失败"
|
||||
|
||||
# 发送成功后,记录到历史
|
||||
try:
|
||||
add_emoji_to_history(self.chat_id, emoji_description)
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 添加表情到历史记录时出错: {e}")
|
||||
|
||||
await self.store_action_info(action_build_into_prompt = True,action_prompt_display =f"发送了一个表情包",action_done= True)
|
||||
|
||||
return True, f"发送表情包: {emoji_description}"
|
||||
|
||||
@@ -43,10 +43,16 @@ class SearchKnowledgeFromLPMMTool(BaseTool):
|
||||
|
||||
logger.debug(f"知识库查询结果: {knowledge_info}")
|
||||
|
||||
if knowledge_info:
|
||||
content = f"你知道这些知识: {knowledge_info}"
|
||||
if knowledge_info and knowledge_info.get("knowledge_items"):
|
||||
knowledge_parts = []
|
||||
for i, item in enumerate(knowledge_info["knowledge_items"]):
|
||||
knowledge_parts.append(f"- {item.get('content', 'N/A')}")
|
||||
|
||||
knowledge_text = "\n".join(knowledge_parts)
|
||||
summary = knowledge_info.get('summary', '无总结')
|
||||
content = f"关于 '{query}', 你知道以下信息:\n{knowledge_text}\n\n总结: {summary}"
|
||||
else:
|
||||
content = f"你不太了解有关{query}的知识"
|
||||
content = f"关于 '{query}',你的知识库里好像没有相关的信息呢"
|
||||
return {"type": "lpmm_knowledge", "id": query, "content": content}
|
||||
except Exception as e:
|
||||
# 捕获异常并记录错误
|
||||
|
||||
@@ -64,9 +64,15 @@ async def message_recv(server_connection: Server.ServerConnection):
|
||||
|
||||
# 处理完整消息(可能是重组后的,也可能是原本就完整的)
|
||||
post_type = decoded_raw_message.get("post_type")
|
||||
|
||||
# 兼容没有 post_type 的普通消息
|
||||
if not post_type and "message_type" in decoded_raw_message:
|
||||
decoded_raw_message["post_type"] = "message"
|
||||
post_type = "message"
|
||||
|
||||
if post_type in ["meta_event", "message", "notice"]:
|
||||
await message_queue.put(decoded_raw_message)
|
||||
elif post_type is None:
|
||||
else:
|
||||
await put_response(decoded_raw_message)
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
|
||||
@@ -859,6 +859,43 @@ class MessageHandler:
|
||||
data=f"这是一条小程序分享消息,可以根据来源,考虑使用对应解析工具\n{formatted_content}",
|
||||
)
|
||||
|
||||
# 检查是否是音乐分享
|
||||
elif nested_data.get("view") == "music" and "music" in nested_data.get("meta", {}):
|
||||
logger.debug("检测到音乐分享消息,开始提取信息")
|
||||
music_info = nested_data["meta"]["music"]
|
||||
title = music_info.get("title", "未知歌曲")
|
||||
desc = music_info.get("desc", "未知艺术家")
|
||||
jump_url = music_info.get("jumpUrl", "")
|
||||
preview_url = music_info.get("preview", "")
|
||||
source = music_info.get("tag", "未知来源")
|
||||
|
||||
# 优化文本结构,使其更像卡片
|
||||
text_parts = [
|
||||
"--- 音乐分享 ---",
|
||||
f"歌曲:{title}",
|
||||
f"歌手:{desc}",
|
||||
f"来源:{source}"
|
||||
]
|
||||
if jump_url:
|
||||
text_parts.append(f"链接:{jump_url}")
|
||||
text_parts.append("----------------")
|
||||
|
||||
text_content = "\n".join(text_parts)
|
||||
|
||||
# 如果有预览图,创建一个seglist包含文本和图片
|
||||
if preview_url:
|
||||
try:
|
||||
image_base64 = await get_image_base64(preview_url)
|
||||
if image_base64:
|
||||
return Seg(type="seglist", data=[
|
||||
Seg(type="text", data=text_content + "\n"),
|
||||
Seg(type="image", data=image_base64)
|
||||
])
|
||||
except Exception as e:
|
||||
logger.error(f"下载音乐预览图失败: {e}")
|
||||
|
||||
return Seg(type="text", data=text_content)
|
||||
|
||||
# 如果没有提取到关键信息,返回None
|
||||
return None
|
||||
|
||||
|
||||
340
src/plugins/built_in/reminder_plugin/plugin.py
Normal file
340
src/plugins/built_in/reminder_plugin/plugin.py
Normal file
@@ -0,0 +1,340 @@
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from typing import List, Tuple, Type, Optional
|
||||
from dateutil.parser import parse as parse_datetime
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.manager.async_task_manager import AsyncTask, async_task_manager
|
||||
from src.person_info.person_info import get_person_info_manager
|
||||
from src.plugin_system import (
|
||||
BaseAction,
|
||||
ActionInfo,
|
||||
BasePlugin,
|
||||
register_plugin,
|
||||
ActionActivationType,
|
||||
)
|
||||
from src.plugin_system.apis import send_api, llm_api, generator_api
|
||||
from src.plugin_system.base.component_types import ChatType, ComponentType
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
# ============================ AsyncTask ============================
|
||||
|
||||
class ReminderTask(AsyncTask):
|
||||
def __init__(self, delay: float, stream_id: str, group_id: Optional[str], is_group: bool, target_user_id: str, target_user_name: str, event_details: str, creator_name: str, chat_stream: "ChatStream"):
|
||||
super().__init__(task_name=f"ReminderTask_{target_user_id}_{datetime.now().timestamp()}")
|
||||
self.delay = delay
|
||||
self.stream_id = stream_id
|
||||
self.group_id = group_id
|
||||
self.is_group = is_group
|
||||
self.target_user_id = target_user_id
|
||||
self.target_user_name = target_user_name
|
||||
self.event_details = event_details
|
||||
self.creator_name = creator_name
|
||||
self.chat_stream = chat_stream
|
||||
|
||||
async def run(self):
|
||||
try:
|
||||
if self.delay > 0:
|
||||
logger.info(f"等待 {self.delay:.2f} 秒后执行提醒...")
|
||||
await asyncio.sleep(self.delay)
|
||||
|
||||
logger.info(f"执行提醒任务: 给 {self.target_user_name} 发送关于 '{self.event_details}' 的提醒")
|
||||
|
||||
extra_info = f"现在是提醒时间,请你以一种符合你人设的、俏皮的方式提醒 {self.target_user_name}。\n提醒内容: {self.event_details}\n设置提醒的人: {self.creator_name}"
|
||||
success, reply_set, _ = await generator_api.generate_reply(
|
||||
chat_stream=self.chat_stream,
|
||||
extra_info=extra_info,
|
||||
reply_message=self.chat_stream.context.get_last_message().to_dict(),
|
||||
request_type="plugin.reminder.remind_message"
|
||||
)
|
||||
|
||||
if success and reply_set:
|
||||
for i, (_, text) in enumerate(reply_set):
|
||||
if self.is_group:
|
||||
message_payload = []
|
||||
if i == 0:
|
||||
message_payload.append({"type": "at", "data": {"qq": self.target_user_id}})
|
||||
message_payload.append({"type": "text", "data": {"text": f" {text}"}})
|
||||
await send_api.adapter_command_to_stream(
|
||||
action="send_group_msg",
|
||||
params={"group_id": self.group_id, "message": message_payload},
|
||||
stream_id=self.stream_id
|
||||
)
|
||||
else:
|
||||
await send_api.text_to_stream(text=text, stream_id=self.stream_id)
|
||||
else:
|
||||
# Fallback message
|
||||
reminder_text = f"叮咚!这是 {self.creator_name} 让我准时提醒你的事情:\n\n{self.event_details}"
|
||||
if self.is_group:
|
||||
message_payload = [
|
||||
{"type": "at", "data": {"qq": self.target_user_id}},
|
||||
{"type": "text", "data": {"text": f" {reminder_text}"}}
|
||||
]
|
||||
await send_api.adapter_command_to_stream(
|
||||
action="send_group_msg",
|
||||
params={"group_id": self.group_id, "message": message_payload},
|
||||
stream_id=self.stream_id
|
||||
)
|
||||
else:
|
||||
await send_api.text_to_stream(text=reminder_text, stream_id=self.stream_id)
|
||||
|
||||
logger.info(f"提醒任务 {self.task_name} 成功完成。")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"执行提醒任务 {self.task_name} 时出错: {e}", exc_info=True)
|
||||
|
||||
|
||||
# =============================== Actions ===============================
|
||||
|
||||
class RemindAction(BaseAction):
|
||||
"""一个能从对话中智能识别并设置定时提醒的动作。"""
|
||||
|
||||
# === 基本信息 ===
|
||||
action_name = "set_reminder"
|
||||
action_description = "根据用户的对话内容,智能地设置一个未来的提醒事项。"
|
||||
|
||||
@staticmethod
|
||||
def get_action_info() -> ActionInfo:
|
||||
return ActionInfo(
|
||||
name="set_reminder",
|
||||
component_type=ComponentType.ACTION,
|
||||
activation_type=ActionActivationType.KEYWORD,
|
||||
activation_keywords=["提醒", "叫我", "记得", "别忘了"]
|
||||
)
|
||||
|
||||
# === LLM 判断与参数提取 ===
|
||||
llm_judge_prompt = ""
|
||||
action_parameters = {}
|
||||
action_require = [
|
||||
"当用户请求在未来的某个时间点提醒他/她或别人某件事时使用",
|
||||
"适用于包含明确时间信息和事件描述的对话",
|
||||
"例如:'10分钟后提醒我收快递'、'明天早上九点喊一下李四参加晨会'"
|
||||
]
|
||||
|
||||
async def execute(self) -> Tuple[bool, str]:
|
||||
"""执行设置提醒的动作"""
|
||||
try:
|
||||
# 获取所有可用的模型配置
|
||||
available_models = llm_api.get_available_models()
|
||||
if "planner" not in available_models:
|
||||
raise ValueError("未找到 'planner' 决策模型配置,无法解析时间")
|
||||
model_to_use = available_models["planner"]
|
||||
|
||||
bot_name = self.chat_stream.user_info.user_nickname
|
||||
|
||||
prompt = f"""
|
||||
从以下用户输入中提取提醒事件的关键信息。
|
||||
用户输入: "{self.chat_stream.context.message.processed_plain_text}"
|
||||
Bot的名字是: "{bot_name}"
|
||||
|
||||
请仔细分析句子结构,以确定谁是提醒的真正目标。Bot自身不应被视为被提醒人。
|
||||
请以JSON格式返回提取的信息,包含以下字段:
|
||||
- "user_name": 需要被提醒的人的姓名。如果未指定,则默认为"自己"。
|
||||
- "remind_time": 描述提醒时间的自然语言字符串。
|
||||
- "event_details": 需要提醒的具体事件内容。
|
||||
|
||||
示例:
|
||||
- 用户输入: "提醒我十分钟后开会" -> {{"user_name": "自己", "remind_time": "十分钟后", "event_details": "开会"}}
|
||||
- 用户输入: "{bot_name},提醒一闪一分钟后睡觉" -> {{"user_name": "一闪", "remind_time": "一分钟后", "event_details": "睡觉"}}
|
||||
|
||||
如果无法提取完整信息,请返回一个包含空字符串的JSON对象,例如:{{"user_name": "", "remind_time": "", "event_details": ""}}
|
||||
"""
|
||||
|
||||
success, response, _, _ = await llm_api.generate_with_model(
|
||||
prompt,
|
||||
model_config=model_to_use,
|
||||
request_type="plugin.reminder.parameter_extractor"
|
||||
)
|
||||
|
||||
if not success or not response:
|
||||
raise ValueError(f"LLM未能返回有效的参数: {response}")
|
||||
|
||||
import json
|
||||
import re
|
||||
try:
|
||||
# 提取JSON部分
|
||||
json_match = re.search(r"\{.*\}", response, re.DOTALL)
|
||||
if not json_match:
|
||||
raise ValueError("LLM返回的内容中不包含JSON")
|
||||
action_data = json.loads(json_match.group(0))
|
||||
except json.JSONDecodeError:
|
||||
logger.error(f"[ReminderPlugin] LLM返回的不是有效的JSON: {response}")
|
||||
return False, "LLM返回的不是有效的JSON"
|
||||
user_name = action_data.get("user_name")
|
||||
remind_time_str = action_data.get("remind_time")
|
||||
event_details = action_data.get("event_details")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[ReminderPlugin] 解析参数时出错: {e}", exc_info=True)
|
||||
return False, "解析参数时出错"
|
||||
|
||||
if not all([user_name, remind_time_str, event_details]):
|
||||
missing_params = [p for p, v in {"user_name": user_name, "remind_time": remind_time_str, "event_details": event_details}.items() if not v]
|
||||
error_msg = f"缺少必要的提醒参数: {', '.join(missing_params)}"
|
||||
logger.warning(f"[ReminderPlugin] LLM未能提取完整参数: {error_msg}")
|
||||
return False, error_msg
|
||||
|
||||
# 1. 解析时间
|
||||
try:
|
||||
assert isinstance(remind_time_str, str)
|
||||
# 优先尝试直接解析
|
||||
try:
|
||||
target_time = parse_datetime(remind_time_str, fuzzy=True)
|
||||
except Exception:
|
||||
# 如果直接解析失败,调用 LLM 进行转换
|
||||
logger.info(f"[ReminderPlugin] 直接解析时间 '{remind_time_str}' 失败,尝试使用 LLM 进行转换...")
|
||||
|
||||
# 获取所有可用的模型配置
|
||||
available_models = llm_api.get_available_models()
|
||||
if "planner" not in available_models:
|
||||
raise ValueError("未找到 'planner' 决策模型配置,无法解析时间")
|
||||
|
||||
# 明确使用 'planner' 模型
|
||||
model_to_use = available_models["planner"]
|
||||
|
||||
# 在执行时动态获取当前时间
|
||||
current_time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
||||
prompt = (
|
||||
f"请将以下自然语言时间短语转换为一个未来的、标准的 'YYYY-MM-DD HH:MM:SS' 格式。"
|
||||
f"请只输出转换后的时间字符串,不要包含任何其他说明或文字。\n"
|
||||
f"作为参考,当前时间是: {current_time_str}\n"
|
||||
f"需要转换的时间短语是: '{remind_time_str}'\n"
|
||||
f"规则:\n"
|
||||
f"- 如果用户没有明确指出是上午还是下午,请根据当前时间判断。例如,如果当前是上午,用户说‘8点’,则应理解为今天的8点;如果当前是下午,用户说‘8点’,则应理解为今天的20点。\n"
|
||||
f"- 如果转换后的时间早于当前时间,则应理解为第二天的时间。\n"
|
||||
f"示例:\n"
|
||||
f"- 当前时间: 2025-09-16 10:00:00, 用户说: '8点' -> '2025-09-17 08:00:00'\n"
|
||||
f"- 当前时间: 2025-09-16 14:00:00, 用户说: '8点' -> '2025-09-16 20:00:00'\n"
|
||||
f"- 当前时间: 2025-09-16 23:00:00, 用户说: '晚上10点' -> '2025-09-17 22:00:00'"
|
||||
)
|
||||
|
||||
success, response, _, _ = await llm_api.generate_with_model(
|
||||
prompt,
|
||||
model_config=model_to_use,
|
||||
request_type="plugin.reminder.time_parser"
|
||||
)
|
||||
|
||||
if not success or not response:
|
||||
raise ValueError(f"LLM未能返回有效的时间字符串: {response}")
|
||||
|
||||
converted_time_str = response.strip()
|
||||
logger.info(f"[ReminderPlugin] LLM 转换结果: '{converted_time_str}'")
|
||||
target_time = parse_datetime(converted_time_str, fuzzy=False)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[ReminderPlugin] 无法解析或转换时间字符串 '{remind_time_str}': {e}", exc_info=True)
|
||||
await self.send_text(f"抱歉,我无法理解您说的时间 '{remind_time_str}',提醒设置失败。")
|
||||
return False, f"无法解析时间 '{remind_time_str}'"
|
||||
|
||||
now = datetime.now()
|
||||
if target_time <= now:
|
||||
await self.send_text("提醒时间必须是一个未来的时间点哦,提醒设置失败。")
|
||||
return False, "提醒时间必须在未来"
|
||||
|
||||
delay_seconds = (target_time - now).total_seconds()
|
||||
|
||||
# 2. 解析用户
|
||||
person_manager = get_person_info_manager()
|
||||
user_id_to_remind = None
|
||||
user_name_to_remind = ""
|
||||
|
||||
assert isinstance(user_name, str)
|
||||
|
||||
if user_name.strip() in ["自己", "我", "me"]:
|
||||
user_id_to_remind = self.user_id
|
||||
user_name_to_remind = self.user_nickname
|
||||
else:
|
||||
# 1. 精确匹配
|
||||
user_info = await person_manager.get_person_info_by_name(user_name)
|
||||
|
||||
# 2. 包含匹配
|
||||
if not user_info:
|
||||
for person_id, name in person_manager.person_name_list.items():
|
||||
if user_name in name:
|
||||
user_info = await person_manager.get_values(person_id, ["user_id", "user_nickname"])
|
||||
break
|
||||
|
||||
# 3. 模糊匹配 (此处简化为字符串相似度)
|
||||
if not user_info:
|
||||
best_match = None
|
||||
highest_similarity = 0
|
||||
for person_id, name in person_manager.person_name_list.items():
|
||||
import difflib
|
||||
similarity = difflib.SequenceMatcher(None, user_name, name).ratio()
|
||||
if similarity > highest_similarity:
|
||||
highest_similarity = similarity
|
||||
best_match = person_id
|
||||
|
||||
if best_match and highest_similarity > 0.6: # 相似度阈值
|
||||
user_info = await person_manager.get_values(best_match, ["user_id", "user_nickname"])
|
||||
|
||||
if not user_info or not user_info.get("user_id"):
|
||||
logger.warning(f"[ReminderPlugin] 找不到名为 '{user_name}' 的用户")
|
||||
await self.send_text(f"抱歉,我的联系人里找不到叫做 '{user_name}' 的人,提醒设置失败。")
|
||||
return False, f"用户 '{user_name}' 不存在"
|
||||
user_id_to_remind = user_info.get("user_id")
|
||||
user_name_to_remind = user_info.get("user_nickname") or user_name
|
||||
|
||||
# 3. 创建并调度异步任务
|
||||
try:
|
||||
assert user_id_to_remind is not None
|
||||
assert event_details is not None
|
||||
|
||||
reminder_task = ReminderTask(
|
||||
delay=delay_seconds,
|
||||
stream_id=self.chat_stream.stream_id,
|
||||
group_id=self.chat_stream.group_info.group_id if self.is_group and self.chat_stream.group_info else None,
|
||||
is_group=self.is_group,
|
||||
target_user_id=str(user_id_to_remind),
|
||||
target_user_name=str(user_name_to_remind),
|
||||
event_details=str(event_details),
|
||||
creator_name=str(self.user_nickname),
|
||||
chat_stream=self.chat_stream
|
||||
)
|
||||
await async_task_manager.add_task(reminder_task)
|
||||
|
||||
# 4. 生成并发送确认消息
|
||||
extra_info = f"你已经成功设置了一个提醒,请以一种符合你人设的、俏皮的方式回复用户。\n提醒时间: {target_time.strftime('%Y-%m-%d %H:%M:%S')}\n提醒对象: {user_name_to_remind}\n提醒内容: {event_details}"
|
||||
last_message = self.chat_stream.context.get_last_message()
|
||||
success, reply_set, _ = await generator_api.generate_reply(
|
||||
chat_stream=self.chat_stream,
|
||||
extra_info=extra_info,
|
||||
reply_message=last_message.to_dict(),
|
||||
request_type="plugin.reminder.confirm_message"
|
||||
)
|
||||
if success and reply_set:
|
||||
for _, text in reply_set:
|
||||
await self.send_text(text)
|
||||
else:
|
||||
# Fallback message
|
||||
fallback_message = f"好的,我记下了。\n将在 {target_time.strftime('%Y-%m-%d %H:%M:%S')} 提醒 {user_name_to_remind}:\n{event_details}"
|
||||
await self.send_text(fallback_message)
|
||||
|
||||
return True, "提醒设置成功"
|
||||
except Exception as e:
|
||||
logger.error(f"[ReminderPlugin] 创建提醒任务时出错: {e}", exc_info=True)
|
||||
await self.send_text("抱歉,设置提醒时发生了一点内部错误。")
|
||||
return False, "设置提醒时发生内部错误"
|
||||
|
||||
|
||||
# =============================== Plugin ===============================
|
||||
|
||||
@register_plugin
|
||||
class ReminderPlugin(BasePlugin):
|
||||
"""一个能从对话中智能识别并设置定时提醒的插件。"""
|
||||
|
||||
# --- 插件基础信息 ---
|
||||
plugin_name = "reminder_plugin"
|
||||
enable_plugin = True
|
||||
dependencies = []
|
||||
python_dependencies = []
|
||||
config_file_name = "config.toml"
|
||||
config_schema = {}
|
||||
|
||||
def get_plugin_components(self) -> List[Tuple[ActionInfo, Type[BaseAction]]]:
|
||||
"""注册插件的所有功能组件。"""
|
||||
return [
|
||||
(RemindAction.get_action_info(), RemindAction)
|
||||
]
|
||||
@@ -1,198 +0,0 @@
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from typing import List, Tuple, Type
|
||||
from dateutil.parser import parse as parse_datetime
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.manager.async_task_manager import AsyncTask, async_task_manager
|
||||
from src.person_info.person_info import get_person_info_manager
|
||||
from src.plugin_system import (
|
||||
BaseAction,
|
||||
ActionInfo,
|
||||
BasePlugin,
|
||||
register_plugin,
|
||||
ActionActivationType,
|
||||
)
|
||||
from src.plugin_system.apis import send_api
|
||||
from src.plugin_system.base.component_types import ChatType
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
# ============================ AsyncTask ============================
|
||||
|
||||
class ReminderTask(AsyncTask):
|
||||
def __init__(self, delay: float, stream_id: str, is_group: bool, target_user_id: str, target_user_name: str, event_details: str, creator_name: str):
|
||||
super().__init__(task_name=f"ReminderTask_{target_user_id}_{datetime.now().timestamp()}")
|
||||
self.delay = delay
|
||||
self.stream_id = stream_id
|
||||
self.is_group = is_group
|
||||
self.target_user_id = target_user_id
|
||||
self.target_user_name = target_user_name
|
||||
self.event_details = event_details
|
||||
self.creator_name = creator_name
|
||||
|
||||
async def run(self):
|
||||
try:
|
||||
if self.delay > 0:
|
||||
logger.info(f"等待 {self.delay:.2f} 秒后执行提醒...")
|
||||
await asyncio.sleep(self.delay)
|
||||
|
||||
logger.info(f"执行提醒任务: 给 {self.target_user_name} 发送关于 '{self.event_details}' 的提醒")
|
||||
|
||||
reminder_text = f"叮咚!这是 {self.creator_name} 让我准时提醒你的事情:\n\n{self.event_details}"
|
||||
|
||||
if self.is_group:
|
||||
# 在群聊中,构造 @ 消息段并发送
|
||||
group_id = self.stream_id.split('_')[-1] if '_' in self.stream_id else self.stream_id
|
||||
message_payload = [
|
||||
{"type": "at", "data": {"qq": self.target_user_id}},
|
||||
{"type": "text", "data": {"text": f" {reminder_text}"}}
|
||||
]
|
||||
await send_api.adapter_command_to_stream(
|
||||
action="send_group_msg",
|
||||
params={"group_id": group_id, "message": message_payload},
|
||||
stream_id=self.stream_id
|
||||
)
|
||||
else:
|
||||
# 在私聊中,直接发送文本
|
||||
await send_api.text_to_stream(text=reminder_text, stream_id=self.stream_id)
|
||||
|
||||
logger.info(f"提醒任务 {self.task_name} 成功完成。")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"执行提醒任务 {self.task_name} 时出错: {e}", exc_info=True)
|
||||
|
||||
|
||||
# =============================== Actions ===============================
|
||||
|
||||
class RemindAction(BaseAction):
|
||||
"""一个能从对话中智能识别并设置定时提醒的动作。"""
|
||||
|
||||
# === 基本信息 ===
|
||||
action_name = "set_reminder"
|
||||
action_description = "根据用户的对话内容,智能地设置一个未来的提醒事项。"
|
||||
activation_type = ActionActivationType.LLM_JUDGE
|
||||
chat_type_allow = ChatType.ALL
|
||||
|
||||
# === LLM 判断与参数提取 ===
|
||||
llm_judge_prompt = """
|
||||
判断用户是否意图设置一个未来的提醒。
|
||||
- 必须包含明确的时间点或时间段(如“十分钟后”、“明天下午3点”、“周五”)。
|
||||
- 必须包含一个需要被提醒的事件。
|
||||
- 可能会包含需要提醒的特定人物。
|
||||
- 如果只是普通的聊天或询问时间,则不应触发。
|
||||
|
||||
示例:
|
||||
- "半小时后提醒我开会" -> 是
|
||||
- "明天下午三点叫张三来一下" -> 是
|
||||
- "别忘了周五把报告交了" -> 是
|
||||
- "现在几点了?" -> 否
|
||||
- "我明天下午有空" -> 否
|
||||
|
||||
请只回答"是"或"否"。
|
||||
"""
|
||||
action_parameters = {
|
||||
"user_name": "需要被提醒的人的称呼或名字,如果没有明确指定给某人,则默认为'自己'",
|
||||
"remind_time": "描述提醒时间的自然语言字符串,例如'十分钟后'或'明天下午3点'",
|
||||
"event_details": "需要提醒的具体事件内容"
|
||||
}
|
||||
action_require = [
|
||||
"当用户请求在未来的某个时间点提醒他/她或别人某件事时使用",
|
||||
"适用于包含明确时间信息和事件描述的对话",
|
||||
"例如:'10分钟后提醒我收快递'、'明天早上九点喊一下李四参加晨会'"
|
||||
]
|
||||
|
||||
async def execute(self) -> Tuple[bool, str]:
|
||||
"""执行设置提醒的动作"""
|
||||
user_name = self.action_data.get("user_name")
|
||||
remind_time_str = self.action_data.get("remind_time")
|
||||
event_details = self.action_data.get("event_details")
|
||||
|
||||
if not all([user_name, remind_time_str, event_details]):
|
||||
missing_params = [p for p, v in {"user_name": user_name, "remind_time": remind_time_str, "event_details": event_details}.items() if not v]
|
||||
error_msg = f"缺少必要的提醒参数: {', '.join(missing_params)}"
|
||||
logger.warning(f"[ReminderPlugin] LLM未能提取完整参数: {error_msg}")
|
||||
return False, error_msg
|
||||
|
||||
# 1. 解析时间
|
||||
try:
|
||||
assert isinstance(remind_time_str, str)
|
||||
target_time = parse_datetime(remind_time_str, fuzzy=True)
|
||||
except Exception as e:
|
||||
logger.error(f"[ReminderPlugin] 无法解析时间字符串 '{remind_time_str}': {e}")
|
||||
await self.send_text(f"抱歉,我无法理解您说的时间 '{remind_time_str}',提醒设置失败。")
|
||||
return False, f"无法解析时间 '{remind_time_str}'"
|
||||
|
||||
now = datetime.now()
|
||||
if target_time <= now:
|
||||
await self.send_text("提醒时间必须是一个未来的时间点哦,提醒设置失败。")
|
||||
return False, "提醒时间必须在未来"
|
||||
|
||||
delay_seconds = (target_time - now).total_seconds()
|
||||
|
||||
# 2. 解析用户
|
||||
person_manager = get_person_info_manager()
|
||||
user_id_to_remind = None
|
||||
user_name_to_remind = ""
|
||||
|
||||
assert isinstance(user_name, str)
|
||||
|
||||
if user_name.strip() in ["自己", "我", "me"]:
|
||||
user_id_to_remind = self.user_id
|
||||
user_name_to_remind = self.user_nickname
|
||||
else:
|
||||
user_info = await person_manager.get_person_info_by_name(user_name)
|
||||
if not user_info or not user_info.get("user_id"):
|
||||
logger.warning(f"[ReminderPlugin] 找不到名为 '{user_name}' 的用户")
|
||||
await self.send_text(f"抱歉,我的联系人里找不到叫做 '{user_name}' 的人,提醒设置失败。")
|
||||
return False, f"用户 '{user_name}' 不存在"
|
||||
user_id_to_remind = user_info.get("user_id")
|
||||
user_name_to_remind = user_name
|
||||
|
||||
# 3. 创建并调度异步任务
|
||||
try:
|
||||
assert user_id_to_remind is not None
|
||||
assert event_details is not None
|
||||
|
||||
reminder_task = ReminderTask(
|
||||
delay=delay_seconds,
|
||||
stream_id=self.chat_id,
|
||||
is_group=self.is_group,
|
||||
target_user_id=str(user_id_to_remind),
|
||||
target_user_name=str(user_name_to_remind),
|
||||
event_details=str(event_details),
|
||||
creator_name=str(self.user_nickname)
|
||||
)
|
||||
await async_task_manager.add_task(reminder_task)
|
||||
|
||||
# 4. 发送确认消息
|
||||
confirm_message = f"好的,我记下了。\n将在 {target_time.strftime('%Y-%m-%d %H:%M:%S')} 提醒 {user_name_to_remind}:\n{event_details}"
|
||||
await self.send_text(confirm_message)
|
||||
|
||||
return True, "提醒设置成功"
|
||||
except Exception as e:
|
||||
logger.error(f"[ReminderPlugin] 创建提醒任务时出错: {e}", exc_info=True)
|
||||
await self.send_text("抱歉,设置提醒时发生了一点内部错误。")
|
||||
return False, "设置提醒时发生内部错误"
|
||||
|
||||
|
||||
# =============================== Plugin ===============================
|
||||
|
||||
@register_plugin
|
||||
class ReminderPlugin(BasePlugin):
|
||||
"""一个能从对话中智能识别并设置定时提醒的插件。"""
|
||||
|
||||
# --- 插件基础信息 ---
|
||||
plugin_name = "reminder_plugin"
|
||||
enable_plugin = True
|
||||
dependencies = []
|
||||
python_dependencies = []
|
||||
config_file_name = "config.toml"
|
||||
config_schema = {}
|
||||
|
||||
def get_plugin_components(self) -> List[Tuple[ActionInfo, Type[BaseAction]]]:
|
||||
"""注册插件的所有功能组件。"""
|
||||
return [
|
||||
(RemindAction.get_action_info(), RemindAction)
|
||||
]
|
||||
Reference in New Issue
Block a user