Files
Mofox-Core/src/person_info/group_info.py
2025-08-11 22:53:00 +08:00

558 lines
23 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import copy
import hashlib
import datetime
import asyncio
import json
from typing import Dict, Union, Optional, List
from src.common.logger import get_logger
from src.common.database.database import db
from src.common.database.database_model import GroupInfo
"""
GroupInfoManager 类方法功能摘要:
1. get_group_id - 根据平台和群号生成MD5哈希的唯一group_id
2. create_group_info - 创建新群组信息文档(自动合并默认值)
3. update_one_field - 更新单个字段值(若文档不存在则创建)
4. del_one_document - 删除指定group_id的文档
5. get_value - 获取单个字段值(返回实际值或默认值)
6. get_values - 批量获取字段值(任一字段无效则返回空字典)
7. add_member - 添加群成员
8. remove_member - 移除群成员
9. get_member_list - 获取群成员列表
"""
logger = get_logger("group_info")
JSON_SERIALIZED_FIELDS = ["member_list", "topic"]
group_info_default = {
"group_id": None,
"group_name": None,
"platform": "unknown",
"group_impression": None,
"member_list": [],
"topic":[],
"create_time": None,
"last_active": None,
"member_count": 0,
}
class GroupInfoManager:
def __init__(self):
self.group_name_list = {}
try:
db.connect(reuse_if_open=True)
# 设置连接池参数
if hasattr(db, "execute_sql"):
# 设置SQLite优化参数
db.execute_sql("PRAGMA cache_size = -64000") # 64MB缓存
db.execute_sql("PRAGMA temp_store = memory") # 临时存储在内存中
db.execute_sql("PRAGMA mmap_size = 268435456") # 256MB内存映射
db.create_tables([GroupInfo], safe=True)
except Exception as e:
logger.error(f"数据库连接或 GroupInfo 表创建失败: {e}")
# 初始化时读取所有group_name
try:
for record in GroupInfo.select(GroupInfo.group_id, GroupInfo.group_name).where(
GroupInfo.group_name.is_null(False)
):
if record.group_name:
self.group_name_list[record.group_id] = record.group_name
logger.debug(f"已加载 {len(self.group_name_list)} 个群组名称 (Peewee)")
except Exception as e:
logger.error(f"从 Peewee 加载 group_name_list 失败: {e}")
@staticmethod
def get_group_id(platform: str, group_number: Union[int, str]) -> str:
"""获取群组唯一id"""
# 添加空值检查,防止 platform 为 None 时出错
if platform is None:
platform = "unknown"
elif "-" in platform:
platform = platform.split("-")[1]
components = [platform, str(group_number)]
key = "_".join(components)
return hashlib.md5(key.encode()).hexdigest()
async def is_group_known(self, platform: str, group_number: int):
"""判断是否知道某个群组"""
group_id = self.get_group_id(platform, group_number)
def _db_check_known_sync(g_id: str):
return GroupInfo.get_or_none(GroupInfo.group_id == g_id) is not None
try:
return await asyncio.to_thread(_db_check_known_sync, group_id)
except Exception as e:
logger.error(f"检查群组 {group_id} 是否已知时出错 (Peewee): {e}")
return False
@staticmethod
async def create_group_info(group_id: str, data: Optional[dict] = None):
"""创建一个群组信息项"""
if not group_id:
logger.debug("创建失败group_id不存在")
return
_group_info_default = copy.deepcopy(group_info_default)
model_fields = GroupInfo._meta.fields.keys() # type: ignore
final_data = {"group_id": group_id}
# Start with defaults for all model fields
for key, default_value in _group_info_default.items():
if key in model_fields:
final_data[key] = default_value
# Override with provided data
if data:
for key, value in data.items():
if key in model_fields:
final_data[key] = value
# Ensure group_id is correctly set from the argument
final_data["group_id"] = group_id
# Serialize JSON fields
for key in JSON_SERIALIZED_FIELDS:
if key in final_data:
if isinstance(final_data[key], (list, dict)):
final_data[key] = json.dumps(final_data[key], ensure_ascii=False)
elif final_data[key] is None: # Default for lists is [], store as "[]"
final_data[key] = json.dumps([], ensure_ascii=False)
def _db_create_sync(g_data: dict):
try:
GroupInfo.create(**g_data)
return True
except Exception as e:
logger.error(f"创建 GroupInfo 记录 {g_data.get('group_id')} 失败 (Peewee): {e}")
return False
await asyncio.to_thread(_db_create_sync, final_data)
async def _safe_create_group_info(self, group_id: str, data: Optional[dict] = None):
"""安全地创建群组信息,处理竞态条件"""
if not group_id:
logger.debug("创建失败group_id不存在")
return
_group_info_default = copy.deepcopy(group_info_default)
model_fields = GroupInfo._meta.fields.keys() # type: ignore
final_data = {"group_id": group_id}
# Start with defaults for all model fields
for key, default_value in _group_info_default.items():
if key in model_fields:
final_data[key] = default_value
# Override with provided data
if data:
for key, value in data.items():
if key in model_fields:
final_data[key] = value
# Ensure group_id is correctly set from the argument
final_data["group_id"] = group_id
# Serialize JSON fields
for key in JSON_SERIALIZED_FIELDS:
if key in final_data:
if isinstance(final_data[key], (list, dict)):
final_data[key] = json.dumps(final_data[key], ensure_ascii=False)
elif final_data[key] is None: # Default for lists is [], store as "[]"
final_data[key] = json.dumps([], ensure_ascii=False)
def _db_safe_create_sync(g_data: dict):
try:
# 首先检查是否已存在
existing = GroupInfo.get_or_none(GroupInfo.group_id == g_data["group_id"])
if existing:
logger.debug(f"群组 {g_data['group_id']} 已存在,跳过创建")
return True
# 尝试创建
GroupInfo.create(**g_data)
return True
except Exception as e:
if "UNIQUE constraint failed" in str(e):
logger.debug(f"检测到并发创建群组 {g_data.get('group_id')},跳过错误")
return True # 其他协程已创建,视为成功
else:
logger.error(f"创建 GroupInfo 记录 {g_data.get('group_id')} 失败 (Peewee): {e}")
return False
await asyncio.to_thread(_db_safe_create_sync, final_data)
async def update_one_field(self, group_id: str, field_name: str, value, data: Optional[Dict] = None):
"""更新某一个字段,会补全"""
if field_name not in GroupInfo._meta.fields: # type: ignore
logger.debug(f"更新'{field_name}'失败,未在 GroupInfo Peewee 模型中定义的字段。")
return
processed_value = value
if field_name in JSON_SERIALIZED_FIELDS:
if isinstance(value, (list, dict)):
processed_value = json.dumps(value, ensure_ascii=False, indent=None)
elif value is None: # Store None as "[]" for JSON list fields
processed_value = json.dumps([], ensure_ascii=False, indent=None)
def _db_update_sync(g_id: str, f_name: str, val_to_set):
import time
start_time = time.time()
try:
record = GroupInfo.get_or_none(GroupInfo.group_id == g_id)
query_time = time.time()
if record:
setattr(record, f_name, val_to_set)
record.save()
save_time = time.time()
total_time = save_time - start_time
if total_time > 0.5: # 如果超过500ms就记录日志
logger.warning(
f"数据库更新操作耗时 {total_time:.3f}秒 (查询: {query_time - start_time:.3f}s, 保存: {save_time - query_time:.3f}s) group_id={g_id}, field={f_name}"
)
return True, False # Found and updated, no creation needed
else:
total_time = time.time() - start_time
if total_time > 0.5:
logger.warning(f"数据库查询操作耗时 {total_time:.3f}秒 group_id={g_id}, field={f_name}")
return False, True # Not found, needs creation
except Exception as e:
total_time = time.time() - start_time
logger.error(f"数据库操作异常,耗时 {total_time:.3f}秒: {e}")
raise
found, needs_creation = await asyncio.to_thread(_db_update_sync, group_id, field_name, processed_value)
if needs_creation:
logger.info(f"{group_id} 不存在,将新建。")
creation_data = data if data is not None else {}
# Ensure platform and group_number are present for context if available from 'data'
# but primarily, set the field that triggered the update.
# The create_group_info will handle defaults and serialization.
creation_data[field_name] = value # Pass original value to create_group_info
# Ensure platform and group_number are in creation_data if available,
# otherwise create_group_info will use defaults.
if data and "platform" in data:
creation_data["platform"] = data["platform"]
if data and "group_number" in data:
creation_data["group_number"] = data["group_number"]
# 使用安全的创建方法,处理竞态条件
await self._safe_create_group_info(group_id, creation_data)
@staticmethod
async def del_one_document(group_id: str):
"""删除指定 group_id 的文档"""
if not group_id:
logger.debug("删除失败group_id 不能为空")
return
def _db_delete_sync(g_id: str):
try:
query = GroupInfo.delete().where(GroupInfo.group_id == g_id)
deleted_count = query.execute()
return deleted_count
except Exception as e:
logger.error(f"删除 GroupInfo {g_id} 失败 (Peewee): {e}")
return 0
deleted_count = await asyncio.to_thread(_db_delete_sync, group_id)
if deleted_count > 0:
logger.debug(f"删除成功group_id={group_id} (Peewee)")
else:
logger.debug(f"删除失败:未找到 group_id={group_id} 或删除未影响行 (Peewee)")
@staticmethod
async def get_value(group_id: str, field_name: str):
"""获取指定群组指定字段的值"""
default_value_for_field = group_info_default.get(field_name)
if field_name in JSON_SERIALIZED_FIELDS and default_value_for_field is None:
default_value_for_field = [] # Ensure JSON fields default to [] if not in DB
def _db_get_value_sync(g_id: str, f_name: str):
record = GroupInfo.get_or_none(GroupInfo.group_id == g_id)
if record:
val = getattr(record, f_name, None)
if f_name in JSON_SERIALIZED_FIELDS:
if isinstance(val, str):
try:
return json.loads(val)
except json.JSONDecodeError:
logger.warning(f"字段 {f_name} for {g_id} 包含无效JSON: {val}. 返回默认值.")
return [] # Default for JSON fields on error
elif val is None: # Field exists in DB but is None
return [] # Default for JSON fields
# If val is already a list/dict (e.g. if somehow set without serialization)
return val # Should ideally not happen if update_one_field is always used
return val
return None # Record not found
try:
value_from_db = await asyncio.to_thread(_db_get_value_sync, group_id, field_name)
if value_from_db is not None:
return value_from_db
if field_name in group_info_default:
return default_value_for_field
logger.warning(f"字段 {field_name} 在 group_info_default 中未定义,且在数据库中未找到。")
return None # Ultimate fallback
except Exception as e:
logger.error(f"获取字段 {field_name} for {group_id} 时出错 (Peewee): {e}")
# Fallback to default in case of any error during DB access
return default_value_for_field if field_name in group_info_default else None
@staticmethod
async def get_values(group_id: str, field_names: list) -> dict:
"""获取指定group_id文档的多个字段值若不存在该字段则返回该字段的全局默认值"""
if not group_id:
logger.debug("get_values获取失败group_id不能为空")
return {}
result = {}
def _db_get_record_sync(g_id: str):
return GroupInfo.get_or_none(GroupInfo.group_id == g_id)
record = await asyncio.to_thread(_db_get_record_sync, group_id)
for field_name in field_names:
if field_name not in GroupInfo._meta.fields: # type: ignore
if field_name in group_info_default:
result[field_name] = copy.deepcopy(group_info_default[field_name])
logger.debug(f"字段'{field_name}'不在Peewee模型中使用默认配置值。")
else:
logger.debug(f"get_values查询失败字段'{field_name}'未在Peewee模型和默认配置中定义。")
result[field_name] = None
continue
if record:
value = getattr(record, field_name)
if value is not None:
result[field_name] = value
else:
result[field_name] = copy.deepcopy(group_info_default.get(field_name))
else:
result[field_name] = copy.deepcopy(group_info_default.get(field_name))
return result
async def add_member(self, group_id: str, member_info: dict):
"""添加群成员(使用 last_active_time不使用 join_time"""
if not group_id or not member_info:
logger.debug("添加成员失败group_id或member_info不能为空")
return
# 规范化成员字段
normalized_member = dict(member_info)
normalized_member.pop("join_time", None)
if "last_active_time" not in normalized_member:
normalized_member["last_active_time"] = datetime.datetime.now().timestamp()
member_id = normalized_member.get("user_id")
if not member_id:
logger.debug("添加成员失败:缺少 user_id")
return
# 获取当前成员列表
current_members = await self.get_value(group_id, "member_list")
if not isinstance(current_members, list):
current_members = []
# 移除已存在的同 user_id 成员
current_members = [m for m in current_members if m.get("user_id") != member_id]
# 添加新成员
current_members.append(normalized_member)
# 更新成员列表和成员数量
await self.update_one_field(group_id, "member_list", current_members)
await self.update_one_field(group_id, "member_count", len(current_members))
await self.update_one_field(group_id, "last_active", datetime.datetime.now().timestamp())
logger.info(f"群组 {group_id} 添加/更新成员 {normalized_member.get('nickname', member_id)} 成功")
async def remove_member(self, group_id: str, user_id: str):
"""移除群成员"""
if not group_id or not user_id:
logger.debug("移除成员失败group_id或user_id不能为空")
return
# 获取当前成员列表
current_members = await self.get_value(group_id, "member_list")
if not isinstance(current_members, list):
logger.debug(f"群组 {group_id} 成员列表为空或格式错误")
return
# 移除指定成员
original_count = len(current_members)
current_members = [m for m in current_members if m.get("user_id") != user_id]
new_count = len(current_members)
if new_count < original_count:
# 更新成员列表和成员数量
await self.update_one_field(group_id, "member_list", current_members)
await self.update_one_field(group_id, "member_count", new_count)
await self.update_one_field(group_id, "last_active", datetime.datetime.now().timestamp())
logger.info(f"群组 {group_id} 移除成员 {user_id} 成功")
else:
logger.debug(f"群组 {group_id} 中未找到成员 {user_id}")
async def get_member_list(self, group_id: str) -> List[dict]:
"""获取群成员列表"""
if not group_id:
logger.debug("获取成员列表失败group_id不能为空")
return []
members = await self.get_value(group_id, "member_list")
if isinstance(members, list):
return members
return []
async def get_or_create_group(
self, platform: str, group_number: int, group_name: str = None
) -> str:
"""
根据 platform 和 group_number 获取 group_id。
如果对应的群组不存在,则使用提供的信息创建新群组。
使用try-except处理竞态条件避免重复创建错误。
"""
group_id = self.get_group_id(platform, group_number)
def _db_get_or_create_sync(g_id: str, init_data: dict):
"""原子性的获取或创建操作"""
# 首先尝试获取现有记录
record = GroupInfo.get_or_none(GroupInfo.group_id == g_id)
if record:
return record, False # 记录存在,未创建
# 记录不存在,尝试创建
try:
GroupInfo.create(**init_data)
return GroupInfo.get(GroupInfo.group_id == g_id), True # 创建成功
except Exception as e:
# 如果创建失败(可能是因为竞态条件),再次尝试获取
if "UNIQUE constraint failed" in str(e):
logger.debug(f"检测到并发创建群组 {g_id},获取现有记录")
record = GroupInfo.get_or_none(GroupInfo.group_id == g_id)
if record:
return record, False # 其他协程已创建,返回现有记录
# 如果仍然失败,重新抛出异常
raise e
initial_data = {
"group_id": group_id,
"platform": platform,
"group_number": str(group_number),
"group_name": group_name,
"create_time": datetime.datetime.now().timestamp(),
"last_active": datetime.datetime.now().timestamp(),
"member_count": 0,
"member_list": [],
"group_info": {},
}
# 序列化JSON字段
for key in JSON_SERIALIZED_FIELDS:
if key in initial_data:
if isinstance(initial_data[key], (list, dict)):
initial_data[key] = json.dumps(initial_data[key], ensure_ascii=False)
elif initial_data[key] is None:
initial_data[key] = json.dumps([], ensure_ascii=False)
model_fields = GroupInfo._meta.fields.keys() # type: ignore
filtered_initial_data = {k: v for k, v in initial_data.items() if v is not None and k in model_fields}
record, was_created = await asyncio.to_thread(_db_get_or_create_sync, group_id, filtered_initial_data)
if was_created:
logger.info(f"群组 {platform}:{group_number} (group_id: {group_id}) 不存在,将创建新记录 (Peewee)。")
logger.info(f"已为 {group_id} 创建新记录,初始数据 (filtered for model): {filtered_initial_data}")
else:
logger.debug(f"群组 {platform}:{group_number} (group_id: {group_id}) 已存在,返回现有记录。")
return group_id
async def get_group_info_by_name(self, group_name: str) -> dict | None:
"""根据 group_name 查找群组并返回基本信息 (如果找到)"""
if not group_name:
logger.debug("get_group_info_by_name 获取失败group_name 不能为空")
return None
found_group_id = None
for gid, name_in_cache in self.group_name_list.items():
if name_in_cache == group_name:
found_group_id = gid
break
if not found_group_id:
def _db_find_by_name_sync(g_name_to_find: str):
return GroupInfo.get_or_none(GroupInfo.group_name == g_name_to_find)
record = await asyncio.to_thread(_db_find_by_name_sync, group_name)
if record:
found_group_id = record.group_id
if (
found_group_id not in self.group_name_list
or self.group_name_list[found_group_id] != group_name
):
self.group_name_list[found_group_id] = group_name
else:
logger.debug(f"数据库中也未找到名为 '{group_name}' 的群组 (Peewee)")
return None
if found_group_id:
required_fields = [
"group_id",
"platform",
"group_number",
"group_name",
"group_impression",
"short_impression",
"member_count",
"create_time",
"last_active",
]
valid_fields_to_get = [
f
for f in required_fields
if f in GroupInfo._meta.fields or f in group_info_default # type: ignore
]
group_data = await self.get_values(found_group_id, valid_fields_to_get)
if group_data:
final_result = {key: group_data.get(key) for key in required_fields}
return final_result
else:
logger.warning(f"找到了 group_id '{found_group_id}' 但 get_values 返回空 (Peewee)")
return None
logger.error(f"逻辑错误:未能为 '{group_name}' 确定 group_id (Peewee)")
return None
group_info_manager = None
def get_group_info_manager():
global group_info_manager
if group_info_manager is None:
group_info_manager = GroupInfoManager()
return group_info_manager