fix: PFC不读取聊天记录
This commit is contained in:
@@ -45,7 +45,19 @@ class ActionPlanner:
|
||||
|
||||
# 构建对话目标
|
||||
if conversation_info.goal_list:
|
||||
goal, reasoning = conversation_info.goal_list[-1]
|
||||
last_goal = conversation_info.goal_list[-1]
|
||||
print(last_goal)
|
||||
# 处理字典或元组格式
|
||||
if isinstance(last_goal, tuple) and len(last_goal) == 2:
|
||||
goal, reasoning = last_goal
|
||||
elif isinstance(last_goal, dict) and 'goal' in last_goal and 'reasoning' in last_goal:
|
||||
# 处理字典格式
|
||||
goal = last_goal.get('goal', "目前没有明确对话目标")
|
||||
reasoning = last_goal.get('reasoning', "目前没有明确对话目标,最好思考一个对话目标")
|
||||
else:
|
||||
# 处理未知格式
|
||||
goal = "目前没有明确对话目标"
|
||||
reasoning = "目前没有明确对话目标,最好思考一个对话目标"
|
||||
else:
|
||||
goal = "目前没有明确对话目标"
|
||||
reasoning = "目前没有明确对话目标,最好思考一个对话目标"
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import time
|
||||
import asyncio
|
||||
import traceback
|
||||
from typing import Optional, Dict, Any, List, Tuple
|
||||
from src.common.logger import get_module_logger
|
||||
from ..message.message_base import UserInfo
|
||||
@@ -44,18 +45,14 @@ class ChatObserver:
|
||||
self.stream_id = stream_id
|
||||
self.message_storage = message_storage or MongoDBMessageStorage()
|
||||
|
||||
self.last_user_speak_time: Optional[float] = None # 对方上次发言时间
|
||||
self.last_bot_speak_time: Optional[float] = None # 机器人上次发言时间
|
||||
self.last_check_time: float = time.time() # 上次查看聊天记录时间
|
||||
# self.last_user_speak_time: Optional[float] = None # 对方上次发言时间
|
||||
# self.last_bot_speak_time: Optional[float] = None # 机器人上次发言时间
|
||||
# self.last_check_time: float = time.time() # 上次查看聊天记录时间
|
||||
self.last_message_read: Optional[str] = None # 最后读取的消息ID
|
||||
self.last_message_time: Optional[float] = None # 最后一条消息的时间戳
|
||||
|
||||
self.waiting_start_time: float = time.time() # 等待开始时间,初始化为当前时间
|
||||
|
||||
# 消息历史记录
|
||||
self.message_history: List[Dict[str, Any]] = [] # 所有消息历史
|
||||
self.last_message_id: Optional[str] = None # 最后一条消息的ID
|
||||
self.message_count: int = 0 # 消息计数
|
||||
|
||||
# 运行状态
|
||||
self._running: bool = False
|
||||
@@ -72,7 +69,7 @@ class ChatObserver:
|
||||
self.is_cold_chat_state: bool = False
|
||||
|
||||
self.update_event = asyncio.Event()
|
||||
self.update_interval = 5 # 更新间隔(秒)
|
||||
self.update_interval = 2 # 更新间隔(秒)
|
||||
self.message_cache = []
|
||||
self.update_running = False
|
||||
|
||||
@@ -98,21 +95,17 @@ class ChatObserver:
|
||||
Args:
|
||||
message: 消息数据
|
||||
"""
|
||||
self.message_history.append(message)
|
||||
self.last_message_id = message["message_id"]
|
||||
self.last_message_time = message["time"] # 更新最后消息时间
|
||||
self.message_count += 1
|
||||
|
||||
# 更新说话时间
|
||||
user_info = UserInfo.from_dict(message.get("user_info", {}))
|
||||
if user_info.user_id == global_config.BOT_QQ:
|
||||
self.last_bot_speak_time = message["time"]
|
||||
else:
|
||||
self.last_user_speak_time = message["time"]
|
||||
try:
|
||||
|
||||
# 发送新消息通知
|
||||
notification = create_new_message_notification(sender="chat_observer", target="pfc", message=message)
|
||||
# logger.info(f"发送新ccchandleer消息通知: {message}")
|
||||
notification = create_new_message_notification(sender="chat_observer", target="observation_info", message=message)
|
||||
# logger.info(f"发送新消ddddd息通知: {notification}")
|
||||
# print(self.notification_manager)
|
||||
await self.notification_manager.send_notification(notification)
|
||||
except Exception as e:
|
||||
logger.error(f"添加消息到历史记录时出错: {e}")
|
||||
print(traceback.format_exc())
|
||||
|
||||
# 检查并更新冷场状态
|
||||
await self._check_cold_chat()
|
||||
@@ -156,9 +149,6 @@ class ChatObserver:
|
||||
Returns:
|
||||
bool: 是否有新消息
|
||||
"""
|
||||
if time_point is None:
|
||||
logger.warning("time_point 为 None,返回 False")
|
||||
return False
|
||||
|
||||
if self.last_message_time is None:
|
||||
logger.debug("没有最后消息时间,返回 False")
|
||||
@@ -215,6 +205,8 @@ class ChatObserver:
|
||||
if new_messages:
|
||||
self.last_message_read = new_messages[-1]["message_id"]
|
||||
|
||||
print(f"获取111111111122222222新消息: {new_messages}")
|
||||
|
||||
return new_messages
|
||||
|
||||
async def _fetch_new_messages_before(self, time_point: float) -> List[Dict[str, Any]]:
|
||||
@@ -231,26 +223,32 @@ class ChatObserver:
|
||||
if new_messages:
|
||||
self.last_message_read = new_messages[-1]["message_id"]
|
||||
|
||||
logger.debug(f"获取指定时间点111之前的消息: {new_messages}")
|
||||
|
||||
return new_messages
|
||||
|
||||
"""主要观察循环"""
|
||||
|
||||
async def _update_loop(self):
|
||||
"""更新循环"""
|
||||
try:
|
||||
start_time = time.time()
|
||||
messages = await self._fetch_new_messages_before(start_time)
|
||||
for message in messages:
|
||||
await self._add_message_to_history(message)
|
||||
except Exception as e:
|
||||
logger.error(f"缓冲消息出错: {e}")
|
||||
# try:
|
||||
# start_time = time.time()
|
||||
# messages = await self._fetch_new_messages_before(start_time)
|
||||
# for message in messages:
|
||||
# await self._add_message_to_history(message)
|
||||
# logger.debug(f"缓冲消息: {messages}")
|
||||
# except Exception as e:
|
||||
# logger.error(f"缓冲消息出错: {e}")
|
||||
|
||||
while self._running:
|
||||
try:
|
||||
# 等待事件或超时(1秒)
|
||||
try:
|
||||
# print("等待事件")
|
||||
await asyncio.wait_for(self._update_event.wait(), timeout=1)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
# print("超时")
|
||||
pass # 超时后也执行一次检查
|
||||
|
||||
self._update_event.clear() # 重置触发事件
|
||||
@@ -355,51 +353,6 @@ class ChatObserver:
|
||||
|
||||
return time_info
|
||||
|
||||
def start_periodic_update(self):
|
||||
"""启动观察器的定期更新"""
|
||||
if not self.update_running:
|
||||
self.update_running = True
|
||||
asyncio.create_task(self._periodic_update())
|
||||
|
||||
async def _periodic_update(self):
|
||||
"""定期更新消息历史"""
|
||||
try:
|
||||
while self.update_running:
|
||||
await self._update_message_history()
|
||||
await asyncio.sleep(self.update_interval)
|
||||
except Exception as e:
|
||||
logger.error(f"定期更新消息历史时出错: {str(e)}")
|
||||
|
||||
async def _update_message_history(self) -> bool:
|
||||
"""更新消息历史
|
||||
|
||||
Returns:
|
||||
bool: 是否有新消息
|
||||
"""
|
||||
try:
|
||||
messages = await self.message_storage.get_messages_for_stream(self.stream_id, limit=50)
|
||||
|
||||
if not messages:
|
||||
return False
|
||||
|
||||
# 检查是否有新消息
|
||||
has_new_messages = False
|
||||
if messages and (
|
||||
not self.message_cache or messages[0]["message_id"] != self.message_cache[0]["message_id"]
|
||||
):
|
||||
has_new_messages = True
|
||||
|
||||
self.message_cache = messages
|
||||
|
||||
if has_new_messages:
|
||||
self.update_event.set()
|
||||
self.update_event.clear()
|
||||
return True
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"更新消息历史时出错: {str(e)}")
|
||||
return False
|
||||
|
||||
def get_cached_messages(self, limit: int = 50) -> List[Dict[str, Any]]:
|
||||
"""获取缓存的消息历史
|
||||
@@ -421,3 +374,6 @@ class ChatObserver:
|
||||
if not self.message_cache:
|
||||
return None
|
||||
return self.message_cache[0]
|
||||
|
||||
def __str__(self):
|
||||
return f"ChatObserver for {self.stream_id}"
|
||||
|
||||
@@ -98,11 +98,17 @@ class NotificationManager:
|
||||
notification_type: 要处理的通知类型
|
||||
handler: 处理器实例
|
||||
"""
|
||||
print(1145145511114445551111444)
|
||||
if target not in self._handlers:
|
||||
print("没11有target")
|
||||
self._handlers[target] = {}
|
||||
if notification_type not in self._handlers[target]:
|
||||
print("没11有notification_type")
|
||||
self._handlers[target][notification_type] = []
|
||||
print(self._handlers[target][notification_type])
|
||||
print(f"注册1111111111111111111111处理器: {target} {notification_type} {handler}")
|
||||
self._handlers[target][notification_type].append(handler)
|
||||
print(self._handlers[target][notification_type])
|
||||
|
||||
def unregister_handler(self, target: str, notification_type: NotificationType, handler: NotificationHandler):
|
||||
"""注销通知处理器
|
||||
@@ -126,6 +132,7 @@ class NotificationManager:
|
||||
async def send_notification(self, notification: Notification):
|
||||
"""发送通知"""
|
||||
self._notification_history.append(notification)
|
||||
# print("kaishichul-----------------------------------i")
|
||||
|
||||
# 如果是状态通知,更新活跃状态
|
||||
if isinstance(notification, StateNotification):
|
||||
@@ -134,11 +141,15 @@ class NotificationManager:
|
||||
else:
|
||||
self._active_states.discard(notification.type)
|
||||
|
||||
|
||||
# 调用目标接收者的处理器
|
||||
target = notification.target
|
||||
if target in self._handlers:
|
||||
handlers = self._handlers[target].get(notification.type, [])
|
||||
# print(1111111)
|
||||
print(handlers)
|
||||
for handler in handlers:
|
||||
print(f"调用处理器: {handler}")
|
||||
await handler.handle_notification(notification)
|
||||
|
||||
def get_active_states(self) -> Set[NotificationType]:
|
||||
@@ -171,6 +182,13 @@ class NotificationManager:
|
||||
|
||||
return history
|
||||
|
||||
def __str__(self):
|
||||
str = ""
|
||||
for target, handlers in self._handlers.items():
|
||||
for notification_type, handler_list in handlers.items():
|
||||
str += f"NotificationManager for {target} {notification_type} {handler_list}"
|
||||
return str
|
||||
|
||||
|
||||
# 一些常用的通知创建函数
|
||||
def create_new_message_notification(sender: str, target: str, message: Dict[str, Any]) -> Notification:
|
||||
@@ -182,8 +200,9 @@ def create_new_message_notification(sender: str, target: str, message: Dict[str,
|
||||
target=target,
|
||||
data={
|
||||
"message_id": message.get("message_id"),
|
||||
"content": message.get("content"),
|
||||
"sender": message.get("sender"),
|
||||
"processed_plain_text": message.get("processed_plain_text"),
|
||||
"detailed_plain_text": message.get("detailed_plain_text"),
|
||||
"user_info": message.get("user_info"),
|
||||
"time": message.get("time"),
|
||||
},
|
||||
)
|
||||
@@ -276,3 +295,5 @@ class ChatStateManager:
|
||||
|
||||
current_time = datetime.now().timestamp()
|
||||
return (current_time - self.state_info.last_message_time) <= threshold
|
||||
|
||||
|
||||
|
||||
@@ -60,9 +60,10 @@ class Conversation:
|
||||
self.chat_observer = ChatObserver.get_instance(self.stream_id)
|
||||
self.chat_observer.start()
|
||||
self.observation_info = ObservationInfo()
|
||||
self.observation_info.bind_to_chat_observer(self.stream_id)
|
||||
self.observation_info.bind_to_chat_observer(self.chat_observer)
|
||||
# print(self.chat_observer.get_cached_messages(limit=)
|
||||
|
||||
|
||||
# 对话信息
|
||||
self.conversation_info = ConversationInfo()
|
||||
except Exception as e:
|
||||
logger.error(f"初始化对话实例:注册信息组件失败: {e}")
|
||||
@@ -140,6 +141,7 @@ class Conversation:
|
||||
if action == "direct_reply":
|
||||
self.state = ConversationState.GENERATING
|
||||
self.generated_reply = await self.reply_generator.generate(observation_info, conversation_info)
|
||||
print(f"生成回复: {self.generated_reply}")
|
||||
|
||||
# # 检查回复是否合适
|
||||
# is_suitable, reason, need_replan = await self.reply_generator.check_reply(
|
||||
@@ -148,6 +150,7 @@ class Conversation:
|
||||
# )
|
||||
|
||||
if self._check_new_messages_after_planning():
|
||||
logger.info("333333发现新消息,重新考虑行动")
|
||||
return None
|
||||
|
||||
await self._send_reply()
|
||||
@@ -212,15 +215,9 @@ class Conversation:
|
||||
logger.warning("没有生成回复")
|
||||
return
|
||||
|
||||
messages = self.chat_observer.get_cached_messages(limit=1)
|
||||
if not messages:
|
||||
logger.warning("没有最近的消息可以回复")
|
||||
return
|
||||
|
||||
latest_message = self._convert_to_message(messages[0])
|
||||
try:
|
||||
await self.direct_sender.send_message(
|
||||
chat_stream=self.chat_stream, content=self.generated_reply, reply_to_message=latest_message
|
||||
chat_stream=self.chat_stream, content=self.generated_reply
|
||||
)
|
||||
self.chat_observer.trigger_update() # 触发立即更新
|
||||
if not await self.chat_observer.wait_for_update():
|
||||
|
||||
@@ -1,71 +0,0 @@
|
||||
from typing import TYPE_CHECKING
|
||||
from src.common.logger import get_module_logger
|
||||
from .chat_states import NotificationHandler, Notification, NotificationType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .conversation import Conversation
|
||||
|
||||
logger = get_module_logger("notification_handler")
|
||||
|
||||
|
||||
class PFCNotificationHandler(NotificationHandler):
|
||||
"""PFC通知处理器"""
|
||||
|
||||
def __init__(self, conversation: "Conversation"):
|
||||
"""初始化PFC通知处理器
|
||||
|
||||
Args:
|
||||
conversation: 对话实例
|
||||
"""
|
||||
self.conversation = conversation
|
||||
|
||||
async def handle_notification(self, notification: Notification):
|
||||
"""处理通知
|
||||
|
||||
Args:
|
||||
notification: 通知对象
|
||||
"""
|
||||
logger.debug(f"收到通知: {notification.type.name}, 数据: {notification.data}")
|
||||
|
||||
# 根据通知类型执行不同的处理
|
||||
if notification.type == NotificationType.NEW_MESSAGE:
|
||||
# 新消息通知
|
||||
await self._handle_new_message(notification)
|
||||
elif notification.type == NotificationType.COLD_CHAT:
|
||||
# 冷聊天通知
|
||||
await self._handle_cold_chat(notification)
|
||||
elif notification.type == NotificationType.COMMAND:
|
||||
# 命令通知
|
||||
await self._handle_command(notification)
|
||||
else:
|
||||
logger.warning(f"未知的通知类型: {notification.type.name}")
|
||||
|
||||
async def _handle_new_message(self, notification: Notification):
|
||||
"""处理新消息通知
|
||||
|
||||
Args:
|
||||
notification: 通知对象
|
||||
"""
|
||||
|
||||
# 更新决策信息
|
||||
observation_info = self.conversation.observation_info
|
||||
observation_info.last_message_time = notification.data.get("time", 0)
|
||||
observation_info.add_unprocessed_message(notification.data)
|
||||
|
||||
# 手动触发观察器更新
|
||||
self.conversation.chat_observer.trigger_update()
|
||||
|
||||
async def _handle_cold_chat(self, notification: Notification):
|
||||
"""处理冷聊天通知
|
||||
|
||||
Args:
|
||||
notification: 通知对象
|
||||
"""
|
||||
# 获取冷聊天信息
|
||||
cold_duration = notification.data.get("duration", 0)
|
||||
|
||||
# 更新决策信息
|
||||
observation_info = self.conversation.observation_info
|
||||
observation_info.conversation_cold_duration = cold_duration
|
||||
|
||||
logger.info(f"对话已冷: {cold_duration}秒")
|
||||
@@ -6,7 +6,7 @@ import time
|
||||
from dataclasses import dataclass, field
|
||||
from src.common.logger import get_module_logger
|
||||
from .chat_observer import ChatObserver
|
||||
from .chat_states import NotificationHandler
|
||||
from .chat_states import NotificationHandler, NotificationType
|
||||
|
||||
logger = get_module_logger("observation_info")
|
||||
|
||||
@@ -22,63 +22,70 @@ class ObservationInfoHandler(NotificationHandler):
|
||||
"""
|
||||
self.observation_info = observation_info
|
||||
|
||||
async def handle_notification(self, notification: Dict[str, Any]):
|
||||
"""处理通知
|
||||
async def handle_notification(self, notification):
|
||||
# 获取通知类型和数据
|
||||
notification_type = notification.type
|
||||
data = notification.data
|
||||
|
||||
Args:
|
||||
notification: 通知数据
|
||||
"""
|
||||
notification_type = notification.get("type")
|
||||
data = notification.get("data", {})
|
||||
|
||||
if notification_type == "NEW_MESSAGE":
|
||||
if notification_type == NotificationType.NEW_MESSAGE:
|
||||
# 处理新消息通知
|
||||
logger.debug(f"收到新消息通知data: {data}")
|
||||
message = data.get("message", {})
|
||||
self.observation_info.update_from_message(message)
|
||||
# self.observation_info.has_unread_messages = True
|
||||
# self.observation_info.new_unread_message.append(message.get("processed_plain_text", ""))
|
||||
message_id = data.get("message_id")
|
||||
processed_plain_text = data.get("processed_plain_text")
|
||||
detailed_plain_text = data.get("detailed_plain_text")
|
||||
user_info = data.get("user_info")
|
||||
time_value = data.get("time")
|
||||
|
||||
elif notification_type == "COLD_CHAT":
|
||||
message = {
|
||||
"message_id": message_id,
|
||||
"processed_plain_text": processed_plain_text,
|
||||
"detailed_plain_text": detailed_plain_text,
|
||||
"user_info": user_info,
|
||||
"time": time_value
|
||||
}
|
||||
|
||||
self.observation_info.update_from_message(message)
|
||||
|
||||
elif notification_type == NotificationType.COLD_CHAT:
|
||||
# 处理冷场通知
|
||||
is_cold = data.get("is_cold", False)
|
||||
self.observation_info.update_cold_chat_status(is_cold, time.time())
|
||||
|
||||
elif notification_type == "ACTIVE_CHAT":
|
||||
elif notification_type == NotificationType.ACTIVE_CHAT:
|
||||
# 处理活跃通知
|
||||
is_active = data.get("is_active", False)
|
||||
self.observation_info.is_cold = not is_active
|
||||
|
||||
elif notification_type == "BOT_SPEAKING":
|
||||
elif notification_type == NotificationType.BOT_SPEAKING:
|
||||
# 处理机器人说话通知
|
||||
self.observation_info.is_typing = False
|
||||
self.observation_info.last_bot_speak_time = time.time()
|
||||
|
||||
elif notification_type == "USER_SPEAKING":
|
||||
elif notification_type == NotificationType.USER_SPEAKING:
|
||||
# 处理用户说话通知
|
||||
self.observation_info.is_typing = False
|
||||
self.observation_info.last_user_speak_time = time.time()
|
||||
|
||||
elif notification_type == "MESSAGE_DELETED":
|
||||
elif notification_type == NotificationType.MESSAGE_DELETED:
|
||||
# 处理消息删除通知
|
||||
message_id = data.get("message_id")
|
||||
self.observation_info.unprocessed_messages = [
|
||||
msg for msg in self.observation_info.unprocessed_messages if msg.get("message_id") != message_id
|
||||
]
|
||||
|
||||
elif notification_type == "USER_JOINED":
|
||||
elif notification_type == NotificationType.USER_JOINED:
|
||||
# 处理用户加入通知
|
||||
user_id = data.get("user_id")
|
||||
if user_id:
|
||||
self.observation_info.active_users.add(user_id)
|
||||
|
||||
elif notification_type == "USER_LEFT":
|
||||
elif notification_type == NotificationType.USER_LEFT:
|
||||
# 处理用户离开通知
|
||||
user_id = data.get("user_id")
|
||||
if user_id:
|
||||
self.observation_info.active_users.discard(user_id)
|
||||
|
||||
elif notification_type == "ERROR":
|
||||
elif notification_type == NotificationType.ERROR:
|
||||
# 处理错误通知
|
||||
error_msg = data.get("error", "")
|
||||
logger.error(f"收到错误通知: {error_msg}")
|
||||
@@ -100,6 +107,7 @@ class ObservationInfo:
|
||||
last_message_content: str = ""
|
||||
last_message_sender: Optional[str] = None
|
||||
bot_id: Optional[str] = None
|
||||
chat_history_count: int = 0
|
||||
new_messages_count: int = 0
|
||||
cold_chat_duration: float = 0.0
|
||||
|
||||
@@ -117,28 +125,37 @@ class ObservationInfo:
|
||||
self.chat_observer = None
|
||||
self.handler = ObservationInfoHandler(self)
|
||||
|
||||
def bind_to_chat_observer(self, stream_id: str):
|
||||
def bind_to_chat_observer(self, chat_observer: ChatObserver):
|
||||
"""绑定到指定的chat_observer
|
||||
|
||||
Args:
|
||||
stream_id: 聊天流ID
|
||||
"""
|
||||
self.chat_observer = ChatObserver.get_instance(stream_id)
|
||||
self.chat_observer = chat_observer
|
||||
print(f"1919810----------------------绑定-----------------------------")
|
||||
print(self.chat_observer)
|
||||
print(f"1919810--------------------绑定-----------------------------")
|
||||
print(self.chat_observer.notification_manager)
|
||||
print(f"1919810-------------------绑定-----------------------------")
|
||||
self.chat_observer.notification_manager.register_handler(
|
||||
target="observation_info", notification_type="NEW_MESSAGE", handler=self.handler
|
||||
target="observation_info", notification_type=NotificationType.NEW_MESSAGE, handler=self.handler
|
||||
)
|
||||
self.chat_observer.notification_manager.register_handler(
|
||||
target="observation_info", notification_type="COLD_CHAT", handler=self.handler
|
||||
target="observation_info", notification_type=NotificationType.COLD_CHAT, handler=self.handler
|
||||
)
|
||||
print("1919810------------------------绑定-----------------------------")
|
||||
print(f"1919810--------------------绑定-----------------------------")
|
||||
print(self.chat_observer.notification_manager)
|
||||
print(f"1919810-------------------绑定-----------------------------")
|
||||
|
||||
def unbind_from_chat_observer(self):
|
||||
"""解除与chat_observer的绑定"""
|
||||
if self.chat_observer:
|
||||
self.chat_observer.notification_manager.unregister_handler(
|
||||
target="observation_info", notification_type="NEW_MESSAGE", handler=self.handler
|
||||
target="observation_info", notification_type=NotificationType.NEW_MESSAGE, handler=self.handler
|
||||
)
|
||||
self.chat_observer.notification_manager.unregister_handler(
|
||||
target="observation_info", notification_type="COLD_CHAT", handler=self.handler
|
||||
target="observation_info", notification_type=NotificationType.COLD_CHAT, handler=self.handler
|
||||
)
|
||||
self.chat_observer = None
|
||||
|
||||
@@ -148,8 +165,11 @@ class ObservationInfo:
|
||||
Args:
|
||||
message: 消息数据
|
||||
"""
|
||||
print("1919810-----------------------------------------------------")
|
||||
logger.debug(f"更新信息from_message: {message}")
|
||||
self.last_message_time = message["time"]
|
||||
self.last_message_id = message["message_id"]
|
||||
|
||||
self.last_message_content = message.get("processed_plain_text", "")
|
||||
|
||||
user_info = UserInfo.from_dict(message.get("user_info", {}))
|
||||
@@ -169,7 +189,6 @@ class ObservationInfo:
|
||||
def update_changed(self):
|
||||
"""更新changed状态"""
|
||||
self.changed = True
|
||||
# self.meta_plan_trigger = True
|
||||
|
||||
def update_cold_chat_status(self, is_cold: bool, current_time: float):
|
||||
"""更新冷场状态
|
||||
@@ -223,17 +242,3 @@ class ObservationInfo:
|
||||
self.unprocessed_messages.clear()
|
||||
self.new_messages_count = 0
|
||||
|
||||
def add_unprocessed_message(self, message: Dict[str, Any]):
|
||||
"""添加未处理的消息
|
||||
|
||||
Args:
|
||||
message: 消息数据
|
||||
"""
|
||||
# 防止重复添加同一消息
|
||||
message_id = message.get("message_id")
|
||||
if message_id and not any(m.get("message_id") == message_id for m in self.unprocessed_messages):
|
||||
self.unprocessed_messages.append(message)
|
||||
self.new_messages_count += 1
|
||||
|
||||
# 同时更新其他消息相关信息
|
||||
self.update_from_message(message)
|
||||
|
||||
@@ -99,11 +99,12 @@ class GoalAnalyzer:
|
||||
3. 添加新目标
|
||||
4. 删除不再相关的目标
|
||||
|
||||
请以JSON格式输出当前的所有对话目标,包含以下字段:
|
||||
请以JSON数组格式输出当前的所有对话目标,每个目标包含以下字段:
|
||||
1. goal: 对话目标(简短的一句话)
|
||||
2. reasoning: 对话原因,为什么设定这个目标(简要解释)
|
||||
|
||||
输出格式示例:
|
||||
[
|
||||
{{
|
||||
"goal": "回答用户关于Python编程的具体问题",
|
||||
"reasoning": "用户提出了关于Python的技术问题,需要专业且准确的解答"
|
||||
@@ -111,7 +112,8 @@ class GoalAnalyzer:
|
||||
{{
|
||||
"goal": "回答用户关于python安装的具体问题",
|
||||
"reasoning": "用户提出了关于Python的技术问题,需要专业且准确的解答"
|
||||
}}"""
|
||||
}}
|
||||
]"""
|
||||
|
||||
logger.debug(f"发送到LLM的提示词: {prompt}")
|
||||
try:
|
||||
@@ -120,13 +122,37 @@ class GoalAnalyzer:
|
||||
except Exception as e:
|
||||
logger.error(f"分析对话目标时出错: {str(e)}")
|
||||
content = ""
|
||||
# 使用简化函数提取JSON内容
|
||||
success, result = get_items_from_json(
|
||||
content, "goal", "reasoning", required_types={"goal": str, "reasoning": str}
|
||||
)
|
||||
# TODO
|
||||
|
||||
conversation_info.goal_list.append(result)
|
||||
# 使用改进后的get_items_from_json函数处理JSON数组
|
||||
success, result = get_items_from_json(
|
||||
content, "goal", "reasoning",
|
||||
required_types={"goal": str, "reasoning": str},
|
||||
allow_array=True
|
||||
)
|
||||
|
||||
if success:
|
||||
# 判断结果是单个字典还是字典列表
|
||||
if isinstance(result, list):
|
||||
# 清空现有目标列表并添加新目标
|
||||
conversation_info.goal_list = []
|
||||
for item in result:
|
||||
goal = item.get("goal", "")
|
||||
reasoning = item.get("reasoning", "")
|
||||
conversation_info.goal_list.append((goal, reasoning))
|
||||
|
||||
# 返回第一个目标作为当前主要目标(如果有)
|
||||
if result:
|
||||
first_goal = result[0]
|
||||
return (first_goal.get("goal", ""), "", first_goal.get("reasoning", ""))
|
||||
else:
|
||||
# 单个目标的情况
|
||||
goal = result.get("goal", "")
|
||||
reasoning = result.get("reasoning", "")
|
||||
conversation_info.goal_list.append((goal, reasoning))
|
||||
return (goal, "", reasoning)
|
||||
|
||||
# 如果解析失败,返回默认值
|
||||
return ("", "", "")
|
||||
|
||||
async def _update_goals(self, new_goal: str, method: str, reasoning: str):
|
||||
"""更新目标列表
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import json
|
||||
import re
|
||||
from typing import Dict, Any, Optional, Tuple
|
||||
from typing import Dict, Any, Optional, Tuple, List, Union
|
||||
from src.common.logger import get_module_logger
|
||||
|
||||
logger = get_module_logger("pfc_utils")
|
||||
@@ -11,7 +11,8 @@ def get_items_from_json(
|
||||
*items: str,
|
||||
default_values: Optional[Dict[str, Any]] = None,
|
||||
required_types: Optional[Dict[str, type]] = None,
|
||||
) -> Tuple[bool, Dict[str, Any]]:
|
||||
allow_array: bool = True,
|
||||
) -> Tuple[bool, Union[Dict[str, Any], List[Dict[str, Any]]]]:
|
||||
"""从文本中提取JSON内容并获取指定字段
|
||||
|
||||
Args:
|
||||
@@ -19,9 +20,10 @@ def get_items_from_json(
|
||||
*items: 要提取的字段名
|
||||
default_values: 字段的默认值,格式为 {字段名: 默认值}
|
||||
required_types: 字段的必需类型,格式为 {字段名: 类型}
|
||||
allow_array: 是否允许解析JSON数组
|
||||
|
||||
Returns:
|
||||
Tuple[bool, Dict[str, Any]]: (是否成功, 提取的字段字典)
|
||||
Tuple[bool, Union[Dict[str, Any], List[Dict[str, Any]]]]: (是否成功, 提取的字段字典或字典列表)
|
||||
"""
|
||||
content = content.strip()
|
||||
result = {}
|
||||
@@ -30,7 +32,57 @@ def get_items_from_json(
|
||||
if default_values:
|
||||
result.update(default_values)
|
||||
|
||||
# 尝试解析JSON
|
||||
# 首先尝试解析为JSON数组
|
||||
if allow_array:
|
||||
try:
|
||||
# 尝试找到文本中的JSON数组
|
||||
array_pattern = r"\[[\s\S]*\]"
|
||||
array_match = re.search(array_pattern, content)
|
||||
if array_match:
|
||||
array_content = array_match.group()
|
||||
json_array = json.loads(array_content)
|
||||
|
||||
# 确认是数组类型
|
||||
if isinstance(json_array, list):
|
||||
# 验证数组中的每个项目是否包含所有必需字段
|
||||
valid_items = []
|
||||
for item in json_array:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
|
||||
# 检查是否有所有必需字段
|
||||
if all(field in item for field in items):
|
||||
# 验证字段类型
|
||||
if required_types:
|
||||
type_valid = True
|
||||
for field, expected_type in required_types.items():
|
||||
if field in item and not isinstance(item[field], expected_type):
|
||||
type_valid = False
|
||||
break
|
||||
|
||||
if not type_valid:
|
||||
continue
|
||||
|
||||
# 验证字符串字段不为空
|
||||
string_valid = True
|
||||
for field in items:
|
||||
if isinstance(item[field], str) and not item[field].strip():
|
||||
string_valid = False
|
||||
break
|
||||
|
||||
if not string_valid:
|
||||
continue
|
||||
|
||||
valid_items.append(item)
|
||||
|
||||
if valid_items:
|
||||
return True, valid_items
|
||||
except json.JSONDecodeError:
|
||||
logger.debug("JSON数组解析失败,尝试解析单个JSON对象")
|
||||
except Exception as e:
|
||||
logger.debug(f"尝试解析JSON数组时出错: {str(e)}")
|
||||
|
||||
# 尝试解析JSON对象
|
||||
try:
|
||||
json_data = json.loads(content)
|
||||
except json.JSONDecodeError:
|
||||
|
||||
@@ -26,11 +26,11 @@ logger = get_module_logger("llm_generator", config=llm_config)
|
||||
class ResponseGenerator:
|
||||
def __init__(self):
|
||||
self.model_normal = LLM_request(
|
||||
model=global_config.llm_normal, temperature=0.3, max_tokens=256, request_type="response_heartflow"
|
||||
model=global_config.llm_normal, temperature=0.15, max_tokens=256, request_type="response_heartflow"
|
||||
)
|
||||
|
||||
self.model_sum = LLM_request(
|
||||
model=global_config.llm_summary_by_topic, temperature=0.7, max_tokens=2000, request_type="relation"
|
||||
model=global_config.llm_summary_by_topic, temperature=0.6, max_tokens=2000, request_type="relation"
|
||||
)
|
||||
self.current_model_type = "r1" # 默认使用 R1
|
||||
self.current_model_name = "unknown model"
|
||||
|
||||
Reference in New Issue
Block a user