This commit is contained in:
tcmofashi
2025-05-23 12:14:31 +08:00
74 changed files with 726 additions and 1446 deletions

View File

@@ -1,5 +1,5 @@
"""
MaiMBot插件系统
MaiBot模块系统
包含聊天、情绪、记忆、日程等功能模块
"""

View File

@@ -17,7 +17,7 @@ from src.manager.mood_manager import mood_manager
from src.chat.heart_flow.utils_chat import get_chat_type_and_target_info
from src.chat.message_receive.chat_stream import ChatStream
from src.chat.focus_chat.hfc_utils import parse_thinking_id_to_timestamp
from src.individuality.individuality import Individuality
from src.individuality.individuality import individuality
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat
import time
@@ -281,7 +281,6 @@ class DefaultExpressor:
in_mind_reply,
target_message,
) -> str:
individuality = Individuality.get_instance()
prompt_personality = individuality.get_prompt(x_person=0, level=2)
# Determine if it's a group chat
@@ -294,7 +293,7 @@ class DefaultExpressor:
message_list_before_now = get_raw_msg_before_timestamp_with_chat(
chat_id=chat_stream.stream_id,
timestamp=time.time(),
limit=global_config.chat.observation_context_size,
limit=global_config.focus_chat.observation_context_size,
)
chat_talking_prompt = await build_readable_messages(
message_list_before_now,

View File

@@ -36,24 +36,6 @@ def init_prompt() -> None:
"""
Prompt(learn_style_prompt, "learn_style_prompt")
personality_expression_prompt = """
{personality}
请从以上人设中总结出这个角色可能的语言风格
思考回复的特殊内容和情感
思考有没有特殊的梗,一并总结成语言风格
总结成如下格式的规律,总结的内容要详细,但具有概括性:
"xxx"时,可以"xxx", xxx不超过10个字
例如:
"表示十分惊叹"时,使用"我嘞个xxxx"
"表示讽刺的赞同,不想讲道理"时,使用"对对对"
"想说明某个观点,但懒得明说",使用"懂的都懂"
现在请你概括
"""
Prompt(personality_expression_prompt, "personality_expression_prompt")
learn_grammar_prompt = """
{chat_str}
@@ -278,44 +260,6 @@ class ExpressionLearner:
expressions.append((chat_id, situation, style))
return expressions
async def extract_and_store_personality_expressions(self):
"""
检查data/expression/personality目录不存在则创建。
用peronality变量作为chat_str调用LLM生成表达风格解析后count=100存储到expressions.json。
"""
dir_path = os.path.join("data", "expression", "personality")
os.makedirs(dir_path, exist_ok=True)
file_path = os.path.join(dir_path, "expressions.json")
# 构建prompt
prompt = await global_prompt_manager.format_prompt(
"personality_expression_prompt",
personality=global_config.personality.expression_style,
)
# logger.info(f"个性表达方式提取prompt: {prompt}")
try:
response, _ = await self.express_learn_model.generate_response_async(prompt)
except Exception as e:
logger.error(f"个性表达方式提取失败: {e}")
return
logger.info(f"个性表达方式提取response: {response}")
# chat_id用personality
expressions = self.parse_expression_response(response, "personality")
# 转为dict并count=100
result = []
for _, situation, style in expressions:
result.append({"situation": situation, "style": style, "count": 100})
# 超过50条时随机删除多余的只保留50条
if len(result) > 50:
remove_count = len(result) - 50
remove_indices = set(random.sample(range(len(result)), remove_count))
result = [item for idx, item in enumerate(result) if idx not in remove_indices]
with open(file_path, "w", encoding="utf-8") as f:
json.dump(result, f, ensure_ascii=False, indent=2)
logger.info(f"已写入{len(result)}条表达到{file_path}")
init_prompt()

View File

@@ -3,7 +3,7 @@ import contextlib
import time
import traceback
from collections import deque
from typing import List, Optional, Dict, Any, Deque, Callable, Coroutine
from typing import List, Optional, Dict, Any, Deque
from src.chat.message_receive.chat_stream import ChatStream
from src.chat.message_receive.chat_stream import chat_manager
from rich.traceback import install
@@ -26,10 +26,22 @@ from src.chat.focus_chat.info_processors.self_processor import SelfProcessor
from src.chat.focus_chat.planners.planner import ActionPlanner
from src.chat.focus_chat.planners.action_manager import ActionManager
from src.chat.focus_chat.working_memory.working_memory import WorkingMemory
from src.config.config import global_config
install(extra_lines=3)
# 定义处理器映射:键是处理器名称,值是 (处理器类, 可选的配置键名)
# 如果配置键名为 None则该处理器默认启用且不能通过 focus_chat_processor 配置禁用
PROCESSOR_CLASSES = {
"ChattingInfoProcessor": (ChattingInfoProcessor, None),
"MindProcessor": (MindProcessor, None),
"ToolProcessor": (ToolProcessor, "tool_use_processor"),
"WorkingMemoryProcessor": (WorkingMemoryProcessor, "working_memory_processor"),
"SelfProcessor": (SelfProcessor, "self_identify_processor"),
}
WAITING_TIME_THRESHOLD = 300 # 等待新消息时间阈值,单位秒
EMOJI_SEND_PRO = 0.3 # 设置一个概率,比如 30% 才真的发
@@ -90,6 +102,21 @@ class HeartFChatting:
observe_id=self.stream_id, working_memory=self.working_memory
)
# 根据配置文件和默认规则确定启用的处理器
self.enabled_processor_names: List[str] = []
config_processor_settings = global_config.focus_chat_processor
for proc_name, (_proc_class, config_key) in PROCESSOR_CLASSES.items():
if config_key: # 此处理器可通过配置控制
if getattr(config_processor_settings, config_key, True): # 默认启用 (如果配置中未指定该键)
self.enabled_processor_names.append(proc_name)
else: # 此处理器不在配置映射中 (config_key is None),默认启用
self.enabled_processor_names.append(proc_name)
logger.info(f"{self.log_prefix} 将启用的处理器: {self.enabled_processor_names}")
self.processors: List[BaseProcessor] = []
self._register_default_processors()
self.expressor = DefaultExpressor(chat_id=self.stream_id)
self.action_manager = ActionManager()
self.action_planner = ActionPlanner(log_prefix=self.log_prefix, action_manager=self.action_manager)
@@ -97,9 +124,6 @@ class HeartFChatting:
self.hfcloop_observation.set_action_manager(self.action_manager)
self.all_observations = observations
# --- 处理器列表 ---
self.processors: List[BaseProcessor] = []
self._register_default_processors()
# 初始化状态控制
self._initialized = False
@@ -150,13 +174,40 @@ class HeartFChatting:
return True
def _register_default_processors(self):
"""注册默认的信息处理器"""
self.processors.append(ChattingInfoProcessor())
self.processors.append(MindProcessor(subheartflow_id=self.stream_id))
self.processors.append(ToolProcessor(subheartflow_id=self.stream_id))
self.processors.append(WorkingMemoryProcessor(subheartflow_id=self.stream_id))
self.processors.append(SelfProcessor(subheartflow_id=self.stream_id))
logger.info(f"{self.log_prefix} 已注册默认处理器: {[p.__class__.__name__ for p in self.processors]}")
"""根据 self.enabled_processor_names 注册信息处理器"""
self.processors = [] # 清空已有的
for name in self.enabled_processor_names: # 'name' is "ChattingInfoProcessor", etc.
processor_info = PROCESSOR_CLASSES.get(name) # processor_info is (ProcessorClass, config_key)
if processor_info:
processor_actual_class = processor_info[0] # 获取实际的类定义
# 根据处理器类名判断是否需要 subheartflow_id
if name in ["MindProcessor", "ToolProcessor", "WorkingMemoryProcessor", "SelfProcessor"]:
self.processors.append(processor_actual_class(subheartflow_id=self.stream_id))
elif name == "ChattingInfoProcessor":
self.processors.append(processor_actual_class())
else:
# 对于PROCESSOR_CLASSES中定义但此处未明确处理构造的处理器
# (例如, 新增了一个处理器到PROCESSOR_CLASSES, 它不需要id, 也不叫ChattingInfoProcessor)
try:
self.processors.append(processor_actual_class()) # 尝试无参构造
logger.debug(f"{self.log_prefix} 注册处理器 {name} (尝试无参构造).")
except TypeError:
logger.error(
f"{self.log_prefix} 处理器 {name} 构造失败。它可能需要参数(如 subheartflow_id但未在注册逻辑中明确处理。"
)
else:
# 这理论上不应该发生,因为 enabled_processor_names 是从 PROCESSOR_CLASSES 的键生成的
logger.warning(
f"{self.log_prefix} 在 PROCESSOR_CLASSES 中未找到名为 '{name}' 的处理器定义,将跳过注册。"
)
if self.processors:
logger.info(
f"{self.log_prefix} 已根据配置和默认规则注册处理器: {[p.__class__.__name__ for p in self.processors]}"
)
else:
logger.warning(f"{self.log_prefix} 没有注册任何处理器。这可能是由于配置错误或所有处理器都被禁用了。")
async def start(self):
"""
@@ -260,6 +311,8 @@ class HeartFChatting:
+ (f"\n详情: {'; '.join(timer_strings)}" if timer_strings else "")
)
await asyncio.sleep(global_config.focus_chat.think_interval)
except asyncio.CancelledError:
# 设置了关闭标志位后被取消是正常流程
if not self._shutting_down:

View File

@@ -5,7 +5,6 @@ from ...config.config import global_config
from ..message_receive.message import MessageRecv
from ..message_receive.storage import MessageStorage
from ..utils.utils import is_mentioned_bot_in_message
from maim_message import Seg
from src.chat.heart_flow.heartflow import heartflow
from src.common.logger_manager import get_logger
from ..message_receive.chat_stream import chat_manager
@@ -79,26 +78,26 @@ async def _calculate_interest(message: MessageRecv) -> Tuple[float, bool]:
return interested_rate, is_mentioned
def _get_message_type(message: MessageRecv) -> str:
"""获取消息类型
# def _get_message_type(message: MessageRecv) -> str:
# """获取消息类型
Args:
message: 消息对象
# Args:
# message: 消息对象
Returns:
str: 消息类型
"""
if message.message_segment.type != "seglist":
return message.message_segment.type
# Returns:
# str: 消息类型
# """
# if message.message_segment.type != "seglist":
# return message.message_segment.type
if (
isinstance(message.message_segment.data, list)
and all(isinstance(x, Seg) for x in message.message_segment.data)
and len(message.message_segment.data) == 1
):
return message.message_segment.data[0].type
# if (
# isinstance(message.message_segment.data, list)
# and all(isinstance(x, Seg) for x in message.message_segment.data)
# and len(message.message_segment.data) == 1
# ):
# return message.message_segment.data[0].type
return "seglist"
# return "seglist"
def _check_ban_words(text: str, chat, userinfo) -> bool:
@@ -112,7 +111,7 @@ def _check_ban_words(text: str, chat, userinfo) -> bool:
Returns:
bool: 是否包含过滤词
"""
for word in global_config.chat.ban_words:
for word in global_config.message_receive.ban_words:
if word in text:
chat_name = chat.group_info.group_name if chat.group_info else "私聊"
logger.info(f"[{chat_name}]{userinfo.user_nickname}:{text}")
@@ -132,7 +131,7 @@ def _check_ban_regex(text: str, chat, userinfo) -> bool:
Returns:
bool: 是否匹配过滤正则
"""
for pattern in global_config.chat.ban_msgs_regex:
for pattern in global_config.message_receive.ban_msgs_regex:
if pattern.search(text):
chat_name = chat.group_info.group_name if chat.group_info else "私聊"
logger.info(f"[{chat_name}]{userinfo.user_nickname}:{text}")
@@ -141,7 +140,7 @@ def _check_ban_regex(text: str, chat, userinfo) -> bool:
return False
class HeartFCProcessor:
class HeartFCMessageReceiver:
"""心流处理器,负责处理接收到的消息并计算兴趣度"""
def __init__(self):

