全面更换orjson

This commit is contained in:
雅诺狐
2025-08-26 14:20:26 +08:00
committed by Windpicker-owo
parent 9f514d8799
commit ab3a36bfa7
44 changed files with 1163 additions and 1379 deletions

View File

@@ -1,6 +1,6 @@
import hashlib
import asyncio
import json
import orjson
import time
from json_repair import repair_json
@@ -483,9 +483,9 @@ class PersonInfoManager:
for key in JSON_SERIALIZED_FIELDS:
if key in final_data:
if isinstance(final_data[key], (list, dict)):
final_data[key] = json.dumps(final_data[key], ensure_ascii=False)
final_data[key] = orjson.dumps(final_data[key]).decode('utf-8')
elif final_data[key] is None: # Default for lists is [], store as "[]"
final_data[key] = json.dumps([], ensure_ascii=False)
final_data[key] = orjson.dumps([]).decode('utf-8')
# If it's already a string, assume it's valid JSON or a non-JSON string field
def _db_create_sync(p_data: dict):
@@ -532,9 +532,9 @@ class PersonInfoManager:
for key in JSON_SERIALIZED_FIELDS:
if key in final_data:
if isinstance(final_data[key], (list, dict)):
final_data[key] = json.dumps(final_data[key], ensure_ascii=False)
final_data[key] = orjson.dumps(final_data[key]).decode('utf-8')
elif final_data[key] is None: # Default for lists is [], store as "[]"
final_data[key] = json.dumps([], ensure_ascii=False)
final_data[key] = orjson.dumps([]).decode('utf-8')
def _db_safe_create_sync(p_data: dict):
with get_db_session() as session:
@@ -571,9 +571,9 @@ class PersonInfoManager:
processed_value = value
if field_name in JSON_SERIALIZED_FIELDS:
if isinstance(value, (list, dict)):
processed_value = json.dumps(value, ensure_ascii=False, indent=None)
processed_value = orjson.dumps(value).decode('utf-8')
elif value is None: # Store None as "[]" for JSON list fields
processed_value = json.dumps([], ensure_ascii=False, indent=None)
processed_value = orjson.dumps([]).decode('utf-8')
def _db_update_sync(p_id: str, f_name: str, val_to_set):
@@ -652,7 +652,7 @@ class PersonInfoManager:
try:
fixed_json = repair_json(text)
if isinstance(fixed_json, str):
parsed_json = json.loads(fixed_json)
parsed_json = orjson.loads(fixed_json)
else:
parsed_json = fixed_json
@@ -826,8 +826,8 @@ class PersonInfoManager:
if f_name in JSON_SERIALIZED_FIELDS:
if isinstance(val, str):
try:
return json.loads(val)
except json.JSONDecodeError:
return orjson.loads(val)
except orjson.JSONDecodeError:
logger.warning(f"字段 {f_name} for {p_id} 包含无效JSON: {val}. 返回默认值.")
return [] # Default for JSON fields on error
elif val is None: # Field exists in DB but is None
@@ -863,8 +863,8 @@ class PersonInfoManager:
if field_name in JSON_SERIALIZED_FIELDS:
if isinstance(val, str):
try:
return json.loads(val)
except json.JSONDecodeError:
return orjson.loads(val)
except orjson.JSONDecodeError:
logger.warning(f"字段 {field_name} for {person_id} 包含无效JSON: {val}. 返回默认值.")
return []
elif val is None:
@@ -1003,9 +1003,9 @@ class PersonInfoManager:
for key in JSON_SERIALIZED_FIELDS:
if key in initial_data:
if isinstance(initial_data[key], (list, dict)):
initial_data[key] = json.dumps(initial_data[key], ensure_ascii=False)
initial_data[key] = orjson.dumps(initial_data[key]).decode('utf-8')
elif initial_data[key] is None:
initial_data[key] = json.dumps([], ensure_ascii=False)
initial_data[key] = orjson.dumps([]).decode('utf-8')
# 获取 SQLAlchemy 模odel的所有字段名
model_fields = [column.name for column in PersonInfo.__table__.columns]

View File

