This commit is contained in:
SengokuCola
2025-07-20 14:24:49 +08:00
34 changed files with 913 additions and 442 deletions

View File

@@ -179,8 +179,7 @@ class HeartFChatting:
await asyncio.sleep(10)
if self.loop_mode == ChatMode.NORMAL:
self.energy_value -= 0.3
if self.energy_value <= 0.3:
self.energy_value = 0.3
self.energy_value = max(self.energy_value, 0.3)
def print_cycle_info(self, cycle_timers):
# 记录循环信息和计时器结果
@@ -257,6 +256,7 @@ class HeartFChatting:
return f"{person_name}:{message_data.get('processed_plain_text')}"
async def _observe(self, message_data: Optional[Dict[str, Any]] = None):
# sourcery skip: hoist-statement-from-if, merge-comparisons, reintroduce-else
if not message_data:
message_data = {}
action_type = "no_action"
@@ -462,7 +462,7 @@ class HeartFChatting:
"兴趣"模式下,判断是否回复并生成内容。
"""
interested_rate = message_data.get("interest_value", 0.0) * self.willing_amplifier
interested_rate = (message_data.get("interest_value") or 0.0) * self.willing_amplifier
self.willing_manager.setup(message_data, self.chat_stream)

View File

@@ -106,10 +106,10 @@ class EmbeddingStore:
asyncio.get_running_loop()
# 如果在事件循环中,使用线程池执行
import concurrent.futures
def run_in_thread():
return asyncio.run(get_embedding(s))
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(run_in_thread)
result = future.result()
@@ -294,10 +294,10 @@ class EmbeddingStore:
"""
if self.faiss_index is None:
logger.debug("FaissIndex尚未构建,返回None")
return None
return []
if self.idx2hash is None:
logger.warning("idx2hash尚未构建,返回None")
return None
return []
# L2归一化
faiss.normalize_L2(np.array([query], dtype=np.float32))
@@ -318,15 +318,15 @@ class EmbeddingStore:
class EmbeddingManager:
def __init__(self):
self.paragraphs_embedding_store = EmbeddingStore(
local_storage['pg_namespace'],
local_storage["pg_namespace"], # type: ignore
EMBEDDING_DATA_DIR_STR,
)
self.entities_embedding_store = EmbeddingStore(
local_storage['pg_namespace'],
local_storage["pg_namespace"], # type: ignore
EMBEDDING_DATA_DIR_STR,
)
self.relation_embedding_store = EmbeddingStore(
local_storage['pg_namespace'],
local_storage["pg_namespace"], # type: ignore
EMBEDDING_DATA_DIR_STR,
)
self.stored_pg_hashes = set()

View File