View File

@@ -1,20 +1,16 @@
from src.config.config import global_config
from src.common.logger_manager import get_logger
from src.individuality.individuality import Individuality
from src.individuality.individuality import individuality
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat
from src.chat.person_info.relationship_manager import relationship_manager
from src.chat.utils.utils import get_embedding
import time
from typing import Union, Optional
from typing import Optional
from src.chat.utils.utils import get_recent_group_speaker
from src.manager.mood_manager import mood_manager
from src.chat.memory_system.Hippocampus import HippocampusManager
from src.chat.knowledge.knowledge_lib import qa_manager
import random
import json
import math
from src.common.database.database_model import Knowledges
logger = get_logger("prompt")
@@ -103,7 +99,6 @@ class PromptBuilder:
return None
async def _build_prompt_normal(self, chat_stream, message_txt: str, sender_name: str = "某人") -> str:
individuality = Individuality.get_instance()
prompt_personality = individuality.get_prompt(x_person=2, level=2)
is_group_chat = bool(chat_stream.group_info)
@@ -112,7 +107,7 @@ class PromptBuilder:
who_chat_in_group = get_recent_group_speaker(
chat_stream.stream_id,
(chat_stream.user_info.platform, chat_stream.user_info.user_id) if chat_stream.user_info else None,
limit=global_config.chat.observation_context_size,
limit=global_config.focus_chat.observation_context_size,
)
elif chat_stream.user_info:
who_chat_in_group.append(
@@ -161,7 +156,7 @@ class PromptBuilder:
message_list_before_now = get_raw_msg_before_timestamp_with_chat(
chat_id=chat_stream.stream_id,
timestamp=time.time(),
limit=global_config.chat.observation_context_size,
limit=global_config.focus_chat.observation_context_size,
)
chat_talking_prompt = await build_readable_messages(
message_list_before_now,
@@ -265,129 +260,6 @@ class PromptBuilder:
return prompt
async def get_prompt_info_old(self, message: str, threshold: float):
start_time = time.time()
related_info = ""
logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}")
# 1. 先从LLM获取主题类似于记忆系统的做法
topics = []
# 如果无法提取到主题,直接使用整个消息
if not topics:
logger.info("未能提取到任何主题,使用整个消息进行查询")
embedding = await get_embedding(message, request_type="prompt_build")
if not embedding:
logger.error("获取消息嵌入向量失败")
return ""
related_info = self.get_info_from_db(embedding, limit=3, threshold=threshold)
logger.info(f"知识库检索完成,总耗时: {time.time() - start_time:.3f}")
return related_info
# 2. 对每个主题进行知识库查询
logger.info(f"开始处理{len(topics)}个主题的知识库查询")
# 优化批量获取嵌入向量减少API调用
embeddings = {}
topics_batch = [topic for topic in topics if len(topic) > 0]
if message: # 确保消息非空
topics_batch.append(message)
# 批量获取嵌入向量
embed_start_time = time.time()
for text in topics_batch:
if not text or len(text.strip()) == 0:
continue
try:
embedding = await get_embedding(text, request_type="prompt_build")
if embedding:
embeddings[text] = embedding
else:
logger.warning(f"获取'{text}'的嵌入向量失败")
except Exception as e:
logger.error(f"获取'{text}'的嵌入向量时发生错误: {str(e)}")
logger.info(f"批量获取嵌入向量完成,耗时: {time.time() - embed_start_time:.3f}")
if not embeddings:
logger.error("所有嵌入向量获取失败")
return ""
# 3. 对每个主题进行知识库查询
all_results = []
query_start_time = time.time()
# 首先添加原始消息的查询结果
if message in embeddings:
original_results = self.get_info_from_db(embeddings[message], limit=3, threshold=threshold, return_raw=True)
if original_results:
for result in original_results:
result["topic"] = "原始消息"
all_results.extend(original_results)
logger.info(f"原始消息查询到{len(original_results)}条结果")
# 然后添加每个主题的查询结果
for topic in topics:
if not topic or topic not in embeddings:
continue
try:
topic_results = self.get_info_from_db(embeddings[topic], limit=3, threshold=threshold, return_raw=True)
if topic_results:
# 添加主题标记
for result in topic_results:
result["topic"] = topic
all_results.extend(topic_results)
logger.info(f"主题'{topic}'查询到{len(topic_results)}条结果")
except Exception as e:
logger.error(f"查询主题'{topic}'时发生错误: {str(e)}")
logger.info(f"知识库查询完成,耗时: {time.time() - query_start_time:.3f}秒,共获取{len(all_results)}条结果")
# 4. 去重和过滤
process_start_time = time.time()
unique_contents = set()
filtered_results = []
for result in all_results:
content = result["content"]
if content not in unique_contents:
unique_contents.add(content)
filtered_results.append(result)
# 5. 按相似度排序
filtered_results.sort(key=lambda x: x["similarity"], reverse=True)
# 6. 限制总数量最多10条
filtered_results = filtered_results[:10]
logger.info(
f"结果处理完成,耗时: {time.time() - process_start_time:.3f}秒,过滤后剩余{len(filtered_results)}条结果"
)
# 7. 格式化输出
if filtered_results:
format_start_time = time.time()
grouped_results = {}
for result in filtered_results:
topic = result["topic"]
if topic not in grouped_results:
grouped_results[topic] = []
grouped_results[topic].append(result)
# 按主题组织输出
for topic, results in grouped_results.items():
related_info += f"【主题: {topic}\n"
for _i, result in enumerate(results, 1):
_similarity = result["similarity"]
content = result["content"].strip()
related_info += f"{content}\n"
related_info += "\n"
logger.info(f"格式化输出完成,耗时: {time.time() - format_start_time:.3f}")
logger.info(f"知识库检索总耗时: {time.time() - start_time:.3f}")
return related_info
async def get_prompt_info(self, message: str, threshold: float):
related_info = ""
start_time = time.time()
@@ -407,93 +279,11 @@ class PromptBuilder:
logger.debug(f"获取知识库内容,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}")
return related_info
else:
logger.debug("从LPMM知识库获取知识失败使用旧版数据库进行检索")
knowledge_from_old = await self.get_prompt_info_old(message, threshold=threshold)
related_info += knowledge_from_old
logger.debug(f"获取知识库内容,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}")
return related_info
logger.debug("从LPMM知识库获取知识失败可能是从未导入过知识,返回空知识...")
return "未检索到知识"
except Exception as e:
logger.error(f"获取知识库内容时发生异常: {str(e)}")
try:
knowledge_from_old = await self.get_prompt_info_old(message, threshold=threshold)
related_info += knowledge_from_old
logger.debug(
f"异常后使用旧版数据库获取知识,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}"
)
return related_info
except Exception as e2:
logger.error(f"使用旧版数据库获取知识时也发生异常: {str(e2)}")
return ""
@staticmethod
def get_info_from_db(
query_embedding: list, limit: int = 1, threshold: float = 0.5, return_raw: bool = False
) -> Union[str, list]:
if not query_embedding:
return "" if not return_raw else []
results_with_similarity = []
try:
# Fetch all knowledge entries
# This might be inefficient for very large databases.
# Consider strategies like FAISS or other vector search libraries if performance becomes an issue.
all_knowledges = Knowledges.select()
if not all_knowledges:
return [] if return_raw else ""
query_embedding_magnitude = math.sqrt(sum(x * x for x in query_embedding))
if query_embedding_magnitude == 0: # Avoid division by zero
return "" if not return_raw else []
for knowledge_item in all_knowledges:
try:
db_embedding_str = knowledge_item.embedding
db_embedding = json.loads(db_embedding_str)
if len(db_embedding) != len(query_embedding):
logger.warning(
f"Embedding length mismatch for knowledge ID {knowledge_item.id if hasattr(knowledge_item, 'id') else 'N/A'}. Skipping."
)
continue
# Calculate Cosine Similarity
dot_product = sum(q * d for q, d in zip(query_embedding, db_embedding))
db_embedding_magnitude = math.sqrt(sum(x * x for x in db_embedding))
if db_embedding_magnitude == 0: # Avoid division by zero
similarity = 0.0
else:
similarity = dot_product / (query_embedding_magnitude * db_embedding_magnitude)
if similarity >= threshold:
results_with_similarity.append({"content": knowledge_item.content, "similarity": similarity})
except json.JSONDecodeError:
logger.error(
f"Failed to parse embedding for knowledge ID {knowledge_item.id if hasattr(knowledge_item, 'id') else 'N/A'}"
)
except Exception as e:
logger.error(f"Error processing knowledge item: {e}")
# Sort by similarity in descending order
results_with_similarity.sort(key=lambda x: x["similarity"], reverse=True)
# Limit results
limited_results = results_with_similarity[:limit]
logger.debug(f"知识库查询结果数量 (after Peewee processing): {len(limited_results)}")
if not limited_results:
return "" if not return_raw else []
if return_raw:
return limited_results
else:
return "\n".join(str(result["content"]) for result in limited_results)
except Exception as e:
logger.error(f"Error querying Knowledges with Peewee: {e}")
return "" if not return_raw else []
return "未检索到知识"
init_prompt()

View File

@@ -80,4 +80,4 @@ class ActionInfo(InfoBase):
Returns:
bool: 如果有任何动作需要添加或移除则返回True
"""
return bool(self.get_add_actions() or self.get_remove_actions())
return bool(self.get_add_actions() or self.get_remove_actions())

View File

@@ -1,14 +1,10 @@
from typing import List, Optional, Any
from src.chat.focus_chat.info.obs_info import ObsInfo
from src.chat.heart_flow.observation.observation import Observation
from src.chat.focus_chat.info.info_base import InfoBase
from src.chat.focus_chat.info.action_info import ActionInfo
from .base_processor import BaseProcessor
from src.common.logger_manager import get_logger
from src.chat.heart_flow.observation.chatting_observation import ChattingObservation
from src.chat.heart_flow.observation.hfcloop_observation import HFCloopObservation
from src.chat.focus_chat.info.cycle_info import CycleInfo
from datetime import datetime
from typing import Dict
from src.chat.models.utils_model import LLMRequest
from src.config.config import global_config
@@ -55,10 +51,7 @@ class ActionProcessor(BaseProcessor):
# 处理Observation对象
if observations:
for obs in observations:
if isinstance(obs, HFCloopObservation):
# 创建动作信息
action_info = ActionInfo()
action_changes = await self.analyze_loop_actions(obs)
@@ -75,7 +68,6 @@ class ActionProcessor(BaseProcessor):
return processed_infos
async def analyze_loop_actions(self, obs: HFCloopObservation) -> Dict[str, List[str]]:
"""分析最近的循环内容并决定动作的增减
@@ -87,29 +79,29 @@ class ActionProcessor(BaseProcessor):
}
"""
result = {"add": [], "remove": []}
# 获取最近10次循环
recent_cycles = obs.history_loop[-10:] if len(obs.history_loop) > 10 else obs.history_loop
if not recent_cycles:
return result
# 统计no_reply的数量
no_reply_count = 0
reply_sequence = [] # 记录最近的动作序列
for cycle in recent_cycles:
action_type = cycle.loop_plan_info["action_result"]["action_type"]
if action_type == "no_reply":
no_reply_count += 1
reply_sequence.append(action_type == "reply")
# 检查no_reply比例
if len(recent_cycles) >= 5 and (no_reply_count / len(recent_cycles)) >= 0.8:
result["add"].append("exit_focus_chat")
# 获取最近三次的reply状态
last_three = reply_sequence[-3:] if len(reply_sequence) >= 3 else reply_sequence
# 根据最近的reply情况决定是否移除reply动作
if len(last_three) >= 3 and all(last_three):
# 如果最近三次都是reply直接移除
@@ -122,5 +114,5 @@ class ActionProcessor(BaseProcessor):
# 如果最近一次是reply20%概率移除
if random.random() < 0.2:
result["remove"].append("reply")
return result

View File

@@ -5,7 +5,7 @@ from src.config.config import global_config
import time
import traceback
from src.common.logger_manager import get_logger
from src.individuality.individuality import Individuality
from src.individuality.individuality import individuality
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.utils.json_utils import safe_json_dumps
from src.chat.message_receive.chat_stream import chat_manager
@@ -194,7 +194,6 @@ class MindProcessor(BaseProcessor):
# ---------- 3. 准备个性化数据 ----------
# 获取个性化信息
individuality = Individuality.get_instance()
relation_prompt = ""
for person in person_list:

View File

@@ -5,7 +5,7 @@ from src.config.config import global_config
import time
import traceback
from src.common.logger_manager import get_logger
from src.individuality.individuality import Individuality
from src.individuality.individuality import individuality
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.message_receive.chat_stream import chat_manager
from src.chat.person_info.relationship_manager import relationship_manager
@@ -21,8 +21,9 @@ logger = get_logger("processor")
def init_prompt():
indentify_prompt = """
你的名字是{bot_name},你是一个AI但是你伪装成了一个人类你的人格是{prompt_personality}
的头像形象是一只橙色的鱼,头上有绿色的树叶
{name_block}
是一个AI但是你伪装成了一个人类你的人格是{prompt_personality}
{indentify_block}
{relation_prompt}
{memory_str}
@@ -36,8 +37,8 @@ def init_prompt():
3. 你的自我认同是否有助于你的回答,如果你需要自我相关的信息来帮你参与聊天,请输出,否则请输出十个字以内的简短自我认同
4. 一般情况下不用输出自我认同,只需要输出十几个字的简短自我认同就好,除非有明显需要自我认同的场景
回复的平淡一些,简短一些,说中文,不要浮夸,平淡一些。
请注意不要输出多余内容(包括前后缀,冒号和引号,括号()表情包at或 @等 )。只输出内容。
思考的平淡一些,简短一些,说中文,不要浮夸,平淡一些。
请注意不要输出多余内容(包括前后缀,冒号和引号,括号()表情包at或 @等 )。只输出自我认同内容。
"""
Prompt(indentify_prompt, "indentify_prompt")
@@ -125,16 +126,22 @@ class SelfProcessor(BaseProcessor):
# hfcloop_observe_info = observation.get_observe_info()
pass
individuality = Individuality.get_instance()
personality_block = individuality.get_prompt(x_person=2, level=2)
nickname_str = ""
for nicknames in global_config.bot.alias_names:
nickname_str += f"{nicknames},"
name_block = f"你的名字是{global_config.bot.nickname},你的昵称有{nickname_str},有人也会用这些昵称称呼你。"
personality_block = individuality.get_personality_prompt(x_person=2, level=2)
identity_block = individuality.get_identity_prompt(x_person=2, level=2)
relation_prompt = ""
for person in person_list:
relation_prompt += await relationship_manager.build_relationship_info(person, is_id=True)
prompt = (await global_prompt_manager.get_prompt_async("indentify_prompt")).format(
bot_name=individuality.name,
name_block=name_block,
prompt_personality=personality_block,
indentify_block=identity_block,
memory_str=memory_str,
relation_prompt=relation_prompt,
time_now=time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),

View File

@@ -3,7 +3,7 @@ from src.chat.models.utils_model import LLMRequest
from src.config.config import global_config
import time
from src.common.logger_manager import get_logger
from src.individuality.individuality import Individuality
from src.individuality.individuality import individuality
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.tools.tool_use import ToolUser
from src.chat.utils.json_utils import process_llm_tool_calls
@@ -133,7 +133,7 @@ class ToolProcessor(BaseProcessor):
relation_prompt += await relationship_manager.build_relationship_info(person, is_id=True)
# 获取个性信息
individuality = Individuality.get_instance()
# prompt_personality = individuality.get_prompt(x_person=2, level=2)
# 获取时间信息

View File

@@ -1,9 +1,8 @@
from typing import Dict, List, Optional, Callable, Coroutine, Type, Any
from typing import Dict, List, Optional, Type, Any
from src.chat.focus_chat.planners.actions.base_action import BaseAction, _ACTION_REGISTRY
from src.chat.heart_flow.observation.observation import Observation
from src.chat.focus_chat.expressors.default_expressor import DefaultExpressor
from src.chat.message_receive.chat_stream import ChatStream
from src.chat.focus_chat.heartFC_Cycleinfo import CycleDetail
from src.common.logger_manager import get_logger
import importlib
import pkgutil

View File

@@ -1,12 +1,9 @@
import asyncio
import traceback
from src.common.logger_manager import get_logger
from src.chat.utils.timer_calculator import Timer
from src.chat.focus_chat.planners.actions.base_action import BaseAction, register_action
from typing import Tuple, List, Callable, Coroutine
from typing import Tuple, List
from src.chat.heart_flow.observation.observation import Observation
from src.chat.heart_flow.observation.chatting_observation import ChattingObservation
from src.chat.heart_flow.sub_heartflow import SubHeartFlow
from src.chat.message_receive.chat_stream import ChatStream
from src.chat.heart_flow.heartflow import heartflow
from src.chat.heart_flow.sub_heartflow import ChatState
@@ -61,8 +58,6 @@ class ExitFocusChatAction(BaseAction):
self._shutting_down = shutting_down
self.chat_id = chat_stream.stream_id
async def handle_action(self) -> Tuple[bool, str]:
"""
处理退出专注聊天的情况
@@ -83,7 +78,7 @@ class ExitFocusChatAction(BaseAction):
if self.sub_heartflow:
try:
# 转换为normal_chat状态
await self.sub_heartflow.change_chat_state(ChatState.NORMAL_CHAT)
await self.sub_heartflow.change_chat_state(ChatState.CHAT)
status_message = "已成功切换到普通聊天模式"
logger.info(f"{self.log_prefix} {status_message}")
except Exception as e:
@@ -95,7 +90,6 @@ class ExitFocusChatAction(BaseAction):
logger.warning(f"{self.log_prefix} {warning_msg}")
return False, warning_msg
return True, status_message
except asyncio.CancelledError:
@@ -105,4 +99,4 @@ class ExitFocusChatAction(BaseAction):
error_msg = f"处理 'exit_focus_chat' 时发生错误: {str(e)}"
logger.error(f"{self.log_prefix} {error_msg}")
logger.error(traceback.format_exc())
return False, error_msg
return False, error_msg

