修复代码格式和文件名大小写问题
This commit is contained in:
@@ -73,11 +73,11 @@ class PersonInfoManager:
|
||||
|
||||
# # 初始化时读取所有person_name
|
||||
try:
|
||||
# 在这里获取会话
|
||||
# 在这里获取会话
|
||||
with get_db_session() as session:
|
||||
for record in session.execute(select(PersonInfo.person_id, PersonInfo.person_name).where(
|
||||
PersonInfo.person_name.is_not(None)
|
||||
)).fetchall():
|
||||
for record in session.execute(
|
||||
select(PersonInfo.person_id, PersonInfo.person_name).where(PersonInfo.person_name.is_not(None))
|
||||
).fetchall():
|
||||
if record.person_name:
|
||||
self.person_name_list[record.person_id] = record.person_name
|
||||
logger.debug(f"已加载 {len(self.person_name_list)} 个用户名称 (SQLAlchemy)")
|
||||
@@ -90,7 +90,7 @@ class PersonInfoManager:
|
||||
# 检查platform是否为None或空
|
||||
if platform is None:
|
||||
platform = "unknown"
|
||||
|
||||
|
||||
if "-" in platform:
|
||||
platform = platform.split("-")[1]
|
||||
|
||||
@@ -103,7 +103,7 @@ class PersonInfoManager:
|
||||
person_id = self.get_person_id(platform, user_id)
|
||||
|
||||
def _db_check_known_sync(p_id: str):
|
||||
# 在需要时获取会话
|
||||
# 在需要时获取会话
|
||||
with get_db_session() as session:
|
||||
return session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar() is not None
|
||||
|
||||
@@ -116,7 +116,7 @@ class PersonInfoManager:
|
||||
def get_person_id_by_person_name(self, person_name: str) -> str:
|
||||
"""根据用户名获取用户ID"""
|
||||
try:
|
||||
# 在需要时获取会话
|
||||
# 在需要时获取会话
|
||||
with get_db_session() as session:
|
||||
record = session.execute(select(PersonInfo).where(PersonInfo.person_name == person_name)).scalar()
|
||||
return record.person_id if record else ""
|
||||
@@ -155,9 +155,9 @@ class PersonInfoManager:
|
||||
for key in JSON_SERIALIZED_FIELDS:
|
||||
if key in final_data:
|
||||
if isinstance(final_data[key], (list, dict)):
|
||||
final_data[key] = orjson.dumps(final_data[key]).decode('utf-8')
|
||||
final_data[key] = orjson.dumps(final_data[key]).decode("utf-8")
|
||||
elif final_data[key] is None: # Default for lists is [], store as "[]"
|
||||
final_data[key] = orjson.dumps([]).decode('utf-8')
|
||||
final_data[key] = orjson.dumps([]).decode("utf-8")
|
||||
# If it's already a string, assume it's valid JSON or a non-JSON string field
|
||||
|
||||
def _db_create_sync(p_data: dict):
|
||||
@@ -166,7 +166,7 @@ class PersonInfoManager:
|
||||
new_person = PersonInfo(**p_data)
|
||||
session.add(new_person)
|
||||
session.commit()
|
||||
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"创建 PersonInfo 记录 {p_data.get('person_id')} 失败 (SQLAlchemy): {e}")
|
||||
@@ -204,14 +204,16 @@ class PersonInfoManager:
|
||||
for key in JSON_SERIALIZED_FIELDS:
|
||||
if key in final_data:
|
||||
if isinstance(final_data[key], (list, dict)):
|
||||
final_data[key] = orjson.dumps(final_data[key]).decode('utf-8')
|
||||
final_data[key] = orjson.dumps(final_data[key]).decode("utf-8")
|
||||
elif final_data[key] is None: # Default for lists is [], store as "[]"
|
||||
final_data[key] = orjson.dumps([]).decode('utf-8')
|
||||
final_data[key] = orjson.dumps([]).decode("utf-8")
|
||||
|
||||
def _db_safe_create_sync(p_data: dict):
|
||||
with get_db_session() as session:
|
||||
try:
|
||||
existing = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_data["person_id"])).scalar()
|
||||
existing = session.execute(
|
||||
select(PersonInfo).where(PersonInfo.person_id == p_data["person_id"])
|
||||
).scalar()
|
||||
if existing:
|
||||
logger.debug(f"用户 {p_data['person_id']} 已存在,跳过创建")
|
||||
return True
|
||||
@@ -220,7 +222,7 @@ class PersonInfoManager:
|
||||
new_person = PersonInfo(**p_data)
|
||||
session.add(new_person)
|
||||
session.commit()
|
||||
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
if "UNIQUE constraint failed" in str(e):
|
||||
@@ -243,12 +245,11 @@ class PersonInfoManager:
|
||||
processed_value = value
|
||||
if field_name in JSON_SERIALIZED_FIELDS:
|
||||
if isinstance(value, (list, dict)):
|
||||
processed_value = orjson.dumps(value).decode('utf-8')
|
||||
processed_value = orjson.dumps(value).decode("utf-8")
|
||||
elif value is None: # Store None as "[]" for JSON list fields
|
||||
processed_value = orjson.dumps([]).decode('utf-8')
|
||||
processed_value = orjson.dumps([]).decode("utf-8")
|
||||
|
||||
def _db_update_sync(p_id: str, f_name: str, val_to_set):
|
||||
|
||||
start_time = time.time()
|
||||
with get_db_session() as session:
|
||||
try:
|
||||
@@ -257,7 +258,7 @@ class PersonInfoManager:
|
||||
|
||||
if record:
|
||||
setattr(record, f_name, val_to_set)
|
||||
|
||||
|
||||
save_time = time.time()
|
||||
|
||||
total_time = save_time - start_time
|
||||
@@ -420,13 +421,15 @@ class PersonInfoManager:
|
||||
|
||||
def _db_check_name_exists_sync(name_to_check):
|
||||
with get_db_session() as session:
|
||||
return session.execute(select(PersonInfo).where(PersonInfo.person_name == name_to_check)).scalar() is not None
|
||||
return (
|
||||
session.execute(select(PersonInfo).where(PersonInfo.person_name == name_to_check)).scalar()
|
||||
is not None
|
||||
)
|
||||
|
||||
if await asyncio.to_thread(_db_check_name_exists_sync, generated_nickname):
|
||||
is_duplicate = True
|
||||
current_name_set.add(generated_nickname)
|
||||
|
||||
|
||||
if not is_duplicate:
|
||||
await self.update_one_field(person_id, "person_name", generated_nickname)
|
||||
await self.update_one_field(person_id, "name_reason", result.get("reason", "未提供理由"))
|
||||
@@ -607,7 +610,9 @@ class PersonInfoManager:
|
||||
if way(value):
|
||||
found_results[record.person_id] = value
|
||||
except Exception as e_query:
|
||||
logger.error(f"数据库查询失败 (SQLAlchemy specific_value_list for {f_name}): {str(e_query)}", exc_info=True)
|
||||
logger.error(
|
||||
f"数据库查询失败 (SQLAlchemy specific_value_list for {f_name}): {str(e_query)}", exc_info=True
|
||||
)
|
||||
return found_results
|
||||
|
||||
try:
|
||||
@@ -639,8 +644,10 @@ class PersonInfoManager:
|
||||
new_person = PersonInfo(**init_data)
|
||||
session.add(new_person)
|
||||
session.commit()
|
||||
|
||||
return session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar(), True # 创建成功
|
||||
|
||||
return session.execute(
|
||||
select(PersonInfo).where(PersonInfo.person_id == p_id)
|
||||
).scalar(), True # 创建成功
|
||||
except Exception as e:
|
||||
# 如果创建失败(可能是因为竞态条件),再次尝试获取
|
||||
if "UNIQUE constraint failed" in str(e):
|
||||
@@ -671,9 +678,9 @@ class PersonInfoManager:
|
||||
for key in JSON_SERIALIZED_FIELDS:
|
||||
if key in initial_data:
|
||||
if isinstance(initial_data[key], (list, dict)):
|
||||
initial_data[key] = orjson.dumps(initial_data[key]).decode('utf-8')
|
||||
initial_data[key] = orjson.dumps(initial_data[key]).decode("utf-8")
|
||||
elif initial_data[key] is None:
|
||||
initial_data[key] = orjson.dumps([]).decode('utf-8')
|
||||
initial_data[key] = orjson.dumps([]).decode("utf-8")
|
||||
|
||||
# 获取 SQLAlchemy 模odel的所有字段名
|
||||
model_fields = [column.name for column in PersonInfo.__table__.columns]
|
||||
@@ -732,11 +739,7 @@ class PersonInfoManager:
|
||||
]
|
||||
# 获取 SQLAlchemy 模型的所有字段名
|
||||
model_fields = [column.name for column in PersonInfo.__table__.columns]
|
||||
valid_fields_to_get = [
|
||||
f
|
||||
for f in required_fields
|
||||
if f in model_fields or f in person_info_default
|
||||
]
|
||||
valid_fields_to_get = [f for f in required_fields if f in model_fields or f in person_info_default]
|
||||
|
||||
person_data = await self.get_values(found_person_id, valid_fields_to_get)
|
||||
|
||||
|
||||
@@ -349,13 +349,12 @@ class RelationshipBuilder:
|
||||
# 统筹各模块协作、对外提供服务接口
|
||||
# ================================
|
||||
|
||||
async def build_relation(self,immediate_build: str = "",max_build_threshold: int = MAX_MESSAGE_COUNT):
|
||||
async def build_relation(self, immediate_build: str = "", max_build_threshold: int = MAX_MESSAGE_COUNT):
|
||||
"""构建关系
|
||||
immediate_build: 立即构建关系,可选值为"all"或person_id
|
||||
"""
|
||||
self._cleanup_old_segments()
|
||||
current_time = time.time()
|
||||
|
||||
|
||||
if latest_messages := get_raw_msg_by_timestamp_with_chat(
|
||||
self.chat_id,
|
||||
@@ -387,8 +386,10 @@ class RelationshipBuilder:
|
||||
for person_id, segments in self.person_engaged_cache.items():
|
||||
total_message_count = self._get_total_message_count(person_id)
|
||||
person_name = get_person_info_manager().get_value_sync(person_id, "person_name") or person_id
|
||||
|
||||
if total_message_count >= max_build_threshold or (total_message_count >= 5 and (immediate_build == person_id or immediate_build == "all")):
|
||||
|
||||
if total_message_count >= max_build_threshold or (
|
||||
total_message_count >= 5 and (immediate_build == person_id or immediate_build == "all")
|
||||
):
|
||||
users_to_build_relationship.append(person_id)
|
||||
logger.info(
|
||||
f"{self.log_prefix} 用户 {person_name} 满足关系构建条件,总消息数:{total_message_count},消息段数:{len(segments)}"
|
||||
@@ -409,7 +410,6 @@ class RelationshipBuilder:
|
||||
# 移除已处理的用户缓存
|
||||
del self.person_engaged_cache[person_id]
|
||||
self._save_cache()
|
||||
|
||||
|
||||
# ================================
|
||||
# 关系构建模块
|
||||
|
||||
@@ -88,7 +88,7 @@ class RelationshipManager:
|
||||
# 获取平台信息,优先使用chat_info_platform,如果为None则使用user_platform
|
||||
platform = msg.get("chat_info_platform") or msg.get("user_platform", "unknown")
|
||||
user_id = msg.get("user_id")
|
||||
|
||||
|
||||
await person_info_manager.get_or_create_person(
|
||||
platform=platform, # type: ignore
|
||||
user_id=user_id, # type: ignore
|
||||
@@ -237,9 +237,7 @@ class RelationshipManager:
|
||||
elif not isinstance(current_points, list):
|
||||
current_points = []
|
||||
current_points.extend(points_list)
|
||||
await person_info_manager.update_one_field(
|
||||
person_id, "points", orjson.dumps(current_points).decode('utf-8')
|
||||
)
|
||||
await person_info_manager.update_one_field(person_id, "points", orjson.dumps(current_points).decode("utf-8"))
|
||||
|
||||
# 将新记录添加到现有记录中
|
||||
if isinstance(current_points, list):
|
||||
@@ -285,9 +283,7 @@ class RelationshipManager:
|
||||
current_points = await self._update_impression(person_id, current_points, timestamp)
|
||||
|
||||
# 更新数据库
|
||||
await person_info_manager.update_one_field(
|
||||
person_id, "points", orjson.dumps(current_points).decode('utf-8')
|
||||
)
|
||||
await person_info_manager.update_one_field(person_id, "points", orjson.dumps(current_points).decode("utf-8"))
|
||||
|
||||
await person_info_manager.update_one_field(person_id, "know_times", know_times + 1)
|
||||
know_since = await person_info_manager.get_value(person_id, "know_since") or 0
|
||||
@@ -488,12 +484,10 @@ class RelationshipManager:
|
||||
|
||||
forgotten_points = []
|
||||
info_list = []
|
||||
await person_info_manager.update_one_field(
|
||||
person_id, "info_list", orjson.dumps(info_list).decode('utf-8')
|
||||
)
|
||||
await person_info_manager.update_one_field(person_id, "info_list", orjson.dumps(info_list).decode("utf-8"))
|
||||
|
||||
await person_info_manager.update_one_field(
|
||||
person_id, "forgotten_points", orjson.dumps(forgotten_points).decode('utf-8')
|
||||
person_id, "forgotten_points", orjson.dumps(forgotten_points).decode("utf-8")
|
||||
)
|
||||
|
||||
return current_points
|
||||
|
||||
Reference in New Issue
Block a user