复活吧!我的关系系统,新增正反馈机制
This commit is contained in:
@@ -1,12 +1,9 @@
|
|||||||
import asyncio
|
|
||||||
from typing import Optional
|
|
||||||
from src.common.logger import get_module_logger, LogConfig, RELATION_STYLE_CONFIG
|
from src.common.logger import get_module_logger, LogConfig, RELATION_STYLE_CONFIG
|
||||||
|
|
||||||
from ...common.database import db
|
|
||||||
from ..message.message_base import UserInfo
|
|
||||||
from .chat_stream import ChatStream
|
from .chat_stream import ChatStream
|
||||||
import math
|
import math
|
||||||
from bson.decimal128 import Decimal128
|
from bson.decimal128 import Decimal128
|
||||||
|
from .person_info import person_info_manager
|
||||||
|
import time
|
||||||
|
|
||||||
relationship_config = LogConfig(
|
relationship_config = LogConfig(
|
||||||
# 使用关系专用样式
|
# 使用关系专用样式
|
||||||
@@ -15,265 +12,61 @@ relationship_config = LogConfig(
|
|||||||
)
|
)
|
||||||
logger = get_module_logger("rel_manager", config=relationship_config)
|
logger = get_module_logger("rel_manager", config=relationship_config)
|
||||||
|
|
||||||
|
|
||||||
class Impression:
|
|
||||||
traits: str = None
|
|
||||||
called: str = None
|
|
||||||
know_time: float = None
|
|
||||||
|
|
||||||
relationship_value: float = None
|
|
||||||
|
|
||||||
|
|
||||||
class Relationship:
|
|
||||||
user_id: int = None
|
|
||||||
platform: str = None
|
|
||||||
gender: str = None
|
|
||||||
age: int = None
|
|
||||||
nickname: str = None
|
|
||||||
relationship_value: float = None
|
|
||||||
saved = False
|
|
||||||
|
|
||||||
def __init__(self, chat: ChatStream = None, data: dict = None):
|
|
||||||
self.user_id = chat.user_info.user_id if chat else data.get("user_id", 0)
|
|
||||||
self.platform = chat.platform if chat else data.get("platform", "")
|
|
||||||
self.nickname = chat.user_info.user_nickname if chat else data.get("nickname", "")
|
|
||||||
self.relationship_value = data.get("relationship_value", 0) if data else 0
|
|
||||||
self.age = data.get("age", 0) if data else 0
|
|
||||||
self.gender = data.get("gender", "") if data else ""
|
|
||||||
|
|
||||||
|
|
||||||
class RelationshipManager:
|
class RelationshipManager:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.relationships: dict[tuple[int, str], Relationship] = {} # 修改为使用(user_id, platform)作为键
|
self.positive_feedback_dict = {} # 正反馈系统
|
||||||
|
|
||||||
async def update_relationship(self, chat_stream: ChatStream, data: dict = None, **kwargs) -> Optional[Relationship]:
|
def positive_feedback_sys(self, person_id, value, label: str, stance: str):
|
||||||
"""更新或创建关系
|
"""正反馈系统"""
|
||||||
Args:
|
|
||||||
chat_stream: 聊天流对象
|
positive_list = [
|
||||||
data: 字典格式的数据(可选)
|
"开心",
|
||||||
**kwargs: 其他参数
|
"惊讶",
|
||||||
Returns:
|
"害羞",
|
||||||
Relationship: 关系对象
|
"困惑",
|
||||||
"""
|
]
|
||||||
# 确定user_id和platform
|
|
||||||
if chat_stream.user_info is not None:
|
negative_list = [
|
||||||
user_id = chat_stream.user_info.user_id
|
"愤怒",
|
||||||
platform = chat_stream.user_info.platform or "qq"
|
"悲伤",
|
||||||
|
"恐惧",
|
||||||
|
"厌恶",
|
||||||
|
]
|
||||||
|
|
||||||
|
if person_id not in self.positive_feedback_dict:
|
||||||
|
self.positive_feedback_dict[person_id] = 0
|
||||||
|
|
||||||
|
if label in positive_list and stance != "反对":
|
||||||
|
if 6 > self.positive_feedback_dict[person_id] >= 0:
|
||||||
|
self.positive_feedback_dict[person_id] += 1
|
||||||
|
elif self.positive_feedback_dict[person_id] < 0:
|
||||||
|
self.positive_feedback_dict[person_id] = 0
|
||||||
|
return value
|
||||||
|
elif label in negative_list and stance != "支持":
|
||||||
|
if -6 < self.positive_feedback_dict[person_id] <= 0:
|
||||||
|
self.positive_feedback_dict[person_id] -= 1
|
||||||
|
elif self.positive_feedback_dict[person_id] > 0:
|
||||||
|
self.positive_feedback_dict[person_id] = 0
|
||||||
|
return value
|
||||||
else:
|
else:
|
||||||
platform = platform or "qq"
|
return value
|
||||||
|
|
||||||
if user_id is None:
|
gain_coefficient = [1.1, 1.2, 1.4, 1.7, 1.9, 2.0]
|
||||||
raise ValueError("必须提供user_id或user_info")
|
value *= gain_coefficient[abs(self.positive_feedback_dict[person_id])-1]
|
||||||
|
logger.info(f"触发增益,当前增益系数:{gain_coefficient[abs(self.positive_feedback_dict[person_id])-1]}")
|
||||||
|
|
||||||
# 使用(user_id, platform)作为键
|
return value
|
||||||
key = (user_id, platform)
|
|
||||||
|
|
||||||
# 检查是否在内存中已存在
|
|
||||||
relationship = self.relationships.get(key)
|
|
||||||
if relationship:
|
|
||||||
# 如果存在,更新现有对象
|
|
||||||
if isinstance(data, dict):
|
|
||||||
for k, value in data.items():
|
|
||||||
if hasattr(relationship, k) and value is not None:
|
|
||||||
setattr(relationship, k, value)
|
|
||||||
else:
|
|
||||||
# 如果不存在,创建新对象
|
|
||||||
if chat_stream.user_info is not None:
|
|
||||||
relationship = Relationship(chat=chat_stream, **kwargs)
|
|
||||||
else:
|
|
||||||
raise ValueError("必须提供user_id或user_info")
|
|
||||||
self.relationships[key] = relationship
|
|
||||||
|
|
||||||
# 保存到数据库
|
|
||||||
await self.storage_relationship(relationship)
|
|
||||||
relationship.saved = True
|
|
||||||
|
|
||||||
return relationship
|
|
||||||
|
|
||||||
async def update_relationship_value(self, chat_stream: ChatStream, **kwargs) -> Optional[Relationship]:
|
|
||||||
"""更新关系值
|
|
||||||
Args:
|
|
||||||
user_id: 用户ID(可选,如果提供user_info则不需要)
|
|
||||||
platform: 平台(可选,如果提供user_info则不需要)
|
|
||||||
user_info: 用户信息对象(可选)
|
|
||||||
**kwargs: 其他参数
|
|
||||||
Returns:
|
|
||||||
Relationship: 关系对象
|
|
||||||
"""
|
|
||||||
# 确定user_id和platform
|
|
||||||
user_info = chat_stream.user_info
|
|
||||||
if user_info is not None:
|
|
||||||
user_id = user_info.user_id
|
|
||||||
platform = user_info.platform or "qq"
|
|
||||||
else:
|
|
||||||
platform = platform or "qq"
|
|
||||||
|
|
||||||
if user_id is None:
|
|
||||||
raise ValueError("必须提供user_id或user_info")
|
|
||||||
|
|
||||||
# 使用(user_id, platform)作为键
|
|
||||||
key = (user_id, platform)
|
|
||||||
|
|
||||||
# 检查是否在内存中已存在
|
|
||||||
relationship = self.relationships.get(key)
|
|
||||||
if relationship:
|
|
||||||
for k, value in kwargs.items():
|
|
||||||
if k == "relationship_value":
|
|
||||||
# 检查relationship.relationship_value是否为double类型
|
|
||||||
if not isinstance(relationship.relationship_value, float):
|
|
||||||
try:
|
|
||||||
# 处理 Decimal128 类型
|
|
||||||
if isinstance(relationship.relationship_value, Decimal128):
|
|
||||||
relationship.relationship_value = float(relationship.relationship_value.to_decimal())
|
|
||||||
else:
|
|
||||||
relationship.relationship_value = float(relationship.relationship_value)
|
|
||||||
logger.info(
|
|
||||||
f"[关系管理] 用户 {user_id}({platform}) 的关系值已转换为double类型: {relationship.relationship_value}"
|
|
||||||
) # noqa: E501
|
|
||||||
except (ValueError, TypeError):
|
|
||||||
# 如果不能解析/强转则将relationship.relationship_value设置为double类型的0
|
|
||||||
relationship.relationship_value = 0.0
|
|
||||||
logger.warning(f"[关系管理] 用户 {user_id}({platform}) 的无法转换为double类型,已设置为0")
|
|
||||||
relationship.relationship_value += value
|
|
||||||
await self.storage_relationship(relationship)
|
|
||||||
relationship.saved = True
|
|
||||||
return relationship
|
|
||||||
else:
|
|
||||||
# 如果不存在且提供了user_info,则创建新的关系
|
|
||||||
if user_info is not None:
|
|
||||||
return await self.update_relationship(chat_stream=chat_stream, **kwargs)
|
|
||||||
logger.warning(f"[关系管理] 用户 {user_id}({platform}) 不存在,无法更新")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def get_relationship(self, chat_stream: ChatStream) -> Optional[Relationship]:
|
|
||||||
"""获取用户关系对象
|
|
||||||
Args:
|
|
||||||
user_id: 用户ID(可选,如果提供user_info则不需要)
|
|
||||||
platform: 平台(可选,如果提供user_info则不需要)
|
|
||||||
user_info: 用户信息对象(可选)
|
|
||||||
Returns:
|
|
||||||
Relationship: 关系对象
|
|
||||||
"""
|
|
||||||
# 确定user_id和platform
|
|
||||||
user_info = chat_stream.user_info
|
|
||||||
platform = chat_stream.user_info.platform or "qq"
|
|
||||||
if user_info is not None:
|
|
||||||
user_id = user_info.user_id
|
|
||||||
platform = user_info.platform or "qq"
|
|
||||||
else:
|
|
||||||
platform = platform or "qq"
|
|
||||||
|
|
||||||
if user_id is None:
|
|
||||||
raise ValueError("必须提供user_id或user_info")
|
|
||||||
|
|
||||||
key = (user_id, platform)
|
|
||||||
if key in self.relationships:
|
|
||||||
return self.relationships[key]
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def load_relationship(self, data: dict) -> Relationship:
|
|
||||||
"""从数据库加载或创建新的关系对象"""
|
|
||||||
# 确保data中有platform字段,如果没有则默认为'qq'
|
|
||||||
if "platform" not in data:
|
|
||||||
data["platform"] = "qq"
|
|
||||||
|
|
||||||
rela = Relationship(data=data)
|
|
||||||
rela.saved = True
|
|
||||||
key = (rela.user_id, rela.platform)
|
|
||||||
self.relationships[key] = rela
|
|
||||||
return rela
|
|
||||||
|
|
||||||
async def load_all_relationships(self):
|
|
||||||
"""加载所有关系对象"""
|
|
||||||
all_relationships = db.relationships.find({})
|
|
||||||
for data in all_relationships:
|
|
||||||
await self.load_relationship(data)
|
|
||||||
|
|
||||||
async def _start_relationship_manager(self):
|
|
||||||
"""每5分钟自动保存一次关系数据"""
|
|
||||||
# 获取所有关系记录
|
|
||||||
all_relationships = db.relationships.find({})
|
|
||||||
# 依次加载每条记录
|
|
||||||
for data in all_relationships:
|
|
||||||
await self.load_relationship(data)
|
|
||||||
logger.debug(f"[关系管理] 已加载 {len(self.relationships)} 条关系记录")
|
|
||||||
|
|
||||||
while True:
|
|
||||||
logger.debug("正在自动保存关系")
|
|
||||||
await asyncio.sleep(300) # 等待300秒(5分钟)
|
|
||||||
await self._save_all_relationships()
|
|
||||||
|
|
||||||
async def _save_all_relationships(self):
|
|
||||||
"""将所有关系数据保存到数据库"""
|
|
||||||
# 保存所有关系数据
|
|
||||||
for _, relationship in self.relationships.items():
|
|
||||||
if not relationship.saved:
|
|
||||||
relationship.saved = True
|
|
||||||
await self.storage_relationship(relationship)
|
|
||||||
|
|
||||||
async def storage_relationship(self, relationship: Relationship):
|
|
||||||
"""将关系记录存储到数据库中"""
|
|
||||||
user_id = relationship.user_id
|
|
||||||
platform = relationship.platform
|
|
||||||
nickname = relationship.nickname
|
|
||||||
relationship_value = relationship.relationship_value
|
|
||||||
gender = relationship.gender
|
|
||||||
age = relationship.age
|
|
||||||
saved = relationship.saved
|
|
||||||
|
|
||||||
db.relationships.update_one(
|
|
||||||
{"user_id": user_id, "platform": platform},
|
|
||||||
{
|
|
||||||
"$set": {
|
|
||||||
"platform": platform,
|
|
||||||
"nickname": nickname,
|
|
||||||
"relationship_value": relationship_value,
|
|
||||||
"gender": gender,
|
|
||||||
"age": age,
|
|
||||||
"saved": saved,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
upsert=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_name(self, user_id: int = None, platform: str = None, user_info: UserInfo = None) -> str:
|
|
||||||
"""获取用户昵称
|
|
||||||
Args:
|
|
||||||
user_id: 用户ID(可选,如果提供user_info则不需要)
|
|
||||||
platform: 平台(可选,如果提供user_info则不需要)
|
|
||||||
user_info: 用户信息对象(可选)
|
|
||||||
Returns:
|
|
||||||
str: 用户昵称
|
|
||||||
"""
|
|
||||||
# 确定user_id和platform
|
|
||||||
if user_info is not None:
|
|
||||||
user_id = user_info.user_id
|
|
||||||
platform = user_info.platform or "qq"
|
|
||||||
else:
|
|
||||||
platform = platform or "qq"
|
|
||||||
|
|
||||||
if user_id is None:
|
|
||||||
raise ValueError("必须提供user_id或user_info")
|
|
||||||
|
|
||||||
# 确保user_id是整数类型
|
|
||||||
user_id = int(user_id)
|
|
||||||
key = (user_id, platform)
|
|
||||||
if key in self.relationships:
|
|
||||||
return self.relationships[key].nickname
|
|
||||||
elif user_info is not None:
|
|
||||||
return user_info.user_nickname or user_info.user_cardname or "某人"
|
|
||||||
else:
|
|
||||||
return "某人"
|
|
||||||
|
|
||||||
async def calculate_update_relationship_value(self, chat_stream: ChatStream, label: str, stance: str) -> None:
|
async def calculate_update_relationship_value(self, chat_stream: ChatStream, label: str, stance: str) -> None:
|
||||||
"""计算变更关系值
|
"""计算并变更关系值
|
||||||
新的关系值变更计算方式:
|
新的关系值变更计算方式:
|
||||||
将关系值限定在-1000到1000
|
将关系值限定在-1000到1000
|
||||||
对于关系值的变更,期望:
|
对于关系值的变更,期望:
|
||||||
1.向两端逼近时会逐渐减缓
|
1.向两端逼近时会逐渐减缓
|
||||||
2.关系越差,改善越难,关系越好,恶化越容易
|
2.关系越差,改善越难,关系越好,恶化越容易
|
||||||
3.人维护关系的精力往往有限,所以当高关系值用户越多,对于中高关系值用户增长越慢
|
3.人维护关系的精力往往有限,所以当高关系值用户越多,对于中高关系值用户增长越慢
|
||||||
|
4.连续正面或负面情感会正反馈
|
||||||
"""
|
"""
|
||||||
stancedict = {
|
stancedict = {
|
||||||
"支持": 0,
|
"支持": 0,
|
||||||
@@ -283,19 +76,25 @@ class RelationshipManager:
|
|||||||
|
|
||||||
valuedict = {
|
valuedict = {
|
||||||
"开心": 1.5,
|
"开心": 1.5,
|
||||||
"愤怒": -3.5,
|
"愤怒": -2.0,
|
||||||
"悲伤": -1.5,
|
"悲伤": -0.5,
|
||||||
"惊讶": 0.6,
|
"惊讶": 0.6,
|
||||||
"害羞": 2.0,
|
"害羞": 2.0,
|
||||||
"平静": 0.3,
|
"平静": 0.3,
|
||||||
"恐惧": -2,
|
"恐惧": -1.5,
|
||||||
"厌恶": -2.5,
|
"厌恶": -1.0,
|
||||||
"困惑": 0.5,
|
"困惑": 0.5,
|
||||||
}
|
}
|
||||||
if self.get_relationship(chat_stream):
|
|
||||||
old_value = self.get_relationship(chat_stream).relationship_value
|
person_id = person_info_manager.get_person_id(chat_stream.user_info.platform, chat_stream.user_info.user_id)
|
||||||
else:
|
data = {
|
||||||
return
|
"platform" : chat_stream.user_info.platform,
|
||||||
|
"user_id" : chat_stream.user_info.user_id,
|
||||||
|
"nickname" : chat_stream.user_info.user_nickname,
|
||||||
|
"konw_time" : int(time.time())
|
||||||
|
}
|
||||||
|
old_value = await person_info_manager.get_value(person_id, "relationship_value")
|
||||||
|
old_value = self.ensure_float(old_value, person_id)
|
||||||
|
|
||||||
if old_value > 1000:
|
if old_value > 1000:
|
||||||
old_value = 1000
|
old_value = 1000
|
||||||
@@ -307,26 +106,26 @@ class RelationshipManager:
|
|||||||
if valuedict[label] >= 0 and stancedict[stance] != 2:
|
if valuedict[label] >= 0 and stancedict[stance] != 2:
|
||||||
value = value * math.cos(math.pi * old_value / 2000)
|
value = value * math.cos(math.pi * old_value / 2000)
|
||||||
if old_value > 500:
|
if old_value > 500:
|
||||||
high_value_count = 0
|
rdict = await person_info_manager.get_specific_value_list("relationship_value", lambda x: x > 700)
|
||||||
for _, relationship in self.relationships.items():
|
high_value_count = len(rdict)
|
||||||
if relationship.relationship_value >= 700:
|
if old_value > 700:
|
||||||
high_value_count += 1
|
|
||||||
if old_value >= 700:
|
|
||||||
value *= 3 / (high_value_count + 2) # 排除自己
|
value *= 3 / (high_value_count + 2) # 排除自己
|
||||||
else:
|
else:
|
||||||
value *= 3 / (high_value_count + 3)
|
value *= 3 / (high_value_count + 3)
|
||||||
elif valuedict[label] < 0 and stancedict[stance] != 0:
|
elif valuedict[label] < 0 and stancedict[stance] != 0:
|
||||||
value = value * math.exp(old_value / 1000)
|
value = value * math.exp(old_value / 2000)
|
||||||
else:
|
else:
|
||||||
value = 0
|
value = 0
|
||||||
elif old_value < 0:
|
elif old_value < 0:
|
||||||
if valuedict[label] >= 0 and stancedict[stance] != 2:
|
if valuedict[label] >= 0 and stancedict[stance] != 2:
|
||||||
value = value * math.exp(old_value / 1000)
|
value = value * math.exp(old_value / 2000)
|
||||||
elif valuedict[label] < 0 and stancedict[stance] != 0:
|
elif valuedict[label] < 0 and stancedict[stance] != 0:
|
||||||
value = value * math.cos(math.pi * old_value / 2000)
|
value = value * math.cos(math.pi * old_value / 2000)
|
||||||
else:
|
else:
|
||||||
value = 0
|
value = 0
|
||||||
|
|
||||||
|
value = self.positive_feedback_sys(person_id, value, label, stance)
|
||||||
|
|
||||||
level_num = self.calculate_level_num(old_value + value)
|
level_num = self.calculate_level_num(old_value + value)
|
||||||
relationship_level = ["厌恶", "冷漠", "一般", "友好", "喜欢", "暧昧"]
|
relationship_level = ["厌恶", "冷漠", "一般", "友好", "喜欢", "暧昧"]
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -336,10 +135,11 @@ class RelationshipManager:
|
|||||||
f"变更: {value:+.5f}"
|
f"变更: {value:+.5f}"
|
||||||
)
|
)
|
||||||
|
|
||||||
await self.update_relationship_value(chat_stream=chat_stream, relationship_value=value)
|
await person_info_manager.update_one_field(person_id, "relationship_value", old_value + value, data)
|
||||||
|
|
||||||
def build_relationship_info(self, person) -> str:
|
def build_relationship_info(self, person) -> str:
|
||||||
relationship_value = relationship_manager.get_relationship(person).relationship_value
|
person_id = person_info_manager.get_person_id(person.user_info.platform, person.user_info.user_id)
|
||||||
|
relationship_value = person_info_manager.get_value(person_id, "relationship_value")
|
||||||
level_num = self.calculate_level_num(relationship_value)
|
level_num = self.calculate_level_num(relationship_value)
|
||||||
relationship_level = ["厌恶", "冷漠", "一般", "友好", "喜欢", "暧昧"]
|
relationship_level = ["厌恶", "冷漠", "一般", "友好", "喜欢", "暧昧"]
|
||||||
relation_prompt2_list = [
|
relation_prompt2_list = [
|
||||||
@@ -379,5 +179,14 @@ class RelationshipManager:
|
|||||||
level_num = 5 if relationship_value > 1000 else 0
|
level_num = 5 if relationship_value > 1000 else 0
|
||||||
return level_num
|
return level_num
|
||||||
|
|
||||||
|
def ensure_float(elsf, value, person_id):
|
||||||
|
"""确保返回浮点数,转换失败返回0.0"""
|
||||||
|
if isinstance(value, float):
|
||||||
|
return value
|
||||||
|
try:
|
||||||
|
return float(value.to_decimal() if isinstance(value, Decimal128) else value)
|
||||||
|
except (ValueError, TypeError, AttributeError):
|
||||||
|
logger.warning(f"[关系管理] {person_id}值转换失败(原始值:{value}),已重置为0")
|
||||||
|
return 0.0
|
||||||
|
|
||||||
relationship_manager = RelationshipManager()
|
relationship_manager = RelationshipManager()
|
||||||
|
|||||||
Reference in New Issue
Block a user