View File

@@ -3,7 +3,7 @@ import traceback
from src.common.logger_manager import get_logger
from src.chat.utils.timer_calculator import Timer
from src.chat.focus_chat.planners.actions.base_action import BaseAction, register_action
from typing import Tuple, List, Callable, Coroutine
from typing import Tuple, List
from src.chat.heart_flow.observation.observation import Observation
from src.chat.heart_flow.observation.chatting_observation import ChattingObservation
from src.chat.focus_chat.hfc_utils import parse_thinking_id_to_timestamp

View File

@@ -41,7 +41,7 @@ class PluginAction(BaseAction):
return platform, user_id
# 提供简化的API方法
async def send_message(self, text: str, target: Optional[str] = None) -> bool:
async def send_message(self, type: str, data: str, target: Optional[str] = "") -> bool:
"""发送消息的简化方法
Args:
@@ -60,7 +60,7 @@ class PluginAction(BaseAction):
return False
# 构造简化的动作数据
reply_data = {"text": text, "target": target or "", "emojis": []}
# reply_data = {"text": text, "target": target or "", "emojis": []}
# 获取锚定消息(如果有)
observations = self._services.get("observations", [])
@@ -68,7 +68,8 @@ class PluginAction(BaseAction):
chatting_observation: ChattingObservation = next(
obs for obs in observations if isinstance(obs, ChattingObservation)
)
anchor_message = chatting_observation.search_message_by_text(reply_data["target"])
anchor_message = chatting_observation.search_message_by_text(target)
# 如果没有找到锚点消息,创建一个占位符
if not anchor_message:
@@ -80,7 +81,7 @@ class PluginAction(BaseAction):
anchor_message.update_chat_stream(chat_stream)
response_set = [
("text", text),
(type, data),
]
# 调用内部方法发送消息

View File

@@ -12,7 +12,7 @@ from src.chat.focus_chat.info.action_info import ActionInfo
from src.chat.focus_chat.info.structured_info import StructuredInfo
from src.common.logger_manager import get_logger
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.individuality.individuality import Individuality
from src.individuality.individuality import individuality
from src.chat.focus_chat.planners.action_manager import ActionManager
logger = get_logger("planner")
@@ -92,37 +92,37 @@ class ActionPlanner:
try:
# 获取观察信息
extra_info: list[str] = []
# 首先处理动作变更
for info in all_plan_info:
if isinstance(info, ActionInfo) and info.has_changes():
add_actions = info.get_add_actions()
remove_actions = info.get_remove_actions()
reason = info.get_reason()
# 处理动作的增加
for action_name in add_actions:
if action_name in self.action_manager.get_registered_actions():
self.action_manager.add_action_to_using(action_name)
logger.debug(f"{self.log_prefix}添加动作: {action_name}, 原因: {reason}")
# 处理动作的移除
for action_name in remove_actions:
self.action_manager.remove_action_from_using(action_name)
logger.debug(f"{self.log_prefix}移除动作: {action_name}, 原因: {reason}")
# 如果当前选择的动作被移除了更新为no_reply
if action in remove_actions:
action = "no_reply"
reasoning = f"之前选择的动作{action}已被移除,原因: {reason}"
# 继续处理其他信息
for info in all_plan_info:
if isinstance(info, ObsInfo):
observed_messages = info.get_talking_message()
observed_messages_str = info.get_talking_message_str_truncate()
chat_type = info.get_chat_type()
is_group_chat = (chat_type == "group")
is_group_chat = chat_type == "group"
elif isinstance(info, MindInfo):
current_mind = info.get_current_mind()
elif isinstance(info, CycleInfo):
@@ -134,20 +134,16 @@ class ActionPlanner:
# 获取当前可用的动作
current_available_actions = self.action_manager.get_using_actions()
# 如果没有可用动作直接返回no_reply
if not current_available_actions:
logger.warning(f"{self.log_prefix}没有可用的动作将使用no_reply")
action = "no_reply"
reasoning = "没有可用的动作"
return {
"action_result": {
"action_type": action,
"action_data": action_data,
"reasoning": reasoning
},
"action_result": {"action_type": action, "action_data": action_data, "reasoning": reasoning},
"current_mind": current_mind,
"observed_messages": observed_messages
"observed_messages": observed_messages,
}
# --- 构建提示词 (调用修改后的 PromptBuilder 方法) ---
@@ -271,7 +267,6 @@ class ActionPlanner:
else:
mind_info_block = "你刚参与聊天"
individuality = Individuality.get_instance()
personality_block = individuality.get_prompt(x_person=2, level=2)
action_options_block = ""

View File

@@ -4,7 +4,7 @@ from typing import Optional, Coroutine, Callable, Any, List
from src.common.logger_manager import get_logger
from src.chat.heart_flow.mai_state_manager import MaiStateManager, MaiStateInfo
from src.chat.heart_flow.subheartflow_manager import SubHeartflowManager
from src.config.config import global_config
logger = get_logger("background_tasks")
@@ -94,13 +94,6 @@ class BackgroundTaskManager:
f"清理任务已启动 间隔:{CLEANUP_INTERVAL_SECONDS}s",
"_cleanup_task",
),
# 新增兴趣评估任务配置
(
self._run_into_focus_cycle,
"debug", # 设为debug避免过多日志
f"专注评估任务已启动 间隔:{INTEREST_EVAL_INTERVAL_SECONDS}s",
"_into_focus_task",
),
# 新增私聊激活任务配置
(
# Use lambda to pass the interval to the runner function
@@ -111,6 +104,19 @@ class BackgroundTaskManager:
),
]
# 根据 chat_mode 条件添加专注评估任务
if not (global_config.chat.chat_mode == "normal"):
task_configs.append(
(
self._run_into_focus_cycle,
"debug", # 设为debug避免过多日志
f"专注评估任务已启动 间隔:{INTEREST_EVAL_INTERVAL_SECONDS}s",
"_into_focus_task",
)
)
else:
logger.info("聊天模式为 normal跳过启动专注评估任务")
# 统一启动所有任务
for task_func, log_level, log_msg, task_attr_name in task_configs:
# 检查任务变量是否存在且未完成
@@ -183,7 +189,6 @@ class BackgroundTaskManager:
logger.info("检测到离线,停用所有子心流")
await self.subheartflow_manager.deactivate_all_subflows()
async def _perform_cleanup_work(self):
"""执行子心流清理任务
1. 获取需要清理的不活跃子心流列表
@@ -209,18 +214,15 @@ class BackgroundTaskManager:
# 记录最终清理结果
logger.info(f"[清理任务] 清理完成, 共停止 {stopped_count}/{len(flows_to_stop)} 个子心流")
# --- 新增兴趣评估工作函数 ---
async def _perform_into_focus_work(self):
"""执行一轮子心流兴趣评估与提升检查。"""
# 直接调用 subheartflow_manager 的方法,并传递当前状态信息
await self.subheartflow_manager.sbhf_absent_into_focus()
await self.subheartflow_manager.sbhf_normal_into_focus()
async def _run_state_update_cycle(self, interval: int):
await _run_periodic_loop(task_name="State Update", interval=interval, task_func=self._perform_state_update_work)
async def _run_cleanup_cycle(self):
await _run_periodic_loop(
task_name="Subflow Cleanup", interval=CLEANUP_INTERVAL_SECONDS, task_func=self._perform_cleanup_work

View File

@@ -4,13 +4,13 @@ import enum
class ChatState(enum.Enum):
ABSENT = "没在看群"
CHAT = "随便水群"
NORMAL = "随便水群"
FOCUSED = "认真水群"
class ChatStateInfo:
def __init__(self):
self.chat_status: ChatState = ChatState.CHAT
self.chat_status: ChatState = ChatState.NORMAL
self.current_state_time = 120
self.mood_manager = mood_manager

View File

@@ -1,9 +1,6 @@
from src.chat.heart_flow.sub_heartflow import SubHeartflow, ChatState
from src.chat.models.utils_model import LLMRequest
from src.config.config import global_config
from src.common.logger_manager import get_logger
from typing import Any, Optional
from src.tools.tool_use import ToolUser
from src.chat.heart_flow.mai_state_manager import MaiStateInfo, MaiStateManager
from src.chat.heart_flow.subheartflow_manager import SubHeartflowManager
from src.chat.heart_flow.background_tasks import BackgroundTaskManager # Import BackgroundTaskManager

View File

@@ -4,21 +4,10 @@ import random
from typing import List, Tuple, Optional
from src.common.logger_manager import get_logger
from src.manager.mood_manager import mood_manager
from src.config.config import global_config
logger = get_logger("mai_state")
# -- 状态相关的可配置参数 (可以从 glocal_config 加载) --
# The line `enable_unlimited_hfc_chat = False` is setting a configuration parameter that controls
# whether a specific debugging feature is enabled or not. When `enable_unlimited_hfc_chat` is set to
# `False`, it means that the debugging feature for unlimited focused chatting is disabled.
# enable_unlimited_hfc_chat = True # 调试用:无限专注聊天
enable_unlimited_hfc_chat = False
prevent_offline_state = True
# 目前默认不启用OFFLINE状
class MaiState(enum.Enum):
"""
聊天状态:
@@ -97,7 +86,6 @@ class MaiStateManager:
current_time = time.time()
current_status = current_state_info.mai_status
time_in_current_status = current_time - current_state_info.last_status_change_time
_time_since_last_min_check = current_time - current_state_info.last_min_check_time
next_state: Optional[MaiState] = None
def _resolve_offline(candidate_state: MaiState) -> MaiState:
@@ -141,10 +129,6 @@ class MaiStateManager:
)
next_state = resolved_candidate
if enable_unlimited_hfc_chat:
logger.debug("调试用:开挂了,强制切换到专注聊天")
next_state = MaiState.FOCUSED_CHAT
if next_state is not None and next_state != current_status:
return next_state
else:

View File

@@ -57,7 +57,7 @@ class ChattingObservation(Observation):
self.talking_message_str_truncate = ""
self.name = global_config.bot.nickname
self.nick_name = global_config.bot.alias_names
self.max_now_obs_len = global_config.chat.observation_context_size
self.max_now_obs_len = global_config.focus_chat.observation_context_size
self.overlap_len = global_config.focus_chat.compressed_length
self.mid_memories = []
self.max_mid_memory_len = global_config.focus_chat.compress_length_limit

View File

@@ -18,7 +18,7 @@ class HFCloopObservation:
self.last_observe_time = datetime.now().timestamp() # 初始化为当前时间
self.history_loop: List[CycleDetail] = []
self.action_manager: ActionManager = None
self.all_actions = {}
def get_observe_info(self):

View File

@@ -2,7 +2,7 @@ from .observation.observation import Observation
from src.chat.heart_flow.observation.chatting_observation import ChattingObservation
import asyncio
import time
from typing import Optional, List, Dict, Tuple, Callable, Coroutine
from typing import Optional, List, Dict, Tuple
import traceback
from src.common.logger_manager import get_logger
from src.chat.message_receive.message import MessageRecv
@@ -13,6 +13,7 @@ from src.chat.heart_flow.mai_state_manager import MaiStateInfo
from src.chat.heart_flow.chat_state_info import ChatState, ChatStateInfo
from .utils_chat import get_chat_type_and_target_info
from .interest_chatting import InterestChatting
from src.config.config import global_config
logger = get_logger("sub_heartflow")
@@ -23,7 +24,6 @@ class SubHeartflow:
self,
subheartflow_id,
mai_states: MaiStateInfo,
hfc_no_reply_callback: Callable[[], Coroutine[None, None, None]],
):
"""子心流初始化函数
@@ -35,7 +35,6 @@ class SubHeartflow:
# 基础属性,两个值是一样的
self.subheartflow_id = subheartflow_id
self.chat_id = subheartflow_id
self.hfc_no_reply_callback = hfc_no_reply_callback
# 麦麦的状态
self.mai_states = mai_states
@@ -89,13 +88,13 @@ class SubHeartflow:
await self.interest_chatting.initialize()
logger.debug(f"{self.log_prefix} InterestChatting 实例已初始化。")
# 创建并初始化 normal_chat_instance
chat_stream = chat_manager.get_stream(self.chat_id)
if chat_stream:
self.normal_chat_instance = NormalChat(chat_stream=chat_stream,interest_dict=self.get_interest_dict())
await self.normal_chat_instance.initialize()
await self.normal_chat_instance.start_chat()
logger.info(f"{self.log_prefix} NormalChat 实例已创建并启动。")
# 根据配置决定初始状态
if global_config.chat.chat_mode == "focus":
logger.info(f"{self.log_prefix} 配置为 focus 模式,将直接尝试进入 FOCUSED 状态。")
await self.change_chat_state(ChatState.FOCUSED)
else: # "auto" 或其他模式保持原有逻辑或默认为 NORMAL
logger.info(f"{self.log_prefix} 配置为 auto 或其他模式,将尝试进入 NORMAL 状态。")
await self.change_chat_state(ChatState.NORMAL)
def update_last_chat_state_time(self):
self.chat_state_last_time = time.time() - self.chat_state_changed_time
@@ -128,10 +127,9 @@ class SubHeartflow:
if not chat_stream:
logger.error(f"{log_prefix} 无法获取 chat_stream无法启动 NormalChat。")
return False
if rewind:
# 在 rewind 为 True 或 NormalChat 实例尚未创建时,创建新实例
if rewind or not self.normal_chat_instance:
self.normal_chat_instance = NormalChat(chat_stream=chat_stream, interest_dict=self.get_interest_dict())
else:
self.normal_chat_instance = NormalChat(chat_stream=chat_stream)
# 进行异步初始化
await self.normal_chat_instance.initialize()
@@ -187,9 +185,10 @@ class SubHeartflow:
logger.info(f"{log_prefix} 麦麦准备开始专注聊天...")
try:
# 创建 HeartFChatting 实例,并传递 从构造函数传入的 回调函数
self.heart_fc_instance = HeartFChatting(
chat_id=self.subheartflow_id,
observations=self.observations,
observations=self.observations,
)
# 初始化并启动 HeartFChatting
@@ -216,7 +215,7 @@ class SubHeartflow:
state_changed = False
log_prefix = f"[{self.log_prefix}]"
if new_state == ChatState.CHAT:
if new_state == ChatState.NORMAL:
logger.debug(f"{log_prefix} 准备进入或保持 普通聊天 状态")
if await self._start_normal_chat():
logger.debug(f"{log_prefix} 成功进入或保持 NormalChat 状态。")
@@ -260,20 +259,6 @@ class SubHeartflow:
f"{log_prefix} 尝试将状态从 {current_state.value} 变为 {new_state.value},但未成功或未执行更改。"
)
async def subheartflow_start_working(self):
"""启动子心流的后台任务
功能说明:
- 负责子心流的主要后台循环
- 每30秒检查一次停止标志
"""
logger.trace(f"{self.log_prefix} 子心流开始工作...")
while not self.should_stop:
await asyncio.sleep(30) # 30秒检查一次停止标志
logger.info(f"{self.log_prefix} 子心流后台任务已停止。")
def add_observation(self, observation: Observation):
for existing_obs in self.observations:
if existing_obs.observe_id == observation.observe_id:

