This commit is contained in:
SengokuCola
2025-07-17 00:57:07 +08:00
37 changed files with 644 additions and 406 deletions

View File

@@ -107,11 +107,12 @@ class ExpressionLearner:
last_active_time = expr.get("last_active_time", time.time())
# 查重同chat_id+type+situation+style
from src.common.database.database_model import Expression
query = Expression.select().where(
(Expression.chat_id == chat_id) &
(Expression.type == type_str) &
(Expression.situation == situation) &
(Expression.style == style_val)
(Expression.chat_id == chat_id)
& (Expression.type == type_str)
& (Expression.situation == situation)
& (Expression.style == style_val)
)
if query.exists():
expr_obj = query.get()
@@ -125,7 +126,7 @@ class ExpressionLearner:
count=count,
last_active_time=last_active_time,
chat_id=chat_id,
type=type_str
type=type_str,
)
logger.info(f"已迁移 {expr_file} 到数据库")
except Exception as e:
@@ -149,24 +150,28 @@ class ExpressionLearner:
# 直接从数据库查询
style_query = Expression.select().where((Expression.chat_id == chat_id) & (Expression.type == "style"))
for expr in style_query:
learnt_style_expressions.append({
"situation": expr.situation,
"style": expr.style,
"count": expr.count,
"last_active_time": expr.last_active_time,
"source_id": chat_id,
"type": "style"
})
learnt_style_expressions.append(
{
"situation": expr.situation,
"style": expr.style,
"count": expr.count,
"last_active_time": expr.last_active_time,
"source_id": chat_id,
"type": "style",
}
)
grammar_query = Expression.select().where((Expression.chat_id == chat_id) & (Expression.type == "grammar"))
for expr in grammar_query:
learnt_grammar_expressions.append({
"situation": expr.situation,
"style": expr.style,
"count": expr.count,
"last_active_time": expr.last_active_time,
"source_id": chat_id,
"type": "grammar"
})
learnt_grammar_expressions.append(
{
"situation": expr.situation,
"style": expr.style,
"count": expr.count,
"last_active_time": expr.last_active_time,
"source_id": chat_id,
"type": "grammar",
}
)
return learnt_style_expressions, learnt_grammar_expressions
def is_similar(self, s1: str, s2: str) -> bool:
@@ -213,14 +218,16 @@ class ExpressionLearner:
logger.error(f"全局衰减{type}表达方式失败: {e}")
continue
learnt_style: Optional[List[Tuple[str, str, str]]] = []
learnt_grammar: Optional[List[Tuple[str, str, str]]] = []
# 学习新的表达方式(这里会进行局部衰减)
for _ in range(3):
learnt_style: Optional[List[Tuple[str, str, str]]] = await self.learn_and_store(type="style", num=25)
learnt_style = await self.learn_and_store(type="style", num=25)
if not learnt_style:
return [], []
for _ in range(1):
learnt_grammar: Optional[List[Tuple[str, str, str]]] = await self.learn_and_store(type="grammar", num=10)
learnt_grammar = await self.learn_and_store(type="grammar", num=10)
if not learnt_grammar:
return [], []
@@ -321,10 +328,10 @@ class ExpressionLearner:
for new_expr in expr_list:
# 查找是否已存在相似表达方式
query = Expression.select().where(
(Expression.chat_id == chat_id) &
(Expression.type == type) &
(Expression.situation == new_expr["situation"]) &
(Expression.style == new_expr["style"])
(Expression.chat_id == chat_id)
& (Expression.type == type)
& (Expression.situation == new_expr["situation"])
& (Expression.style == new_expr["style"])
)
if query.exists():
expr_obj = query.get()
@@ -342,13 +349,17 @@ class ExpressionLearner:
count=1,
last_active_time=current_time,
chat_id=chat_id,
type=type
type=type,
)
# 限制最大数量
exprs = list(Expression.select().where((Expression.chat_id == chat_id) & (Expression.type == type)).order_by(Expression.count.asc()))
exprs = list(
Expression.select()
.where((Expression.chat_id == chat_id) & (Expression.type == type))
.order_by(Expression.count.asc())
)
if len(exprs) > MAX_EXPRESSION_COUNT:
# 删除count最小的多余表达方式
for expr in exprs[:len(exprs) - MAX_EXPRESSION_COUNT]:
for expr in exprs[: len(exprs) - MAX_EXPRESSION_COUNT]:
expr.delete_instance()
return learnt_expressions