@@ -30,20 +30,20 @@ def _get_kg_dir():
"""
安全地获取KG数据目录路径
"""
root_path = local_storage['root_path']
root_path: str = local_storage["root_path"]
if root_path is None:
# 如果 local_storage 中没有 root_path使用当前文件的相对路径作为备用
current_dir = os.path.dirname(os.path.abspath(__file__))
root_path = os.path.abspath(os.path.join(current_dir, "..", "..", ".."))
logger.warning(f"local_storage 中未找到 root_path使用备用路径: {root_path}")
# 获取RAG数据目录
rag_data_dir = global_config["persistence"]["rag_data_dir"]
rag_data_dir: str = global_config["persistence"]["rag_data_dir"]
if rag_data_dir is None:
kg_dir = os.path.join(root_path, "data/rag")
else:
kg_dir = os.path.join(root_path, rag_data_dir)
return str(kg_dir).replace("\\", "/")
@@ -65,9 +65,9 @@ class KGManager:
# 持久化相关 - 使用延迟初始化的路径
self.dir_path = get_kg_dir_str()
self.graph_data_path = self.dir_path + "/" + local_storage['rag_graph_namespace'] + ".graphml"
self.ent_cnt_data_path = self.dir_path + "/" + local_storage['rag_ent_cnt_namespace'] + ".parquet"
self.pg_hash_file_path = self.dir_path + "/" + local_storage['rag_pg_hash_namespace'] + ".json"
self.graph_data_path = self.dir_path + "/" + local_storage["rag_graph_namespace"] + ".graphml"
self.ent_cnt_data_path = self.dir_path + "/" + local_storage["rag_ent_cnt_namespace"] + ".parquet"
self.pg_hash_file_path = self.dir_path + "/" + local_storage["rag_pg_hash_namespace"] + ".json"
def save_to_file(self):
"""将KG数据保存到文件"""
@@ -91,11 +91,11 @@ class KGManager:
"""从文件加载KG数据"""
# 确保文件存在
if not os.path.exists(self.pg_hash_file_path):
raise Exception(f"KG段落hash文件{self.pg_hash_file_path}不存在")
raise FileNotFoundError(f"KG段落hash文件{self.pg_hash_file_path}不存在")
if not os.path.exists(self.ent_cnt_data_path):
raise Exception(f"KG实体计数文件{self.ent_cnt_data_path}不存在")
raise FileNotFoundError(f"KG实体计数文件{self.ent_cnt_data_path}不存在")
if not os.path.exists(self.graph_data_path):
raise Exception(f"KG图文件{self.graph_data_path}不存在")
raise FileNotFoundError(f"KG图文件{self.graph_data_path}不存在")
# 加载段落hash
with open(self.pg_hash_file_path, "r", encoding="utf-8") as f:
@@ -122,8 +122,8 @@ class KGManager:
# 避免自连接
continue
# 一个triple就是一条边同时构建双向联系
hash_key1 = local_storage['ent_namespace'] + "-" + get_sha256(triple[0])
hash_key2 = local_storage['ent_namespace'] + "-" + get_sha256(triple[2])
hash_key1 = local_storage["ent_namespace"] + "-" + get_sha256(triple[0])
hash_key2 = local_storage["ent_namespace"] + "-" + get_sha256(triple[2])
node_to_node[(hash_key1, hash_key2)] = node_to_node.get((hash_key1, hash_key2), 0) + 1.0
node_to_node[(hash_key2, hash_key1)] = node_to_node.get((hash_key2, hash_key1), 0) + 1.0
entity_set.add(hash_key1)
@@ -141,8 +141,8 @@ class KGManager:
"""构建实体节点与文段节点之间的关系"""
for idx in triple_list_data:
for triple in triple_list_data[idx]:
ent_hash_key = local_storage['ent_namespace'] + "-" + get_sha256(triple[0])
pg_hash_key = local_storage['pg_namespace'] + "-" + str(idx)
ent_hash_key = local_storage["ent_namespace"] + "-" + get_sha256(triple[0])
pg_hash_key = local_storage["pg_namespace"] + "-" + str(idx)
node_to_node[(ent_hash_key, pg_hash_key)] = node_to_node.get((ent_hash_key, pg_hash_key), 0) + 1.0
@staticmethod
@@ -157,8 +157,8 @@ class KGManager:
ent_hash_list = set()
for triple_list in triple_list_data.values():
for triple in triple_list:
ent_hash_list.add(local_storage['ent_namespace'] + "-" + get_sha256(triple[0]))
ent_hash_list.add(local_storage['ent_namespace'] + "-" + get_sha256(triple[2]))
ent_hash_list.add(local_storage["ent_namespace"] + "-" + get_sha256(triple[0]))
ent_hash_list.add(local_storage["ent_namespace"] + "-" + get_sha256(triple[2]))
ent_hash_list = list(ent_hash_list)
synonym_hash_set = set()
@@ -263,7 +263,7 @@ class KGManager:
for src_tgt in node_to_node.keys():
for node_hash in src_tgt:
if node_hash not in existed_nodes:
if node_hash.startswith(local_storage['ent_namespace']):
if node_hash.startswith(local_storage["ent_namespace"]):
# 新增实体节点
node = embedding_manager.entities_embedding_store.store.get(node_hash)
if node is None:
@@ -275,7 +275,7 @@ class KGManager:
node_item["type"] = "ent"
node_item["create_time"] = now_time
self.graph.update_node(node_item)
elif node_hash.startswith(local_storage['pg_namespace']):
elif node_hash.startswith(local_storage["pg_namespace"]):
# 新增文段节点
node = embedding_manager.paragraphs_embedding_store.store.get(node_hash)
if node is None:
@@ -359,7 +359,7 @@ class KGManager:
# 关系三元组
triple = relation[2:-2].split("', '")
for ent in [(triple[0]), (triple[2])]:
ent_hash = local_storage['ent_namespace'] + "-" + get_sha256(ent)
ent_hash = local_storage["ent_namespace"] + "-" + get_sha256(ent)
if ent_hash in existed_nodes: # 该实体需在KG中存在
if ent_hash not in ent_sim_scores: # 尚未记录的实体
ent_sim_scores[ent_hash] = []
@@ -437,7 +437,9 @@ class KGManager:
# 获取最终结果
# 从搜索结果中提取文段节点的结果
passage_node_res = [
(node_key, score) for node_key, score in ppr_res.items() if node_key.startswith(local_storage['pg_namespace'])
(node_key, score)
for node_key, score in ppr_res.items()
if node_key.startswith(local_storage["pg_namespace"])
]
del ppr_res

