ruff fix但指定了--unsafe-fixes
This commit is contained in:
@@ -30,7 +30,7 @@ class CacheManager:
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
if not cls._instance:
|
||||
cls._instance = super(CacheManager, cls).__new__(cls)
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self, default_ttl: int = 3600):
|
||||
@@ -70,7 +70,7 @@ class CacheManager:
|
||||
return None
|
||||
|
||||
# 确保embedding_result是一维数组或列表
|
||||
if isinstance(embedding_result, (list, tuple, np.ndarray)):
|
||||
if isinstance(embedding_result, list | tuple | np.ndarray):
|
||||
# 转换为numpy数组进行处理
|
||||
embedding_array = np.array(embedding_result)
|
||||
|
||||
|
||||
@@ -96,7 +96,7 @@ class InterestMatchResult(BaseDataModel):
|
||||
confidence: float = 0.0 # 匹配置信度 (0.0-1.0)
|
||||
matched_keywords: list[str] = field(default_factory=list)
|
||||
|
||||
def add_match(self, tag_name: str, score: float, keywords: list[str] = None):
|
||||
def add_match(self, tag_name: str, score: float, keywords: list[str] | None = None):
|
||||
"""添加匹配结果"""
|
||||
self.matched_tags.append(tag_name)
|
||||
self.match_scores[tag_name] = score
|
||||
|
||||
@@ -220,7 +220,7 @@ class DatabaseMessages(BaseDataModel):
|
||||
"chat_info_user_cardname": self.chat_info.user_info.user_cardname,
|
||||
}
|
||||
|
||||
def update_message_info(self, interest_value: float = None, actions: list = None, should_reply: bool = None):
|
||||
def update_message_info(self, interest_value: float | None = None, actions: list | None = None, should_reply: bool | None = None):
|
||||
"""
|
||||
更新消息信息
|
||||
|
||||
|
||||
@@ -187,7 +187,7 @@ class ConnectionPoolManager:
|
||||
|
||||
async def _cleanup_expired_connections_locked(self):
|
||||
"""清理过期连接(需要在锁内调用)"""
|
||||
current_time = time.time()
|
||||
time.time()
|
||||
expired_connections = []
|
||||
|
||||
for connection_info in list(self._connections):
|
||||
|
||||
@@ -61,7 +61,7 @@ class DatabaseBatchScheduler:
|
||||
|
||||
# 调度控制
|
||||
self._scheduler_task: asyncio.Task | None = None
|
||||
self._is_running = bool = False
|
||||
self._is_running = False
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
# 统计信息
|
||||
@@ -189,7 +189,7 @@ class DatabaseBatchScheduler:
|
||||
queue.clear()
|
||||
|
||||
# 批量执行各队列的操作
|
||||
for queue_key, operations in queues_copy.items():
|
||||
for operations in queues_copy.values():
|
||||
if operations:
|
||||
await self._execute_operations(list(operations))
|
||||
|
||||
@@ -270,7 +270,7 @@ class DatabaseBatchScheduler:
|
||||
query = select(ops[0].model_class)
|
||||
for field_name, value in conditions.items():
|
||||
model_attr = getattr(ops[0].model_class, field_name)
|
||||
if isinstance(value, (list, tuple, set)):
|
||||
if isinstance(value, list | tuple | set):
|
||||
query = query.where(model_attr.in_(value))
|
||||
else:
|
||||
query = query.where(model_attr == value)
|
||||
@@ -336,7 +336,7 @@ class DatabaseBatchScheduler:
|
||||
stmt = update(op.model_class)
|
||||
for field_name, value in op.conditions.items():
|
||||
model_attr = getattr(op.model_class, field_name)
|
||||
if isinstance(value, (list, tuple, set)):
|
||||
if isinstance(value, list | tuple | set):
|
||||
stmt = stmt.where(model_attr.in_(value))
|
||||
else:
|
||||
stmt = stmt.where(model_attr == value)
|
||||
@@ -366,7 +366,7 @@ class DatabaseBatchScheduler:
|
||||
stmt = delete(op.model_class)
|
||||
for field_name, value in op.conditions.items():
|
||||
model_attr = getattr(op.model_class, field_name)
|
||||
if isinstance(value, (list, tuple, set)):
|
||||
if isinstance(value, list | tuple | set):
|
||||
stmt = stmt.where(model_attr.in_(value))
|
||||
else:
|
||||
stmt = stmt.where(model_attr == value)
|
||||
@@ -398,7 +398,7 @@ class DatabaseBatchScheduler:
|
||||
if field_name not in merged[condition_key]:
|
||||
merged[condition_key][field_name] = []
|
||||
|
||||
if isinstance(value, (list, tuple, set)):
|
||||
if isinstance(value, list | tuple | set):
|
||||
merged[condition_key][field_name].extend(value)
|
||||
else:
|
||||
merged[condition_key][field_name].append(value)
|
||||
|
||||
@@ -915,7 +915,7 @@ class ModuleColoredConsoleRenderer:
|
||||
for key, value in event_dict.items():
|
||||
if key not in ("timestamp", "level", "logger_name", "event") and key not in ("color", "alias"):
|
||||
# 确保值也转换为字符串
|
||||
if isinstance(value, (dict, list)):
|
||||
if isinstance(value, dict | list):
|
||||
try:
|
||||
value_str = orjson.dumps(value).decode("utf-8")
|
||||
except (TypeError, ValueError):
|
||||
@@ -1213,7 +1213,7 @@ def shutdown_logging():
|
||||
|
||||
# 关闭所有其他logger的handler
|
||||
logger_dict = logging.getLogger().manager.loggerDict
|
||||
for _name, logger_obj in logger_dict.items():
|
||||
for logger_obj in logger_dict.values():
|
||||
if isinstance(logger_obj, logging.Logger):
|
||||
for handler in logger_obj.handlers[:]:
|
||||
if hasattr(handler, "close"):
|
||||
|
||||
@@ -24,7 +24,7 @@ class ChromaDBImpl(VectorDBBase):
|
||||
if not cls._instance:
|
||||
with cls._lock:
|
||||
if not cls._instance:
|
||||
cls._instance = super(ChromaDBImpl, cls).__new__(cls)
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self, path: str = "data/chroma_db", **kwargs: Any):
|
||||
|
||||
Reference in New Issue
Block a user