import copy import hashlib import datetime import asyncio import json from typing import Dict, Union, Optional, List from src.common.logger import get_logger from src.common.database.database import db from src.common.database.database_model import GroupInfo """ GroupInfoManager 类方法功能摘要: 1. get_group_id - 根据平台和群号生成MD5哈希的唯一group_id 2. create_group_info - 创建新群组信息文档(自动合并默认值) 3. update_one_field - 更新单个字段值(若文档不存在则创建) 4. del_one_document - 删除指定group_id的文档 5. get_value - 获取单个字段值(返回实际值或默认值) 6. get_values - 批量获取字段值(任一字段无效则返回空字典) 7. add_member - 添加群成员 8. remove_member - 移除群成员 9. get_member_list - 获取群成员列表 """ logger = get_logger("group_info") JSON_SERIALIZED_FIELDS = ["member_list", "topic"] group_info_default = { "group_id": None, "group_name": None, "platform": "unknown", "group_impression": None, "member_list": [], "topic":[], "create_time": None, "last_active": None, "member_count": 0, } class GroupInfoManager: def __init__(self): self.group_name_list = {} try: db.connect(reuse_if_open=True) # 设置连接池参数 if hasattr(db, "execute_sql"): # 设置SQLite优化参数 db.execute_sql("PRAGMA cache_size = -64000") # 64MB缓存 db.execute_sql("PRAGMA temp_store = memory") # 临时存储在内存中 db.execute_sql("PRAGMA mmap_size = 268435456") # 256MB内存映射 db.create_tables([GroupInfo], safe=True) except Exception as e: logger.error(f"数据库连接或 GroupInfo 表创建失败: {e}") # 初始化时读取所有group_name try: for record in GroupInfo.select(GroupInfo.group_id, GroupInfo.group_name).where( GroupInfo.group_name.is_null(False) ): if record.group_name: self.group_name_list[record.group_id] = record.group_name logger.debug(f"已加载 {len(self.group_name_list)} 个群组名称 (Peewee)") except Exception as e: logger.error(f"从 Peewee 加载 group_name_list 失败: {e}") @staticmethod def get_group_id(platform: str, group_number: Union[int, str]) -> str: """获取群组唯一id""" # 添加空值检查,防止 platform 为 None 时出错 if platform is None: platform = "unknown" elif "-" in platform: platform = platform.split("-")[1] components = [platform, str(group_number)] key = "_".join(components) return hashlib.md5(key.encode()).hexdigest() async def is_group_known(self, platform: str, group_number: int): """判断是否知道某个群组""" group_id = self.get_group_id(platform, group_number) def _db_check_known_sync(g_id: str): return GroupInfo.get_or_none(GroupInfo.group_id == g_id) is not None try: return await asyncio.to_thread(_db_check_known_sync, group_id) except Exception as e: logger.error(f"检查群组 {group_id} 是否已知时出错 (Peewee): {e}") return False @staticmethod async def create_group_info(group_id: str, data: Optional[dict] = None): """创建一个群组信息项""" if not group_id: logger.debug("创建失败,group_id不存在") return _group_info_default = copy.deepcopy(group_info_default) model_fields = GroupInfo._meta.fields.keys() # type: ignore final_data = {"group_id": group_id} # Start with defaults for all model fields for key, default_value in _group_info_default.items(): if key in model_fields: final_data[key] = default_value # Override with provided data if data: for key, value in data.items(): if key in model_fields: final_data[key] = value # Ensure group_id is correctly set from the argument final_data["group_id"] = group_id # Serialize JSON fields for key in JSON_SERIALIZED_FIELDS: if key in final_data: if isinstance(final_data[key], (list, dict)): final_data[key] = json.dumps(final_data[key], ensure_ascii=False) elif final_data[key] is None: # Default for lists is [], store as "[]" final_data[key] = json.dumps([], ensure_ascii=False) def _db_create_sync(g_data: dict): try: GroupInfo.create(**g_data) return True except Exception as e: logger.error(f"创建 GroupInfo 记录 {g_data.get('group_id')} 失败 (Peewee): {e}") return False await asyncio.to_thread(_db_create_sync, final_data) async def _safe_create_group_info(self, group_id: str, data: Optional[dict] = None): """安全地创建群组信息,处理竞态条件""" if not group_id: logger.debug("创建失败,group_id不存在") return _group_info_default = copy.deepcopy(group_info_default) model_fields = GroupInfo._meta.fields.keys() # type: ignore final_data = {"group_id": group_id} # Start with defaults for all model fields for key, default_value in _group_info_default.items(): if key in model_fields: final_data[key] = default_value # Override with provided data if data: for key, value in data.items(): if key in model_fields: final_data[key] = value # Ensure group_id is correctly set from the argument final_data["group_id"] = group_id # Serialize JSON fields for key in JSON_SERIALIZED_FIELDS: if key in final_data: if isinstance(final_data[key], (list, dict)): final_data[key] = json.dumps(final_data[key], ensure_ascii=False) elif final_data[key] is None: # Default for lists is [], store as "[]" final_data[key] = json.dumps([], ensure_ascii=False) def _db_safe_create_sync(g_data: dict): try: # 首先检查是否已存在 existing = GroupInfo.get_or_none(GroupInfo.group_id == g_data["group_id"]) if existing: logger.debug(f"群组 {g_data['group_id']} 已存在,跳过创建") return True # 尝试创建 GroupInfo.create(**g_data) return True except Exception as e: if "UNIQUE constraint failed" in str(e): logger.debug(f"检测到并发创建群组 {g_data.get('group_id')},跳过错误") return True # 其他协程已创建,视为成功 else: logger.error(f"创建 GroupInfo 记录 {g_data.get('group_id')} 失败 (Peewee): {e}") return False await asyncio.to_thread(_db_safe_create_sync, final_data) async def update_one_field(self, group_id: str, field_name: str, value, data: Optional[Dict] = None): """更新某一个字段,会补全""" if field_name not in GroupInfo._meta.fields: # type: ignore logger.debug(f"更新'{field_name}'失败,未在 GroupInfo Peewee 模型中定义的字段。") return processed_value = value if field_name in JSON_SERIALIZED_FIELDS: if isinstance(value, (list, dict)): processed_value = json.dumps(value, ensure_ascii=False, indent=None) elif value is None: # Store None as "[]" for JSON list fields processed_value = json.dumps([], ensure_ascii=False, indent=None) def _db_update_sync(g_id: str, f_name: str, val_to_set): import time start_time = time.time() try: record = GroupInfo.get_or_none(GroupInfo.group_id == g_id) query_time = time.time() if record: setattr(record, f_name, val_to_set) record.save() save_time = time.time() total_time = save_time - start_time if total_time > 0.5: # 如果超过500ms就记录日志 logger.warning( f"数据库更新操作耗时 {total_time:.3f}秒 (查询: {query_time - start_time:.3f}s, 保存: {save_time - query_time:.3f}s) group_id={g_id}, field={f_name}" ) return True, False # Found and updated, no creation needed else: total_time = time.time() - start_time if total_time > 0.5: logger.warning(f"数据库查询操作耗时 {total_time:.3f}秒 group_id={g_id}, field={f_name}") return False, True # Not found, needs creation except Exception as e: total_time = time.time() - start_time logger.error(f"数据库操作异常,耗时 {total_time:.3f}秒: {e}") raise found, needs_creation = await asyncio.to_thread(_db_update_sync, group_id, field_name, processed_value) if needs_creation: logger.info(f"{group_id} 不存在,将新建。") creation_data = data if data is not None else {} # Ensure platform and group_number are present for context if available from 'data' # but primarily, set the field that triggered the update. # The create_group_info will handle defaults and serialization. creation_data[field_name] = value # Pass original value to create_group_info # Ensure platform and group_number are in creation_data if available, # otherwise create_group_info will use defaults. if data and "platform" in data: creation_data["platform"] = data["platform"] if data and "group_number" in data: creation_data["group_number"] = data["group_number"] # 使用安全的创建方法,处理竞态条件 await self._safe_create_group_info(group_id, creation_data) @staticmethod async def del_one_document(group_id: str): """删除指定 group_id 的文档""" if not group_id: logger.debug("删除失败:group_id 不能为空") return def _db_delete_sync(g_id: str): try: query = GroupInfo.delete().where(GroupInfo.group_id == g_id) deleted_count = query.execute() return deleted_count except Exception as e: logger.error(f"删除 GroupInfo {g_id} 失败 (Peewee): {e}") return 0 deleted_count = await asyncio.to_thread(_db_delete_sync, group_id) if deleted_count > 0: logger.debug(f"删除成功:group_id={group_id} (Peewee)") else: logger.debug(f"删除失败:未找到 group_id={group_id} 或删除未影响行 (Peewee)") @staticmethod async def get_value(group_id: str, field_name: str): """获取指定群组指定字段的值""" default_value_for_field = group_info_default.get(field_name) if field_name in JSON_SERIALIZED_FIELDS and default_value_for_field is None: default_value_for_field = [] # Ensure JSON fields default to [] if not in DB def _db_get_value_sync(g_id: str, f_name: str): record = GroupInfo.get_or_none(GroupInfo.group_id == g_id) if record: val = getattr(record, f_name, None) if f_name in JSON_SERIALIZED_FIELDS: if isinstance(val, str): try: return json.loads(val) except json.JSONDecodeError: logger.warning(f"字段 {f_name} for {g_id} 包含无效JSON: {val}. 返回默认值.") return [] # Default for JSON fields on error elif val is None: # Field exists in DB but is None return [] # Default for JSON fields # If val is already a list/dict (e.g. if somehow set without serialization) return val # Should ideally not happen if update_one_field is always used return val return None # Record not found try: value_from_db = await asyncio.to_thread(_db_get_value_sync, group_id, field_name) if value_from_db is not None: return value_from_db if field_name in group_info_default: return default_value_for_field logger.warning(f"字段 {field_name} 在 group_info_default 中未定义,且在数据库中未找到。") return None # Ultimate fallback except Exception as e: logger.error(f"获取字段 {field_name} for {group_id} 时出错 (Peewee): {e}") # Fallback to default in case of any error during DB access return default_value_for_field if field_name in group_info_default else None @staticmethod async def get_values(group_id: str, field_names: list) -> dict: """获取指定group_id文档的多个字段值,若不存在该字段,则返回该字段的全局默认值""" if not group_id: logger.debug("get_values获取失败:group_id不能为空") return {} result = {} def _db_get_record_sync(g_id: str): return GroupInfo.get_or_none(GroupInfo.group_id == g_id) record = await asyncio.to_thread(_db_get_record_sync, group_id) for field_name in field_names: if field_name not in GroupInfo._meta.fields: # type: ignore if field_name in group_info_default: result[field_name] = copy.deepcopy(group_info_default[field_name]) logger.debug(f"字段'{field_name}'不在Peewee模型中,使用默认配置值。") else: logger.debug(f"get_values查询失败:字段'{field_name}'未在Peewee模型和默认配置中定义。") result[field_name] = None continue if record: value = getattr(record, field_name) if value is not None: result[field_name] = value else: result[field_name] = copy.deepcopy(group_info_default.get(field_name)) else: result[field_name] = copy.deepcopy(group_info_default.get(field_name)) return result async def add_member(self, group_id: str, member_info: dict): """添加群成员(使用 last_active_time,不使用 join_time)""" if not group_id or not member_info: logger.debug("添加成员失败:group_id或member_info不能为空") return # 规范化成员字段 normalized_member = dict(member_info) normalized_member.pop("join_time", None) if "last_active_time" not in normalized_member: normalized_member["last_active_time"] = datetime.datetime.now().timestamp() member_id = normalized_member.get("user_id") if not member_id: logger.debug("添加成员失败:缺少 user_id") return # 获取当前成员列表 current_members = await self.get_value(group_id, "member_list") if not isinstance(current_members, list): current_members = [] # 移除已存在的同 user_id 成员 current_members = [m for m in current_members if m.get("user_id") != member_id] # 添加新成员 current_members.append(normalized_member) # 更新成员列表和成员数量 await self.update_one_field(group_id, "member_list", current_members) await self.update_one_field(group_id, "member_count", len(current_members)) await self.update_one_field(group_id, "last_active", datetime.datetime.now().timestamp()) logger.info(f"群组 {group_id} 添加/更新成员 {normalized_member.get('nickname', member_id)} 成功") async def remove_member(self, group_id: str, user_id: str): """移除群成员""" if not group_id or not user_id: logger.debug("移除成员失败:group_id或user_id不能为空") return # 获取当前成员列表 current_members = await self.get_value(group_id, "member_list") if not isinstance(current_members, list): logger.debug(f"群组 {group_id} 成员列表为空或格式错误") return # 移除指定成员 original_count = len(current_members) current_members = [m for m in current_members if m.get("user_id") != user_id] new_count = len(current_members) if new_count < original_count: # 更新成员列表和成员数量 await self.update_one_field(group_id, "member_list", current_members) await self.update_one_field(group_id, "member_count", new_count) await self.update_one_field(group_id, "last_active", datetime.datetime.now().timestamp()) logger.info(f"群组 {group_id} 移除成员 {user_id} 成功") else: logger.debug(f"群组 {group_id} 中未找到成员 {user_id}") async def get_member_list(self, group_id: str) -> List[dict]: """获取群成员列表""" if not group_id: logger.debug("获取成员列表失败:group_id不能为空") return [] members = await self.get_value(group_id, "member_list") if isinstance(members, list): return members return [] async def get_or_create_group( self, platform: str, group_number: int, group_name: str = None ) -> str: """ 根据 platform 和 group_number 获取 group_id。 如果对应的群组不存在,则使用提供的信息创建新群组。 使用try-except处理竞态条件,避免重复创建错误。 """ group_id = self.get_group_id(platform, group_number) def _db_get_or_create_sync(g_id: str, init_data: dict): """原子性的获取或创建操作""" # 首先尝试获取现有记录 record = GroupInfo.get_or_none(GroupInfo.group_id == g_id) if record: return record, False # 记录存在,未创建 # 记录不存在,尝试创建 try: GroupInfo.create(**init_data) return GroupInfo.get(GroupInfo.group_id == g_id), True # 创建成功 except Exception as e: # 如果创建失败(可能是因为竞态条件),再次尝试获取 if "UNIQUE constraint failed" in str(e): logger.debug(f"检测到并发创建群组 {g_id},获取现有记录") record = GroupInfo.get_or_none(GroupInfo.group_id == g_id) if record: return record, False # 其他协程已创建,返回现有记录 # 如果仍然失败,重新抛出异常 raise e initial_data = { "group_id": group_id, "platform": platform, "group_number": str(group_number), "group_name": group_name, "create_time": datetime.datetime.now().timestamp(), "last_active": datetime.datetime.now().timestamp(), "member_count": 0, "member_list": [], "group_info": {}, } # 序列化JSON字段 for key in JSON_SERIALIZED_FIELDS: if key in initial_data: if isinstance(initial_data[key], (list, dict)): initial_data[key] = json.dumps(initial_data[key], ensure_ascii=False) elif initial_data[key] is None: initial_data[key] = json.dumps([], ensure_ascii=False) model_fields = GroupInfo._meta.fields.keys() # type: ignore filtered_initial_data = {k: v for k, v in initial_data.items() if v is not None and k in model_fields} record, was_created = await asyncio.to_thread(_db_get_or_create_sync, group_id, filtered_initial_data) if was_created: logger.info(f"群组 {platform}:{group_number} (group_id: {group_id}) 不存在,将创建新记录 (Peewee)。") logger.info(f"已为 {group_id} 创建新记录,初始数据 (filtered for model): {filtered_initial_data}") else: logger.debug(f"群组 {platform}:{group_number} (group_id: {group_id}) 已存在,返回现有记录。") return group_id async def get_group_info_by_name(self, group_name: str) -> dict | None: """根据 group_name 查找群组并返回基本信息 (如果找到)""" if not group_name: logger.debug("get_group_info_by_name 获取失败:group_name 不能为空") return None found_group_id = None for gid, name_in_cache in self.group_name_list.items(): if name_in_cache == group_name: found_group_id = gid break if not found_group_id: def _db_find_by_name_sync(g_name_to_find: str): return GroupInfo.get_or_none(GroupInfo.group_name == g_name_to_find) record = await asyncio.to_thread(_db_find_by_name_sync, group_name) if record: found_group_id = record.group_id if ( found_group_id not in self.group_name_list or self.group_name_list[found_group_id] != group_name ): self.group_name_list[found_group_id] = group_name else: logger.debug(f"数据库中也未找到名为 '{group_name}' 的群组 (Peewee)") return None if found_group_id: required_fields = [ "group_id", "platform", "group_number", "group_name", "group_impression", "short_impression", "member_count", "create_time", "last_active", ] valid_fields_to_get = [ f for f in required_fields if f in GroupInfo._meta.fields or f in group_info_default # type: ignore ] group_data = await self.get_values(found_group_id, valid_fields_to_get) if group_data: final_result = {key: group_data.get(key) for key in required_fields} return final_result else: logger.warning(f"找到了 group_id '{found_group_id}' 但 get_values 返回空 (Peewee)") return None logger.error(f"逻辑错误:未能为 '{group_name}' 确定 group_id (Peewee)") return None group_info_manager = None def get_group_info_manager(): global group_info_manager if group_info_manager is None: group_info_manager = GroupInfoManager() return group_info_manager