View File

@@ -33,6 +33,7 @@ RAG_PG_HASH_NAMESPACE = "rag-pg-hash"
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
DATA_PATH = os.path.join(ROOT_PATH, "data")
def _initialize_knowledge_local_storage():
"""
初始化知识库相关的本地存储配置
@@ -41,55 +42,58 @@ def _initialize_knowledge_local_storage():
# 定义所有需要初始化的配置项
default_configs = {
# 路径配置
'root_path': ROOT_PATH,
'data_path': f"{ROOT_PATH}/data",
"root_path": ROOT_PATH,
"data_path": f"{ROOT_PATH}/data",
# 实体和命名空间配置
'lpmm_invalid_entity': INVALID_ENTITY,
'pg_namespace': PG_NAMESPACE,
'ent_namespace': ENT_NAMESPACE,
'rel_namespace': REL_NAMESPACE,
"lpmm_invalid_entity": INVALID_ENTITY,
"pg_namespace": PG_NAMESPACE,
"ent_namespace": ENT_NAMESPACE,
"rel_namespace": REL_NAMESPACE,
# RAG相关命名空间配置
'rag_graph_namespace': RAG_GRAPH_NAMESPACE,
'rag_ent_cnt_namespace': RAG_ENT_CNT_NAMESPACE,
'rag_pg_hash_namespace': RAG_PG_HASH_NAMESPACE
"rag_graph_namespace": RAG_GRAPH_NAMESPACE,
"rag_ent_cnt_namespace": RAG_ENT_CNT_NAMESPACE,
"rag_pg_hash_namespace": RAG_PG_HASH_NAMESPACE,
}
# 日志级别映射重要配置用info其他用debug
important_configs = {'root_path', 'data_path'}
important_configs = {"root_path", "data_path"}
# 批量设置配置项
initialized_count = 0
for key, default_value in default_configs.items():
if local_storage[key] is None:
local_storage[key] = default_value
# 根据重要性选择日志级别
if key in important_configs:
logger.info(f"设置{key}: {default_value}")
else:
logger.debug(f"设置{key}: {default_value}")
initialized_count += 1
if initialized_count > 0:
logger.info(f"知识库本地存储初始化完成,共设置 {initialized_count} 项配置")
else:
logger.debug("知识库本地存储配置已存在,跳过初始化")
# 初始化本地存储路径
# sourcery skip: dict-comprehension
_initialize_knowledge_local_storage()
qa_manager = None
inspire_manager = None
# 检查LPMM知识库是否启用
if bot_global_config.lpmm_knowledge.enable:
logger.info("正在初始化Mai-LPMM")
logger.info("创建LLM客户端")
llm_client_list = dict()
llm_client_list = {}
for key in global_config["llm_providers"]:
llm_client_list[key] = LLMClient(
global_config["llm_providers"][key]["base_url"],
global_config["llm_providers"][key]["api_key"],
global_config["llm_providers"][key]["base_url"], # type: ignore
global_config["llm_providers"][key]["api_key"], # type: ignore
)
# 初始化Embedding库
@@ -98,7 +102,7 @@ if bot_global_config.lpmm_knowledge.enable:
try:
embed_manager.load_from_file()
except Exception as e:
logger.warning("此消息不会影响正常使用从文件加载Embedding库时{}".format(e))
logger.warning(f"此消息不会影响正常使用从文件加载Embedding库时{e}")
# logger.warning("如果你是第一次导入知识,或者还未导入知识,请忽略此错误")
logger.info("Embedding库加载完成")
# 初始化KG
@@ -107,7 +111,7 @@ if bot_global_config.lpmm_knowledge.enable:
try:
kg_manager.load_from_file()
except Exception as e:
logger.warning("此消息不会影响正常使用从文件加载KG时{}".format(e))
logger.warning(f"此消息不会影响正常使用从文件加载KG时{e}")
# logger.warning("如果你是第一次导入知识,或者还未导入知识,请忽略此错误")
logger.info("KG加载完成")
@@ -116,7 +120,7 @@ if bot_global_config.lpmm_knowledge.enable:
# 数据比对Embedding库与KG的段落hash集合
for pg_hash in kg_manager.stored_paragraph_hashes:
key = PG_NAMESPACE + "-" + pg_hash
key = f"{PG_NAMESPACE}-{pg_hash}"
if key not in embed_manager.stored_pg_hashes:
logger.warning(f"KG中存在Embedding库中不存在的段落{key}")
@@ -134,5 +138,3 @@ if bot_global_config.lpmm_knowledge.enable:
else:
logger.info("LPMM知识库已禁用跳过初始化")
# 创建空的占位符对象,避免导入错误
qa_manager = None
inspire_manager = None

View File

@@ -1,5 +1,3 @@
from .llm_client import LLMMessage
entity_extract_system_prompt = """你是一个性能优异的实体提取系统。请从段落中提取出所有实体并以JSON列表的形式输出。
输出格式示例:
@@ -63,10 +61,10 @@ qa_system_prompt = """
"""
def build_qa_context(question: str, knowledge: list[tuple[str, str, str]]) -> list[LLMMessage]:
knowledge = "\n".join([f"{i + 1}. 相关性:{k[0]}\n{k[1]}" for i, k in enumerate(knowledge)])
messages = [
LLMMessage("system", qa_system_prompt).to_dict(),
LLMMessage("user", f"问题:\n{question}\n\n可能有帮助的信息:\n{knowledge}").to_dict(),
]
return messages
# def build_qa_context(question: str, knowledge: list[tuple[str, str, str]]) -> list[LLMMessage]:
# knowledge = "\n".join([f"{i + 1}. 相关性:{k[0]}\n{k[1]}" for i, k in enumerate(knowledge)])
# messages = [
# LLMMessage("system", qa_system_prompt).to_dict(),
# LLMMessage("user", f"问题:\n{question}\n\n可能有帮助的信息\n{knowledge}").to_dict(),
# ]
# return messages

