typing
This commit is contained in:
@@ -112,7 +112,7 @@ class HeartFCMessageReceiver:
|
||||
|
||||
# subheartflow.add_message_to_normal_chat_cache(message, interested_rate, is_mentioned)
|
||||
|
||||
chat_mood = mood_manager.get_mood_by_chat_id(subheartflow.chat_id) # type: ignore
|
||||
chat_mood = mood_manager.get_mood_by_chat_id(subheartflow.chat_id)
|
||||
asyncio.create_task(chat_mood.update_mood_by_message(message, interested_rate))
|
||||
|
||||
# 3. 日志记录
|
||||
|
||||
@@ -494,10 +494,10 @@ class MessageSending(MessageProcessBase):
|
||||
# ) -> "MessageSending":
|
||||
# """从思考状态消息创建发送状态消息"""
|
||||
# return cls(
|
||||
# message_id=thinking.message_info.message_id, # type: ignore
|
||||
# message_id=thinking.message_info.message_id,
|
||||
# chat_stream=thinking.chat_stream,
|
||||
# message_segment=message_segment,
|
||||
# bot_user_info=thinking.message_info.user_info, # type: ignore
|
||||
# bot_user_info=thinking.message_info.user_info,
|
||||
# reply=thinking.reply,
|
||||
# is_head=is_head,
|
||||
# is_emoji=is_emoji,
|
||||
|
||||
@@ -2348,7 +2348,7 @@ class AsyncStatisticOutputTask(AsyncTask):
|
||||
|
||||
@staticmethod
|
||||
def _format_model_classified_stat(stats: Dict[str, Any]) -> str:
|
||||
return StatisticOutputTask._format_model_classified_stat(stats) # type: ignore
|
||||
return StatisticOutputTask._format_model_classified_stat(stats)
|
||||
|
||||
def _format_chat_stat(self, stats: Dict[str, Any]) -> str:
|
||||
return StatisticOutputTask._format_chat_stat(self, stats) # type: ignore
|
||||
|
||||
@@ -285,7 +285,7 @@ def random_remove_punctuation(text: str) -> str:
|
||||
continue
|
||||
elif char == ",":
|
||||
rand = random.random()
|
||||
if rand < 0.25: # 5%概率删除逗号
|
||||
if rand < 0.05: # 5%概率删除逗号
|
||||
continue
|
||||
elif rand < 0.25: # 20%概率把逗号变成空格
|
||||
result += " "
|
||||
@@ -628,7 +628,7 @@ def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Dict]]:
|
||||
elif chat_stream.user_info: # It's a private chat
|
||||
is_group_chat = False
|
||||
user_info = chat_stream.user_info
|
||||
platform: str = chat_stream.platform # type: ignore
|
||||
platform: str = chat_stream.platform
|
||||
user_id: str = user_info.user_id # type: ignore
|
||||
|
||||
# Initialize target_info with basic info
|
||||
|
||||
@@ -94,7 +94,7 @@ class BaseWillingManager(ABC):
|
||||
|
||||
def setup(self, message: dict, chat: ChatStream):
|
||||
person_id = PersonInfoManager.get_person_id(chat.platform, chat.user_info.user_id) # type: ignore
|
||||
self.ongoing_messages[message.get("message_id", "")] = WillingInfo( # type: ignore
|
||||
self.ongoing_messages[message.get("message_id", "")] = WillingInfo(
|
||||
message=message,
|
||||
chat=chat,
|
||||
person_info_manager=get_person_info_manager(),
|
||||
|
||||
@@ -65,7 +65,7 @@ class ChatStreams(BaseModel):
|
||||
# user_cardname 可能为空字符串或不存在,设置 null=True 更具灵活性。
|
||||
user_cardname = TextField(null=True)
|
||||
|
||||
class Meta: # type: ignore
|
||||
class Meta:
|
||||
# 如果 BaseModel.Meta.database 已设置,则此模型将继承该数据库配置。
|
||||
# 如果不使用带有数据库实例的 BaseModel,或者想覆盖它,
|
||||
# 请取消注释并在下面设置数据库实例:
|
||||
@@ -89,7 +89,7 @@ class LLMUsage(BaseModel):
|
||||
status = TextField()
|
||||
timestamp = DateTimeField(index=True) # 更改为 DateTimeField 并添加索引
|
||||
|
||||
class Meta: # type: ignore
|
||||
class Meta:
|
||||
# 如果 BaseModel.Meta.database 已设置,则此模型将继承该数据库配置。
|
||||
# database = db
|
||||
table_name = "llm_usage"
|
||||
@@ -112,7 +112,7 @@ class Emoji(BaseModel):
|
||||
usage_count = IntegerField(default=0) # 使用次数(被使用的次数)
|
||||
last_used_time = FloatField(null=True) # 上次使用时间
|
||||
|
||||
class Meta: # type: ignore
|
||||
class Meta:
|
||||
# database = db # 继承自 BaseModel
|
||||
table_name = "emoji"
|
||||
|
||||
@@ -163,7 +163,7 @@ class Messages(BaseModel):
|
||||
is_picid = BooleanField(default=False)
|
||||
is_command = BooleanField(default=False)
|
||||
|
||||
class Meta: # type: ignore
|
||||
class Meta:
|
||||
# database = db # 继承自 BaseModel
|
||||
table_name = "messages"
|
||||
|
||||
@@ -187,7 +187,7 @@ class ActionRecords(BaseModel):
|
||||
chat_info_stream_id = TextField()
|
||||
chat_info_platform = TextField()
|
||||
|
||||
class Meta: # type: ignore
|
||||
class Meta:
|
||||
# database = db # 继承自 BaseModel
|
||||
table_name = "action_records"
|
||||
|
||||
@@ -207,7 +207,7 @@ class Images(BaseModel):
|
||||
type = TextField() # 图像类型,例如 "emoji"
|
||||
vlm_processed = BooleanField(default=False) # 是否已经过VLM处理
|
||||
|
||||
class Meta: # type: ignore
|
||||
class Meta:
|
||||
table_name = "images"
|
||||
|
||||
|
||||
@@ -221,7 +221,7 @@ class ImageDescriptions(BaseModel):
|
||||
description = TextField() # 图像的描述
|
||||
timestamp = FloatField() # 时间戳
|
||||
|
||||
class Meta: # type: ignore
|
||||
class Meta:
|
||||
# database = db # 继承自 BaseModel
|
||||
table_name = "image_descriptions"
|
||||
|
||||
@@ -237,7 +237,7 @@ class OnlineTime(BaseModel):
|
||||
start_timestamp = DateTimeField(default=datetime.datetime.now)
|
||||
end_timestamp = DateTimeField(index=True)
|
||||
|
||||
class Meta: # type: ignore
|
||||
class Meta:
|
||||
# database = db # 继承自 BaseModel
|
||||
table_name = "online_time"
|
||||
|
||||
@@ -264,7 +264,7 @@ class PersonInfo(BaseModel):
|
||||
last_know = FloatField(null=True) # 最后一次印象总结时间
|
||||
attitude = IntegerField(null=True, default=50) # 态度,0-100,从非常厌恶到十分喜欢
|
||||
|
||||
class Meta: # type: ignore
|
||||
class Meta:
|
||||
# database = db # 继承自 BaseModel
|
||||
table_name = "person_info"
|
||||
|
||||
@@ -277,7 +277,7 @@ class Memory(BaseModel):
|
||||
create_time = FloatField(null=True)
|
||||
last_view_time = FloatField(null=True)
|
||||
|
||||
class Meta: # type: ignore
|
||||
class Meta:
|
||||
table_name = "memory"
|
||||
|
||||
|
||||
@@ -290,7 +290,7 @@ class Knowledges(BaseModel):
|
||||
embedding = TextField() # 知识内容的嵌入向量,存储为 JSON 字符串的浮点数列表
|
||||
# 可以添加其他元数据字段,如 source, create_time 等
|
||||
|
||||
class Meta: # type: ignore
|
||||
class Meta:
|
||||
# database = db # 继承自 BaseModel
|
||||
table_name = "knowledges"
|
||||
|
||||
@@ -307,7 +307,7 @@ class Expression(BaseModel):
|
||||
chat_id = TextField(index=True)
|
||||
type = TextField()
|
||||
|
||||
class Meta: # type: ignore
|
||||
class Meta:
|
||||
table_name = "expression"
|
||||
|
||||
|
||||
@@ -331,7 +331,7 @@ class ThinkingLog(BaseModel):
|
||||
# And: import datetime
|
||||
created_at = DateTimeField(default=datetime.datetime.now)
|
||||
|
||||
class Meta: # type: ignore
|
||||
class Meta:
|
||||
table_name = "thinking_logs"
|
||||
|
||||
|
||||
@@ -346,7 +346,7 @@ class GraphNodes(BaseModel):
|
||||
created_time = FloatField() # 创建时间戳
|
||||
last_modified = FloatField() # 最后修改时间戳
|
||||
|
||||
class Meta: # type: ignore
|
||||
class Meta:
|
||||
table_name = "graph_nodes"
|
||||
|
||||
|
||||
@@ -362,7 +362,7 @@ class GraphEdges(BaseModel):
|
||||
created_time = FloatField() # 创建时间戳
|
||||
last_modified = FloatField() # 最后修改时间戳
|
||||
|
||||
class Meta: # type: ignore
|
||||
class Meta:
|
||||
table_name = "graph_edges"
|
||||
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ import asyncio
|
||||
import json
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Tuple, Union, Dict, Any
|
||||
from typing import Tuple, Union, Dict, Any, Callable
|
||||
import aiohttp
|
||||
from aiohttp.client import ClientResponse
|
||||
from src.common.logger import get_logger
|
||||
@@ -300,7 +300,7 @@ class LLMRequest:
|
||||
file_format: str = None,
|
||||
payload: dict = None,
|
||||
retry_policy: dict = None,
|
||||
response_handler: callable = None,
|
||||
response_handler: Callable = None,
|
||||
user_id: str = "system",
|
||||
request_type: str = None,
|
||||
):
|
||||
@@ -336,19 +336,17 @@ class LLMRequest:
|
||||
headers["Accept"] = "text/event-stream"
|
||||
async with aiohttp.ClientSession(connector=await get_tcp_connector()) as session:
|
||||
post_kwargs = {"headers": headers}
|
||||
#form-data数据上传方式不同
|
||||
# form-data数据上传方式不同
|
||||
if file_bytes:
|
||||
post_kwargs["data"] = request_content["payload"]
|
||||
else:
|
||||
post_kwargs["json"] = request_content["payload"]
|
||||
|
||||
async with session.post(
|
||||
request_content["api_url"], **post_kwargs
|
||||
) as response:
|
||||
async with session.post(request_content["api_url"], **post_kwargs) as response:
|
||||
handled_result = await self._handle_response(
|
||||
response, request_content, retry, response_handler, user_id, request_type, endpoint
|
||||
)
|
||||
return handled_result
|
||||
return handled_result
|
||||
|
||||
except Exception as e:
|
||||
handled_payload, count_delta = await self._handle_exception(e, retry, request_content)
|
||||
@@ -366,11 +364,11 @@ class LLMRequest:
|
||||
response: ClientResponse,
|
||||
request_content: Dict[str, Any],
|
||||
retry_count: int,
|
||||
response_handler: callable,
|
||||
response_handler: Callable,
|
||||
user_id,
|
||||
request_type,
|
||||
endpoint,
|
||||
) -> Union[Dict[str, Any], None]:
|
||||
):
|
||||
policy = request_content["policy"]
|
||||
stream_mode = request_content["stream_mode"]
|
||||
if response.status in policy["retry_codes"] or response.status in policy["abort_codes"]:
|
||||
@@ -477,9 +475,7 @@ class LLMRequest:
|
||||
}
|
||||
return result
|
||||
|
||||
async def _handle_error_response(
|
||||
self, response: ClientResponse, retry_count: int, policy: Dict[str, Any]
|
||||
) -> Union[Dict[str, any]]:
|
||||
async def _handle_error_response(self, response: ClientResponse, retry_count: int, policy: Dict[str, Any]):
|
||||
if response.status in policy["retry_codes"]:
|
||||
wait_time = policy["base_wait"] * (2**retry_count)
|
||||
logger.warning(f"模型 {self.model_name} 错误码: {response.status}, 等待 {wait_time}秒后重试")
|
||||
@@ -629,7 +625,9 @@ class LLMRequest:
|
||||
)
|
||||
# 安全地检查和记录请求详情
|
||||
handled_payload = await _safely_record(request_content, payload)
|
||||
logger.critical(f"请求头: {await self._build_headers(no_key=True)} 请求体: {str(handled_payload)[:100]}")
|
||||
logger.critical(
|
||||
f"请求头: {await self._build_headers(no_key=True)} 请求体: {str(handled_payload)[:100]}"
|
||||
)
|
||||
raise RuntimeError(
|
||||
f"模型 {self.model_name} API请求失败: 状态码 {exception.status}, {exception.message}"
|
||||
)
|
||||
@@ -643,7 +641,9 @@ class LLMRequest:
|
||||
logger.critical(f"模型 {self.model_name} 请求失败: {str(exception)}")
|
||||
# 安全地检查和记录请求详情
|
||||
handled_payload = await _safely_record(request_content, payload)
|
||||
logger.critical(f"请求头: {await self._build_headers(no_key=True)} 请求体: {str(handled_payload)[:100]}")
|
||||
logger.critical(
|
||||
f"请求头: {await self._build_headers(no_key=True)} 请求体: {str(handled_payload)[:100]}"
|
||||
)
|
||||
raise RuntimeError(f"模型 {self.model_name} API请求失败: {str(exception)}")
|
||||
|
||||
async def _transform_parameters(self, params: dict) -> dict:
|
||||
@@ -682,15 +682,14 @@ class LLMRequest:
|
||||
logger.warning(f"暂不支持的文件类型: {file_format}")
|
||||
|
||||
data.add_field(
|
||||
"file",io.BytesIO(file_bytes),
|
||||
"file",
|
||||
io.BytesIO(file_bytes),
|
||||
filename=f"file.{file_format}",
|
||||
content_type=f'{content_type}' # 根据实际文件类型设置
|
||||
)
|
||||
data.add_field(
|
||||
"model", self.model_name
|
||||
content_type=f"{content_type}", # 根据实际文件类型设置
|
||||
)
|
||||
data.add_field("model", self.model_name)
|
||||
return data
|
||||
|
||||
|
||||
async def _build_payload(self, prompt: str, image_base64: str = None, image_format: str = None) -> dict:
|
||||
"""构建请求体"""
|
||||
# 复制一份参数,避免直接修改 self.params
|
||||
@@ -819,9 +818,11 @@ class LLMRequest:
|
||||
|
||||
async def generate_response_for_voice(self, voice_bytes: bytes) -> Tuple:
|
||||
"""根据输入的语音文件生成模型的异步响应"""
|
||||
response = await self._execute_request(endpoint="/audio/transcriptions",file_bytes=voice_bytes, file_format='wav')
|
||||
response = await self._execute_request(
|
||||
endpoint="/audio/transcriptions", file_bytes=voice_bytes, file_format="wav"
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
async def generate_response_async(self, prompt: str, **kwargs) -> Union[str, Tuple]:
|
||||
"""异步方式根据输入的提示生成模型的响应"""
|
||||
# 构建请求体,不硬编码max_tokens
|
||||
|
||||
@@ -134,7 +134,7 @@ class ChatMood:
|
||||
|
||||
self.mood_state = response
|
||||
|
||||
self.last_change_time = message_time # type: ignore
|
||||
self.last_change_time = message_time
|
||||
|
||||
async def regress_mood(self):
|
||||
message_time = time.time()
|
||||
|
||||
@@ -71,7 +71,7 @@ class RenamePersonTool(BaseTool):
|
||||
user_nickname=user_nickname, # type: ignore
|
||||
user_cardname=user_cardname, # type: ignore
|
||||
user_avatar=user_avatar, # type: ignore
|
||||
request=request_context, # type: ignore
|
||||
request=request_context,
|
||||
)
|
||||
|
||||
# 3. 处理结果
|
||||
|
||||
Reference in New Issue
Block a user