View File

@@ -2,13 +2,11 @@ import asyncio
import time
import random
from typing import Dict, Any, Optional, List
import functools
from src.common.logger_manager import get_logger
from src.chat.message_receive.chat_stream import chat_manager
from src.chat.heart_flow.sub_heartflow import SubHeartflow, ChatState
from src.chat.heart_flow.mai_state_manager import MaiStateInfo
from src.chat.heart_flow.observation.chatting_observation import ChattingObservation
from src.config.config import global_config
# 初始化日志记录器
@@ -62,7 +60,6 @@ class SubHeartflowManager:
self._lock = asyncio.Lock() # 用于保护 self.subheartflows 的访问
self.mai_state_info: MaiStateInfo = mai_state_info # 存储传入的 MaiStateInfo 实例
async def force_change_state(self, subflow_id: Any, target_state: ChatState) -> bool:
"""强制改变指定子心流的状态"""
async with self._lock:
@@ -101,35 +98,25 @@ class SubHeartflowManager:
return subflow
try:
# --- 使用 functools.partial 创建 HFC 回调 --- #
# 将 manager 的 _handle_hfc_no_reply 方法与当前的 subheartflow_id 绑定
hfc_callback = functools.partial(self._handle_hfc_no_reply, subheartflow_id)
# --- 结束创建回调 --- #
# 初始化子心流, 传入 mai_state_info 和 partial 创建的回调
# 初始化子心流, 传入 mai_state_info
new_subflow = SubHeartflow(
subheartflow_id,
self.mai_state_info,
hfc_callback, # <-- 传递 partial 创建的回调
)
# 异步初始化
await new_subflow.initialize()
# 添加聊天观察者
# 首先创建并添加聊天观察者
observation = ChattingObservation(chat_id=subheartflow_id)
await observation.initialize()
new_subflow.add_observation(observation)
# 然后再进行异步初始化,此时 SubHeartflow 内部若需启动 HeartFChatting就能拿到 observation
await new_subflow.initialize()
# 注册子心流
self.subheartflows[subheartflow_id] = new_subflow
heartflow_name = chat_manager.get_stream_name(subheartflow_id) or subheartflow_id
logger.info(f"[{heartflow_name}] 开始接收消息")
# 启动后台任务
asyncio.create_task(new_subflow.subheartflow_start_working())
return new_subflow
except Exception as e:
logger.error(f"创建子心流 {subheartflow_id} 失败: {e}", exc_info=True)
@@ -199,22 +186,14 @@ class SubHeartflowManager:
f"{log_prefix} 完成,共处理 {processed_count} 个子心流,成功将 {changed_count} 个非 ABSENT 子心流的状态更改为 ABSENT。"
)
async def sbhf_absent_into_focus(self):
async def sbhf_normal_into_focus(self):
"""评估子心流兴趣度满足条件则提升到FOCUSED状态基于start_hfc_probability"""
try:
current_state = self.mai_state_info.get_current_state()
# 检查是否允许进入 FOCUS 模式
if not global_config.chat.allow_focus_mode:
if int(time.time()) % 60 == 0: # 每60秒输出一次日志避免刷屏
logger.trace("未开启 FOCUSED 状态 (allow_focus_mode=False)")
return
for sub_hf in list(self.subheartflows.values()):
flow_id = sub_hf.subheartflow_id
stream_name = chat_manager.get_stream_name(flow_id) or flow_id
# 跳过非CHAT状态或已经是FOCUSED状态的子心流
# 跳过已经是FOCUSED状态的子心流
if sub_hf.chat_state.chat_status == ChatState.FOCUSED:
continue
@@ -225,13 +204,6 @@ class SubHeartflowManager:
f"{stream_name},现在状态: {sub_hf.chat_state.chat_status.value},进入专注概率: {sub_hf.interest_chatting.start_hfc_probability}"
)
# 调试用
from .mai_state_manager import enable_unlimited_hfc_chat
if not enable_unlimited_hfc_chat:
if sub_hf.chat_state.chat_status != ChatState.CHAT:
continue
if random.random() >= sub_hf.interest_chatting.start_hfc_probability:
continue
@@ -250,12 +222,11 @@ class SubHeartflowManager:
except Exception as e:
logger.error(f"启动HFC 兴趣评估失败: {e}", exc_info=True)
async def sbhf_focus_into_absent_or_chat(self, subflow_id: Any):
async def sbhf_focus_into_normal(self, subflow_id: Any):
"""
接收来自 HeartFChatting 的请求,将特定子心流的状态转换为 CHAT
接收来自 HeartFChatting 的请求,将特定子心流的状态转换为 NORMAL
通常在连续多次 "no_reply" 后被调用。
对于私聊和群聊,都转换为 CHAT
对于私聊和群聊,都转换为 NORMAL
Args:
subflow_id: 需要转换状态的子心流 ID。
@@ -263,15 +234,15 @@ class SubHeartflowManager:
async with self._lock:
subflow = self.subheartflows.get(subflow_id)
if not subflow:
logger.warning(f"[状态转换请求] 尝试转换不存在的子心流 {subflow_id}CHAT")
logger.warning(f"[状态转换请求] 尝试转换不存在的子心流 {subflow_id}NORMAL")
return
stream_name = chat_manager.get_stream_name(subflow_id) or subflow_id
current_state = subflow.chat_state.chat_status
if current_state == ChatState.FOCUSED:
target_state = ChatState.CHAT
log_reason = "转为CHAT"
target_state = ChatState.NORMAL
log_reason = "转为NORMAL"
logger.info(
f"[状态转换请求] 接收到请求,将 {stream_name} (当前: {current_state.value}) 尝试转换为 {target_state.value} ({log_reason})"
@@ -292,34 +263,10 @@ class SubHeartflowManager:
f"[状态转换请求] 转换 {stream_name}{target_state.value} 时出错: {e}", exc_info=True
)
elif current_state == ChatState.ABSENT:
logger.debug(f"[状态转换请求] {stream_name} 处于 ABSENT 状态,尝试转为 CHAT")
await subflow.change_chat_state(ChatState.CHAT)
logger.debug(f"[状态转换请求] {stream_name} 处于 ABSENT 状态,尝试转为 NORMAL")
await subflow.change_chat_state(ChatState.NORMAL)
else:
logger.debug(
f"[状态转换请求] {stream_name} 当前状态为 {current_state.value},无需转换"
)
def count_subflows_by_state(self, state: ChatState) -> int:
"""统计指定状态的子心流数量"""
count = 0
# 遍历所有子心流实例
for subheartflow in self.subheartflows.values():
# 检查子心流状态是否匹配
if subheartflow.chat_state.chat_status == state:
count += 1
return count
def count_subflows_by_state_nolock(self, state: ChatState) -> int:
"""
统计指定状态的子心流数量 (不上锁版本)。
警告:仅应在已持有 self._lock 的上下文中使用此方法。
"""
count = 0
for subheartflow in self.subheartflows.values():
if subheartflow.chat_state.chat_status == state:
count += 1
return count
logger.debug(f"[状态转换请求] {stream_name} 当前状态为 {current_state.value},无需转换")
async def delete_subflow(self, subheartflow_id: Any):
"""删除指定的子心流。"""
@@ -336,28 +283,14 @@ class SubHeartflowManager:
else:
logger.warning(f"尝试删除不存在的 SubHeartflow: {subheartflow_id}")
async def _handle_hfc_no_reply(self, subheartflow_id: Any):
"""处理来自 HeartFChatting 的连续无回复信号 (通过 partial 绑定 ID)"""
# 注意:这里不需要再获取锁,因为 sbhf_focus_into_absent_or_chat 内部会处理锁
logger.debug(f"[管理器 HFC 处理器] 接收到来自 {subheartflow_id} 的 HFC 无回复信号")
await self.sbhf_focus_into_absent_or_chat(subheartflow_id)
# --- 新增:处理私聊从 ABSENT 直接到 FOCUSED 的逻辑 --- #
async def sbhf_absent_private_into_focus(self):
"""检查 ABSENT 状态的私聊子心流是否有新活动,若有且未达 FOCUSED 上限,则直接转换为 FOCUSED。"""
"""检查 ABSENT 状态的私聊子心流是否有新活动,若有则直接转换为 FOCUSED。"""
log_prefix_task = "[私聊激活检查]"
transitioned_count = 0
checked_count = 0
# --- 检查是否允许 FOCUS 模式 --- #
if not global_config.chat.allow_focus_mode:
return
async with self._lock:
# --- 获取当前 FOCUSED 计数 (不上锁版本) --- #
current_focused_count = self.count_subflows_by_state_nolock(ChatState.FOCUSED)
# --- 筛选出所有 ABSENT 状态的私聊子心流 --- #
eligible_subflows = [
hf
@@ -372,7 +305,6 @@ class SubHeartflowManager:
# --- 遍历评估每个符合条件的私聊 --- #
for sub_hf in eligible_subflows:
flow_id = sub_hf.subheartflow_id
stream_name = chat_manager.get_stream_name(flow_id) or flow_id
log_prefix = f"[{stream_name}]({log_prefix_task})"
@@ -393,13 +325,12 @@ class SubHeartflowManager:
else:
logger.warning(f"{log_prefix} 无法获取主要观察者来检查活动状态。")
# --- 如果活跃且未达上限,则尝试转换 --- #
# --- 如果活跃,则尝试转换 --- #
if is_active:
await sub_hf.change_chat_state(ChatState.FOCUSED)
# 确认转换成功
if sub_hf.chat_state.chat_status == ChatState.FOCUSED:
transitioned_count += 1
current_focused_count += 1 # 更新计数器以供本轮后续检查
logger.info(f"{log_prefix} 成功进入 FOCUSED 状态。")
else:
logger.warning(

View File

@@ -6,7 +6,7 @@ from .global_logger import logger
from . import prompt_template
from .lpmmconfig import global_config, INVALID_ENTITY
from .llm_client import LLMClient
from .utils.json_fix import fix_broken_generated_json
from .utils.json_fix import new_fix_broken_generated_json
def _entity_extract(llm_client: LLMClient, paragraph: str) -> List[str]:
@@ -24,7 +24,7 @@ def _entity_extract(llm_client: LLMClient, paragraph: str) -> List[str]:
if "]" in request_result:
request_result = request_result[: request_result.rindex("]") + 1]
entity_extract_result = json.loads(fix_broken_generated_json(request_result))
entity_extract_result = json.loads(new_fix_broken_generated_json(request_result))
entity_extract_result = [
entity
@@ -53,7 +53,7 @@ def _rdf_triple_extract(llm_client: LLMClient, paragraph: str, entities: list) -
if "]" in request_result:
request_result = request_result[: request_result.rindex("]") + 1]
entity_extract_result = json.loads(fix_broken_generated_json(request_result))
entity_extract_result = json.loads(new_fix_broken_generated_json(request_result))
for triple in entity_extract_result:
if len(triple) != 3 or (triple[0] is None or triple[1] is None or triple[2] is None) or "" in triple:

View File

@@ -121,5 +121,5 @@ class QAManager:
found_knowledge = found_knowledge[:MAX_KNOWLEDGE_LENGTH] + "\n"
return found_knowledge
else:
logger.info("LPMM知识库并未初始化使用旧版数据库进行检索")
logger.info("LPMM知识库并未初始化可能是从未导入过知识...")
return None

View File

@@ -1,4 +1,5 @@
import json
from json_repair import repair_json
def _find_unclosed(json_str):
@@ -74,3 +75,24 @@ def fix_broken_generated_json(json_str: str) -> str:
json_str += closing_map[open_char]
return json_str
def new_fix_broken_generated_json(json_str: str) -> str:
"""
使用 json-repair 库修复格式错误的 JSON 字符串。
如果原始 json_str 字符串可以被 json.loads() 成功加载,则直接返回而不进行任何修改。
参数:
json_str (str): 需要修复的格式错误的 JSON 字符串。
返回:
str: 修复后的 JSON 字符串。
"""
try:
# 尝试加载 JSON 以查看其是否有效
json.loads(json_str)
return json_str # 如果有效则按原样返回
except json.JSONDecodeError:
# 如果无效,则尝试修复它
return repair_json(json_str)

View File

@@ -11,7 +11,6 @@ import jieba
import networkx as nx
import numpy as np
from collections import Counter
from ...common.database.database import memory_db as db
from ...chat.models.utils_model import LLMRequest
from src.common.logger_manager import get_logger
from src.chat.memory_system.sample_distribution import MemoryBuildScheduler # 分布生成器

View File

@@ -7,7 +7,7 @@ from src.chat.message_receive.chat_stream import chat_manager
from src.chat.message_receive.message import MessageRecv
from src.experimental.only_message_process import MessageProcessor
from src.experimental.PFC.pfc_manager import PFCManager
from src.chat.focus_chat.heartflow_processor import HeartFCProcessor
from src.chat.focus_chat.heartflow_message_revceiver import HeartFCMessageReceiver
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.config.config import global_config
@@ -23,7 +23,7 @@ class ChatBot:
self.bot = None # bot 实例引用
self._started = False
self.mood_manager = mood_manager # 获取情绪管理器单例
self.heartflow_processor = HeartFCProcessor() # 新增
self.heartflow_message_receiver = HeartFCMessageReceiver() # 新增
# 创建初始化PFC管理器的任务会在_ensure_started时执行
self.only_process_chat = MessageProcessor()
@@ -111,11 +111,11 @@ class ChatBot:
# 禁止PFC进入普通的心流消息处理逻辑
else:
logger.trace("进入普通心流私聊处理")
await self.heartflow_processor.process_message(message_data)
await self.heartflow_message_receiver.process_message(message_data)
# 群聊默认进入心流消息处理逻辑
else:
logger.trace(f"检测到群聊消息群ID: {group_info.group_id}")
await self.heartflow_processor.process_message(message_data)
await self.heartflow_message_receiver.process_message(message_data)
if template_group_name:
async with global_prompt_manager.async_message_scope(template_group_name):

View File

@@ -27,7 +27,7 @@ logger = get_logger("normal_chat")
class NormalChat:
def __init__(self, chat_stream: ChatStream, interest_dict: dict = {}):
def __init__(self, chat_stream: ChatStream, interest_dict: dict = None):
"""初始化 NormalChat 实例。只进行同步操作。"""
# Basic info from chat_stream (sync)
@@ -39,10 +39,8 @@ class NormalChat:
# Interest dict
self.interest_dict = interest_dict
# --- Initialize attributes (defaults) ---
self.is_group_chat: bool = False
self.chat_target_info: Optional[dict] = None
# --- End Initialization ---
# Other sync initializations
self.gpt = NormalChatGenerator()
@@ -52,9 +50,6 @@ class NormalChat:
self._chat_task: Optional[asyncio.Task] = None
self._initialized = False # Track initialization status
# logger.info(f"[{self.stream_name}] NormalChat 实例 __init__ 完成 (同步部分)。")
# Avoid logging here as stream_name might not be final
async def initialize(self):
"""异步初始化,获取聊天类型和目标信息。"""
if self._initialized:
@@ -464,10 +459,11 @@ class NormalChat:
await self.initialize() # Ensure initialized before starting tasks
if self._chat_task is None or self._chat_task.done():
logger.info(f"[{self.stream_name}] 开始后台处理初始兴趣消息和轮询任务...")
logger.info(f"[{self.stream_name}] 开始回顾消息...")
# Process initial messages first
await self._process_initial_interest_messages()
# Then start polling task
logger.info(f"[{self.stream_name}] 开始处理兴趣消息...")
polling_task = asyncio.create_task(self._reply_interested_message())
polling_task.add_done_callback(lambda t: self._handle_task_completion(t))
self._chat_task = polling_task

View File

@@ -49,7 +49,7 @@ class ClassicalWillingManager(BaseWillingManager):
# 检查群组权限(如果是群聊)
if (
willing_info.group_info
and willing_info.group_info.group_id in global_config.chat_target.talk_frequency_down_groups
and willing_info.group_info.group_id in global_config.normal_chat.talk_frequency_down_groups
):
reply_probability = reply_probability / global_config.normal_chat.down_frequency_rate

View File

@@ -180,7 +180,7 @@ class MxpWillingManager(BaseWillingManager):
if w_info.is_emoji:
probability *= global_config.normal_chat.emoji_response_penalty
if w_info.group_info and w_info.group_info.group_id in global_config.chat_target.talk_frequency_down_groups:
if w_info.group_info and w_info.group_info.group_id in global_config.normal_chat.talk_frequency_down_groups:
probability /= global_config.normal_chat.down_frequency_rate
self.temporary_willing = current_willing

View File

@@ -9,7 +9,7 @@ import asyncio
import numpy as np
from src.chat.models.utils_model import LLMRequest
from src.config.config import global_config
from src.individuality.individuality import Individuality
from src.individuality.individuality import individuality
import matplotlib
@@ -257,7 +257,6 @@ class PersonInfoManager:
current_name_set = set(self.person_name_list.values())
while current_try < max_retries:
individuality = Individuality.get_instance()
prompt_personality = individuality.get_prompt(x_person=2, level=1)
bot_name = individuality.personality.bot_nickname

View File

@@ -127,7 +127,7 @@ class InfoCatcher:
Messages.select()
.where((Messages.chat_id == chat_id_val) & (Messages.message_id < message_id_val))
.order_by(Messages.time.desc())
.limit(global_config.chat.observation_context_size * 3)
.limit(global_config.focus_chat.observation_context_size * 3)
)
return list(messages_before_query)

View File

@@ -6,16 +6,13 @@ from collections import Counter
import jieba
import numpy as np
from maim_message import UserInfo
from pymongo.errors import PyMongoError
from src.common.logger import get_module_logger
from src.manager.mood_manager import mood_manager
from ..message_receive.message import MessageRecv
from ..models.utils_model import LLMRequest
from .typo_generator import ChineseTypoGenerator
from ...common.database.database import db
from ...config.config import global_config
from ...common.database.database_model import Messages
from ...common.message_repository import find_messages, count_messages
logger = get_module_logger("chat_utils")
@@ -112,11 +109,7 @@ async def get_embedding(text, request_type="embedding"):
def get_recent_group_detailed_plain_text(chat_stream_id: str, limit: int = 12, combine=False):
filter_query = {"chat_id": chat_stream_id}
sort_order = [("time", -1)]
recent_messages = find_messages(
message_filter=filter_query,
sort=sort_order,
limit=limit
)
recent_messages = find_messages(message_filter=filter_query, sort=sort_order, limit=limit)
if not recent_messages:
return []
@@ -141,23 +134,21 @@ def get_recent_group_speaker(chat_stream_id: str, sender, limit: int = 12) -> li
# 获取当前群聊记录内发言的人
filter_query = {"chat_id": chat_stream_id}
sort_order = [("time", -1)]
recent_messages = find_messages(
message_filter=filter_query,
sort=sort_order,
limit=limit
)
recent_messages = find_messages(message_filter=filter_query, sort=sort_order, limit=limit)
if not recent_messages:
return []
who_chat_in_group = []
for msg_db_data in recent_messages:
user_info = UserInfo.from_dict({
"platform": msg_db_data["user_platform"],
"user_id": msg_db_data["user_id"],
"user_nickname": msg_db_data["user_nickname"],
"user_cardname": msg_db_data.get("user_cardname", "")
})
user_info = UserInfo.from_dict(
{
"platform": msg_db_data["user_platform"],
"user_id": msg_db_data["user_id"],
"user_nickname": msg_db_data["user_nickname"],
"user_cardname": msg_db_data.get("user_cardname", ""),
}
)
if (
(user_info.platform, user_info.user_id) != sender
and user_info.user_id != global_config.bot.qq_account
@@ -324,7 +315,7 @@ def process_llm_response(text: str) -> list[str]:
else:
protected_text = text
kaomoji_mapping = {}
# 提取被 () 或 [] 包裹且包含中文的内容
# 提取被 () 或 [] 包裹且包含中文的内容
pattern = re.compile(r"[(\[](?=.*[一-鿿]).*?[)\]]")
# _extracted_contents = pattern.findall(text)
_extracted_contents = pattern.findall(protected_text) # 在保护后的文本上查找
@@ -579,14 +570,13 @@ def count_messages_between(start_time: float, end_time: float, stream_id: str) -
# 使用message_repository中的count_messages和find_messages函数
# 构建查询条件
filter_query = {"chat_id": stream_id, "time": {"$gt": start_time, "$lte": end_time}}
try:
# 先获取消息数量
count = count_messages(filter_query)
# 获取消息内容计算总长度
messages = find_messages(message_filter=filter_query)
total_length = sum(len(msg.get("processed_plain_text", "")) for msg in messages)

View File

@@ -1,312 +0,0 @@
import os
import sys
import requests
from dotenv import load_dotenv
import hashlib
from datetime import datetime
from tqdm import tqdm
from rich.console import Console
from rich.table import Table
from rich.traceback import install
install(extra_lines=3)
# 添加项目根目录到 Python 路径
root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
sys.path.append(root_path)
# 现在可以导入src模块
from common.database.database import db # noqa E402
# 加载根目录下的env.edv文件
env_path = os.path.join(root_path, ".env")
if not os.path.exists(env_path):
raise FileNotFoundError(f"配置文件不存在: {env_path}")
load_dotenv(env_path)
class KnowledgeLibrary:
def __init__(self):
self.raw_info_dir = "data/raw_info"
self._ensure_dirs()
self.api_key = os.getenv("SILICONFLOW_KEY")
if not self.api_key:
raise ValueError("SILICONFLOW_API_KEY 环境变量未设置")
self.console = Console()
def _ensure_dirs(self):
"""确保必要的目录存在"""
os.makedirs(self.raw_info_dir, exist_ok=True)
@staticmethod
def read_file(file_path: str) -> str:
"""读取文件内容"""
with open(file_path, "r", encoding="utf-8") as f:
return f.read()
@staticmethod
def split_content(content: str, max_length: int = 512) -> list:
"""将内容分割成适当大小的块,按空行分割
Args:
content: 要分割的文本内容
max_length: 每个块的最大长度
Returns:
list: 分割后的文本块列表
"""
# 按空行分割内容
paragraphs = [p.strip() for p in content.split("\n\n") if p.strip()]
chunks = []
for para in paragraphs:
para_length = len(para)
# 如果段落长度小于等于最大长度,直接添加
if para_length <= max_length:
chunks.append(para)
else:
# 如果段落超过最大长度,则按最大长度切分
for i in range(0, para_length, max_length):
chunks.append(para[i : i + max_length])
return chunks
def get_embedding(self, text: str) -> list:
"""获取文本的embedding向量"""
url = "https://api.siliconflow.cn/v1/embeddings"
payload = {"model": "BAAI/bge-m3", "input": text, "encoding_format": "float"}
headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
response = requests.post(url, json=payload, headers=headers)
if response.status_code != 200:
print(f"获取embedding失败: {response.text}")
return None
return response.json()["data"][0]["embedding"]
def process_files(self, knowledge_length: int = 512):
"""处理raw_info目录下的所有txt文件"""
txt_files = [f for f in os.listdir(self.raw_info_dir) if f.endswith(".txt")]
if not txt_files:
self.console.print("[red]警告:在 {} 目录下没有找到任何txt文件[/red]".format(self.raw_info_dir))
self.console.print("[yellow]请将需要处理的文本文件放入该目录后再运行程序[/yellow]")
return
total_stats = {"processed_files": 0, "total_chunks": 0, "failed_files": [], "skipped_files": []}
self.console.print(f"\n[bold blue]开始处理知识库文件 - 共{len(txt_files)}个文件[/bold blue]")
for filename in tqdm(txt_files, desc="处理文件进度"):
file_path = os.path.join(self.raw_info_dir, filename)
result = self.process_single_file(file_path, knowledge_length)
self._update_stats(total_stats, result, filename)
self._display_processing_results(total_stats)
def process_single_file(self, file_path: str, knowledge_length: int = 512):
"""处理单个文件"""
result = {"status": "success", "chunks_processed": 0, "error": None}
try:
current_hash = self.calculate_file_hash(file_path)
processed_record = db.processed_files.find_one({"file_path": file_path})
if processed_record:
if processed_record.get("hash") == current_hash:
if knowledge_length in processed_record.get("split_by", []):
result["status"] = "skipped"
return result
content = self.read_file(file_path)
chunks = self.split_content(content, knowledge_length)
for chunk in tqdm(chunks, desc=f"处理 {os.path.basename(file_path)} 的文本块", leave=False):
embedding = self.get_embedding(chunk)
if embedding:
knowledge = {
"content": chunk,
"embedding": embedding,
"source_file": file_path,
"split_length": knowledge_length,
"created_at": datetime.now(),
}
db.knowledges.insert_one(knowledge)
result["chunks_processed"] += 1
split_by = processed_record.get("split_by", []) if processed_record else []
if knowledge_length not in split_by:
split_by.append(knowledge_length)
db.knowledges.processed_files.update_one(
{"file_path": file_path},
{"$set": {"hash": current_hash, "last_processed": datetime.now(), "split_by": split_by}},
upsert=True,
)
except Exception as e:
result["status"] = "failed"
result["error"] = str(e)
return result
@staticmethod
def _update_stats(total_stats, result, filename):
"""更新总体统计信息"""
if result["status"] == "success":
total_stats["processed_files"] += 1
total_stats["total_chunks"] += result["chunks_processed"]
elif result["status"] == "failed":
total_stats["failed_files"].append((filename, result["error"]))
elif result["status"] == "skipped":
total_stats["skipped_files"].append(filename)
def _display_processing_results(self, stats):
"""显示处理结果统计"""
self.console.print("\n[bold green]处理完成!统计信息如下:[/bold green]")
table = Table(show_header=True, header_style="bold magenta")
table.add_column("统计项", style="dim")
table.add_column("数值")
table.add_row("成功处理文件数", str(stats["processed_files"]))
table.add_row("处理的知识块总数", str(stats["total_chunks"]))
table.add_row("跳过的文件数", str(len(stats["skipped_files"])))
table.add_row("失败的文件数", str(len(stats["failed_files"])))
self.console.print(table)
if stats["failed_files"]:
self.console.print("\n[bold red]处理失败的文件:[/bold red]")
for filename, error in stats["failed_files"]:
self.console.print(f"[red]- {filename}: {error}[/red]")
if stats["skipped_files"]:
self.console.print("\n[bold yellow]跳过的文件(已处理):[/bold yellow]")
for filename in stats["skipped_files"]:
self.console.print(f"[yellow]- {filename}[/yellow]")
@staticmethod
def calculate_file_hash(file_path):
"""计算文件的MD5哈希值"""
hash_md5 = hashlib.md5()
with open(file_path, "rb") as f:
for chunk in iter(lambda: f.read(4096), b""):
hash_md5.update(chunk)
return hash_md5.hexdigest()
def search_similar_segments(self, query: str, limit: int = 5) -> list:
"""搜索与查询文本相似的片段"""
query_embedding = self.get_embedding(query)
if not query_embedding:
return []
# 使用余弦相似度计算
pipeline = [
{
"$addFields": {
"dotProduct": {
"$reduce": {
"input": {"$range": [0, {"$size": "$embedding"}]},
"initialValue": 0,
"in": {
"$add": [
"$$value",
{
"$multiply": [
{"$arrayElemAt": ["$embedding", "$$this"]},
{"$arrayElemAt": [query_embedding, "$$this"]},
]
},
]
},
}
},
"magnitude1": {
"$sqrt": {
"$reduce": {
"input": "$embedding",
"initialValue": 0,
"in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]},
}
}
},
"magnitude2": {
"$sqrt": {
"$reduce": {
"input": query_embedding,
"initialValue": 0,
"in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]},
}
}
},
}
},
{"$addFields": {"similarity": {"$divide": ["$dotProduct", {"$multiply": ["$magnitude1", "$magnitude2"]}]}}},
{"$sort": {"similarity": -1}},
{"$limit": limit},
{"$project": {"content": 1, "similarity": 1, "file_path": 1}},
]
results = list(db.knowledges.aggregate(pipeline))
return results
# 创建单例实例
knowledge_library = KnowledgeLibrary()
if __name__ == "__main__":
console = Console()
console.print("[bold green]知识库处理工具[/bold green]")
while True:
console.print("\n请选择要执行的操作:")
console.print("[1] 麦麦开始学习")
console.print("[2] 麦麦全部忘光光(仅知识)")
console.print("[q] 退出程序")
choice = input("\n请输入选项: ").strip()
if choice.lower() == "q":
console.print("[yellow]程序退出[/yellow]")
sys.exit(0)
elif choice == "2":
confirm = input("确定要删除所有知识吗?这个操作不可撤销!(y/n): ").strip().lower()
if confirm == "y":
db.knowledges.delete_many({})
console.print("[green]已清空所有知识![/green]")
continue
elif choice == "1":
if not os.path.exists(knowledge_library.raw_info_dir):
console.print(f"[yellow]创建目录:{knowledge_library.raw_info_dir}[/yellow]")
os.makedirs(knowledge_library.raw_info_dir, exist_ok=True)
# 询问分割长度
while True:
try:
length_input = input("请输入知识分割长度默认512输入q退出回车使用默认值: ").strip()
if length_input.lower() == "q":
break
if not length_input: # 如果直接回车,使用默认值
knowledge_length = 512
break
knowledge_length = int(length_input)
if knowledge_length <= 0:
print("分割长度必须大于0请重新输入")
continue
break
except ValueError:
print("请输入有效的数字")
continue
if length_input.lower() == "q":
continue
# 测试知识库功能
print(f"开始处理知识库文件,使用分割长度: {knowledge_length}...")
knowledge_library.process_files(knowledge_length=knowledge_length)
else:
console.print("[red]无效的选项,请重新选择[/red]")
continue