From cabaf74194072fac475420bdcdac904d5afc3db8 Mon Sep 17 00:00:00 2001 From: Windpicker-owo <3431391539@qq.com> Date: Sat, 1 Nov 2025 17:06:40 +0800 Subject: [PATCH] =?UTF-8?q?style:=20ruff=E8=87=AA=E5=8A=A8=E6=A0=BC?= =?UTF-8?q?=E5=BC=8F=E5=8C=96=E4=BF=AE=E5=A4=8D=20-=20=E4=BF=AE=E5=A4=8D18?= =?UTF-8?q?0=E4=B8=AA=E7=A9=BA=E7=99=BD=E8=A1=8C=E5=92=8C=E6=A0=BC?= =?UTF-8?q?=E5=BC=8F=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/common/database/api/crud.py | 149 ++++++++++---------- src/common/database/api/query.py | 179 ++++++++++++------------ src/person_info/person_info.py | 20 +-- src/person_info/relationship_fetcher.py | 4 +- 4 files changed, 173 insertions(+), 179 deletions(-) diff --git a/src/common/database/api/crud.py b/src/common/database/api/crud.py index ed6ab24c7..a1245d491 100644 --- a/src/common/database/api/crud.py +++ b/src/common/database/api/crud.py @@ -6,11 +6,9 @@ - 智能预加载:关联数据自动预加载 """ -from typing import Any, Optional, Type, TypeVar +from typing import Any, TypeVar -from sqlalchemy import and_, delete, func, select, update -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.inspection import inspect +from sqlalchemy import delete, func, select, update from src.common.database.core.models import Base from src.common.database.core.session import get_db_session @@ -19,7 +17,6 @@ from src.common.database.optimization import ( Priority, get_batch_scheduler, get_cache, - get_preloader, ) from src.common.logger import get_logger @@ -30,10 +27,10 @@ T = TypeVar("T", bound=Base) def _model_to_dict(instance: Base) -> dict[str, Any]: """将 SQLAlchemy 模型实例转换为字典 - + Args: instance: SQLAlchemy 模型实例 - + Returns: 字典表示,包含所有列的值 """ @@ -47,13 +44,13 @@ def _model_to_dict(instance: Base) -> dict[str, Any]: return result -def _dict_to_model(model_class: Type[T], data: dict[str, Any]) -> T: +def _dict_to_model(model_class: type[T], data: dict[str, Any]) -> T: """从字典创建 SQLAlchemy 模型实例 (detached状态) - + Args: model_class: SQLAlchemy 模型类 data: 字典数据 - + Returns: 模型实例 (detached, 所有字段已加载) """ @@ -66,13 +63,13 @@ def _dict_to_model(model_class: Type[T], data: dict[str, Any]) -> T: class CRUDBase: """基础CRUD操作类 - + 提供通用的增删改查操作,自动集成缓存和批处理 """ - def __init__(self, model: Type[T]): + def __init__(self, model: type[T]): """初始化CRUD操作 - + Args: model: SQLAlchemy模型类 """ @@ -83,18 +80,18 @@ class CRUDBase: self, id: int, use_cache: bool = True, - ) -> Optional[T]: + ) -> T | None: """根据ID获取单条记录 - + Args: id: 记录ID use_cache: 是否使用缓存 - + Returns: 模型实例或None """ cache_key = f"{self.model_name}:id:{id}" - + # 尝试从缓存获取 (缓存的是字典) if use_cache: cache = await get_cache() @@ -103,13 +100,13 @@ class CRUDBase: logger.debug(f"缓存命中: {cache_key}") # 从字典恢复对象 return _dict_to_model(self.model, cached_dict) - + # 从数据库查询 async with get_db_session() as session: stmt = select(self.model).where(self.model.id == id) result = await session.execute(stmt) instance = result.scalar_one_or_none() - + if instance is not None: # 预加载所有字段 for column in self.model.__table__.columns: @@ -117,31 +114,31 @@ class CRUDBase: getattr(instance, column.name) except Exception: pass - + # 转换为字典并写入缓存 if use_cache: instance_dict = _model_to_dict(instance) cache = await get_cache() await cache.set(cache_key, instance_dict) - + return instance async def get_by( self, use_cache: bool = True, **filters: Any, - ) -> Optional[T]: + ) -> T | None: """根据条件获取单条记录 - + Args: use_cache: 是否使用缓存 **filters: 过滤条件 - + Returns: 模型实例或None """ - cache_key = f"{self.model_name}:filter:{str(sorted(filters.items()))}" - + cache_key = f"{self.model_name}:filter:{sorted(filters.items())!s}" + # 尝试从缓存获取 (缓存的是字典) if use_cache: cache = await get_cache() @@ -150,17 +147,17 @@ class CRUDBase: logger.debug(f"缓存命中: {cache_key}") # 从字典恢复对象 return _dict_to_model(self.model, cached_dict) - + # 从数据库查询 async with get_db_session() as session: stmt = select(self.model) for key, value in filters.items(): if hasattr(self.model, key): stmt = stmt.where(getattr(self.model, key) == value) - + result = await session.execute(stmt) instance = result.scalar_one_or_none() - + if instance is not None: # 触发所有列的加载,避免 detached 后的延迟加载问题 # 遍历所有列属性以确保它们被加载到内存中 @@ -169,13 +166,13 @@ class CRUDBase: getattr(instance, column.name) except Exception: pass # 忽略访问错误 - + # 转换为字典并写入缓存 if use_cache: instance_dict = _model_to_dict(instance) cache = await get_cache() await cache.set(cache_key, instance_dict) - + return instance async def get_multi( @@ -186,18 +183,18 @@ class CRUDBase: **filters: Any, ) -> list[T]: """获取多条记录 - + Args: skip: 跳过的记录数 limit: 返回的最大记录数 use_cache: 是否使用缓存 **filters: 过滤条件 - + Returns: 模型实例列表 """ - cache_key = f"{self.model_name}:multi:{skip}:{limit}:{str(sorted(filters.items()))}" - + cache_key = f"{self.model_name}:multi:{skip}:{limit}:{sorted(filters.items())!s}" + # 尝试从缓存获取 (缓存的是字典列表) if use_cache: cache = await get_cache() @@ -206,11 +203,11 @@ class CRUDBase: logger.debug(f"缓存命中: {cache_key}") # 从字典列表恢复对象列表 return [_dict_to_model(self.model, d) for d in cached_dicts] - + # 从数据库查询 async with get_db_session() as session: stmt = select(self.model) - + # 应用过滤条件 for key, value in filters.items(): if hasattr(self.model, key): @@ -218,13 +215,13 @@ class CRUDBase: stmt = stmt.where(getattr(self.model, key).in_(value)) else: stmt = stmt.where(getattr(self.model, key) == value) - + # 应用分页 stmt = stmt.offset(skip).limit(limit) - + result = await session.execute(stmt) instances = list(result.scalars().all()) - + # 触发所有实例的列加载,避免 detached 后的延迟加载问题 for instance in instances: for column in self.model.__table__.columns: @@ -232,13 +229,13 @@ class CRUDBase: getattr(instance, column.name) except Exception: pass # 忽略访问错误 - + # 转换为字典列表并写入缓存 if use_cache: instances_dicts = [_model_to_dict(inst) for inst in instances] cache = await get_cache() await cache.set(cache_key, instances_dicts) - + return instances async def create( @@ -247,11 +244,11 @@ class CRUDBase: use_batch: bool = False, ) -> T: """创建新记录 - + Args: obj_in: 创建数据 use_batch: 是否使用批处理 - + Returns: 创建的模型实例 """ @@ -266,7 +263,7 @@ class CRUDBase: ) future = await scheduler.add_operation(operation) await future - + # 批处理返回成功,创建实例 instance = self.model(**obj_in) return instance @@ -284,14 +281,14 @@ class CRUDBase: id: int, obj_in: dict[str, Any], use_batch: bool = False, - ) -> Optional[T]: + ) -> T | None: """更新记录 - + Args: id: 记录ID obj_in: 更新数据 use_batch: 是否使用批处理 - + Returns: 更新后的模型实例或None """ @@ -299,7 +296,7 @@ class CRUDBase: instance = await self.get(id, use_cache=False) if instance is None: return None - + if use_batch: # 使用批处理 scheduler = await get_batch_scheduler() @@ -312,7 +309,7 @@ class CRUDBase: ) future = await scheduler.add_operation(operation) await future - + # 更新实例属性 for key, value in obj_in.items(): if hasattr(instance, key): @@ -324,7 +321,7 @@ class CRUDBase: stmt = select(self.model).where(self.model.id == id) result = await session.execute(stmt) db_instance = result.scalar_one_or_none() - + if db_instance: for key, value in obj_in.items(): if hasattr(db_instance, key): @@ -332,12 +329,12 @@ class CRUDBase: await session.flush() await session.refresh(db_instance) instance = db_instance - + # 清除缓存 cache_key = f"{self.model_name}:id:{id}" cache = await get_cache() await cache.delete(cache_key) - + return instance async def delete( @@ -346,11 +343,11 @@ class CRUDBase: use_batch: bool = False, ) -> bool: """删除记录 - + Args: id: 记录ID use_batch: 是否使用批处理 - + Returns: 是否成功删除 """ @@ -372,13 +369,13 @@ class CRUDBase: stmt = delete(self.model).where(self.model.id == id) result = await session.execute(stmt) success = result.rowcount > 0 - + # 清除缓存 if success: cache_key = f"{self.model_name}:id:{id}" cache = await get_cache() await cache.delete(cache_key) - + return success async def count( @@ -386,16 +383,16 @@ class CRUDBase: **filters: Any, ) -> int: """统计记录数 - + Args: **filters: 过滤条件 - + Returns: 记录数量 """ async with get_db_session() as session: stmt = select(func.count(self.model.id)) - + # 应用过滤条件 for key, value in filters.items(): if hasattr(self.model, key): @@ -403,7 +400,7 @@ class CRUDBase: stmt = stmt.where(getattr(self.model, key).in_(value)) else: stmt = stmt.where(getattr(self.model, key) == value) - + result = await session.execute(stmt) return result.scalar() @@ -412,10 +409,10 @@ class CRUDBase: **filters: Any, ) -> bool: """检查记录是否存在 - + Args: **filters: 过滤条件 - + Returns: 是否存在 """ @@ -424,15 +421,15 @@ class CRUDBase: async def get_or_create( self, - defaults: Optional[dict[str, Any]] = None, + defaults: dict[str, Any] | None = None, **filters: Any, ) -> tuple[T, bool]: """获取或创建记录 - + Args: defaults: 创建时的默认值 **filters: 查找条件 - + Returns: (实例, 是否新创建) """ @@ -440,12 +437,12 @@ class CRUDBase: instance = await self.get_by(use_cache=False, **filters) if instance is not None: return instance, False - + # 创建新记录 create_data = {**filters} if defaults: create_data.update(defaults) - + instance = await self.create(create_data) return instance, True @@ -454,10 +451,10 @@ class CRUDBase: objs_in: list[dict[str, Any]], ) -> list[T]: """批量创建记录 - + Args: objs_in: 创建数据列表 - + Returns: 创建的模型实例列表 """ @@ -465,10 +462,10 @@ class CRUDBase: instances = [self.model(**obj_data) for obj_data in objs_in] session.add_all(instances) await session.flush() - + for instance in instances: await session.refresh(instance) - + return instances async def bulk_update( @@ -476,10 +473,10 @@ class CRUDBase: updates: list[tuple[int, dict[str, Any]]], ) -> int: """批量更新记录 - + Args: updates: (id, update_data)元组列表 - + Returns: 更新的记录数 """ @@ -493,10 +490,10 @@ class CRUDBase: ) result = await session.execute(stmt) count += result.rowcount - + # 清除缓存 cache_key = f"{self.model_name}:id:{id}" cache = await get_cache() await cache.delete(cache_key) - + return count diff --git a/src/common/database/api/query.py b/src/common/database/api/query.py index b34587ba7..38d740d51 100644 --- a/src/common/database/api/query.py +++ b/src/common/database/api/query.py @@ -7,19 +7,16 @@ - 关联查询 """ -from typing import Any, Generic, Optional, Sequence, Type, TypeVar +from typing import Any, Generic, TypeVar from sqlalchemy import and_, asc, desc, func, or_, select -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.engine import Row - -from src.common.database.core.models import Base -from src.common.database.core.session import get_db_session -from src.common.database.optimization import get_cache, get_preloader -from src.common.logger import get_logger # 导入 CRUD 辅助函数以避免重复定义 from src.common.database.api.crud import _dict_to_model, _model_to_dict +from src.common.database.core.models import Base +from src.common.database.core.session import get_db_session +from src.common.database.optimization import get_cache +from src.common.logger import get_logger logger = get_logger("database.query") @@ -28,13 +25,13 @@ T = TypeVar("T", bound="Base") class QueryBuilder(Generic[T]): """查询构建器 - + 支持链式调用,构建复杂查询 """ - def __init__(self, model: Type[T]): + def __init__(self, model: type[T]): """初始化查询构建器 - + Args: model: SQLAlchemy模型类 """ @@ -46,7 +43,7 @@ class QueryBuilder(Generic[T]): def filter(self, **conditions: Any) -> "QueryBuilder": """添加过滤条件 - + 支持的操作符: - 直接相等: field=value - 大于: field__gt=value @@ -58,10 +55,10 @@ class QueryBuilder(Generic[T]): - 不包含: field__nin=[values] - 模糊匹配: field__like='%pattern%' - 为空: field__isnull=True - + Args: **conditions: 过滤条件 - + Returns: self,支持链式调用 """ @@ -71,13 +68,13 @@ class QueryBuilder(Generic[T]): field_name, operator = key.rsplit("__", 1) else: field_name, operator = key, "eq" - + if not hasattr(self.model, field_name): logger.warning(f"模型 {self.model_name} 没有字段 {field_name}") continue - + field = getattr(self.model, field_name) - + # 应用操作符 if operator == "eq": self._stmt = self._stmt.where(field == value) @@ -104,17 +101,17 @@ class QueryBuilder(Generic[T]): self._stmt = self._stmt.where(field.isnot(None)) else: logger.warning(f"未知操作符: {operator}") - + # 更新缓存键 - self._cache_key_parts.append(f"filter:{str(sorted(conditions.items()))}") + self._cache_key_parts.append(f"filter:{sorted(conditions.items())!s}") return self def filter_or(self, **conditions: Any) -> "QueryBuilder": """添加OR过滤条件 - + Args: **conditions: OR条件 - + Returns: self,支持链式调用 """ @@ -123,19 +120,19 @@ class QueryBuilder(Generic[T]): if hasattr(self.model, key): field = getattr(self.model, key) or_conditions.append(field == value) - + if or_conditions: self._stmt = self._stmt.where(or_(*or_conditions)) - self._cache_key_parts.append(f"or:{str(sorted(conditions.items()))}") - + self._cache_key_parts.append(f"or:{sorted(conditions.items())!s}") + return self def order_by(self, *fields: str) -> "QueryBuilder": """添加排序 - + Args: *fields: 排序字段,'-'前缀表示降序 - + Returns: self,支持链式调用 """ @@ -147,16 +144,16 @@ class QueryBuilder(Generic[T]): else: if hasattr(self.model, field_name): self._stmt = self._stmt.order_by(asc(getattr(self.model, field_name))) - + self._cache_key_parts.append(f"order:{','.join(fields)}") return self def limit(self, limit: int) -> "QueryBuilder": """限制结果数量 - + Args: limit: 最大数量 - + Returns: self,支持链式调用 """ @@ -166,10 +163,10 @@ class QueryBuilder(Generic[T]): def offset(self, offset: int) -> "QueryBuilder": """跳过指定数量 - + Args: offset: 跳过数量 - + Returns: self,支持链式调用 """ @@ -179,7 +176,7 @@ class QueryBuilder(Generic[T]): def no_cache(self) -> "QueryBuilder": """禁用缓存 - + Returns: self,支持链式调用 """ @@ -188,12 +185,12 @@ class QueryBuilder(Generic[T]): async def all(self) -> list[T]: """获取所有结果 - + Returns: 模型实例列表 """ cache_key = ":".join(self._cache_key_parts) + ":all" - + # 尝试从缓存获取 (缓存的是字典列表) if self._use_cache: cache = await get_cache() @@ -202,12 +199,12 @@ class QueryBuilder(Generic[T]): logger.debug(f"缓存命中: {cache_key}") # 从字典列表恢复对象列表 return [_dict_to_model(self.model, d) for d in cached_dicts] - + # 从数据库查询 async with get_db_session() as session: result = await session.execute(self._stmt) instances = list(result.scalars().all()) - + # 预加载所有列以避免detached对象的lazy loading问题 for instance in instances: for column in self.model.__table__.columns: @@ -215,23 +212,23 @@ class QueryBuilder(Generic[T]): getattr(instance, column.name) except Exception: pass - + # 转换为字典列表并写入缓存 if self._use_cache: instances_dicts = [_model_to_dict(inst) for inst in instances] cache = await get_cache() await cache.set(cache_key, instances_dicts) - + return instances - async def first(self) -> Optional[T]: + async def first(self) -> T | None: """获取第一个结果 - + Returns: 模型实例或None """ cache_key = ":".join(self._cache_key_parts) + ":first" - + # 尝试从缓存获取 (缓存的是字典) if self._use_cache: cache = await get_cache() @@ -240,12 +237,12 @@ class QueryBuilder(Generic[T]): logger.debug(f"缓存命中: {cache_key}") # 从字典恢复对象 return _dict_to_model(self.model, cached_dict) - + # 从数据库查询 async with get_db_session() as session: result = await session.execute(self._stmt) instance = result.scalars().first() - + # 预加载所有列以避免detached对象的lazy loading问题 if instance is not None: for column in self.model.__table__.columns: @@ -253,23 +250,23 @@ class QueryBuilder(Generic[T]): getattr(instance, column.name) except Exception: pass - + # 转换为字典并写入缓存 if instance is not None and self._use_cache: instance_dict = _model_to_dict(instance) cache = await get_cache() await cache.set(cache_key, instance_dict) - + return instance async def count(self) -> int: """统计数量 - + Returns: 记录数量 """ cache_key = ":".join(self._cache_key_parts) + ":count" - + # 尝试从缓存获取 if self._use_cache: cache = await get_cache() @@ -277,25 +274,25 @@ class QueryBuilder(Generic[T]): if cached is not None: logger.debug(f"缓存命中: {cache_key}") return cached - + # 构建count查询 count_stmt = select(func.count()).select_from(self._stmt.subquery()) - + # 从数据库查询 async with get_db_session() as session: result = await session.execute(count_stmt) count = result.scalar() or 0 - + # 写入缓存 if self._use_cache: cache = await get_cache() await cache.set(cache_key, count) - + return count async def exists(self) -> bool: """检查是否存在 - + Returns: 是否存在记录 """ @@ -308,38 +305,38 @@ class QueryBuilder(Generic[T]): page_size: int = 20, ) -> tuple[list[T], int]: """分页查询 - + Args: page: 页码(从1开始) page_size: 每页数量 - + Returns: (结果列表, 总数量) """ # 计算偏移量 offset = (page - 1) * page_size - + # 获取总数 total = await self.count() - + # 获取当前页数据 self._stmt = self._stmt.offset(offset).limit(page_size) self._cache_key_parts.append(f"page:{page}:{page_size}") - + items = await self.all() - + return items, total class AggregateQuery: """聚合查询 - + 提供聚合操作如sum、avg、max、min等 """ - def __init__(self, model: Type[T]): + def __init__(self, model: type[T]): """初始化聚合查询 - + Args: model: SQLAlchemy模型类 """ @@ -349,10 +346,10 @@ class AggregateQuery: def filter(self, **conditions: Any) -> "AggregateQuery": """添加过滤条件 - + Args: **conditions: 过滤条件 - + Returns: self,支持链式调用 """ @@ -364,85 +361,85 @@ class AggregateQuery: async def sum(self, field: str) -> float: """求和 - + Args: field: 字段名 - + Returns: 总和 """ if not hasattr(self.model, field): raise ValueError(f"字段 {field} 不存在") - + async with get_db_session() as session: stmt = select(func.sum(getattr(self.model, field))) - + if self._conditions: stmt = stmt.where(and_(*self._conditions)) - + result = await session.execute(stmt) return result.scalar() or 0 async def avg(self, field: str) -> float: """求平均值 - + Args: field: 字段名 - + Returns: 平均值 """ if not hasattr(self.model, field): raise ValueError(f"字段 {field} 不存在") - + async with get_db_session() as session: stmt = select(func.avg(getattr(self.model, field))) - + if self._conditions: stmt = stmt.where(and_(*self._conditions)) - + result = await session.execute(stmt) return result.scalar() or 0 async def max(self, field: str) -> Any: """求最大值 - + Args: field: 字段名 - + Returns: 最大值 """ if not hasattr(self.model, field): raise ValueError(f"字段 {field} 不存在") - + async with get_db_session() as session: stmt = select(func.max(getattr(self.model, field))) - + if self._conditions: stmt = stmt.where(and_(*self._conditions)) - + result = await session.execute(stmt) return result.scalar() async def min(self, field: str) -> Any: """求最小值 - + Args: field: 字段名 - + Returns: 最小值 """ if not hasattr(self.model, field): raise ValueError(f"字段 {field} 不存在") - + async with get_db_session() as session: stmt = select(func.min(getattr(self.model, field))) - + if self._conditions: stmt = stmt.where(and_(*self._conditions)) - + result = await session.execute(stmt) return result.scalar() @@ -451,31 +448,31 @@ class AggregateQuery: *fields: str, ) -> list[tuple[Any, ...]]: """分组统计 - + Args: *fields: 分组字段 - + Returns: [(分组值1, 分组值2, ..., 数量), ...] """ if not fields: raise ValueError("至少需要一个分组字段") - + group_columns = [] for field_name in fields: if hasattr(self.model, field_name): group_columns.append(getattr(self.model, field_name)) - + if not group_columns: return [] - + async with get_db_session() as session: stmt = select(*group_columns, func.count(self.model.id)) - + if self._conditions: stmt = stmt.where(and_(*self._conditions)) - + stmt = stmt.group_by(*group_columns) - + result = await session.execute(stmt) return [tuple(row) for row in result.all()] diff --git a/src/person_info/person_info.py b/src/person_info/person_info.py index 0c656f56a..539fff829 100644 --- a/src/person_info/person_info.py +++ b/src/person_info/person_info.py @@ -9,9 +9,9 @@ import orjson from json_repair import repair_json from sqlalchemy import select +from src.common.database.api.crud import CRUDBase from src.common.database.compatibility import get_db_session from src.common.database.core.models import PersonInfo -from src.common.database.api.crud import CRUDBase from src.common.database.utils.decorators import cached from src.common.logger import get_logger from src.config.config import global_config, model_config @@ -307,18 +307,18 @@ class PersonInfoManager: crud = CRUDBase(PersonInfo) record = await crud.get_by(person_id=p_id) query_time = time.time() - + if record: # 更新记录 await crud.update(record.id, {f_name: val_to_set}) save_time = time.time() total_time = save_time - start_time - + if total_time > 0.5: logger.warning( f"数据库更新操作耗时 {total_time:.3f}秒 (查询: {query_time - start_time:.3f}s, 保存: {save_time - query_time:.3f}s) person_id={p_id}, field={f_name}" ) - + # 使缓存失效 from src.common.database.optimization.cache_manager import get_cache from src.common.database.utils.decorators import generate_cache_key @@ -327,7 +327,7 @@ class PersonInfoManager: await cache.delete(generate_cache_key("person_value", p_id, f_name)) await cache.delete(generate_cache_key("person_values", p_id)) await cache.delete(generate_cache_key("person_has_field", p_id, f_name)) - + return True, False else: total_time = time.time() - start_time @@ -339,7 +339,7 @@ class PersonInfoManager: logger.error(f"数据库操作异常,耗时 {total_time:.3f}秒: {e}") raise - found, needs_creation = await _db_update_async(person_id, field_name, processed_value) + _found, needs_creation = await _db_update_async(person_id, field_name, processed_value) if needs_creation: logger.info(f"{person_id} 不存在,将新建。") @@ -538,7 +538,7 @@ class PersonInfoManager: record = await crud.get_by(person_id=p_id) if record: await crud.delete(record.id) - + # 注意: 删除操作很少发生,缓存会在TTL过期后自动清除 # 无法从person_id反向得到platform和user_id,因此无法精确清除缓存 # 删除后的查询仍会返回正确结果(None/False) @@ -658,7 +658,7 @@ class PersonInfoManager: try: value = getattr(record, f_name, None) if value is not None and way(value): - person_id_value = getattr(record, 'person_id', None) + person_id_value = getattr(record, "person_id", None) if person_id_value: found_results[person_id_value] = value except Exception as e: @@ -690,7 +690,7 @@ class PersonInfoManager: """原子性的获取或创建操作""" # 使用CRUD进行获取或创建 crud = CRUDBase(PersonInfo) - + # 首先尝试获取现有记录 record = await crud.get_by(person_id=p_id) if record: @@ -736,7 +736,7 @@ class PersonInfoManager: model_fields = [column.name for column in PersonInfo.__table__.columns] filtered_initial_data = {k: v for k, v in initial_data.items() if v is not None and k in model_fields} - record, was_created = await _db_get_or_create_async(person_id, filtered_initial_data) + _record, was_created = await _db_get_or_create_async(person_id, filtered_initial_data) if was_created: logger.info(f"用户 {platform}:{user_id} (person_id: {person_id}) 不存在,将创建新记录。") diff --git a/src/person_info/relationship_fetcher.py b/src/person_info/relationship_fetcher.py index 8942322db..82db6911f 100644 --- a/src/person_info/relationship_fetcher.py +++ b/src/person_info/relationship_fetcher.py @@ -186,7 +186,7 @@ class RelationshipFetcher: # 查询用户关系数据 user_id = str(await person_info_manager.get_value(person_id, "user_id")) platform = str(await person_info_manager.get_value(person_id, "platform")) - + # 使用优化后的API(带缓存) relationship = await get_user_relationship( platform=platform, @@ -261,7 +261,7 @@ class RelationshipFetcher: # 使用优化后的API(带缓存) # 从stream_id解析platform,或使用默认值 platform = stream_id.split("_")[0] if "_" in stream_id else "unknown" - + stream, _ = await get_or_create_chat_stream( stream_id=stream_id, platform=platform,