@@ -4,7 +4,7 @@ import random
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config, model_config
from src.chat.utils.chat_message_builder import build_readable_messages
import json
import orjson
from json_repair import repair_json
from datetime import datetime
from typing import List, Dict, Any
@@ -226,7 +226,7 @@ class RelationshipManager:
# 解析JSON并转换为元组列表
try:
points = repair_json(points)
points_data = json.loads(points)
points_data = orjson.loads(points)
# 只处理正确的格式,错误格式直接跳过
if not points_data or (isinstance(points_data, list) and len(points_data) == 0):
@@ -261,95 +261,127 @@ class RelationshipManager:
logger_str += f"({discarded_count} 条因重要性低被丢弃)\n"
logger.info(logger_str)
except Exception as e:
except orjson.JSONDecodeError:
logger.error(f"解析points JSON失败: {points}")
return
except (KeyError, TypeError) as e:
logger.error(f"处理points数据失败: {e}, points: {points}")
logger.error(traceback.format_exc())
return
person.points.extend(points_list)
current_points = await person_info_manager.get_value(person_id, "points") or []
if isinstance(current_points, str):
try:
current_points = orjson.loads(current_points)
except orjson.JSONDecodeError:
logger.error(f"解析points JSON失败: {current_points}")
current_points = []
elif not isinstance(current_points, list):
current_points = []
current_points.extend(points_list)
await person_info_manager.update_one_field(
person_id, "points", orjson.dumps(current_points).decode('utf-8')
)
# 将新记录添加到现有记录中
if isinstance(current_points, list):
# 只对新添加的points进行相似度检查和合并
for new_point in points_list:
similar_points = []
similar_indices = []
# 在现有points中查找相似的点
for i, existing_point in enumerate(current_points):
# 使用组合的相似度检查方法
if self.check_similarity(new_point[0], existing_point[0]):
similar_points.append(existing_point)
similar_indices.append(i)
if similar_points:
# 合并相似的点
all_points = [new_point] + similar_points
# 使用最新的时间
latest_time = max(p[2] for p in all_points)
# 合并权重
total_weight = sum(p[1] for p in all_points)
# 使用最长的描述
longest_desc = max(all_points, key=lambda x: len(x[0]))[0]
# 创建合并后的点
merged_point = (longest_desc, total_weight, latest_time)
# 从现有points中移除已合并的点
for idx in sorted(similar_indices, reverse=True):
current_points.pop(idx)
# 添加合并后的点
current_points.append(merged_point)
else:
# 如果没有相似的点,直接添加
current_points.append(new_point)
else:
current_points = points_list
# 如果points超过10条按权重随机选择多余的条目移动到forgotten_points
if len(person.points) > 20:
# 计算当前时间
current_time = datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d %H:%M:%S")
# 计算每个点的最终权重(原始权重 * 时间权重)
weighted_points = []
for point in person.points:
time_weight = self.calculate_time_weight(point[2], current_time)
final_weight = point[1] * time_weight
weighted_points.append((point, final_weight))
# 计算总权重
total_weight = sum(w for _, w in weighted_points)
# 按权重随机选择要保留的点
remaining_points = []
# 对每个点进行随机选择
for point, weight in weighted_points:
# 计算保留概率(权重越高越可能保留)
keep_probability = weight / total_weight
if len(remaining_points) < 20:
# 如果还没达到30条直接保留
remaining_points.append(point)
elif random.random() < keep_probability:
# 保留这个点,随机移除一个已保留的点
idx_to_remove = random.randrange(len(remaining_points))
remaining_points[idx_to_remove] = point
person.points = remaining_points
return person
async def get_attitude_to_me(self, readable_messages, timestamp, person: Person):
alias_str = ", ".join(global_config.bot.alias_names)
current_time = datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d %H:%M:%S")
# 解析当前态度值
current_attitude_score = person.attitude_to_me
total_confidence = person.attitude_to_me_confidence
prompt = await global_prompt_manager.format_prompt(
"attitude_to_me_prompt",
bot_name = global_config.bot.nickname,
alias_str = alias_str,
person_name = person.person_name,
nickname = person.nickname,
readable_messages = readable_messages,
current_time = current_time,
# 更新数据库
await person_info_manager.update_one_field(
person_id, "points", orjson.dumps(current_points).decode('utf-8')
)
attitude, _ = await self.relationship_llm.generate_response_async(prompt=prompt)
await person_info_manager.update_one_field(person_id, "know_times", know_times + 1)
know_since = await person_info_manager.get_value(person_id, "know_since") or 0
if know_since == 0:
await person_info_manager.update_one_field(person_id, "know_since", timestamp)
await person_info_manager.update_one_field(person_id, "last_know", timestamp)
logger.info(f"prompt: {prompt}")
logger.info(f"attitude: {attitude}")
logger.debug(f"{person_name} 的印象更新完成")
async def _update_impression(self, person_id, current_points, timestamp):
# 获取现有forgotten_points
person_info_manager = get_person_info_manager()
attitude = repair_json(attitude)
attitude_data = json.loads(attitude)
if not attitude_data or (isinstance(attitude_data, list) and len(attitude_data) == 0):
return ""
# 确保 attitude_data 是字典格式
if not isinstance(attitude_data, dict):
logger.warning(f"LLM返回了错误的JSON格式跳过解析: {type(attitude_data)}, 内容: {attitude_data}")
return ""
attitude_score = attitude_data["attitude"]
confidence = attitude_data["confidence"]
new_confidence = total_confidence + confidence
new_attitude_score = (current_attitude_score * total_confidence + attitude_score * confidence)/new_confidence
person.attitude_to_me = new_attitude_score
person.attitude_to_me_confidence = new_confidence
return person
async def get_neuroticism(self, readable_messages, timestamp, person: Person):
alias_str = ", ".join(global_config.bot.alias_names)
person_name = await person_info_manager.get_value(person_id, "person_name")
nickname = await person_info_manager.get_value(person_id, "nickname")
know_times: float = await person_info_manager.get_value(person_id, "know_times") or 0 # type: ignore
attitude: float = await person_info_manager.get_value(person_id, "attitude") or 50 # type: ignore
# 根据熟悉度,调整印象和简短印象的最大长度
if know_times > 300:
max_impression_length = 2000
max_short_impression_length = 400
elif know_times > 100:
max_impression_length = 1000
max_short_impression_length = 250
elif know_times > 50:
max_impression_length = 500
max_short_impression_length = 150
elif know_times > 10:
max_impression_length = 200
max_short_impression_length = 60
else:
max_impression_length = 100
max_short_impression_length = 30
# 根据好感度,调整印象和简短印象的最大长度
attitude_multiplier = (abs(100 - attitude) / 100) + 1
max_impression_length = max_impression_length * attitude_multiplier
max_short_impression_length = max_short_impression_length * attitude_multiplier
forgotten_points = await person_info_manager.get_value(person_id, "forgotten_points") or []
if isinstance(forgotten_points, str):
try:
forgotten_points = orjson.loads(forgotten_points)
except orjson.JSONDecodeError:
logger.error(f"解析forgotten_points JSON失败: {forgotten_points}")
forgotten_points = []
elif not isinstance(forgotten_points, list):
forgotten_points = []
# 计算当前时间
current_time = datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d %H:%M:%S")
# 解析当前态度值
current_neuroticism_score = person.neuroticism
@@ -444,16 +476,48 @@ class RelationshipManager:
name_mapping[msg_person.person_name] = f"{person_name}"
continue
# 其他用户映射
if msg_person.person_name not in name_mapping and msg_person.person_name is not None:
if current_user > "Z":
current_user = "A"
user_count += 1
name_mapping[msg_person.person_name] = f"用户{current_user}{user_count if user_count > 1 else ''}"
current_user = chr(ord(current_user) + 1)
请严格按照json格式输出不要有其他多余内容
{{
"attitude": <0-100之间的整数>,
}}
"""
try:
relation_value_response, _ = await self.relationship_llm.generate_response_async(
prompt=relation_value_prompt
)
relation_value_json = orjson.loads(repair_json(relation_value_response))
readable_messages = build_readable_messages(
messages=user_messages, replace_bot_name=True, timestamp_mode="normal_no_YMD", truncate=True
# 从LLM获取新生成的值
new_attitude = int(relation_value_json.get("attitude", 50))
# 获取当前的关系值
old_attitude: float = await person_info_manager.get_value(person_id, "attitude") or 50 # type: ignore
# 更新熟悉度
if new_attitude > 25:
attitude = old_attitude + (new_attitude - 25) / 75
else:
attitude = old_attitude
# 更新好感度
if new_attitude > 50:
attitude += (new_attitude - 50) / 50
elif new_attitude < 50:
attitude -= (50 - new_attitude) / 50 * 1.5
await person_info_manager.update_one_field(person_id, "attitude", attitude)
logger.info(f"更新了与 {person_name} 的态度: {attitude}")
except (orjson.JSONDecodeError, ValueError, TypeError) as e:
logger.error(f"解析relation_value JSON失败或值无效: {e}, 响应: {relation_value_response}")
forgotten_points = []
info_list = []
await person_info_manager.update_one_field(
person_id, "info_list", orjson.dumps(info_list).decode('utf-8')
)
await person_info_manager.update_one_field(
person_id, "forgotten_points", orjson.dumps(forgotten_points).decode('utf-8')
)
for original_name, mapped_name in name_mapping.items():