style: 格式化代码,修复不一致的空格和注释,更新ruff action
This commit is contained in:
5
.github/workflows/ruff.yml
vendored
5
.github/workflows/ruff.yml
vendored
@@ -12,7 +12,10 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
ref: ${{ github.head_ref || github.ref_name }}
|
ref: ${{ github.head_ref || github.ref_name }}
|
||||||
- uses: astral-sh/ruff-action@v3
|
- name: Install the latest version of ruff
|
||||||
|
uses: astral-sh/ruff-action@v3
|
||||||
|
with:
|
||||||
|
version: "latest"
|
||||||
- run: ruff check --fix
|
- run: ruff check --fix
|
||||||
- run: ruff format
|
- run: ruff format
|
||||||
- name: Commit changes
|
- name: Commit changes
|
||||||
|
|||||||
@@ -57,7 +57,7 @@ class ChattingObservation(Observation):
|
|||||||
msg_str = ""
|
msg_str = ""
|
||||||
for msg in mid_memory_by_id["messages"]:
|
for msg in mid_memory_by_id["messages"]:
|
||||||
msg_str += f"{msg['detailed_plain_text']}"
|
msg_str += f"{msg['detailed_plain_text']}"
|
||||||
time_diff = int((datetime.now().timestamp() - mid_memory_by_id["created_at"]) / 60)
|
# time_diff = int((datetime.now().timestamp() - mid_memory_by_id["created_at"]) / 60)
|
||||||
# mid_memory_str += f"距离现在{time_diff}分钟前:\n{msg_str}\n"
|
# mid_memory_str += f"距离现在{time_diff}分钟前:\n{msg_str}\n"
|
||||||
mid_memory_str += f"{msg_str}\n"
|
mid_memory_str += f"{msg_str}\n"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -210,8 +210,10 @@ class SubHeartflow:
|
|||||||
relation_prompt_all = (await global_prompt_manager.get_prompt_async("relationship_prompt")).format(
|
relation_prompt_all = (await global_prompt_manager.get_prompt_async("relationship_prompt")).format(
|
||||||
relation_prompt, sender_info.user_nickname
|
relation_prompt, sender_info.user_nickname
|
||||||
)
|
)
|
||||||
|
|
||||||
sender_name_sign = f"<{chat_stream.platform}:{sender_info.user_id}:{sender_info.user_nickname}:{sender_info.user_cardname}>"
|
sender_name_sign = (
|
||||||
|
f"<{chat_stream.platform}:{sender_info.user_id}:{sender_info.user_nickname}:{sender_info.user_cardname}>"
|
||||||
|
)
|
||||||
|
|
||||||
# prompt = ""
|
# prompt = ""
|
||||||
# # prompt += f"麦麦的总体想法是:{self.main_heartflow_info}\n\n"
|
# # prompt += f"麦麦的总体想法是:{self.main_heartflow_info}\n\n"
|
||||||
@@ -230,7 +232,7 @@ class SubHeartflow:
|
|||||||
# prompt += f"记得结合上述的消息,生成内心想法,文字不要浮夸,注意你就是{self.bot_name},{self.bot_name}指的就是你。"
|
# prompt += f"记得结合上述的消息,生成内心想法,文字不要浮夸,注意你就是{self.bot_name},{self.bot_name}指的就是你。"
|
||||||
|
|
||||||
time_now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
time_now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
||||||
|
|
||||||
prompt = (await global_prompt_manager.get_prompt_async("sub_heartflow_prompt_before")).format(
|
prompt = (await global_prompt_manager.get_prompt_async("sub_heartflow_prompt_before")).format(
|
||||||
extra_info_prompt,
|
extra_info_prompt,
|
||||||
# prompt_schedule,
|
# prompt_schedule,
|
||||||
@@ -244,7 +246,7 @@ class SubHeartflow:
|
|||||||
message_txt,
|
message_txt,
|
||||||
self.bot_name,
|
self.bot_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt = await relationship_manager.convert_all_person_sign_to_person_name(prompt)
|
prompt = await relationship_manager.convert_all_person_sign_to_person_name(prompt)
|
||||||
prompt = parse_text_timestamps(prompt, mode="lite")
|
prompt = parse_text_timestamps(prompt, mode="lite")
|
||||||
|
|
||||||
@@ -294,7 +296,7 @@ class SubHeartflow:
|
|||||||
|
|
||||||
message_new_info = chat_talking_prompt
|
message_new_info = chat_talking_prompt
|
||||||
reply_info = reply_content
|
reply_info = reply_content
|
||||||
|
|
||||||
time_now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
time_now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
||||||
|
|
||||||
prompt = (await global_prompt_manager.get_prompt_async("sub_heartflow_prompt_after")).format(
|
prompt = (await global_prompt_manager.get_prompt_async("sub_heartflow_prompt_after")).format(
|
||||||
@@ -307,7 +309,7 @@ class SubHeartflow:
|
|||||||
reply_info,
|
reply_info,
|
||||||
mood_info,
|
mood_info,
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt = await relationship_manager.convert_all_person_sign_to_person_name(prompt)
|
prompt = await relationship_manager.convert_all_person_sign_to_person_name(prompt)
|
||||||
prompt = parse_text_timestamps(prompt, mode="lite")
|
prompt = parse_text_timestamps(prompt, mode="lite")
|
||||||
|
|
||||||
|
|||||||
@@ -150,9 +150,7 @@ class MessageRecv(Message):
|
|||||||
# if user_info.user_cardname != None
|
# if user_info.user_cardname != None
|
||||||
# else f"{user_info.user_nickname}(ta的id:{user_info.user_id})"
|
# else f"{user_info.user_nickname}(ta的id:{user_info.user_id})"
|
||||||
# )
|
# )
|
||||||
name = (
|
name = f"<{self.message_info.platform}:{user_info.user_id}:{user_info.user_nickname}:{user_info.user_cardname}>"
|
||||||
f"<{self.message_info.platform}:{user_info.user_id}:{user_info.user_nickname}:{user_info.user_cardname}>"
|
|
||||||
)
|
|
||||||
return f"[{time}] {name}: {self.processed_plain_text}\n"
|
return f"[{time}] {name}: {self.processed_plain_text}\n"
|
||||||
|
|
||||||
|
|
||||||
@@ -251,9 +249,7 @@ class MessageProcessBase(Message):
|
|||||||
# if user_info.user_cardname != None
|
# if user_info.user_cardname != None
|
||||||
# else f"{user_info.user_nickname}(ta的id:{user_info.user_id})"
|
# else f"{user_info.user_nickname}(ta的id:{user_info.user_id})"
|
||||||
# )
|
# )
|
||||||
name = (
|
name = f"<{self.message_info.platform}:{user_info.user_id}:{user_info.user_nickname}:{user_info.user_cardname}>"
|
||||||
f"<{self.message_info.platform}:{user_info.user_id}:{user_info.user_nickname}:{user_info.user_cardname}>"
|
|
||||||
)
|
|
||||||
return f"[{time}] {name}: {self.processed_plain_text}\n"
|
return f"[{time}] {name}: {self.processed_plain_text}\n"
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -643,11 +643,11 @@ def count_messages_between(start_time: float, end_time: float, stream_id: str) -
|
|||||||
|
|
||||||
def translate_timestamp_to_human_readable(timestamp: float, mode: str = "normal") -> str:
|
def translate_timestamp_to_human_readable(timestamp: float, mode: str = "normal") -> str:
|
||||||
"""将时间戳转换为人类可读的时间格式
|
"""将时间戳转换为人类可读的时间格式
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
timestamp: 时间戳
|
timestamp: 时间戳
|
||||||
mode: 转换模式,"normal"为标准格式,"relative"为相对时间格式
|
mode: 转换模式,"normal"为标准格式,"relative"为相对时间格式
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: 格式化后的时间字符串
|
str: 格式化后的时间字符串
|
||||||
"""
|
"""
|
||||||
@@ -656,7 +656,7 @@ def translate_timestamp_to_human_readable(timestamp: float, mode: str = "normal"
|
|||||||
elif mode == "relative":
|
elif mode == "relative":
|
||||||
now = time.time()
|
now = time.time()
|
||||||
diff = now - timestamp
|
diff = now - timestamp
|
||||||
|
|
||||||
if diff < 20:
|
if diff < 20:
|
||||||
return "刚刚:"
|
return "刚刚:"
|
||||||
elif diff < 60:
|
elif diff < 60:
|
||||||
@@ -671,33 +671,34 @@ def translate_timestamp_to_human_readable(timestamp: float, mode: str = "normal"
|
|||||||
return f"{int(diff / 86400)}天前:\n"
|
return f"{int(diff / 86400)}天前:\n"
|
||||||
else:
|
else:
|
||||||
return time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(timestamp)) + ":"
|
return time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(timestamp)) + ":"
|
||||||
|
|
||||||
|
|
||||||
def parse_text_timestamps(text: str, mode: str = "normal") -> str:
|
def parse_text_timestamps(text: str, mode: str = "normal") -> str:
|
||||||
"""解析文本中的时间戳并转换为可读时间格式
|
"""解析文本中的时间戳并转换为可读时间格式
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
text: 包含时间戳的文本,时间戳应以[]包裹
|
text: 包含时间戳的文本,时间戳应以[]包裹
|
||||||
mode: 转换模式,传递给translate_timestamp_to_human_readable,"normal"或"relative"
|
mode: 转换模式,传递给translate_timestamp_to_human_readable,"normal"或"relative"
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: 替换后的文本
|
str: 替换后的文本
|
||||||
|
|
||||||
转换规则:
|
转换规则:
|
||||||
- normal模式: 将文本中所有时间戳转换为可读格式
|
- normal模式: 将文本中所有时间戳转换为可读格式
|
||||||
- lite模式:
|
- lite模式:
|
||||||
- 第一个和最后一个时间戳必须转换
|
- 第一个和最后一个时间戳必须转换
|
||||||
- 以5秒为间隔划分时间段,每段最多转换一个时间戳
|
- 以5秒为间隔划分时间段,每段最多转换一个时间戳
|
||||||
- 不转换的时间戳替换为空字符串
|
- 不转换的时间戳替换为空字符串
|
||||||
"""
|
"""
|
||||||
# 匹配[数字]或[数字.数字]格式的时间戳
|
# 匹配[数字]或[数字.数字]格式的时间戳
|
||||||
pattern = r'\[(\d+(?:\.\d+)?)\]'
|
pattern = r"\[(\d+(?:\.\d+)?)\]"
|
||||||
|
|
||||||
# 找出所有匹配的时间戳
|
# 找出所有匹配的时间戳
|
||||||
matches = list(re.finditer(pattern, text))
|
matches = list(re.finditer(pattern, text))
|
||||||
|
|
||||||
if not matches:
|
if not matches:
|
||||||
return text
|
return text
|
||||||
|
|
||||||
# normal模式: 直接转换所有时间戳
|
# normal模式: 直接转换所有时间戳
|
||||||
if mode == "normal":
|
if mode == "normal":
|
||||||
result_text = text
|
result_text = text
|
||||||
@@ -711,63 +712,63 @@ def parse_text_timestamps(text: str, mode: str = "normal") -> str:
|
|||||||
else:
|
else:
|
||||||
# lite模式: 按5秒间隔划分并选择性转换
|
# lite模式: 按5秒间隔划分并选择性转换
|
||||||
result_text = text
|
result_text = text
|
||||||
|
|
||||||
# 提取所有时间戳及其位置
|
# 提取所有时间戳及其位置
|
||||||
timestamps = [(float(m.group(1)), m) for m in matches]
|
timestamps = [(float(m.group(1)), m) for m in matches]
|
||||||
timestamps.sort(key=lambda x: x[0]) # 按时间戳升序排序
|
timestamps.sort(key=lambda x: x[0]) # 按时间戳升序排序
|
||||||
|
|
||||||
if not timestamps:
|
if not timestamps:
|
||||||
return text
|
return text
|
||||||
|
|
||||||
# 获取第一个和最后一个时间戳
|
# 获取第一个和最后一个时间戳
|
||||||
first_timestamp, first_match = timestamps[0]
|
first_timestamp, first_match = timestamps[0]
|
||||||
last_timestamp, last_match = timestamps[-1]
|
last_timestamp, last_match = timestamps[-1]
|
||||||
|
|
||||||
# 将时间范围划分成5秒间隔的时间段
|
# 将时间范围划分成5秒间隔的时间段
|
||||||
time_segments = {}
|
time_segments = {}
|
||||||
|
|
||||||
# 对所有时间戳按15秒间隔分组
|
# 对所有时间戳按15秒间隔分组
|
||||||
for ts, match in timestamps:
|
for ts, match in timestamps:
|
||||||
segment_key = int(ts // 15) # 将时间戳除以15取整,作为时间段的键
|
segment_key = int(ts // 15) # 将时间戳除以15取整,作为时间段的键
|
||||||
if segment_key not in time_segments:
|
if segment_key not in time_segments:
|
||||||
time_segments[segment_key] = []
|
time_segments[segment_key] = []
|
||||||
time_segments[segment_key].append((ts, match))
|
time_segments[segment_key].append((ts, match))
|
||||||
|
|
||||||
# 记录需要转换的时间戳
|
# 记录需要转换的时间戳
|
||||||
to_convert = []
|
to_convert = []
|
||||||
|
|
||||||
# 从每个时间段中选择一个时间戳进行转换
|
# 从每个时间段中选择一个时间戳进行转换
|
||||||
for segment, segment_timestamps in time_segments.items():
|
for _, segment_timestamps in time_segments.items():
|
||||||
# 选择这个时间段中的第一个时间戳
|
# 选择这个时间段中的第一个时间戳
|
||||||
to_convert.append(segment_timestamps[0])
|
to_convert.append(segment_timestamps[0])
|
||||||
|
|
||||||
# 确保第一个和最后一个时间戳在转换列表中
|
# 确保第一个和最后一个时间戳在转换列表中
|
||||||
first_in_list = False
|
first_in_list = False
|
||||||
last_in_list = False
|
last_in_list = False
|
||||||
|
|
||||||
for ts, match in to_convert:
|
for ts, _ in to_convert:
|
||||||
if ts == first_timestamp:
|
if ts == first_timestamp:
|
||||||
first_in_list = True
|
first_in_list = True
|
||||||
if ts == last_timestamp:
|
if ts == last_timestamp:
|
||||||
last_in_list = True
|
last_in_list = True
|
||||||
|
|
||||||
if not first_in_list:
|
if not first_in_list:
|
||||||
to_convert.append((first_timestamp, first_match))
|
to_convert.append((first_timestamp, first_match))
|
||||||
if not last_in_list:
|
if not last_in_list:
|
||||||
to_convert.append((last_timestamp, last_match))
|
to_convert.append((last_timestamp, last_match))
|
||||||
|
|
||||||
# 创建需要转换的时间戳集合,用于快速查找
|
# 创建需要转换的时间戳集合,用于快速查找
|
||||||
to_convert_set = {match.group(0) for _, match in to_convert}
|
to_convert_set = {match.group(0) for _, match in to_convert}
|
||||||
|
|
||||||
# 首先替换所有不需要转换的时间戳为空字符串
|
# 首先替换所有不需要转换的时间戳为空字符串
|
||||||
for ts, match in timestamps:
|
for _, match in timestamps:
|
||||||
if match.group(0) not in to_convert_set:
|
if match.group(0) not in to_convert_set:
|
||||||
pattern_instance = re.escape(match.group(0))
|
pattern_instance = re.escape(match.group(0))
|
||||||
result_text = re.sub(pattern_instance, "", result_text, count=1)
|
result_text = re.sub(pattern_instance, "", result_text, count=1)
|
||||||
|
|
||||||
# 按照时间戳原始顺序排序,避免替换时位置错误
|
# 按照时间戳原始顺序排序,避免替换时位置错误
|
||||||
to_convert.sort(key=lambda x: x[1].start())
|
to_convert.sort(key=lambda x: x[1].start())
|
||||||
|
|
||||||
# 执行替换
|
# 执行替换
|
||||||
# 由于替换会改变文本长度,从后向前替换
|
# 由于替换会改变文本长度,从后向前替换
|
||||||
to_convert.reverse()
|
to_convert.reverse()
|
||||||
@@ -775,5 +776,5 @@ def parse_text_timestamps(text: str, mode: str = "normal") -> str:
|
|||||||
readable_time = translate_timestamp_to_human_readable(ts, "relative")
|
readable_time = translate_timestamp_to_human_readable(ts, "relative")
|
||||||
pattern_instance = re.escape(match.group(0))
|
pattern_instance = re.escape(match.group(0))
|
||||||
result_text = re.sub(pattern_instance, readable_time, result_text, count=1)
|
result_text = re.sub(pattern_instance, readable_time, result_text, count=1)
|
||||||
|
|
||||||
return result_text
|
return result_text
|
||||||
|
|||||||
@@ -235,7 +235,6 @@ class ThinkFlowChat:
|
|||||||
do_reply = False
|
do_reply = False
|
||||||
if random() < reply_probability:
|
if random() < reply_probability:
|
||||||
try:
|
try:
|
||||||
|
|
||||||
do_reply = True
|
do_reply = True
|
||||||
|
|
||||||
# 回复前处理
|
# 回复前处理
|
||||||
@@ -397,12 +396,11 @@ class ThinkFlowChat:
|
|||||||
|
|
||||||
# 回复后处理
|
# 回复后处理
|
||||||
await willing_manager.after_generate_reply_handle(message.message_info.message_id)
|
await willing_manager.after_generate_reply_handle(message.message_info.message_id)
|
||||||
|
|
||||||
# 处理认识关系
|
# 处理认识关系
|
||||||
try:
|
try:
|
||||||
is_known = await relationship_manager.is_known_some_one(
|
is_known = await relationship_manager.is_known_some_one(
|
||||||
message.message_info.platform,
|
message.message_info.platform, message.message_info.user_info.user_id
|
||||||
message.message_info.user_info.user_id
|
|
||||||
)
|
)
|
||||||
if not is_known:
|
if not is_known:
|
||||||
logger.info(f"首次认识用户: {message.message_info.user_info.user_nickname}")
|
logger.info(f"首次认识用户: {message.message_info.user_info.user_nickname}")
|
||||||
@@ -410,22 +408,23 @@ class ThinkFlowChat:
|
|||||||
message.message_info.platform,
|
message.message_info.platform,
|
||||||
message.message_info.user_info.user_id,
|
message.message_info.user_info.user_id,
|
||||||
message.message_info.user_info.user_nickname,
|
message.message_info.user_info.user_nickname,
|
||||||
message.message_info.user_info.user_cardname or message.message_info.user_info.user_nickname,
|
message.message_info.user_info.user_cardname
|
||||||
""
|
or message.message_info.user_info.user_nickname,
|
||||||
|
"",
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.debug(f"已认识用户: {message.message_info.user_info.user_nickname}")
|
logger.debug(f"已认识用户: {message.message_info.user_info.user_nickname}")
|
||||||
if not await relationship_manager.is_qved_name(
|
if not await relationship_manager.is_qved_name(
|
||||||
message.message_info.platform,
|
message.message_info.platform, message.message_info.user_info.user_id
|
||||||
message.message_info.user_info.user_id
|
|
||||||
):
|
):
|
||||||
logger.info(f"更新已认识但未取名的用户: {message.message_info.user_info.user_nickname}")
|
logger.info(f"更新已认识但未取名的用户: {message.message_info.user_info.user_nickname}")
|
||||||
await relationship_manager.first_knowing_some_one(
|
await relationship_manager.first_knowing_some_one(
|
||||||
message.message_info.platform,
|
message.message_info.platform,
|
||||||
message.message_info.user_info.user_id,
|
message.message_info.user_info.user_id,
|
||||||
message.message_info.user_info.user_nickname,
|
message.message_info.user_info.user_nickname,
|
||||||
message.message_info.user_info.user_cardname or message.message_info.user_info.user_nickname,
|
message.message_info.user_info.user_cardname
|
||||||
""
|
or message.message_info.user_info.user_nickname,
|
||||||
|
"",
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"处理认识关系失败: {e}")
|
logger.error(f"处理认识关系失败: {e}")
|
||||||
|
|||||||
@@ -109,7 +109,7 @@ class ResponseGenerator:
|
|||||||
# sender_name = f"({message.chat_stream.user_info.user_id}){message.chat_stream.user_info.user_nickname}"
|
# sender_name = f"({message.chat_stream.user_info.user_id}){message.chat_stream.user_info.user_nickname}"
|
||||||
# else:
|
# else:
|
||||||
# sender_name = f"用户({message.chat_stream.user_info.user_id})"
|
# sender_name = f"用户({message.chat_stream.user_info.user_id})"
|
||||||
|
|
||||||
sender_name = f"<{message.chat_stream.user_info.platform}:{message.chat_stream.user_info.user_id}:{message.chat_stream.user_info.user_nickname}:{message.chat_stream.user_info.user_cardname}>"
|
sender_name = f"<{message.chat_stream.user_info.platform}:{message.chat_stream.user_info.user_id}:{message.chat_stream.user_info.user_nickname}:{message.chat_stream.user_info.user_cardname}>"
|
||||||
|
|
||||||
# 构建prompt
|
# 构建prompt
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ from src.heart_flow.heartflow import heartflow
|
|||||||
from src.plugins.utils.prompt_builder import Prompt, global_prompt_manager
|
from src.plugins.utils.prompt_builder import Prompt, global_prompt_manager
|
||||||
from src.plugins.person_info.relationship_manager import relationship_manager
|
from src.plugins.person_info.relationship_manager import relationship_manager
|
||||||
from src.plugins.chat.utils import parse_text_timestamps
|
from src.plugins.chat.utils import parse_text_timestamps
|
||||||
|
|
||||||
logger = get_module_logger("prompt")
|
logger = get_module_logger("prompt")
|
||||||
|
|
||||||
|
|
||||||
@@ -161,10 +162,10 @@ class PromptBuilder:
|
|||||||
prompt_ger=prompt_ger,
|
prompt_ger=prompt_ger,
|
||||||
moderation_prompt=await global_prompt_manager.get_prompt_async("moderation_prompt"),
|
moderation_prompt=await global_prompt_manager.get_prompt_async("moderation_prompt"),
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt = await relationship_manager.convert_all_person_sign_to_person_name(prompt)
|
prompt = await relationship_manager.convert_all_person_sign_to_person_name(prompt)
|
||||||
prompt = parse_text_timestamps(prompt, mode="lite")
|
prompt = parse_text_timestamps(prompt, mode="lite")
|
||||||
|
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
async def _build_prompt_simple(
|
async def _build_prompt_simple(
|
||||||
|
|||||||
@@ -64,12 +64,9 @@ class PersonInfoManager:
|
|||||||
if "person_info" not in db.list_collection_names():
|
if "person_info" not in db.list_collection_names():
|
||||||
db.create_collection("person_info")
|
db.create_collection("person_info")
|
||||||
db.person_info.create_index("person_id", unique=True)
|
db.person_info.create_index("person_id", unique=True)
|
||||||
|
|
||||||
# 初始化时读取所有person_name
|
# 初始化时读取所有person_name
|
||||||
cursor = db.person_info.find(
|
cursor = db.person_info.find({"person_name": {"$exists": True}}, {"person_id": 1, "person_name": 1, "_id": 0})
|
||||||
{"person_name": {"$exists": True}},
|
|
||||||
{"person_id": 1, "person_name": 1, "_id": 0}
|
|
||||||
)
|
|
||||||
for doc in cursor:
|
for doc in cursor:
|
||||||
if doc.get("person_name"):
|
if doc.get("person_name"):
|
||||||
self.person_name_list[doc["person_id"]] = doc["person_name"]
|
self.person_name_list[doc["person_id"]] = doc["person_name"]
|
||||||
@@ -77,10 +74,10 @@ class PersonInfoManager:
|
|||||||
|
|
||||||
def get_person_id(self, platform: str, user_id: int):
|
def get_person_id(self, platform: str, user_id: int):
|
||||||
"""获取唯一id"""
|
"""获取唯一id"""
|
||||||
#如果platform中存在-,就截取-后面的部分
|
# 如果platform中存在-,就截取-后面的部分
|
||||||
if "-" in platform:
|
if "-" in platform:
|
||||||
platform = platform.split("-")[1]
|
platform = platform.split("-")[1]
|
||||||
|
|
||||||
components = [platform, str(user_id)]
|
components = [platform, str(user_id)]
|
||||||
key = "_".join(components)
|
key = "_".join(components)
|
||||||
return hashlib.md5(key.encode()).hexdigest()
|
return hashlib.md5(key.encode()).hexdigest()
|
||||||
@@ -93,8 +90,7 @@ class PersonInfoManager:
|
|||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
async def create_person_info(self, person_id: str, data: dict = None):
|
async def create_person_info(self, person_id: str, data: dict = None):
|
||||||
"""创建一个项"""
|
"""创建一个项"""
|
||||||
if not person_id:
|
if not person_id:
|
||||||
@@ -125,7 +121,7 @@ class PersonInfoManager:
|
|||||||
Data[field_name] = value
|
Data[field_name] = value
|
||||||
logger.debug(f"更新时{person_id}不存在,已新建")
|
logger.debug(f"更新时{person_id}不存在,已新建")
|
||||||
await self.create_person_info(person_id, Data)
|
await self.create_person_info(person_id, Data)
|
||||||
|
|
||||||
async def has_one_field(self, person_id: str, field_name: str):
|
async def has_one_field(self, person_id: str, field_name: str):
|
||||||
"""判断是否存在某一个字段"""
|
"""判断是否存在某一个字段"""
|
||||||
document = db.person_info.find_one({"person_id": person_id}, {field_name: 1})
|
document = db.person_info.find_one({"person_id": person_id}, {field_name: 1})
|
||||||
@@ -133,36 +129,35 @@ class PersonInfoManager:
|
|||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _extract_json_from_text(self, text: str) -> dict:
|
def _extract_json_from_text(self, text: str) -> dict:
|
||||||
"""从文本中提取JSON数据的高容错方法"""
|
"""从文本中提取JSON数据的高容错方法"""
|
||||||
try:
|
try:
|
||||||
|
|
||||||
# 尝试直接解析
|
# 尝试直接解析
|
||||||
return json.loads(text)
|
return json.loads(text)
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
try:
|
try:
|
||||||
# 尝试找到JSON格式的部分
|
# 尝试找到JSON格式的部分
|
||||||
json_pattern = r'\{[^{}]*\}'
|
json_pattern = r"\{[^{}]*\}"
|
||||||
matches = re.findall(json_pattern, text)
|
matches = re.findall(json_pattern, text)
|
||||||
if matches:
|
if matches:
|
||||||
return json.loads(matches[0])
|
return json.loads(matches[0])
|
||||||
|
|
||||||
# 如果上面都失败了,尝试提取键值对
|
# 如果上面都失败了,尝试提取键值对
|
||||||
nickname_pattern = r'"nickname"[:\s]+"([^"]+)"'
|
nickname_pattern = r'"nickname"[:\s]+"([^"]+)"'
|
||||||
reason_pattern = r'"reason"[:\s]+"([^"]+)"'
|
reason_pattern = r'"reason"[:\s]+"([^"]+)"'
|
||||||
|
|
||||||
nickname_match = re.search(nickname_pattern, text)
|
nickname_match = re.search(nickname_pattern, text)
|
||||||
reason_match = re.search(reason_pattern, text)
|
reason_match = re.search(reason_pattern, text)
|
||||||
|
|
||||||
if nickname_match:
|
if nickname_match:
|
||||||
return {
|
return {
|
||||||
"nickname": nickname_match.group(1),
|
"nickname": nickname_match.group(1),
|
||||||
"reason": reason_match.group(1) if reason_match else "未提供理由"
|
"reason": reason_match.group(1) if reason_match else "未提供理由",
|
||||||
}
|
}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"JSON提取失败: {str(e)}")
|
logger.error(f"JSON提取失败: {str(e)}")
|
||||||
|
|
||||||
# 如果所有方法都失败了,返回空结果
|
# 如果所有方法都失败了,返回空结果
|
||||||
return {"nickname": "", "reason": ""}
|
return {"nickname": "", "reason": ""}
|
||||||
|
|
||||||
@@ -171,10 +166,10 @@ class PersonInfoManager:
|
|||||||
if not person_id:
|
if not person_id:
|
||||||
logger.debug("取名失败:person_id不能为空")
|
logger.debug("取名失败:person_id不能为空")
|
||||||
return
|
return
|
||||||
|
|
||||||
old_name = await self.get_value(person_id, "person_name")
|
old_name = await self.get_value(person_id, "person_name")
|
||||||
old_reason = await self.get_value(person_id, "name_reason")
|
old_reason = await self.get_value(person_id, "name_reason")
|
||||||
|
|
||||||
max_retries = 5 # 最大重试次数
|
max_retries = 5 # 最大重试次数
|
||||||
current_try = 0
|
current_try = 0
|
||||||
existing_names = ""
|
existing_names = ""
|
||||||
@@ -182,7 +177,7 @@ class PersonInfoManager:
|
|||||||
individuality = Individuality.get_instance()
|
individuality = Individuality.get_instance()
|
||||||
prompt_personality = individuality.get_prompt(type="personality", x_person=2, level=1)
|
prompt_personality = individuality.get_prompt(type="personality", x_person=2, level=1)
|
||||||
bot_name = individuality.personality.bot_nickname
|
bot_name = individuality.personality.bot_nickname
|
||||||
|
|
||||||
qv_name_prompt = f"你是{bot_name},你{prompt_personality}"
|
qv_name_prompt = f"你是{bot_name},你{prompt_personality}"
|
||||||
qv_name_prompt += f"现在你想给一个用户取一个昵称,用户是的qq昵称是{user_nickname},"
|
qv_name_prompt += f"现在你想给一个用户取一个昵称,用户是的qq昵称是{user_nickname},"
|
||||||
qv_name_prompt += f"用户的qq群昵称名是{user_cardname},"
|
qv_name_prompt += f"用户的qq群昵称名是{user_cardname},"
|
||||||
@@ -195,20 +190,20 @@ class PersonInfoManager:
|
|||||||
if existing_names:
|
if existing_names:
|
||||||
qv_name_prompt += f"\n请注意,以下名称已被使用,不要使用以下昵称:{existing_names}。\n"
|
qv_name_prompt += f"\n请注意,以下名称已被使用,不要使用以下昵称:{existing_names}。\n"
|
||||||
qv_name_prompt += "请用json给出你的想法,并给出理由,示例如下:"
|
qv_name_prompt += "请用json给出你的想法,并给出理由,示例如下:"
|
||||||
qv_name_prompt += '''{
|
qv_name_prompt += """{
|
||||||
"nickname": "昵称",
|
"nickname": "昵称",
|
||||||
"reason": "理由"
|
"reason": "理由"
|
||||||
}'''
|
}"""
|
||||||
logger.debug(f"取名提示词:{qv_name_prompt}")
|
logger.debug(f"取名提示词:{qv_name_prompt}")
|
||||||
response = await self.qv_name_llm.generate_response(qv_name_prompt)
|
response = await self.qv_name_llm.generate_response(qv_name_prompt)
|
||||||
logger.debug(f"取名回复:{response}")
|
logger.debug(f"取名回复:{response}")
|
||||||
result = self._extract_json_from_text(response[0])
|
result = self._extract_json_from_text(response[0])
|
||||||
|
|
||||||
if not result["nickname"]:
|
if not result["nickname"]:
|
||||||
logger.error("生成的昵称为空,重试中...")
|
logger.error("生成的昵称为空,重试中...")
|
||||||
current_try += 1
|
current_try += 1
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 检查生成的昵称是否已存在
|
# 检查生成的昵称是否已存在
|
||||||
if result["nickname"] not in self.person_name_list.values():
|
if result["nickname"] not in self.person_name_list.values():
|
||||||
# 更新数据库和内存中的列表
|
# 更新数据库和内存中的列表
|
||||||
@@ -216,16 +211,16 @@ class PersonInfoManager:
|
|||||||
# await self.update_one_field(person_id, "nickname", user_nickname)
|
# await self.update_one_field(person_id, "nickname", user_nickname)
|
||||||
# await self.update_one_field(person_id, "avatar", user_avatar)
|
# await self.update_one_field(person_id, "avatar", user_avatar)
|
||||||
await self.update_one_field(person_id, "name_reason", result["reason"])
|
await self.update_one_field(person_id, "name_reason", result["reason"])
|
||||||
|
|
||||||
self.person_name_list[person_id] = result["nickname"]
|
self.person_name_list[person_id] = result["nickname"]
|
||||||
logger.debug(f"用户 {person_id} 的名称已更新为 {result['nickname']},原因:{result['reason']}")
|
logger.debug(f"用户 {person_id} 的名称已更新为 {result['nickname']},原因:{result['reason']}")
|
||||||
return result
|
return result
|
||||||
else:
|
else:
|
||||||
existing_names += f"{result['nickname']}、"
|
existing_names += f"{result['nickname']}、"
|
||||||
|
|
||||||
logger.debug(f"生成的昵称 {result['nickname']} 已存在,重试中...")
|
logger.debug(f"生成的昵称 {result['nickname']} 已存在,重试中...")
|
||||||
current_try += 1
|
current_try += 1
|
||||||
|
|
||||||
logger.error(f"在{max_retries}次尝试后仍未能生成唯一昵称")
|
logger.error(f"在{max_retries}次尝试后仍未能生成唯一昵称")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|||||||
@@ -76,13 +76,13 @@ class RelationshipManager:
|
|||||||
return mood_value * coefficient
|
return mood_value * coefficient
|
||||||
else:
|
else:
|
||||||
return mood_value / coefficient
|
return mood_value / coefficient
|
||||||
|
|
||||||
async def is_known_some_one(self, platform , user_id):
|
async def is_known_some_one(self, platform, user_id):
|
||||||
"""判断是否认识某人"""
|
"""判断是否认识某人"""
|
||||||
is_known = person_info_manager.is_person_known(platform, user_id)
|
is_known = person_info_manager.is_person_known(platform, user_id)
|
||||||
return is_known
|
return is_known
|
||||||
|
|
||||||
async def is_qved_name(self, platform , user_id):
|
async def is_qved_name(self, platform, user_id):
|
||||||
"""判断是否认识某人"""
|
"""判断是否认识某人"""
|
||||||
person_id = person_info_manager.get_person_id(platform, user_id)
|
person_id = person_info_manager.get_person_id(platform, user_id)
|
||||||
is_qved = await person_info_manager.has_one_field(person_id, "person_name")
|
is_qved = await person_info_manager.has_one_field(person_id, "person_name")
|
||||||
@@ -93,42 +93,41 @@ class RelationshipManager:
|
|||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def first_knowing_some_one(self, platform , user_id, user_nickname, user_cardname, user_avatar):
|
async def first_knowing_some_one(self, platform, user_id, user_nickname, user_cardname, user_avatar):
|
||||||
"""判断是否认识某人"""
|
"""判断是否认识某人"""
|
||||||
person_id = person_info_manager.get_person_id(platform,user_id)
|
person_id = person_info_manager.get_person_id(platform, user_id)
|
||||||
await person_info_manager.update_one_field(person_id, "nickname", user_nickname)
|
await person_info_manager.update_one_field(person_id, "nickname", user_nickname)
|
||||||
# await person_info_manager.update_one_field(person_id, "user_cardname", user_cardname)
|
# await person_info_manager.update_one_field(person_id, "user_cardname", user_cardname)
|
||||||
# await person_info_manager.update_one_field(person_id, "user_avatar", user_avatar)
|
# await person_info_manager.update_one_field(person_id, "user_avatar", user_avatar)
|
||||||
await person_info_manager.qv_person_name(person_id, user_nickname, user_cardname, user_avatar)
|
await person_info_manager.qv_person_name(person_id, user_nickname, user_cardname, user_avatar)
|
||||||
|
|
||||||
async def convert_all_person_sign_to_person_name(self,input_text:str):
|
async def convert_all_person_sign_to_person_name(self, input_text: str):
|
||||||
"""将所有人的<platform:user_id:nickname:cardname>格式转换为person_name"""
|
"""将所有人的<platform:user_id:nickname:cardname>格式转换为person_name"""
|
||||||
try:
|
try:
|
||||||
# 使用正则表达式匹配<platform:user_id:nickname:cardname>格式
|
# 使用正则表达式匹配<platform:user_id:nickname:cardname>格式
|
||||||
all_person = person_info_manager.person_name_list
|
all_person = person_info_manager.person_name_list
|
||||||
|
|
||||||
pattern = r'<([^:]+):(\d+):([^:]+):([^>]+)>'
|
pattern = r"<([^:]+):(\d+):([^:]+):([^>]+)>"
|
||||||
matches = re.findall(pattern, input_text)
|
matches = re.findall(pattern, input_text)
|
||||||
|
|
||||||
# 遍历匹配结果,将<platform:user_id:nickname:cardname>替换为person_name
|
# 遍历匹配结果,将<platform:user_id:nickname:cardname>替换为person_name
|
||||||
result_text = input_text
|
result_text = input_text
|
||||||
for platform, user_id, nickname, cardname in matches:
|
for platform, user_id, nickname, cardname in matches:
|
||||||
person_id = person_info_manager.get_person_id(platform, user_id)
|
person_id = person_info_manager.get_person_id(platform, user_id)
|
||||||
# 默认使用昵称作为人名
|
# 默认使用昵称作为人名
|
||||||
person_name = nickname.strip() if nickname.strip() else cardname.strip()
|
person_name = nickname.strip() if nickname.strip() else cardname.strip()
|
||||||
|
|
||||||
if person_id in all_person:
|
if person_id in all_person:
|
||||||
if all_person[person_id] != None:
|
if all_person[person_id] != None:
|
||||||
person_name = all_person[person_id]
|
person_name = all_person[person_id]
|
||||||
|
|
||||||
print(f"将<{platform}:{user_id}:{nickname}:{cardname}>替换为{person_name}")
|
print(f"将<{platform}:{user_id}:{nickname}:{cardname}>替换为{person_name}")
|
||||||
|
|
||||||
|
|
||||||
result_text = result_text.replace(f"<{platform}:{user_id}:{nickname}:{cardname}>", person_name)
|
result_text = result_text.replace(f"<{platform}:{user_id}:{nickname}:{cardname}>", person_name)
|
||||||
|
|
||||||
return result_text
|
return result_text
|
||||||
except Exception as e:
|
except Exception:
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
return input_text
|
return input_text
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user