View File

@@ -9,6 +9,7 @@ from maim_message import Seg, UserInfo, BaseMessageInfo, MessageBase
from src.common.logger import get_logger
from src.chat.utils.utils_image import get_image_manager
from src.chat.utils.utils_voice import get_voice_text
from .chat_stream import ChatStream
install(extra_lines=3)
@@ -106,6 +107,7 @@ class MessageRecv(Message):
self.has_emoji = False
self.is_picid = False
self.has_picid = False
self.is_voice = False
self.is_mentioned = None
self.is_command = False
@@ -153,17 +155,27 @@ class MessageRecv(Message):
self.has_emoji = True
self.is_emoji = True
self.is_picid = False
self.is_voice = False
if isinstance(segment.data, str):
return await get_image_manager().get_emoji_description(segment.data)
return "[发了一个表情包,网卡了加载不出来]"
elif segment.type == "voice":
self.is_picid = False
self.is_emoji = False
self.is_voice = True
if isinstance(segment.data, str):
return await get_voice_text(segment.data)
return "[发了一段语音,网卡了加载不出来]"
elif segment.type == "mention_bot":
self.is_picid = False
self.is_emoji = False
self.is_voice = False
self.is_mentioned = float(segment.data) # type: ignore
return ""
elif segment.type == "priority_info":
self.is_picid = False
self.is_emoji = False
self.is_voice = False
if isinstance(segment.data, dict):
# 处理优先级信息
self.priority_mode = "priority"
@@ -212,10 +224,12 @@ class MessageRecvS4U(MessageRecv):
"""
try:
if segment.type == "text":
self.is_voice = False
self.is_picid = False
self.is_emoji = False
return segment.data # type: ignore
elif segment.type == "image":
self.is_voice = False
# 如果是base64图片数据
if isinstance(segment.data, str):
self.has_picid = True
@@ -233,12 +247,22 @@ class MessageRecvS4U(MessageRecv):
if isinstance(segment.data, str):
return await get_image_manager().get_emoji_description(segment.data)
return "[发了一个表情包,网卡了加载不出来]"
elif segment.type == "voice":
self.has_picid = False
self.is_picid = False
self.is_emoji = False
self.is_voice = True
if isinstance(segment.data, str):
return await get_voice_text(segment.data)
return "[发了一段语音,网卡了加载不出来]"
elif segment.type == "mention_bot":
self.is_voice = False
self.is_picid = False
self.is_emoji = False
self.is_mentioned = float(segment.data) # type: ignore
return ""
elif segment.type == "priority_info":
self.is_voice = False
self.is_picid = False
self.is_emoji = False
if isinstance(segment.data, dict):
@@ -253,6 +277,7 @@ class MessageRecvS4U(MessageRecv):
"""
return ""
elif segment.type == "gift":
self.is_voice = False
self.is_gift = True
# 解析gift_info格式为"名称:数量"
name, count = segment.data.split(":", 1) # type: ignore
@@ -343,6 +368,10 @@ class MessageProcessBase(Message):
if isinstance(seg.data, str):
return await get_image_manager().get_emoji_description(seg.data)
return "[表情,网卡了加载不出来]"
elif seg.type == "voice":
if isinstance(seg.data, str):
return await get_voice_text(seg.data)
return "[发了一段语音,网卡了加载不出来]"
elif seg.type == "at":
return f"[@{seg.data}]"
elif seg.type == "reply":
@@ -455,25 +484,25 @@ class MessageSending(MessageProcessBase):
if self.message_segment:
self.processed_plain_text = await self._process_message_segments(self.message_segment)
@classmethod
def from_thinking(
cls,
thinking: MessageThinking,
message_segment: Seg,
is_head: bool = False,
is_emoji: bool = False,
) -> "MessageSending":
"""从思考状态消息创建发送状态消息"""
return cls(
message_id=thinking.message_info.message_id, # type: ignore
chat_stream=thinking.chat_stream,
message_segment=message_segment,
bot_user_info=thinking.message_info.user_info, # type: ignore
reply=thinking.reply,
is_head=is_head,
is_emoji=is_emoji,
sender_info=None,
)
# @classmethod
# def from_thinking(
# cls,
# thinking: MessageThinking,
# message_segment: Seg,
# is_head: bool = False,
# is_emoji: bool = False,
# ) -> "MessageSending":
# """从思考状态消息创建发送状态消息"""
# return cls(
# message_id=thinking.message_info.message_id, # type: ignore
# chat_stream=thinking.chat_stream,
# message_segment=message_segment,
# bot_user_info=thinking.message_info.user_info, # type: ignore
# reply=thinking.reply,
# is_head=is_head,
# is_emoji=is_emoji,
# sender_info=None,
# )
def to_dict(self):
ret = super().to_dict()