View File

@@ -2,6 +2,7 @@ from dataclasses import dataclass
import json
import os
import math
import asyncio
from typing import Dict, List, Tuple
import numpy as np
@@ -99,7 +100,30 @@ class EmbeddingStore:
self.idx2hash = None
def _get_embedding(self, s: str) -> List[float]:
return get_embedding(s)
"""获取字符串的嵌入向量,处理异步调用"""
try:
# 尝试获取当前事件循环
asyncio.get_running_loop()
# 如果在事件循环中,使用线程池执行
import concurrent.futures
def run_in_thread():
return asyncio.run(get_embedding(s))
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(run_in_thread)
result = future.result()
if result is None:
logger.error(f"获取嵌入失败: {s}")
return []
return result
except RuntimeError:
# 没有运行的事件循环,直接运行
result = asyncio.run(get_embedding(s))
if result is None:
logger.error(f"获取嵌入失败: {s}")
return []
return result
def get_test_file_path(self):
return EMBEDDING_TEST_FILE

View File

@@ -1,3 +1,4 @@
import asyncio
import json
import time
from typing import List, Union
@@ -7,8 +8,12 @@ from . import prompt_template
from .knowledge_lib import INVALID_ENTITY
from src.llm_models.utils_model import LLMRequest
from json_repair import repair_json
def _extract_json_from_text(text: str) -> dict:
def _extract_json_from_text(text: str):
"""从文本中提取JSON数据的高容错方法"""
if text is None:
logger.error("输入文本为None")
return []
try:
fixed_json = repair_json(text)
if isinstance(fixed_json, str):
@@ -16,23 +21,66 @@ def _extract_json_from_text(text: str) -> dict:
else:
parsed_json = fixed_json
if isinstance(parsed_json, list) and parsed_json:
parsed_json = parsed_json[0]
if isinstance(parsed_json, dict):
# 如果是列表,直接返回
if isinstance(parsed_json, list):
return parsed_json
# 如果是字典且只有一个项目,可能包装了列表
if isinstance(parsed_json, dict):
# 如果字典只有一个键,并且值是列表,返回那个列表
if len(parsed_json) == 1:
value = list(parsed_json.values())[0]
if isinstance(value, list):
return value
return parsed_json
# 其他情况,尝试转换为列表
logger.warning(f"解析的JSON不是预期格式: {type(parsed_json)}, 内容: {parsed_json}")
return []
except Exception as e:
logger.error(f"JSON提取失败: {e}, 原始文本: {text[:100]}...")
logger.error(f"JSON提取失败: {e}, 原始文本: {text[:100] if text else 'None'}...")
return []
def _entity_extract(llm_req: LLMRequest, paragraph: str) -> List[str]:
"""对段落进行实体提取返回提取出的实体列表JSON格式"""
entity_extract_context = prompt_template.build_entity_extract_context(paragraph)
response, (reasoning_content, model_name) = llm_req.generate_response_async(entity_extract_context)
# 使用 asyncio.run 来运行异步方法
try:
# 如果当前已有事件循环在运行,使用它
loop = asyncio.get_running_loop()
future = asyncio.run_coroutine_threadsafe(
llm_req.generate_response_async(entity_extract_context), loop
)
response, (reasoning_content, model_name) = future.result()
except RuntimeError:
# 如果没有运行中的事件循环,直接使用 asyncio.run
response, (reasoning_content, model_name) = asyncio.run(
llm_req.generate_response_async(entity_extract_context)
)
# 添加调试日志
logger.debug(f"LLM返回的原始响应: {response}")
entity_extract_result = _extract_json_from_text(response)
# 尝试load JSON数据
json.loads(entity_extract_result)
# 检查返回的是否为有效的实体列表
if not isinstance(entity_extract_result, list):
# 如果不是列表,可能是字典格式,尝试从中提取列表
if isinstance(entity_extract_result, dict):
# 尝试常见的键名
for key in ['entities', 'result', 'data', 'items']:
if key in entity_extract_result and isinstance(entity_extract_result[key], list):
entity_extract_result = entity_extract_result[key]
break
else:
# 如果找不到合适的列表,抛出异常
raise Exception(f"实体提取结果格式错误,期望列表但得到: {type(entity_extract_result)}")
else:
raise Exception(f"实体提取结果格式错误,期望列表但得到: {type(entity_extract_result)}")
# 过滤无效实体
entity_extract_result = [
entity
for entity in entity_extract_result
@@ -50,16 +98,47 @@ def _rdf_triple_extract(llm_req: LLMRequest, paragraph: str, entities: list) ->
rdf_extract_context = prompt_template.build_rdf_triple_extract_context(
paragraph, entities=json.dumps(entities, ensure_ascii=False)
)
response, (reasoning_content, model_name) = llm_req.generate_response_async(rdf_extract_context)
# 使用 asyncio.run 来运行异步方法
try:
# 如果当前已有事件循环在运行,使用它
loop = asyncio.get_running_loop()
future = asyncio.run_coroutine_threadsafe(
llm_req.generate_response_async(rdf_extract_context), loop
)
response, (reasoning_content, model_name) = future.result()
except RuntimeError:
# 如果没有运行中的事件循环,直接使用 asyncio.run
response, (reasoning_content, model_name) = asyncio.run(
llm_req.generate_response_async(rdf_extract_context)
)
entity_extract_result = _extract_json_from_text(response)
# 尝试load JSON数据
json.loads(entity_extract_result)
for triple in entity_extract_result:
if len(triple) != 3 or (triple[0] is None or triple[1] is None or triple[2] is None) or "" in triple:
# 添加调试日志
logger.debug(f"RDF LLM返回的原始响应: {response}")
rdf_triple_result = _extract_json_from_text(response)
# 检查返回的是否为有效的三元组列表
if not isinstance(rdf_triple_result, list):
# 如果不是列表,可能是字典格式,尝试从中提取列表
if isinstance(rdf_triple_result, dict):
# 尝试常见的键名
for key in ['triples', 'result', 'data', 'items']:
if key in rdf_triple_result and isinstance(rdf_triple_result[key], list):
rdf_triple_result = rdf_triple_result[key]
break
else:
# 如果找不到合适的列表,抛出异常
raise Exception(f"RDF三元组提取结果格式错误期望列表但得到: {type(rdf_triple_result)}")
else:
raise Exception(f"RDF三元组提取结果格式错误期望列表但得到: {type(rdf_triple_result)}")
# 验证三元组格式
for triple in rdf_triple_result:
if not isinstance(triple, list) or len(triple) != 3 or (triple[0] is None or triple[1] is None or triple[2] is None) or "" in triple:
raise Exception("RDF提取结果格式错误")
return entity_extract_result
return rdf_triple_result
def info_extract_from_str(

View File

@@ -184,10 +184,10 @@ class KGManager:
progress.update(task, advance=1)
continue
ent = embedding_manager.entities_embedding_store.store.get(ent_hash)
assert isinstance(ent, EmbeddingStoreItem)
if ent is None:
progress.update(task, advance=1)
continue
assert isinstance(ent, EmbeddingStoreItem)
# 查询相似实体
similar_ents = embedding_manager.entities_embedding_store.search_top_k(
ent.embedding, global_config["rag"]["params"]["synonym_search_top_k"]
@@ -265,7 +265,10 @@ class KGManager:
if node_hash not in existed_nodes:
if node_hash.startswith(local_storage['ent_namespace']):
# 新增实体节点
node = embedding_manager.entities_embedding_store.store[node_hash]
node = embedding_manager.entities_embedding_store.store.get(node_hash)
if node is None:
logger.warning(f"实体节点 {node_hash} 在嵌入库中不存在,跳过")
continue
assert isinstance(node, EmbeddingStoreItem)
node_item = self.graph[node_hash]
node_item["content"] = node.str
@@ -274,7 +277,10 @@ class KGManager:
self.graph.update_node(node_item)
elif node_hash.startswith(local_storage['pg_namespace']):
# 新增文段节点
node = embedding_manager.paragraphs_embedding_store.store[node_hash]
node = embedding_manager.paragraphs_embedding_store.store.get(node_hash)
if node is None:
logger.warning(f"段落节点 {node_hash} 在嵌入库中不存在,跳过")
continue
assert isinstance(node, EmbeddingStoreItem)
content = node.str.replace("\n", " ")
node_item = self.graph[node_hash]

View File

@@ -11,12 +11,14 @@ entity_extract_system_prompt = """你是一个性能优异的实体提取系统
"""
def build_entity_extract_context(paragraph: str) -> list[LLMMessage]:
messages = [
LLMMessage("system", entity_extract_system_prompt).to_dict(),
LLMMessage("user", f"""段落:\n```\n{paragraph}```""").to_dict(),
]
return messages
def build_entity_extract_context(paragraph: str) -> str:
"""构建实体提取的完整提示文本"""
return f"""{entity_extract_system_prompt}
段落:
```
{paragraph}
```"""
rdf_triple_extract_system_prompt = """你是一个性能优异的RDF资源描述框架由节点和边组成节点表示实体/资源、属性边则表示了实体和实体之间的关系以及实体和属性的关系。构造系统。你的任务是根据给定的段落和实体列表构建RDF图。
@@ -36,12 +38,19 @@ rdf_triple_extract_system_prompt = """你是一个性能优异的RDF资源描
"""
def build_rdf_triple_extract_context(paragraph: str, entities: str) -> list[LLMMessage]:
messages = [
LLMMessage("system", rdf_triple_extract_system_prompt).to_dict(),
LLMMessage("user", f"""段落:\n```\n{paragraph}```\n\n实体列表:\n```\n{entities}```""").to_dict(),
]
return messages
def build_rdf_triple_extract_context(paragraph: str, entities: str) -> str:
"""构建RDF三元组提取的完整提示文本"""
return f"""{rdf_triple_extract_system_prompt}
段落:
```
{paragraph}
```
实体列表:
```
{entities}
```"""
qa_system_prompt = """

View File

@@ -9,51 +9,49 @@ from src.common.logger import get_logger
import traceback
from src.config.config import global_config
from src.common.database.database_model import Memory # Peewee Models导入
from src.common.database.database_model import Memory # Peewee Models导入
logger = get_logger(__name__)
class MemoryItem:
def __init__(self,memory_id:str,chat_id:str,memory_text:str,keywords:list[str]):
def __init__(self, memory_id: str, chat_id: str, memory_text: str, keywords: list[str]):
self.memory_id = memory_id
self.chat_id = chat_id
self.memory_text:str = memory_text
self.keywords:list[str] = keywords
self.create_time:float = time.time()
self.last_view_time:float = time.time()
self.memory_text: str = memory_text
self.keywords: list[str] = keywords
self.create_time: float = time.time()
self.last_view_time: float = time.time()
class MemoryManager:
def __init__(self):
# self.memory_items:list[MemoryItem] = []
pass
class InstantMemory:
def __init__(self,chat_id):
self.chat_id = chat_id
def __init__(self, chat_id):
self.chat_id = chat_id
self.last_view_time = time.time()
self.summary_model = LLMRequest(
model=global_config.model.memory,
temperature=0.5,
request_type="memory.summary",
)
async def if_need_build(self,text):
async def if_need_build(self, text):
prompt = f"""
请判断以下内容中是否有值得记忆的信息如果有请输出1否则输出0
{text}
请只输出1或0就好
"""
try:
response,_ = await self.summary_model.generate_response_async(prompt)
response, _ = await self.summary_model.generate_response_async(prompt)
print(prompt)
print(response)
if "1" in response:
return True
else:
@@ -61,8 +59,8 @@ class InstantMemory:
except Exception as e:
logger.error(f"判断是否需要记忆出现错误:{str(e)} {traceback.format_exc()}")
return False
async def build_memory(self,text):
async def build_memory(self, text):
prompt = f"""
以下内容中存在值得记忆的信息,请你从中总结出一段值得记忆的信息,并输出
{text}
@@ -73,7 +71,7 @@ class InstantMemory:
}}
"""
try:
response,_ = await self.summary_model.generate_response_async(prompt)
response, _ = await self.summary_model.generate_response_async(prompt)
print(prompt)
print(response)
if not response:
@@ -81,53 +79,53 @@ class InstantMemory:
try:
repaired = repair_json(response)
result = json.loads(repaired)
memory_text = result.get('memory_text', '')
keywords = result.get('keywords', '')
memory_text = result.get("memory_text", "")
keywords = result.get("keywords", "")
if isinstance(keywords, str):
keywords_list = [k.strip() for k in keywords.split('/') if k.strip()]
keywords_list = [k.strip() for k in keywords.split("/") if k.strip()]
elif isinstance(keywords, list):
keywords_list = keywords
else:
keywords_list = []
return {'memory_text': memory_text, 'keywords': keywords_list}
return {"memory_text": memory_text, "keywords": keywords_list}
except Exception as parse_e:
logger.error(f"解析记忆json失败{str(parse_e)} {traceback.format_exc()}")
return None
except Exception as e:
logger.error(f"构建记忆出现错误:{str(e)} {traceback.format_exc()}")
return None
async def create_and_store_memory(self,text):
async def create_and_store_memory(self, text):
if_need = await self.if_need_build(text)
if if_need:
logger.info(f"需要记忆:{text}")
memory = await self.build_memory(text)
if memory and memory.get('memory_text'):
memory = await self.build_memory(text)
if memory and memory.get("memory_text"):
memory_id = f"{self.chat_id}_{time.time()}"
memory_item = MemoryItem(
memory_id=memory_id,
chat_id=self.chat_id,
memory_text=memory['memory_text'],
keywords=memory.get('keywords', [])
memory_text=memory["memory_text"],
keywords=memory.get("keywords", []),
)
await self.store_memory(memory_item)
else:
logger.info(f"不需要记忆:{text}")
async def store_memory(self,memory_item:MemoryItem):
async def store_memory(self, memory_item: MemoryItem):
memory = Memory(
memory_id=memory_item.memory_id,
chat_id=memory_item.chat_id,
memory_text=memory_item.memory_text,
keywords=memory_item.keywords,
create_time=memory_item.create_time,
last_view_time=memory_item.last_view_time
last_view_time=memory_item.last_view_time,
)
memory.save()
async def get_memory(self,target:str):
async def get_memory(self, target: str):
from json_repair import repair_json
prompt = f"""
请根据以下发言内容,判断是否需要提取记忆
{target}
@@ -144,7 +142,7 @@ class InstantMemory:
请只输出json格式不要输出其他多余内容
"""
try:
response,_ = await self.summary_model.generate_response_async(prompt)
response, _ = await self.summary_model.generate_response_async(prompt)
print(prompt)
print(response)
if not response:
@@ -153,15 +151,15 @@ class InstantMemory:
repaired = repair_json(response)
result = json.loads(repaired)
# 解析keywords
keywords = result.get('keywords', '')
keywords = result.get("keywords", "")
if isinstance(keywords, str):
keywords_list = [k.strip() for k in keywords.split('/') if k.strip()]
keywords_list = [k.strip() for k in keywords.split("/") if k.strip()]
elif isinstance(keywords, list):
keywords_list = keywords
else:
keywords_list = []
# 解析time为时间段
time_str = result.get('time', '').strip()
time_str = result.get("time", "").strip()
start_time, end_time = self._parse_time_range(time_str)
logger.info(f"start_time: {start_time}, end_time: {end_time}")
# 检索包含关键词的记忆
@@ -170,16 +168,15 @@ class InstantMemory:
start_ts = start_time.timestamp()
end_ts = end_time.timestamp()
query = Memory.select().where(
(Memory.chat_id == self.chat_id) &
(Memory.create_time >= start_ts) &
(Memory.create_time < end_ts)
(Memory.chat_id == self.chat_id)
& (Memory.create_time >= start_ts) # type: ignore
& (Memory.create_time < end_ts) # type: ignore
)
else:
query = Memory.select().where(Memory.chat_id == self.chat_id)
for mem in query:
#对每条记忆
# 对每条记忆
mem_keywords = mem.keywords or []
parsed = ast.literal_eval(mem_keywords)
if isinstance(parsed, list):
@@ -212,6 +209,7 @@ class InstantMemory:
- 空字符串:返回(None, None)
"""
from datetime import datetime, timedelta
now = datetime.now()
if not time_str:
return 0, now
@@ -251,8 +249,8 @@ class InstantMemory:
if m:
months = int(m.group(1))
# 近似每月30天
start = (now - timedelta(days=months*30)).replace(hour=0, minute=0, second=0, microsecond=0)
start = (now - timedelta(days=months * 30)).replace(hour=0, minute=0, second=0, microsecond=0)
end = start + timedelta(days=1)
return start, end
# 其他无法解析
return 0, now
return 0, now

View File

@@ -30,7 +30,7 @@ class ChatMessageContext:
def get_template_name(self) -> Optional[str]:
"""获取模板名称"""
if self.message.message_info.template_info and not self.message.message_info.template_info.template_default:
return self.message.message_info.template_info.template_name
return self.message.message_info.template_info.template_name # type: ignore
return None
def get_last_message(self) -> "MessageRecv":

View File

@@ -107,9 +107,9 @@ class MessageRecv(Message):
self.is_picid = False
self.has_picid = False
self.is_mentioned = None
self.is_command = False
self.priority_mode = "interest"
self.priority_info = None
self.interest_value: float = None # type: ignore
@@ -181,6 +181,7 @@ class MessageRecv(Message):
logger.error(f"处理消息段失败: {str(e)}, 类型: {segment.type}, 数据: {segment.data}")
return f"[处理失败的{segment.type}消息]"
@dataclass
class MessageRecvS4U(MessageRecv):
def __init__(self, message_dict: dict[str, Any]):
@@ -199,7 +200,7 @@ class MessageRecvS4U(MessageRecv):
async def process(self) -> None:
self.processed_plain_text = await self._process_message_segments(self.message_segment)
async def _process_single_segment(self, segment: Seg) -> str:
"""处理单个消息段
@@ -254,7 +255,7 @@ class MessageRecvS4U(MessageRecv):
elif segment.type == "gift":
self.is_gift = True
# 解析gift_info格式为"名称:数量"
name, count = segment.data.split(":", 1)
name, count = segment.data.split(":", 1) # type: ignore
self.gift_info = segment.data
self.gift_name = name.strip()
self.gift_count = int(count.strip())
@@ -267,13 +268,15 @@ class MessageRecvS4U(MessageRecv):
elif segment.type == "superchat":
self.is_superchat = True
self.superchat_info = segment.data
price,message_text = segment.data.split(":", 1)
price, message_text = segment.data.split(":", 1) # type: ignore
self.superchat_price = price.strip()
self.superchat_message_text = message_text.strip()
self.processed_plain_text = str(self.superchat_message_text)
self.processed_plain_text += f"(注意:这是一条超级弹幕信息,价值{self.superchat_price}元,请你认真回复)"
self.processed_plain_text += (
f"(注意:这是一条超级弹幕信息,价值{self.superchat_price}元,请你认真回复)"
)
return self.processed_plain_text
elif segment.type == "screen":
self.is_screen = True

View File

@@ -80,7 +80,7 @@ class ActionManager:
chat_stream: ChatStream,
log_prefix: str,
shutting_down: bool = False,
action_message: dict = None,
action_message: Optional[dict] = None,
) -> Optional[BaseAction]:
"""
创建动作处理器实例

View File

@@ -252,7 +252,7 @@ def _build_readable_messages_internal(
pic_id_mapping: Optional[Dict[str, str]] = None,
pic_counter: int = 1,
show_pic: bool = True,
message_id_list: List[Dict[str, Any]] = None,
message_id_list: Optional[List[Dict[str, Any]]] = None,
) -> Tuple[str, List[Tuple[float, str, str]], Dict[str, str], int]:
"""
内部辅助函数,构建可读消息字符串和原始消息详情列表。
@@ -615,7 +615,7 @@ def build_readable_actions(actions: List[Dict[str, Any]]) -> str:
for action in actions:
action_time = action.get("time", current_time)
action_name = action.get("action_name", "未知动作")
if action_name == "no_action" or action_name == "no_reply":
if action_name in ["no_action", "no_reply"]:
continue
action_prompt_display = action.get("action_prompt_display", "无具体内容")
@@ -697,7 +697,7 @@ def build_readable_messages(
truncate: bool = False,
show_actions: bool = False,
show_pic: bool = True,
message_id_list: List[Dict[str, Any]] = None,
message_id_list: Optional[List[Dict[str, Any]]] = None,
) -> str: # sourcery skip: extract-method
"""
将消息列表转换为可读的文本格式。

View File

@@ -1211,7 +1211,7 @@ class StatisticOutputTask(AsyncTask):
f.write(html_template)
def _generate_focus_tab(self, stat: dict[str, Any]) -> str:
# sourcery skip: for-append-to-extend, list-comprehension, use-any
# sourcery skip: for-append-to-extend, list-comprehension, use-any, use-named-expression, use-next
"""生成Focus统计独立分页的HTML内容"""
# 为每个时间段准备Focus数据
@@ -1559,6 +1559,7 @@ class StatisticOutputTask(AsyncTask):
"""
def _generate_versions_tab(self, stat: dict[str, Any]) -> str:
# sourcery skip: use-named-expression, use-next
"""生成版本对比独立分页的HTML内容"""
# 为每个时间段准备版本对比数据
@@ -2306,13 +2307,13 @@ class AsyncStatisticOutputTask(AsyncTask):
# 复用 StatisticOutputTask 的所有方法
def _collect_all_statistics(self, now: datetime):
return StatisticOutputTask._collect_all_statistics(self, now)
return StatisticOutputTask._collect_all_statistics(self, now) # type: ignore
def _statistic_console_output(self, stats: Dict[str, Any], now: datetime):
return StatisticOutputTask._statistic_console_output(self, stats, now)
return StatisticOutputTask._statistic_console_output(self, stats, now) # type: ignore
def _generate_html_report(self, stats: dict[str, Any], now: datetime):
return StatisticOutputTask._generate_html_report(self, stats, now)
return StatisticOutputTask._generate_html_report(self, stats, now) # type: ignore
# 其他需要的方法也可以类似复用...
@staticmethod
@@ -2324,10 +2325,10 @@ class AsyncStatisticOutputTask(AsyncTask):
return StatisticOutputTask._collect_online_time_for_period(collect_period, now)
def _collect_message_count_for_period(self, collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]:
return StatisticOutputTask._collect_message_count_for_period(self, collect_period)
return StatisticOutputTask._collect_message_count_for_period(self, collect_period) # type: ignore
def _collect_focus_statistics_for_period(self, collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]:
return StatisticOutputTask._collect_focus_statistics_for_period(self, collect_period)
return StatisticOutputTask._collect_focus_statistics_for_period(self, collect_period) # type: ignore
def _process_focus_file_data(
self,
@@ -2336,10 +2337,10 @@ class AsyncStatisticOutputTask(AsyncTask):
collect_period: List[Tuple[str, datetime]],
file_time: datetime,
):
return StatisticOutputTask._process_focus_file_data(self, cycles_data, stats, collect_period, file_time)
return StatisticOutputTask._process_focus_file_data(self, cycles_data, stats, collect_period, file_time) # type: ignore
def _calculate_focus_averages(self, stats: Dict[str, Any]):
return StatisticOutputTask._calculate_focus_averages(self, stats)
return StatisticOutputTask._calculate_focus_averages(self, stats) # type: ignore
@staticmethod
def _format_total_stat(stats: Dict[str, Any]) -> str:
@@ -2347,31 +2348,31 @@ class AsyncStatisticOutputTask(AsyncTask):
@staticmethod
def _format_model_classified_stat(stats: Dict[str, Any]) -> str:
return StatisticOutputTask._format_model_classified_stat(stats)
return StatisticOutputTask._format_model_classified_stat(stats) # type: ignore
def _format_chat_stat(self, stats: Dict[str, Any]) -> str:
return StatisticOutputTask._format_chat_stat(self, stats)
return StatisticOutputTask._format_chat_stat(self, stats) # type: ignore
def _format_focus_stat(self, stats: Dict[str, Any]) -> str:
return StatisticOutputTask._format_focus_stat(self, stats)
return StatisticOutputTask._format_focus_stat(self, stats) # type: ignore
def _generate_chart_data(self, stat: dict[str, Any]) -> dict:
return StatisticOutputTask._generate_chart_data(self, stat)
return StatisticOutputTask._generate_chart_data(self, stat) # type: ignore
def _collect_interval_data(self, now: datetime, hours: int, interval_minutes: int) -> dict:
return StatisticOutputTask._collect_interval_data(self, now, hours, interval_minutes)
return StatisticOutputTask._collect_interval_data(self, now, hours, interval_minutes) # type: ignore
def _generate_chart_tab(self, chart_data: dict) -> str:
return StatisticOutputTask._generate_chart_tab(self, chart_data)
return StatisticOutputTask._generate_chart_tab(self, chart_data) # type: ignore
def _get_chat_display_name_from_id(self, chat_id: str) -> str:
return StatisticOutputTask._get_chat_display_name_from_id(self, chat_id)
return StatisticOutputTask._get_chat_display_name_from_id(self, chat_id) # type: ignore
def _generate_focus_tab(self, stat: dict[str, Any]) -> str:
return StatisticOutputTask._generate_focus_tab(self, stat)
return StatisticOutputTask._generate_focus_tab(self, stat) # type: ignore
def _generate_versions_tab(self, stat: dict[str, Any]) -> str:
return StatisticOutputTask._generate_versions_tab(self, stat)
return StatisticOutputTask._generate_versions_tab(self, stat) # type: ignore
def _convert_defaultdict_to_dict(self, data):
return StatisticOutputTask._convert_defaultdict_to_dict(self, data)
return StatisticOutputTask._convert_defaultdict_to_dict(self, data) # type: ignore

View File

@@ -2,14 +2,13 @@ import importlib
import asyncio
from abc import ABC, abstractmethod
from typing import Dict, Optional
from typing import Dict, Optional, Any
from rich.traceback import install
from dataclasses import dataclass
from src.common.logger import get_logger
from src.config.config import global_config
from src.chat.message_receive.chat_stream import ChatStream, GroupInfo
from src.chat.message_receive.message import MessageRecv
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
install(extra_lines=3)
@@ -54,7 +53,7 @@ class WillingInfo:
interested_rate (float): 兴趣度
"""
message: MessageRecv
message: Dict[str, Any] # 原始消息数据
chat: ChatStream
person_info_manager: PersonInfoManager
chat_id: str