View File

@@ -262,4 +262,4 @@ class ActionManager:
"""
from src.plugin_system.core.component_registry import component_registry
return component_registry.get_component_class(action_name) # type: ignore
return component_registry.get_component_class(action_name, ComponentType.ACTION) # type: ignore

View File

@@ -0,0 +1,35 @@
import base64
from src.config.config import global_config
from src.llm_models.utils_model import LLMRequest
from src.common.logger import get_logger
from rich.traceback import install
install(extra_lines=3)
logger = get_logger("chat_voice")
async def get_voice_text(voice_base64: str) -> str:
"""获取音频文件描述"""
if not global_config.chat.enable_asr:
logger.warning("语音识别未启用,无法处理语音消息")
return "[语音]"
try:
# 解码base64音频数据
# 确保base64字符串只包含ASCII字符
if isinstance(voice_base64, str):
voice_base64 = voice_base64.encode("ascii", errors="ignore").decode("ascii")
voice_bytes = base64.b64decode(voice_base64)
_llm = LLMRequest(model=global_config.model.voice, request_type="voice")
text = await _llm.generate_response_for_voice(voice_bytes)
if text is None:
logger.warning("未能生成语音文本")
return "[语音(文本生成失败)]"
logger.debug(f"描述是{text}")
return f"[语音:{text}]"
except Exception as e:
logger.error(f"语音转文字失败: {str(e)}")
return "[语音]"

View File

@@ -21,6 +21,7 @@ class ClassicalWillingManager(BaseWillingManager):
self._decay_task = asyncio.create_task(self._decay_reply_willing())
async def get_reply_probability(self, message_id):
# sourcery skip: inline-immediately-returned-variable
willing_info = self.ongoing_messages[message_id]
chat_id = willing_info.chat_id
current_willing = self.chat_reply_willing.get(chat_id, 0)

View File

@@ -25,6 +25,8 @@ import asyncio
import time
import math
from src.chat.message_receive.chat_stream import ChatStream
class MxpWillingManager(BaseWillingManager):
"""Mxp意愿管理器"""
@@ -76,7 +78,7 @@ class MxpWillingManager(BaseWillingManager):
self.chat_bot_message_time[w_info.chat_id].append(current_time)
if len(self.chat_bot_message_time[w_info.chat_id]) == int(self.fatigue_messages_triggered_num):
time_interval = 60 - (current_time - self.chat_bot_message_time[w_info.chat_id].pop(0))
self.chat_fatigue_punishment_list[w_info.chat_id].append([current_time, time_interval * 2])
self.chat_fatigue_punishment_list[w_info.chat_id].append((current_time, time_interval * 2))
async def after_generate_reply_handle(self, message_id: str):
"""回复后处理"""
@@ -87,12 +89,14 @@ class MxpWillingManager(BaseWillingManager):
# rel_level = self._get_relationship_level_num(rel_value)
# self.chat_person_reply_willing[w_info.chat_id][w_info.person_id] += rel_level * 0.05
now_chat_new_person = self.last_response_person.get(w_info.chat_id, [w_info.person_id, 0])
now_chat_new_person = self.last_response_person.get(w_info.chat_id, (w_info.person_id, 0))
if now_chat_new_person[0] == w_info.person_id:
if now_chat_new_person[1] < 3:
now_chat_new_person[1] += 1
tmp_list = list(now_chat_new_person)
tmp_list[1] += 1 # type: ignore
self.last_response_person[w_info.chat_id] = tuple(tmp_list) # type: ignore
else:
self.last_response_person[w_info.chat_id] = [w_info.person_id, 0]
self.last_response_person[w_info.chat_id] = (w_info.person_id, 0)
async def not_reply_handle(self, message_id: str):
"""不回复处理"""
@@ -108,11 +112,12 @@ class MxpWillingManager(BaseWillingManager):
self.chat_person_reply_willing[w_info.chat_id][w_info.person_id] += self.single_chat_gain * (
2 * self.last_response_person[w_info.chat_id][1] - 1
)
now_chat_new_person = self.last_response_person.get(w_info.chat_id, ["", 0])
now_chat_new_person = self.last_response_person.get(w_info.chat_id, ("", 0))
if now_chat_new_person[0] != w_info.person_id:
self.last_response_person[w_info.chat_id] = [w_info.person_id, 0]
self.last_response_person[w_info.chat_id] = (w_info.person_id, 0)
async def get_reply_probability(self, message_id: str):
# sourcery skip: merge-duplicate-blocks, remove-redundant-if
"""获取回复概率"""
async with self.lock:
w_info = self.ongoing_messages[message_id]
@@ -121,17 +126,16 @@ class MxpWillingManager(BaseWillingManager):
self.logger.debug(f"基础意愿值:{current_willing}")
if w_info.is_mentioned_bot:
current_willing_ = self.mention_willing_gain / (int(current_willing) + 1)
current_willing += current_willing_
willing_gain = self.mention_willing_gain / (int(current_willing) + 1)
current_willing += willing_gain
if self.is_debug:
self.logger.debug(f"提及增益:{current_willing_}")
self.logger.debug(f"提及增益:{willing_gain}")
if w_info.interested_rate > 0:
current_willing += math.atan(w_info.interested_rate / 2) / math.pi * 2 * self.interest_willing_gain
willing_gain = math.atan(w_info.interested_rate / 2) / math.pi * 2 * self.interest_willing_gain
current_willing += willing_gain
if self.is_debug:
self.logger.debug(
f"兴趣增益:{math.atan(w_info.interested_rate / 2) / math.pi * 2 * self.interest_willing_gain}"
)
self.logger.debug(f"兴趣增益:{willing_gain}")
self.chat_person_reply_willing[w_info.chat_id][w_info.person_id] = current_willing
@@ -152,8 +156,8 @@ class MxpWillingManager(BaseWillingManager):
self.logger.debug(f"疲劳衰减:{self.chat_fatigue_willing_attenuation.get(w_info.chat_id, 0)}")
chat_ongoing_messages = [msg for msg in self.ongoing_messages.values() if msg.chat_id == w_info.chat_id]
chat_person_ogoing_messages = [msg for msg in chat_ongoing_messages if msg.person_id == w_info.person_id]
if len(chat_person_ogoing_messages) >= 2:
chat_person_ongoing_messages = [msg for msg in chat_ongoing_messages if msg.person_id == w_info.person_id]
if len(chat_person_ongoing_messages) >= 2:
current_willing = 0
if self.is_debug:
self.logger.debug("进行中消息惩罚归0")
@@ -191,34 +195,33 @@ class MxpWillingManager(BaseWillingManager):
basic_willing + (willing - basic_willing) * self.intention_decay_rate
)
def setup(self, message, chat, is_mentioned_bot, interested_rate):
super().setup(message, chat, is_mentioned_bot, interested_rate)
self.chat_reply_willing[chat.stream_id] = self.chat_reply_willing.get(
chat.stream_id, self.basic_maximum_willing
)
self.chat_person_reply_willing[chat.stream_id] = self.chat_person_reply_willing.get(chat.stream_id, {})
self.chat_person_reply_willing[chat.stream_id][
self.ongoing_messages[message.message_info.message_id].person_id
] = self.chat_person_reply_willing[chat.stream_id].get(
self.ongoing_messages[message.message_info.message_id].person_id, self.chat_reply_willing[chat.stream_id]
def setup(self, message: dict, chat_stream: ChatStream):
super().setup(message, chat_stream)
stream_id = chat_stream.stream_id
self.chat_reply_willing[stream_id] = self.chat_reply_willing.get(stream_id, self.basic_maximum_willing)
self.chat_person_reply_willing[stream_id] = self.chat_person_reply_willing.get(stream_id, {})
self.chat_person_reply_willing[stream_id][self.ongoing_messages[message.get("message_id", "")].person_id] = (
self.chat_person_reply_willing[stream_id].get(
self.ongoing_messages[message.get("message_id", "")].person_id,
self.chat_reply_willing[stream_id],
)
)
current_time = time.time()
if chat.stream_id not in self.chat_new_message_time:
self.chat_new_message_time[chat.stream_id] = []
self.chat_new_message_time[chat.stream_id].append(current_time)
if len(self.chat_new_message_time[chat.stream_id]) > self.number_of_message_storage:
self.chat_new_message_time[chat.stream_id].pop(0)
if stream_id not in self.chat_new_message_time:
self.chat_new_message_time[stream_id] = []
self.chat_new_message_time[stream_id].append(current_time)
if len(self.chat_new_message_time[stream_id]) > self.number_of_message_storage:
self.chat_new_message_time[stream_id].pop(0)
if chat.stream_id not in self.chat_fatigue_punishment_list:
self.chat_fatigue_punishment_list[chat.stream_id] = [
if stream_id not in self.chat_fatigue_punishment_list:
self.chat_fatigue_punishment_list[stream_id] = [
(
current_time,
self.number_of_message_storage * self.basic_maximum_willing / self.expected_replies_per_min * 60,
)
]
self.chat_fatigue_willing_attenuation[chat.stream_id] = (
self.chat_fatigue_willing_attenuation[stream_id] = (
-2 * self.basic_maximum_willing * self.fatigue_coefficient
)
@@ -227,12 +230,11 @@ class MxpWillingManager(BaseWillingManager):
"""意愿值转化为概率"""
willing = max(0, willing)
if willing < 2:
probability = math.atan(willing * 2) / math.pi * 2
return math.atan(willing * 2) / math.pi * 2
elif willing < 2.5:
probability = math.atan(willing * 4) / math.pi * 2
return math.atan(willing * 4) / math.pi * 2
else:
probability = 1
return probability
return 1
async def _chat_new_message_to_change_basic_willing(self):
"""聊天流新消息改变基础意愿"""
@@ -259,7 +261,7 @@ class MxpWillingManager(BaseWillingManager):
update_time = 20
elif len(message_times) == self.number_of_message_storage:
time_interval = current_time - message_times[0]
basic_willing = self._basic_willing_culculate(time_interval)
basic_willing = self._basic_willing_calculate(time_interval)
self.chat_reply_willing[chat_id] = basic_willing
update_time = 17 * basic_willing / self.basic_maximum_willing + 3
else:
@@ -268,7 +270,7 @@ class MxpWillingManager(BaseWillingManager):
if self.is_debug:
self.logger.debug(f"聊天流意愿值更新:{self.chat_reply_willing}")
def _basic_willing_culculate(self, t: float) -> float:
def _basic_willing_calculate(self, t: float) -> float:
"""基础意愿值计算"""
return math.tan(t * self.expected_replies_per_min * math.pi / 120 / self.number_of_message_storage) / 2

View File

@@ -104,7 +104,7 @@ class BaseWillingManager(ABC):
is_mentioned_bot=message.get("is_mentioned", False),
is_emoji=message.get("is_emoji", False),
is_picid=message.get("is_picid", False),
interested_rate=message.get("interest_value", 0),
interested_rate = message.get("interest_value") or 0.0,
)
def delete(self, message_id: str):