调整对应的调用
This commit is contained in:
@@ -8,15 +8,15 @@ import traceback
|
||||
import io
|
||||
import re
|
||||
import binascii
|
||||
|
||||
from typing import Optional, Tuple, List, Any
|
||||
from PIL import Image
|
||||
from rich.traceback import install
|
||||
|
||||
|
||||
from src.common.database.database_model import Emoji
|
||||
from src.common.database.database import db as peewee_db
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.config.config import global_config, model_config
|
||||
from src.chat.utils.utils_image import image_path_to_base64, get_image_manager
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
@@ -379,9 +379,9 @@ class EmojiManager:
|
||||
|
||||
self._scan_task = None
|
||||
|
||||
self.vlm = LLMRequest(model=global_config.model.vlm, temperature=0.3, max_tokens=1000, request_type="emoji")
|
||||
self.vlm = LLMRequest(model_set=model_config.model_task_config.vlm, request_type="emoji")
|
||||
self.llm_emotion_judge = LLMRequest(
|
||||
model=global_config.model.utils, max_tokens=600, request_type="emoji"
|
||||
model_set=model_config.model_task_config.utils, request_type="emoji"
|
||||
) # 更高的温度,更少的token(后续可以根据情绪来调整温度)
|
||||
|
||||
self.emoji_num = 0
|
||||
@@ -492,6 +492,7 @@ class EmojiManager:
|
||||
return None
|
||||
|
||||
def _levenshtein_distance(self, s1: str, s2: str) -> int:
|
||||
# sourcery skip: simplify-empty-collection-comparison, simplify-len-comparison, simplify-str-len-comparison
|
||||
"""计算两个字符串的编辑距离
|
||||
|
||||
Args:
|
||||
@@ -629,7 +630,7 @@ class EmojiManager:
|
||||
if success:
|
||||
# 注册成功则跳出循环
|
||||
break
|
||||
else:
|
||||
|
||||
# 注册失败则删除对应文件
|
||||
file_path = os.path.join(EMOJI_DIR, filename)
|
||||
os.remove(file_path)
|
||||
@@ -694,6 +695,7 @@ class EmojiManager:
|
||||
return []
|
||||
|
||||
async def get_emoji_from_manager(self, emoji_hash: str) -> Optional["MaiEmoji"]:
|
||||
# sourcery skip: use-next
|
||||
"""从内存中的 emoji_objects 列表获取表情包
|
||||
|
||||
参数:
|
||||
@@ -779,6 +781,7 @@ class EmojiManager:
|
||||
return False
|
||||
|
||||
async def replace_a_emoji(self, new_emoji: "MaiEmoji") -> bool:
|
||||
# sourcery skip: use-getitem-for-re-match-groups
|
||||
"""替换一个表情包
|
||||
|
||||
Args:
|
||||
@@ -820,7 +823,7 @@ class EmojiManager:
|
||||
)
|
||||
|
||||
# 调用大模型进行决策
|
||||
decision, _ = await self.llm_emotion_judge.generate_response_async(prompt, temperature=0.8)
|
||||
decision, _ = await self.llm_emotion_judge.generate_response_async(prompt, temperature=0.8, max_tokens=600)
|
||||
logger.info(f"[决策] 结果: {decision}")
|
||||
|
||||
# 解析决策结果
|
||||
@@ -828,9 +831,7 @@ class EmojiManager:
|
||||
logger.info("[决策] 不删除任何表情包")
|
||||
return False
|
||||
|
||||
# 尝试从决策中提取表情包编号
|
||||
match = re.search(r"删除编号(\d+)", decision)
|
||||
if match:
|
||||
if match := re.search(r"删除编号(\d+)", decision):
|
||||
emoji_index = int(match.group(1)) - 1 # 转换为0-based索引
|
||||
|
||||
# 检查索引是否有效
|
||||
@@ -889,6 +890,7 @@ class EmojiManager:
|
||||
existing_description = None
|
||||
try:
|
||||
from src.common.database.database_model import Images
|
||||
|
||||
existing_image = Images.get_or_none((Images.emoji_hash == image_hash) & (Images.type == "emoji"))
|
||||
if existing_image and existing_image.description:
|
||||
existing_description = existing_image.description
|
||||
@@ -902,15 +904,21 @@ class EmojiManager:
|
||||
logger.info("[优化] 复用已有的详细描述,跳过VLM调用")
|
||||
else:
|
||||
logger.info("[VLM分析] 生成新的详细描述")
|
||||
if image_format == "gif" or image_format == "GIF":
|
||||
if image_format in ["gif", "GIF"]:
|
||||
image_base64 = get_image_manager().transform_gif(image_base64) # type: ignore
|
||||
if not image_base64:
|
||||
raise RuntimeError("GIF表情包转换失败")
|
||||
prompt = "这是一个动态图表情包,每一张图代表了动态图的某一帧,黑色背景代表透明,描述一下表情包表达的情感和内容,描述细节,从互联网梗,meme的角度去分析"
|
||||
description, _ = await self.vlm.generate_response_for_image(prompt, image_base64, "jpg")
|
||||
description, _ = await self.vlm.generate_response_for_image(
|
||||
prompt, image_base64, "jpg", temperature=0.3, max_tokens=1000
|
||||
)
|
||||
else:
|
||||
prompt = "这是一个表情包,请详细描述一下表情包所表达的情感和内容,描述细节,从互联网梗,meme的角度去分析"
|
||||
description, _ = await self.vlm.generate_response_for_image(prompt, image_base64, image_format)
|
||||
prompt = (
|
||||
"这是一个表情包,请详细描述一下表情包所表达的情感和内容,描述细节,从互联网梗,meme的角度去分析"
|
||||
)
|
||||
description, _ = await self.vlm.generate_response_for_image(
|
||||
prompt, image_base64, image_format, temperature=0.3, max_tokens=1000
|
||||
)
|
||||
|
||||
# 审核表情包
|
||||
if global_config.emoji.content_filtration:
|
||||
@@ -922,7 +930,9 @@ class EmojiManager:
|
||||
4. 不要出现5个以上文字
|
||||
请回答这个表情包是否满足上述要求,是则回答是,否则回答否,不要出现任何其他内容
|
||||
'''
|
||||
content, _ = await self.vlm.generate_response_for_image(prompt, image_base64, image_format)
|
||||
content, _ = await self.vlm.generate_response_for_image(
|
||||
prompt, image_base64, image_format, temperature=0.3, max_tokens=1000
|
||||
)
|
||||
if content == "否":
|
||||
return "", []
|
||||
|
||||
@@ -933,7 +943,9 @@ class EmojiManager:
|
||||
你可以关注其幽默和讽刺意味,动用贴吧,微博,小红书的知识,必须从互联网梗,meme的角度去分析
|
||||
请直接输出描述,不要出现任何其他内容,如果有多个描述,可以用逗号分隔
|
||||
"""
|
||||
emotions_text, _ = await self.llm_emotion_judge.generate_response_async(emotion_prompt, temperature=0.7)
|
||||
emotions_text, _ = await self.llm_emotion_judge.generate_response_async(
|
||||
emotion_prompt, temperature=0.7, max_tokens=600
|
||||
)
|
||||
|
||||
# 处理情感列表
|
||||
emotions = [e.strip() for e in emotions_text.split(",") if e.strip()]
|
||||
|
||||
@@ -7,12 +7,12 @@ from datetime import datetime
|
||||
from typing import List, Dict, Optional, Any, Tuple
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import Expression
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config
|
||||
from src.config.config import model_config
|
||||
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_random, build_anonymous_messages
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.common.database.database_model import Expression
|
||||
|
||||
|
||||
MAX_EXPRESSION_COUNT = 300
|
||||
@@ -80,11 +80,8 @@ def init_prompt() -> None:
|
||||
|
||||
class ExpressionLearner:
|
||||
def __init__(self) -> None:
|
||||
# TODO: API-Adapter修改标记
|
||||
self.express_learn_model: LLMRequest = LLMRequest(
|
||||
model=global_config.model.replyer_1,
|
||||
temperature=0.3,
|
||||
request_type="expressor.learner",
|
||||
model_set=model_config.model_task_config.replyer_1, request_type="expressor.learner"
|
||||
)
|
||||
self.llm_model = None
|
||||
self._ensure_expression_directories()
|
||||
@@ -287,15 +284,18 @@ class ExpressionLearner:
|
||||
获取指定chat_id的表达方式创建信息,按创建日期排序
|
||||
"""
|
||||
try:
|
||||
expressions = (Expression.select()
|
||||
expressions = (
|
||||
Expression.select()
|
||||
.where(Expression.chat_id == chat_id)
|
||||
.order_by(Expression.create_date.desc())
|
||||
.limit(limit))
|
||||
.limit(limit)
|
||||
)
|
||||
|
||||
result = []
|
||||
for expr in expressions:
|
||||
create_date = expr.create_date if expr.create_date is not None else expr.last_active_time
|
||||
result.append({
|
||||
result.append(
|
||||
{
|
||||
"situation": expr.situation,
|
||||
"style": expr.style,
|
||||
"type": expr.type,
|
||||
@@ -304,7 +304,8 @@ class ExpressionLearner:
|
||||
"create_date_formatted": format_create_date(create_date),
|
||||
"last_active_time": expr.last_active_time,
|
||||
"last_active_formatted": format_create_date(expr.last_active_time),
|
||||
})
|
||||
}
|
||||
)
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
@@ -527,7 +528,7 @@ class ExpressionLearner:
|
||||
logger.debug(f"学习{type_str}的prompt: {prompt}")
|
||||
|
||||
try:
|
||||
response, _ = await self.express_learn_model.generate_response_async(prompt)
|
||||
response, _ = await self.express_learn_model.generate_response_async(prompt, temperature=0.3)
|
||||
except Exception as e:
|
||||
logger.error(f"学习{type_str}失败: {e}")
|
||||
return None
|
||||
|
||||
@@ -1,16 +1,17 @@
|
||||
import json
|
||||
import time
|
||||
import random
|
||||
import hashlib
|
||||
|
||||
from typing import List, Dict, Tuple, Optional, Any
|
||||
from json_repair import repair_json
|
||||
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config
|
||||
from src.config.config import global_config, model_config
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import Expression
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from .expression_learner import get_expression_learner
|
||||
from src.common.database.database_model import Expression
|
||||
|
||||
logger = get_logger("expression_selector")
|
||||
|
||||
@@ -75,10 +76,8 @@ def weighted_sample(population: List[Dict], weights: List[float], k: int) -> Lis
|
||||
class ExpressionSelector:
|
||||
def __init__(self):
|
||||
self.expression_learner = get_expression_learner()
|
||||
# TODO: API-Adapter修改标记
|
||||
self.llm_model = LLMRequest(
|
||||
model=global_config.model.utils_small,
|
||||
request_type="expression.selector",
|
||||
model_set=model_config.model_task_config.utils_small, request_type="expression.selector"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -92,7 +91,6 @@ class ExpressionSelector:
|
||||
id_str = parts[1]
|
||||
stream_type = parts[2]
|
||||
is_group = stream_type == "group"
|
||||
import hashlib
|
||||
if is_group:
|
||||
components = [platform, str(id_str)]
|
||||
else:
|
||||
@@ -108,8 +106,7 @@ class ExpressionSelector:
|
||||
for group in groups:
|
||||
group_chat_ids = []
|
||||
for stream_config_str in group:
|
||||
chat_id_candidate = self._parse_stream_config_to_chat_id(stream_config_str)
|
||||
if chat_id_candidate:
|
||||
if chat_id_candidate := self._parse_stream_config_to_chat_id(stream_config_str):
|
||||
group_chat_ids.append(chat_id_candidate)
|
||||
if chat_id in group_chat_ids:
|
||||
return group_chat_ids
|
||||
@@ -118,6 +115,7 @@ class ExpressionSelector:
|
||||
def get_random_expressions(
|
||||
self, chat_id: str, total_num: int, style_percentage: float, grammar_percentage: float
|
||||
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
|
||||
# sourcery skip: extract-duplicate-method, move-assign
|
||||
# 支持多chat_id合并抽选
|
||||
related_chat_ids = self.get_related_chat_ids(chat_id)
|
||||
|
||||
@@ -138,7 +136,8 @@ class ExpressionSelector:
|
||||
"source_id": expr.chat_id,
|
||||
"type": "style",
|
||||
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
|
||||
} for expr in style_query
|
||||
}
|
||||
for expr in style_query
|
||||
]
|
||||
|
||||
grammar_exprs = [
|
||||
@@ -150,7 +149,8 @@ class ExpressionSelector:
|
||||
"source_id": expr.chat_id,
|
||||
"type": "grammar",
|
||||
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
|
||||
} for expr in grammar_query
|
||||
}
|
||||
for expr in grammar_query
|
||||
]
|
||||
|
||||
style_num = int(total_num * style_percentage)
|
||||
@@ -174,22 +174,22 @@ class ExpressionSelector:
|
||||
return
|
||||
updates_by_key = {}
|
||||
for expr in expressions_to_update:
|
||||
source_id = expr.get("source_id")
|
||||
expr_type = expr.get("type", "style")
|
||||
situation = expr.get("situation")
|
||||
style = expr.get("style")
|
||||
source_id: str = expr.get("source_id") # type: ignore
|
||||
expr_type: str = expr.get("type", "style")
|
||||
situation: str = expr.get("situation") # type: ignore
|
||||
style: str = expr.get("style") # type: ignore
|
||||
if not source_id or not situation or not style:
|
||||
logger.warning(f"表达方式缺少必要字段,无法更新: {expr}")
|
||||
continue
|
||||
key = (source_id, expr_type, situation, style)
|
||||
if key not in updates_by_key:
|
||||
updates_by_key[key] = expr
|
||||
for (chat_id, expr_type, situation, style), _expr in updates_by_key.items():
|
||||
for chat_id, expr_type, situation, style in updates_by_key:
|
||||
query = Expression.select().where(
|
||||
(Expression.chat_id == chat_id) &
|
||||
(Expression.type == expr_type) &
|
||||
(Expression.situation == situation) &
|
||||
(Expression.style == style)
|
||||
(Expression.chat_id == chat_id)
|
||||
& (Expression.type == expr_type)
|
||||
& (Expression.situation == situation)
|
||||
& (Expression.style == style)
|
||||
)
|
||||
if query.exists():
|
||||
expr_obj = query.get()
|
||||
@@ -264,7 +264,7 @@ class ExpressionSelector:
|
||||
|
||||
# 4. 调用LLM
|
||||
try:
|
||||
content, (_, _) = await self.llm_model.generate_response_async(prompt=prompt)
|
||||
content, _ = await self.llm_model.generate_response_async(prompt=prompt)
|
||||
|
||||
# logger.info(f"{self.log_prefix} LLM返回结果: {content}")
|
||||
|
||||
|
||||
@@ -5,25 +5,27 @@ import random
|
||||
import time
|
||||
import re
|
||||
import json
|
||||
from itertools import combinations
|
||||
|
||||
import jieba
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
|
||||
from itertools import combinations
|
||||
from typing import List, Tuple, Coroutine, Any, Dict, Set
|
||||
from collections import Counter
|
||||
from ...llm_models.utils_model import LLMRequest
|
||||
from rich.traceback import install
|
||||
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config, model_config
|
||||
from src.common.database.database_model import Messages, GraphNodes, GraphEdges # Peewee Models导入
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.memory_system.sample_distribution import MemoryBuildScheduler # 分布生成器
|
||||
from ..utils.chat_message_builder import (
|
||||
from src.chat.utils.chat_message_builder import (
|
||||
get_raw_msg_by_timestamp,
|
||||
build_readable_messages,
|
||||
get_raw_msg_by_timestamp_with_chat,
|
||||
) # 导入 build_readable_messages
|
||||
from ..utils.utils import translate_timestamp_to_human_readable
|
||||
from rich.traceback import install
|
||||
from src.chat.utils.utils import translate_timestamp_to_human_readable
|
||||
|
||||
from ...config.config import global_config
|
||||
from src.common.database.database_model import Messages, GraphNodes, GraphEdges # Peewee Models导入
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
@@ -198,8 +200,7 @@ class Hippocampus:
|
||||
self.parahippocampal_gyrus = ParahippocampalGyrus(self)
|
||||
# 从数据库加载记忆图
|
||||
self.entorhinal_cortex.sync_memory_from_db()
|
||||
# TODO: API-Adapter修改标记
|
||||
self.model_summary = LLMRequest(global_config.model.memory, request_type="memory.builder")
|
||||
self.model_summary = LLMRequest(model_set=model_config.model_task_config.memory, request_type="memory.builder")
|
||||
|
||||
def get_all_node_names(self) -> list:
|
||||
"""获取记忆图中所有节点的名字列表"""
|
||||
@@ -339,9 +340,7 @@ class Hippocampus:
|
||||
else:
|
||||
topic_num = 5 # 51+字符: 5个关键词 (其余长文本)
|
||||
|
||||
topics_response, (reasoning_content, model_name) = await self.model_summary.generate_response_async(
|
||||
self.find_topic_llm(text, topic_num)
|
||||
)
|
||||
topics_response, _ = await self.model_summary.generate_response_async(self.find_topic_llm(text, topic_num))
|
||||
|
||||
# 提取关键词
|
||||
keywords = re.findall(r"<([^>]+)>", topics_response)
|
||||
@@ -359,7 +358,6 @@ class Hippocampus:
|
||||
|
||||
return keywords
|
||||
|
||||
|
||||
async def get_memory_from_text(
|
||||
self,
|
||||
text: str,
|
||||
@@ -1245,7 +1243,7 @@ class ParahippocampalGyrus:
|
||||
|
||||
# 2. 使用LLM提取关键主题
|
||||
topic_num = self.hippocampus.calculate_topic_num(input_text, compress_rate)
|
||||
topics_response, (reasoning_content, model_name) = await self.hippocampus.model_summary.generate_response_async(
|
||||
topics_response, _ = await self.hippocampus.model_summary.generate_response_async(
|
||||
self.hippocampus.find_topic_llm(input_text, topic_num)
|
||||
)
|
||||
|
||||
@@ -1269,7 +1267,7 @@ class ParahippocampalGyrus:
|
||||
logger.debug(f"过滤后话题: {filtered_topics}")
|
||||
|
||||
# 4. 创建所有话题的摘要生成任务
|
||||
tasks = []
|
||||
tasks: List[Tuple[str, Coroutine[Any, Any, Tuple[str, Tuple[str, str, List[Dict[str, Any]] | None]]]]] = []
|
||||
for topic in filtered_topics:
|
||||
# 调用修改后的 topic_what,不再需要 time_info
|
||||
topic_what_prompt = self.hippocampus.topic_what(input_text, topic)
|
||||
@@ -1281,7 +1279,7 @@ class ParahippocampalGyrus:
|
||||
continue
|
||||
|
||||
# 等待所有任务完成
|
||||
compressed_memory = set()
|
||||
compressed_memory: Set[Tuple[str, str]] = set()
|
||||
similar_topics_dict = {}
|
||||
|
||||
for topic, task in tasks:
|
||||
|
||||
@@ -3,13 +3,16 @@ import time
|
||||
import re
|
||||
import json
|
||||
import ast
|
||||
from json_repair import repair_json
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.common.logger import get_logger
|
||||
import traceback
|
||||
|
||||
from src.config.config import global_config
|
||||
from json_repair import repair_json
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import Memory # Peewee Models导入
|
||||
from src.config.config import model_config
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -35,8 +38,7 @@ class InstantMemory:
|
||||
self.chat_id = chat_id
|
||||
self.last_view_time = time.time()
|
||||
self.summary_model = LLMRequest(
|
||||
model=global_config.model.memory,
|
||||
temperature=0.5,
|
||||
model_set=model_config.model_task_config.memory,
|
||||
request_type="memory.summary",
|
||||
)
|
||||
|
||||
@@ -48,14 +50,11 @@ class InstantMemory:
|
||||
"""
|
||||
|
||||
try:
|
||||
response, _ = await self.summary_model.generate_response_async(prompt)
|
||||
response, _ = await self.summary_model.generate_response_async(prompt, temperature=0.5)
|
||||
print(prompt)
|
||||
print(response)
|
||||
|
||||
if "1" in response:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
return "1" in response
|
||||
except Exception as e:
|
||||
logger.error(f"判断是否需要记忆出现错误:{str(e)} {traceback.format_exc()}")
|
||||
return False
|
||||
@@ -71,9 +70,9 @@ class InstantMemory:
|
||||
}}
|
||||
"""
|
||||
try:
|
||||
response, _ = await self.summary_model.generate_response_async(prompt)
|
||||
print(prompt)
|
||||
print(response)
|
||||
response, _ = await self.summary_model.generate_response_async(prompt, temperature=0.5)
|
||||
# print(prompt)
|
||||
# print(response)
|
||||
if not response:
|
||||
return None
|
||||
try:
|
||||
@@ -142,7 +141,7 @@ class InstantMemory:
|
||||
请只输出json格式,不要输出其他多余内容
|
||||
"""
|
||||
try:
|
||||
response, _ = await self.summary_model.generate_response_async(prompt)
|
||||
response, _ = await self.summary_model.generate_response_async(prompt, temperature=0.5)
|
||||
print(prompt)
|
||||
print(response)
|
||||
if not response:
|
||||
@@ -177,7 +176,7 @@ class InstantMemory:
|
||||
|
||||
for mem in query:
|
||||
# 对每条记忆
|
||||
mem_keywords = mem.keywords or []
|
||||
mem_keywords = mem.keywords or ""
|
||||
parsed = ast.literal_eval(mem_keywords)
|
||||
if isinstance(parsed, list):
|
||||
mem_keywords = [str(k).strip() for k in parsed if str(k).strip()]
|
||||
@@ -201,6 +200,7 @@ class InstantMemory:
|
||||
return None
|
||||
|
||||
def _parse_time_range(self, time_str):
|
||||
# sourcery skip: extract-duplicate-method, use-contextlib-suppress
|
||||
"""
|
||||
支持解析如下格式:
|
||||
- 具体日期时间:YYYY-MM-DD HH:MM:SS
|
||||
@@ -208,8 +208,6 @@ class InstantMemory:
|
||||
- 相对时间:今天,昨天,前天,N天前,N个月前
|
||||
- 空字符串:返回(None, None)
|
||||
"""
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
now = datetime.now()
|
||||
if not time_str:
|
||||
return 0, now
|
||||
@@ -239,14 +237,12 @@ class InstantMemory:
|
||||
start = (now - timedelta(days=2)).replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
end = start + timedelta(days=1)
|
||||
return start, end
|
||||
m = re.match(r"(\d+)天前", time_str)
|
||||
if m:
|
||||
if m := re.match(r"(\d+)天前", time_str):
|
||||
days = int(m.group(1))
|
||||
start = (now - timedelta(days=days)).replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
end = start + timedelta(days=1)
|
||||
return start, end
|
||||
m = re.match(r"(\d+)个月前", time_str)
|
||||
if m:
|
||||
if m := re.match(r"(\d+)个月前", time_str):
|
||||
months = int(m.group(1))
|
||||
# 近似每月30天
|
||||
start = (now - timedelta(days=months * 30)).replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from datetime import datetime
|
||||
from src.chat.memory_system.Hippocampus import hippocampus_manager
|
||||
from typing import List, Dict
|
||||
import difflib
|
||||
import json
|
||||
|
||||
from json_repair import repair_json
|
||||
from typing import List, Dict
|
||||
from datetime import datetime
|
||||
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config, model_config
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.chat.memory_system.Hippocampus import hippocampus_manager
|
||||
|
||||
|
||||
logger = get_logger("memory_activator")
|
||||
@@ -61,11 +63,8 @@ def init_prompt():
|
||||
|
||||
class MemoryActivator:
|
||||
def __init__(self):
|
||||
# TODO: API-Adapter修改标记
|
||||
|
||||
self.key_words_model = LLMRequest(
|
||||
model=global_config.model.utils_small,
|
||||
temperature=0.5,
|
||||
model_set=model_config.model_task_config.utils_small,
|
||||
request_type="memory.activator",
|
||||
)
|
||||
|
||||
@@ -92,7 +91,9 @@ class MemoryActivator:
|
||||
|
||||
# logger.debug(f"prompt: {prompt}")
|
||||
|
||||
response, (reasoning_content, model_name) = await self.key_words_model.generate_response_async(prompt)
|
||||
response, (reasoning_content, model_name, _) = await self.key_words_model.generate_response_async(
|
||||
prompt, temperature=0.5
|
||||
)
|
||||
|
||||
keywords = list(get_keywords_from_json(response))
|
||||
|
||||
|
||||
@@ -203,7 +203,7 @@ class MessageRecvS4U(MessageRecv):
|
||||
self.is_superchat = False
|
||||
self.gift_info = None
|
||||
self.gift_name = None
|
||||
self.gift_count = None
|
||||
self.gift_count: Optional[str] = None
|
||||
self.superchat_info = None
|
||||
self.superchat_price = None
|
||||
self.superchat_message_text = None
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from typing import Dict, Optional, Type
|
||||
from src.plugin_system.base.base_action import BaseAction
|
||||
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
from src.plugin_system.base.component_types import ComponentType, ActionInfo
|
||||
from src.plugin_system.base.base_action import BaseAction
|
||||
|
||||
logger = get_logger("action_manager")
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ import time
|
||||
from typing import List, Any, Dict, TYPE_CHECKING, Tuple
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager, ChatMessageContext
|
||||
from src.chat.planner_actions.action_manager import ActionManager
|
||||
@@ -36,10 +36,7 @@ class ActionModifier:
|
||||
self.action_manager = action_manager
|
||||
|
||||
# 用于LLM判定的小模型
|
||||
self.llm_judge = LLMRequest(
|
||||
model=global_config.model.utils_small,
|
||||
request_type="action.judge",
|
||||
)
|
||||
self.llm_judge = LLMRequest(model_set=model_config.model_task_config.utils_small, request_type="action.judge")
|
||||
|
||||
# 缓存相关属性
|
||||
self._llm_judge_cache = {} # 缓存LLM判定结果
|
||||
|
||||
@@ -7,7 +7,7 @@ from datetime import datetime
|
||||
from json_repair import repair_json
|
||||
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config
|
||||
from src.config.config import global_config, model_config
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.chat.utils.chat_message_builder import (
|
||||
@@ -73,10 +73,7 @@ class ActionPlanner:
|
||||
self.log_prefix = f"[{get_chat_manager().get_stream_name(chat_id) or chat_id}]"
|
||||
self.action_manager = action_manager
|
||||
# LLM规划器配置
|
||||
self.planner_llm = LLMRequest(
|
||||
model=global_config.model.planner,
|
||||
request_type="planner", # 用于动作规划
|
||||
)
|
||||
self.planner_llm = LLMRequest(model_set=model_config.model_task_config.planner, request_type="planner") # 用于动作规划
|
||||
|
||||
self.last_obs_time_mark = 0.0
|
||||
|
||||
@@ -140,7 +137,7 @@ class ActionPlanner:
|
||||
# --- 调用 LLM (普通文本生成) ---
|
||||
llm_content = None
|
||||
try:
|
||||
llm_content, (reasoning_content, _) = await self.planner_llm.generate_response_async(prompt=prompt)
|
||||
llm_content, (reasoning_content, _, _) = await self.planner_llm.generate_response_async(prompt=prompt)
|
||||
|
||||
if global_config.debug.show_prompt:
|
||||
logger.info(f"{self.log_prefix}规划器原始提示词: {prompt}")
|
||||
|
||||
@@ -8,7 +8,8 @@ from typing import List, Optional, Dict, Any, Tuple
|
||||
from datetime import datetime
|
||||
from src.mais4u.mai_think import mai_thinking_manager
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.config.config import global_config, model_config
|
||||
from src.config.api_ada_configs import TaskConfig
|
||||
from src.individuality.individuality import get_individuality
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.chat.message_receive.message import UserInfo, Seg, MessageRecv, MessageSending
|
||||
@@ -106,31 +107,36 @@ class DefaultReplyer:
|
||||
def __init__(
|
||||
self,
|
||||
chat_stream: ChatStream,
|
||||
model_configs: Optional[List[Dict[str, Any]]] = None,
|
||||
model_set_with_weight: Optional[List[Tuple[TaskConfig, float]]] = None,
|
||||
request_type: str = "focus.replyer",
|
||||
):
|
||||
self.request_type = request_type
|
||||
|
||||
if model_configs:
|
||||
self.express_model_configs = model_configs
|
||||
if model_set_with_weight:
|
||||
# self.express_model_configs = model_configs
|
||||
self.model_set: List[Tuple[TaskConfig, float]] = model_set_with_weight
|
||||
else:
|
||||
# 当未提供配置时,使用默认配置并赋予默认权重
|
||||
|
||||
model_config_1 = global_config.model.replyer_1.copy()
|
||||
model_config_2 = global_config.model.replyer_2.copy()
|
||||
# model_config_1 = global_config.model.replyer_1.copy()
|
||||
# model_config_2 = global_config.model.replyer_2.copy()
|
||||
prob_first = global_config.chat.replyer_random_probability
|
||||
|
||||
model_config_1["weight"] = prob_first
|
||||
model_config_2["weight"] = 1.0 - prob_first
|
||||
# model_config_1["weight"] = prob_first
|
||||
# model_config_2["weight"] = 1.0 - prob_first
|
||||
|
||||
self.express_model_configs = [model_config_1, model_config_2]
|
||||
# self.express_model_configs = [model_config_1, model_config_2]
|
||||
self.model_set = [
|
||||
(model_config.model_task_config.replyer_1, prob_first),
|
||||
(model_config.model_task_config.replyer_2, 1.0 - prob_first),
|
||||
]
|
||||
|
||||
if not self.express_model_configs:
|
||||
logger.warning("未找到有效的模型配置,回复生成可能会失败。")
|
||||
# 提供一个最终的回退,以防止在空列表上调用 random.choice
|
||||
fallback_config = global_config.model.replyer_1.copy()
|
||||
fallback_config.setdefault("weight", 1.0)
|
||||
self.express_model_configs = [fallback_config]
|
||||
# if not self.express_model_configs:
|
||||
# logger.warning("未找到有效的模型配置,回复生成可能会失败。")
|
||||
# # 提供一个最终的回退,以防止在空列表上调用 random.choice
|
||||
# fallback_config = global_config.model.replyer_1.copy()
|
||||
# fallback_config.setdefault("weight", 1.0)
|
||||
# self.express_model_configs = [fallback_config]
|
||||
|
||||
self.chat_stream = chat_stream
|
||||
self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_stream.stream_id)
|
||||
@@ -140,13 +146,14 @@ class DefaultReplyer:
|
||||
self.instant_memory = InstantMemory(chat_id=self.chat_stream.stream_id)
|
||||
|
||||
from src.plugin_system.core.tool_use import ToolExecutor # 延迟导入ToolExecutor,不然会循环依赖
|
||||
|
||||
self.tool_executor = ToolExecutor(chat_id=self.chat_stream.stream_id, enable_cache=True, cache_ttl=3)
|
||||
|
||||
def _select_weighted_model_config(self) -> Dict[str, Any]:
|
||||
def _select_weighted_models_config(self) -> Tuple[TaskConfig, float]:
|
||||
"""使用加权随机选择来挑选一个模型配置"""
|
||||
configs = self.express_model_configs
|
||||
configs = self.model_set
|
||||
# 提取权重,如果模型配置中没有'weight'键,则默认为1.0
|
||||
weights = [config.get("weight", 1.0) for config in configs]
|
||||
weights = [weight for _, weight in configs]
|
||||
|
||||
return random.choices(population=configs, weights=weights, k=1)[0]
|
||||
|
||||
@@ -188,12 +195,11 @@ class DefaultReplyer:
|
||||
|
||||
# 4. 调用 LLM 生成回复
|
||||
content = None
|
||||
# TODO: 复活这里
|
||||
# reasoning_content = None
|
||||
# model_name = "unknown_model"
|
||||
reasoning_content = None
|
||||
model_name = "unknown_model"
|
||||
|
||||
try:
|
||||
content = await self.llm_generate_content(prompt)
|
||||
content, reasoning_content, model_name, _ = await self.llm_generate_content(prompt)
|
||||
logger.debug(f"replyer生成内容: {content}")
|
||||
|
||||
except Exception as llm_e:
|
||||
@@ -236,15 +242,14 @@ class DefaultReplyer:
|
||||
)
|
||||
|
||||
content = None
|
||||
# TODO: 复活这里
|
||||
# reasoning_content = None
|
||||
# model_name = "unknown_model"
|
||||
reasoning_content = None
|
||||
model_name = "unknown_model"
|
||||
if not prompt:
|
||||
logger.error("Prompt 构建失败,无法生成回复。")
|
||||
return False, None, None
|
||||
|
||||
try:
|
||||
content = await self.llm_generate_content(prompt)
|
||||
content, reasoning_content, model_name, _ = await self.llm_generate_content(prompt)
|
||||
logger.info(f"想要表达:{raw_reply}||理由:{reason}||生成回复: {content}\n")
|
||||
|
||||
except Exception as llm_e:
|
||||
@@ -843,7 +848,7 @@ class DefaultReplyer:
|
||||
raw_reply: str,
|
||||
reason: str,
|
||||
reply_to: str,
|
||||
) -> str:
|
||||
) -> str: # sourcery skip: remove-redundant-if
|
||||
chat_stream = self.chat_stream
|
||||
chat_id = chat_stream.stream_id
|
||||
is_group_chat = bool(chat_stream.group_info)
|
||||
@@ -977,30 +982,23 @@ class DefaultReplyer:
|
||||
display_message=display_message,
|
||||
)
|
||||
|
||||
async def llm_generate_content(self, prompt: str) -> str:
|
||||
async def llm_generate_content(self, prompt: str):
|
||||
with Timer("LLM生成", {}): # 内部计时器,可选保留
|
||||
# 加权随机选择一个模型配置
|
||||
selected_model_config = self._select_weighted_model_config()
|
||||
model_display_name = selected_model_config.get('model_name') or selected_model_config.get('name', 'N/A')
|
||||
logger.info(
|
||||
f"使用模型生成回复: {model_display_name} (选中概率: {selected_model_config.get('weight', 1.0)})"
|
||||
)
|
||||
selected_model_config, weight = self._select_weighted_models_config()
|
||||
logger.info(f"使用模型集生成回复: {selected_model_config} (选中概率: {weight})")
|
||||
|
||||
express_model = LLMRequest(
|
||||
model=selected_model_config,
|
||||
request_type=self.request_type,
|
||||
)
|
||||
express_model = LLMRequest(model_set=selected_model_config, request_type=self.request_type)
|
||||
|
||||
if global_config.debug.show_prompt:
|
||||
logger.info(f"\n{prompt}\n")
|
||||
else:
|
||||
logger.debug(f"\n{prompt}\n")
|
||||
|
||||
# TODO: 这里的_应该做出替换
|
||||
content, _ = await express_model.generate_response_async(prompt)
|
||||
content, (reasoning_content, model_name, tool_calls) = await express_model.generate_response_async(prompt)
|
||||
|
||||
logger.debug(f"replyer生成内容: {content}")
|
||||
return content
|
||||
return content, reasoning_content, model_name, tool_calls
|
||||
|
||||
|
||||
def weighted_sample_no_replacement(items, weights, k) -> list:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from typing import Dict, Any, Optional, List
|
||||
from typing import Dict, Optional, List, Tuple
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.api_ada_configs import TaskConfig
|
||||
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
||||
from src.chat.replyer.default_generator import DefaultReplyer
|
||||
|
||||
@@ -15,7 +16,7 @@ class ReplyerManager:
|
||||
self,
|
||||
chat_stream: Optional[ChatStream] = None,
|
||||
chat_id: Optional[str] = None,
|
||||
model_configs: Optional[List[Dict[str, Any]]] = None,
|
||||
model_set_with_weight: Optional[List[Tuple[TaskConfig, float]]] = None,
|
||||
request_type: str = "replyer",
|
||||
) -> Optional[DefaultReplyer]:
|
||||
"""
|
||||
@@ -49,7 +50,7 @@ class ReplyerManager:
|
||||
# model_configs 只在此时(初始化时)生效
|
||||
replyer = DefaultReplyer(
|
||||
chat_stream=target_stream,
|
||||
model_configs=model_configs, # 可以是None,此时使用默认模型
|
||||
model_set_with_weight=model_set_with_weight, # 可以是None,此时使用默认模型
|
||||
request_type=request_type,
|
||||
)
|
||||
self._repliers[stream_id] = replyer
|
||||
|
||||
@@ -11,7 +11,7 @@ from typing import Optional, Tuple, Dict, List, Any
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.message_repository import find_messages, count_messages
|
||||
from src.config.config import global_config
|
||||
from src.config.config import global_config, model_config
|
||||
from src.chat.message_receive.message import MessageRecv
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
@@ -109,13 +109,11 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]:
|
||||
return is_mentioned, reply_probability
|
||||
|
||||
|
||||
async def get_embedding(text, request_type="embedding"):
|
||||
async def get_embedding(text, request_type="embedding") -> Optional[List[float]]:
|
||||
"""获取文本的embedding向量"""
|
||||
# TODO: API-Adapter修改标记
|
||||
llm = LLMRequest(model=global_config.model.embedding, request_type=request_type)
|
||||
# return llm.get_embedding_sync(text)
|
||||
llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type=request_type)
|
||||
try:
|
||||
embedding = await llm.get_embedding(text)
|
||||
embedding, _ = await llm.get_embedding(text)
|
||||
except Exception as e:
|
||||
logger.error(f"获取embedding失败: {str(e)}")
|
||||
embedding = None
|
||||
|
||||
@@ -14,7 +14,7 @@ from rich.traceback import install
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database import db
|
||||
from src.common.database.database_model import Images, ImageDescriptions
|
||||
from src.config.config import global_config
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
install(extra_lines=3)
|
||||
@@ -37,7 +37,7 @@ class ImageManager:
|
||||
self._ensure_image_dir()
|
||||
|
||||
self._initialized = True
|
||||
self.vlm = LLMRequest(model=global_config.model.vlm, temperature=0.4, max_tokens=300, request_type="image")
|
||||
self.vlm = LLMRequest(model_set=model_config.model_task_config.vlm, request_type="image")
|
||||
|
||||
try:
|
||||
db.connect(reuse_if_open=True)
|
||||
@@ -107,6 +107,7 @@ class ImageManager:
|
||||
# 优先使用EmojiManager查询已注册表情包的描述
|
||||
try:
|
||||
from src.chat.emoji_system.emoji_manager import get_emoji_manager
|
||||
|
||||
emoji_manager = get_emoji_manager()
|
||||
cached_emoji_description = await emoji_manager.get_emoji_description_by_hash(image_hash)
|
||||
if cached_emoji_description:
|
||||
@@ -116,8 +117,7 @@ class ImageManager:
|
||||
logger.debug(f"查询EmojiManager时出错: {e}")
|
||||
|
||||
# 查询ImageDescriptions表的缓存描述
|
||||
cached_description = self._get_description_from_db(image_hash, "emoji")
|
||||
if cached_description:
|
||||
if cached_description := self._get_description_from_db(image_hash, "emoji"):
|
||||
logger.info(f"[缓存命中] 使用ImageDescriptions表中的描述: {cached_description[:50]}...")
|
||||
return f"[表情包:{cached_description}]"
|
||||
|
||||
@@ -130,10 +130,16 @@ class ImageManager:
|
||||
logger.warning("GIF转换失败,无法获取描述")
|
||||
return "[表情包(GIF处理失败)]"
|
||||
vlm_prompt = "这是一个动态图表情包,每一张图代表了动态图的某一帧,黑色背景代表透明,描述一下表情包表达的情感和内容,描述细节,从互联网梗,meme的角度去分析"
|
||||
detailed_description, _ = await self.vlm.generate_response_for_image(vlm_prompt, image_base64_processed, "jpg")
|
||||
detailed_description, _ = await self.vlm.generate_response_for_image(
|
||||
vlm_prompt, image_base64_processed, "jpg", temperature=0.4, max_tokens=300
|
||||
)
|
||||
else:
|
||||
vlm_prompt = "这是一个表情包,请详细描述一下表情包所表达的情感和内容,描述细节,从互联网梗,meme的角度去分析"
|
||||
detailed_description, _ = await self.vlm.generate_response_for_image(vlm_prompt, image_base64, image_format)
|
||||
vlm_prompt = (
|
||||
"这是一个表情包,请详细描述一下表情包所表达的情感和内容,描述细节,从互联网梗,meme的角度去分析"
|
||||
)
|
||||
detailed_description, _ = await self.vlm.generate_response_for_image(
|
||||
vlm_prompt, image_base64, image_format, temperature=0.4, max_tokens=300
|
||||
)
|
||||
|
||||
if detailed_description is None:
|
||||
logger.warning("VLM未能生成表情包详细描述")
|
||||
@@ -152,13 +158,16 @@ class ImageManager:
|
||||
"""
|
||||
|
||||
# 使用较低温度确保输出稳定
|
||||
emotion_llm = LLMRequest(model=global_config.model.utils, temperature=0.3, max_tokens=50, request_type="emoji")
|
||||
emotion_result, _ = await emotion_llm.generate_response_async(emotion_prompt)
|
||||
emotion_llm = LLMRequest(model_set=model_config.model_task_config.utils, request_type="emoji")
|
||||
emotion_result, _ = await emotion_llm.generate_response_async(
|
||||
emotion_prompt, temperature=0.3, max_tokens=50
|
||||
)
|
||||
|
||||
if emotion_result is None:
|
||||
logger.warning("LLM未能生成情感标签,使用详细描述的前几个词")
|
||||
# 降级处理:从详细描述中提取关键词
|
||||
import jieba
|
||||
|
||||
words = list(jieba.cut(detailed_description))
|
||||
emotion_result = ",".join(words[:2]) if len(words) >= 2 else (words[0] if words else "表情")
|
||||
|
||||
@@ -172,9 +181,7 @@ class ImageManager:
|
||||
|
||||
logger.info(f"[emoji识别] 详细描述: {detailed_description[:50]}... -> 情感标签: {final_emotion}")
|
||||
|
||||
# 再次检查缓存,防止并发写入时重复生成
|
||||
cached_description = self._get_description_from_db(image_hash, "emoji")
|
||||
if cached_description:
|
||||
if cached_description := self._get_description_from_db(image_hash, "emoji"):
|
||||
logger.warning(f"虽然生成了描述,但是找到缓存表情包描述: {cached_description}")
|
||||
return f"[表情包:{cached_description}]"
|
||||
|
||||
@@ -242,9 +249,7 @@ class ImageManager:
|
||||
logger.debug(f"[缓存命中] 使用Images表中的图片描述: {existing_image.description[:50]}...")
|
||||
return f"[图片:{existing_image.description}]"
|
||||
|
||||
# 查询ImageDescriptions表的缓存描述
|
||||
cached_description = self._get_description_from_db(image_hash, "image")
|
||||
if cached_description:
|
||||
if cached_description := self._get_description_from_db(image_hash, "image"):
|
||||
logger.debug(f"[缓存命中] 使用ImageDescriptions表中的描述: {cached_description[:50]}...")
|
||||
return f"[图片:{cached_description}]"
|
||||
|
||||
@@ -252,7 +257,9 @@ class ImageManager:
|
||||
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore
|
||||
prompt = global_config.custom_prompt.image_prompt
|
||||
logger.info(f"[VLM调用] 为图片生成新描述 (Hash: {image_hash[:8]}...)")
|
||||
description, _ = await self.vlm.generate_response_for_image(prompt, image_base64, image_format)
|
||||
description, _ = await self.vlm.generate_response_for_image(
|
||||
prompt, image_base64, image_format, temperature=0.4, max_tokens=300
|
||||
)
|
||||
|
||||
if description is None:
|
||||
logger.warning("AI未能生成图片描述")
|
||||
@@ -445,10 +452,7 @@ class ImageManager:
|
||||
image_bytes = base64.b64decode(image_base64)
|
||||
image_hash = hashlib.md5(image_bytes).hexdigest()
|
||||
|
||||
# 检查图片是否已存在
|
||||
existing_image = Images.get_or_none(Images.emoji_hash == image_hash)
|
||||
|
||||
if existing_image:
|
||||
if existing_image := Images.get_or_none(Images.emoji_hash == image_hash):
|
||||
# 检查是否缺少必要字段,如果缺少则创建新记录
|
||||
if (
|
||||
not hasattr(existing_image, "image_id")
|
||||
@@ -524,9 +528,7 @@ class ImageManager:
|
||||
|
||||
# 优先检查是否已有其他相同哈希的图片记录包含描述
|
||||
existing_with_description = Images.get_or_none(
|
||||
(Images.emoji_hash == image_hash) &
|
||||
(Images.description.is_null(False)) &
|
||||
(Images.description != "")
|
||||
(Images.emoji_hash == image_hash) & (Images.description.is_null(False)) & (Images.description != "")
|
||||
)
|
||||
if existing_with_description and existing_with_description.id != image.id:
|
||||
logger.debug(f"[缓存复用] 从其他相同图片记录复用描述: {existing_with_description.description[:50]}...")
|
||||
@@ -538,8 +540,7 @@ class ImageManager:
|
||||
return
|
||||
|
||||
# 检查ImageDescriptions表的缓存描述
|
||||
cached_description = self._get_description_from_db(image_hash, "image")
|
||||
if cached_description:
|
||||
if cached_description := self._get_description_from_db(image_hash, "image"):
|
||||
logger.debug(f"[缓存复用] 从ImageDescriptions表复用描述: {cached_description[:50]}...")
|
||||
image.description = cached_description
|
||||
image.vlm_processed = True
|
||||
@@ -554,15 +555,15 @@ class ImageManager:
|
||||
|
||||
# 获取VLM描述
|
||||
logger.info(f"[VLM异步调用] 为图片生成描述 (ID: {image_id}, Hash: {image_hash[:8]}...)")
|
||||
description, _ = await self.vlm.generate_response_for_image(prompt, image_base64, image_format)
|
||||
description, _ = await self.vlm.generate_response_for_image(
|
||||
prompt, image_base64, image_format, temperature=0.4, max_tokens=300
|
||||
)
|
||||
|
||||
if description is None:
|
||||
logger.warning("VLM未能生成图片描述")
|
||||
description = "无法生成描述"
|
||||
|
||||
# 再次检查缓存,防止并发写入时重复生成
|
||||
cached_description = self._get_description_from_db(image_hash, "image")
|
||||
if cached_description:
|
||||
if cached_description := self._get_description_from_db(image_hash, "image"):
|
||||
logger.warning(f"虽然生成了描述,但是找到缓存图片描述: {cached_description}")
|
||||
description = cached_description
|
||||
|
||||
@@ -606,7 +607,7 @@ def image_path_to_base64(image_path: str) -> str:
|
||||
raise FileNotFoundError(f"图片文件不存在: {image_path}")
|
||||
|
||||
with open(image_path, "rb") as f:
|
||||
image_data = f.read()
|
||||
if not image_data:
|
||||
raise IOError(f"读取图片文件失败: {image_path}")
|
||||
if image_data := f.read():
|
||||
return base64.b64encode(image_data).decode("utf-8")
|
||||
else:
|
||||
raise IOError(f"读取图片文件失败: {image_path}")
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import base64
|
||||
|
||||
from src.config.config import global_config
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
from src.common.logger import get_logger
|
||||
@@ -20,7 +20,7 @@ async def get_voice_text(voice_base64: str) -> str:
|
||||
if isinstance(voice_base64, str):
|
||||
voice_base64 = voice_base64.encode("ascii", errors="ignore").decode("ascii")
|
||||
voice_bytes = base64.b64decode(voice_base64)
|
||||
_llm = LLMRequest(model=global_config.model.voice, request_type="voice")
|
||||
_llm = LLMRequest(model_set=model_config.model_task_config.voice, request_type="voice")
|
||||
text = await _llm.generate_response_for_voice(voice_bytes)
|
||||
if text is None:
|
||||
logger.warning("未能生成语音文本")
|
||||
|
||||
@@ -19,13 +19,13 @@ Mxp 模式:梦溪畔独家赞助
|
||||
下下策是询问一个菜鸟(@梦溪畔)
|
||||
"""
|
||||
|
||||
from .willing_manager import BaseWillingManager
|
||||
from typing import Dict
|
||||
import asyncio
|
||||
import time
|
||||
import math
|
||||
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
from .willing_manager import BaseWillingManager
|
||||
|
||||
|
||||
class MxpWillingManager(BaseWillingManager):
|
||||
|
||||
@@ -60,268 +60,6 @@ TEMPLATE_DIR = os.path.join(PROJECT_ROOT, "template")
|
||||
MMC_VERSION = "0.10.0-snapshot.2"
|
||||
|
||||
|
||||
# def _get_config_version(toml: Dict) -> Version:
|
||||
# """提取配置文件的 SpecifierSet 版本数据
|
||||
# Args:
|
||||
# toml[dict]: 输入的配置文件字典
|
||||
# Returns:
|
||||
# Version
|
||||
# """
|
||||
|
||||
# if "inner" in toml and "version" in toml["inner"]:
|
||||
# config_version: str = toml["inner"]["version"]
|
||||
# else:
|
||||
# raise InvalidVersion("配置文件缺少版本信息,请检查配置文件。")
|
||||
|
||||
# try:
|
||||
# return version.parse(config_version)
|
||||
# except InvalidVersion as e:
|
||||
# logger.error(
|
||||
# "配置文件中 inner段 的 version 键是错误的版本描述\n"
|
||||
# f"请检查配置文件,当前 version 键: {config_version}\n"
|
||||
# f"错误信息: {e}"
|
||||
# )
|
||||
# raise e
|
||||
|
||||
|
||||
# def _request_conf(parent: Dict, config: ModuleConfig):
|
||||
# request_conf_config = parent.get("request_conf")
|
||||
# config.req_conf.max_retry = request_conf_config.get(
|
||||
# "max_retry", config.req_conf.max_retry
|
||||
# )
|
||||
# config.req_conf.timeout = request_conf_config.get(
|
||||
# "timeout", config.req_conf.timeout
|
||||
# )
|
||||
# config.req_conf.retry_interval = request_conf_config.get(
|
||||
# "retry_interval", config.req_conf.retry_interval
|
||||
# )
|
||||
# config.req_conf.default_temperature = request_conf_config.get(
|
||||
# "default_temperature", config.req_conf.default_temperature
|
||||
# )
|
||||
# config.req_conf.default_max_tokens = request_conf_config.get(
|
||||
# "default_max_tokens", config.req_conf.default_max_tokens
|
||||
# )
|
||||
|
||||
|
||||
# def _api_providers(parent: Dict, config: ModuleConfig):
|
||||
# api_providers_config = parent.get("api_providers")
|
||||
# for provider in api_providers_config:
|
||||
# name = provider.get("name", None)
|
||||
# base_url = provider.get("base_url", None)
|
||||
# api_key = provider.get("api_key", None)
|
||||
# api_keys = provider.get("api_keys", []) # 新增:支持多个API Key
|
||||
# client_type = provider.get("client_type", "openai")
|
||||
|
||||
# if name in config.api_providers: # 查重
|
||||
# logger.error(f"重复的API提供商名称: {name},请检查配置文件。")
|
||||
# raise KeyError(f"重复的API提供商名称: {name},请检查配置文件。")
|
||||
|
||||
# if name and base_url:
|
||||
# # 处理API Key配置:支持单个api_key或多个api_keys
|
||||
# if api_keys:
|
||||
# # 使用新格式:api_keys列表
|
||||
# logger.debug(f"API提供商 '{name}' 配置了 {len(api_keys)} 个API Key")
|
||||
# elif api_key:
|
||||
# # 向后兼容:使用单个api_key
|
||||
# api_keys = [api_key]
|
||||
# logger.debug(f"API提供商 '{name}' 使用单个API Key(向后兼容模式)")
|
||||
# else:
|
||||
# logger.warning(f"API提供商 '{name}' 没有配置API Key,某些功能可能不可用")
|
||||
|
||||
# config.api_providers[name] = APIProvider(
|
||||
# name=name,
|
||||
# base_url=base_url,
|
||||
# api_key=api_key, # 保留向后兼容
|
||||
# api_keys=api_keys, # 新格式
|
||||
# client_type=client_type,
|
||||
# )
|
||||
# else:
|
||||
# logger.error(f"API提供商 '{name}' 的配置不完整,请检查配置文件。")
|
||||
# raise ValueError(f"API提供商 '{name}' 的配置不完整,请检查配置文件。")
|
||||
|
||||
|
||||
# def _models(parent: Dict, config: ModuleConfig):
|
||||
# models_config = parent.get("models")
|
||||
# for model in models_config:
|
||||
# model_identifier = model.get("model_identifier", None)
|
||||
# name = model.get("name", model_identifier)
|
||||
# api_provider = model.get("api_provider", None)
|
||||
# price_in = model.get("price_in", 0.0)
|
||||
# price_out = model.get("price_out", 0.0)
|
||||
# force_stream_mode = model.get("force_stream_mode", False)
|
||||
# task_type = model.get("task_type", "")
|
||||
# capabilities = model.get("capabilities", [])
|
||||
|
||||
# if name in config.models: # 查重
|
||||
# logger.error(f"重复的模型名称: {name},请检查配置文件。")
|
||||
# raise KeyError(f"重复的模型名称: {name},请检查配置文件。")
|
||||
|
||||
# if model_identifier and api_provider:
|
||||
# # 检查API提供商是否存在
|
||||
# if api_provider not in config.api_providers:
|
||||
# logger.error(f"未声明的API提供商 '{api_provider}' ,请检查配置文件。")
|
||||
# raise ValueError(
|
||||
# f"未声明的API提供商 '{api_provider}' ,请检查配置文件。"
|
||||
# )
|
||||
# config.models[name] = ModelInfo(
|
||||
# name=name,
|
||||
# model_identifier=model_identifier,
|
||||
# api_provider=api_provider,
|
||||
# price_in=price_in,
|
||||
# price_out=price_out,
|
||||
# force_stream_mode=force_stream_mode,
|
||||
# task_type=task_type,
|
||||
# capabilities=capabilities,
|
||||
# )
|
||||
# else:
|
||||
# logger.error(f"模型 '{name}' 的配置不完整,请检查配置文件。")
|
||||
# raise ValueError(f"模型 '{name}' 的配置不完整,请检查配置文件。")
|
||||
|
||||
|
||||
# def _task_model_usage(parent: Dict, config: ModuleConfig):
|
||||
# model_usage_configs = parent.get("task_model_usage")
|
||||
# config.task_model_arg_map = {}
|
||||
# for task_name, item in model_usage_configs.items():
|
||||
# if task_name in config.task_model_arg_map:
|
||||
# logger.error(f"子任务 {task_name} 已存在,请检查配置文件。")
|
||||
# raise KeyError(f"子任务 {task_name} 已存在,请检查配置文件。")
|
||||
|
||||
# usage = []
|
||||
# if isinstance(item, Dict):
|
||||
# if "model" in item:
|
||||
# usage.append(
|
||||
# ModelUsageArgConfigItem(
|
||||
# name=item["model"],
|
||||
# temperature=item.get("temperature", None),
|
||||
# max_tokens=item.get("max_tokens", None),
|
||||
# max_retry=item.get("max_retry", None),
|
||||
# )
|
||||
# )
|
||||
# else:
|
||||
# logger.error(f"子任务 {task_name} 的模型配置不合法,请检查配置文件。")
|
||||
# raise ValueError(
|
||||
# f"子任务 {task_name} 的模型配置不合法,请检查配置文件。"
|
||||
# )
|
||||
# elif isinstance(item, List):
|
||||
# for model in item:
|
||||
# if isinstance(model, Dict):
|
||||
# usage.append(
|
||||
# ModelUsageArgConfigItem(
|
||||
# name=model["model"],
|
||||
# temperature=model.get("temperature", None),
|
||||
# max_tokens=model.get("max_tokens", None),
|
||||
# max_retry=model.get("max_retry", None),
|
||||
# )
|
||||
# )
|
||||
# elif isinstance(model, str):
|
||||
# usage.append(
|
||||
# ModelUsageArgConfigItem(
|
||||
# name=model,
|
||||
# temperature=None,
|
||||
# max_tokens=None,
|
||||
# max_retry=None,
|
||||
# )
|
||||
# )
|
||||
# else:
|
||||
# logger.error(
|
||||
# f"子任务 {task_name} 的模型配置不合法,请检查配置文件。"
|
||||
# )
|
||||
# raise ValueError(
|
||||
# f"子任务 {task_name} 的模型配置不合法,请检查配置文件。"
|
||||
# )
|
||||
# elif isinstance(item, str):
|
||||
# usage.append(
|
||||
# ModelUsageArgConfigItem(
|
||||
# name=item,
|
||||
# temperature=None,
|
||||
# max_tokens=None,
|
||||
# max_retry=None,
|
||||
# )
|
||||
# )
|
||||
|
||||
# config.task_model_arg_map[task_name] = ModelUsageArgConfig(
|
||||
# name=task_name,
|
||||
# usage=usage,
|
||||
# )
|
||||
|
||||
|
||||
# def api_ada_load_config(config_path: str) -> ModuleConfig:
|
||||
# """从TOML配置文件加载配置"""
|
||||
# config = ModuleConfig()
|
||||
|
||||
# include_configs: Dict[str, Dict[str, Any]] = {
|
||||
# "request_conf": {
|
||||
# "func": _request_conf,
|
||||
# "support": ">=0.0.0",
|
||||
# "necessary": False,
|
||||
# },
|
||||
# "api_providers": {"func": _api_providers, "support": ">=0.0.0"},
|
||||
# "models": {"func": _models, "support": ">=0.0.0"},
|
||||
# "task_model_usage": {"func": _task_model_usage, "support": ">=0.0.0"},
|
||||
# }
|
||||
|
||||
# if os.path.exists(config_path):
|
||||
# with open(config_path, "rb") as f:
|
||||
# try:
|
||||
# toml_dict = tomlkit.load(f)
|
||||
# except tomlkit.TOMLDecodeError as e:
|
||||
# logger.critical(
|
||||
# f"配置文件model_list.toml填写有误,请检查第{e.lineno}行第{e.colno}处:{e.msg}"
|
||||
# )
|
||||
# exit(1)
|
||||
|
||||
# # 获取配置文件版本
|
||||
# config.INNER_VERSION = _get_config_version(toml_dict)
|
||||
|
||||
# # 检查版本
|
||||
# if config.INNER_VERSION > Version(NEWEST_VER):
|
||||
# logger.warning(
|
||||
# f"当前配置文件版本 {config.INNER_VERSION} 高于支持的最新版本 {NEWEST_VER},可能导致异常,建议更新依赖。"
|
||||
# )
|
||||
|
||||
# # 解析配置文件
|
||||
# # 如果在配置中找到了需要的项,调用对应项的闭包函数处理
|
||||
# for key in include_configs:
|
||||
# if key in toml_dict:
|
||||
# group_specifier_set: SpecifierSet = SpecifierSet(
|
||||
# include_configs[key]["support"]
|
||||
# )
|
||||
|
||||
# # 检查配置文件版本是否在支持范围内
|
||||
# if config.INNER_VERSION in group_specifier_set:
|
||||
# # 如果版本在支持范围内,检查是否存在通知
|
||||
# if "notice" in include_configs[key]:
|
||||
# logger.warning(include_configs[key]["notice"])
|
||||
# # 调用闭包函数处理配置
|
||||
# (include_configs[key]["func"])(toml_dict, config)
|
||||
# else:
|
||||
# # 如果版本不在支持范围内,崩溃并提示用户
|
||||
# logger.error(
|
||||
# f"配置文件中的 '{key}' 字段的版本 ({config.INNER_VERSION}) 不在支持范围内。\n"
|
||||
# f"当前程序仅支持以下版本范围: {group_specifier_set}"
|
||||
# )
|
||||
# raise InvalidVersion(
|
||||
# f"当前程序仅支持以下版本范围: {group_specifier_set}"
|
||||
# )
|
||||
|
||||
# # 如果 necessary 项目存在,而且显式声明是 False,进入特殊处理
|
||||
# elif (
|
||||
# "necessary" in include_configs[key]
|
||||
# and include_configs[key].get("necessary") is False
|
||||
# ):
|
||||
# # 通过 pass 处理的项虽然直接忽略也是可以的,但是为了不增加理解困难,依然需要在这里显式处理
|
||||
# if key == "keywords_reaction":
|
||||
# pass
|
||||
# else:
|
||||
# # 如果用户根本没有需要的配置项,提示缺少配置
|
||||
# logger.error(f"配置文件中缺少必需的字段: '{key}'")
|
||||
# raise KeyError(f"配置文件中缺少必需的字段: '{key}'")
|
||||
|
||||
# logger.info(f"成功加载配置文件: {config_path}")
|
||||
|
||||
# return config
|
||||
|
||||
|
||||
def get_key_comment(toml_table, key):
|
||||
# 获取key的注释(如果有)
|
||||
if hasattr(toml_table, "trivia") and hasattr(toml_table.trivia, "comment"):
|
||||
@@ -626,6 +364,16 @@ class APIAdapterConfig(ConfigBase):
|
||||
"""API提供商列表"""
|
||||
|
||||
def __post_init__(self):
|
||||
# 检查API提供商名称是否重复
|
||||
provider_names = [provider.name for provider in self.api_providers]
|
||||
if len(provider_names) != len(set(provider_names)):
|
||||
raise ValueError("API提供商名称存在重复,请检查配置文件。")
|
||||
|
||||
# 检查模型名称是否重复
|
||||
model_names = [model.name for model in self.models]
|
||||
if len(model_names) != len(set(model_names)):
|
||||
raise ValueError("模型名称存在重复,请检查配置文件。")
|
||||
|
||||
self.api_providers_dict = {provider.name: provider for provider in self.api_providers}
|
||||
self.models_dict = {model.name: model for model in self.models}
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ import hashlib
|
||||
import time
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.person_info.person_info import get_person_info_manager
|
||||
from rich.traceback import install
|
||||
@@ -23,10 +23,7 @@ class Individuality:
|
||||
self.meta_info_file_path = "data/personality/meta.json"
|
||||
self.personality_data_file_path = "data/personality/personality_data.json"
|
||||
|
||||
self.model = LLMRequest(
|
||||
model=global_config.model.utils,
|
||||
request_type="individuality.compress",
|
||||
)
|
||||
self.model = LLMRequest(model_set=model_config.model_task_config.utils, request_type="individuality.compress")
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""初始化个体特征"""
|
||||
@@ -35,7 +32,6 @@ class Individuality:
|
||||
personality_side = global_config.personality.personality_side
|
||||
identity = global_config.personality.identity
|
||||
|
||||
|
||||
person_info_manager = get_person_info_manager()
|
||||
self.bot_person_id = person_info_manager.get_person_id("system", "bot_id")
|
||||
self.name = bot_nickname
|
||||
@@ -235,7 +231,7 @@ class Individuality:
|
||||
"personality": personality,
|
||||
"identity": identity,
|
||||
"bot_nickname": self.name,
|
||||
"last_updated": int(time.time())
|
||||
"last_updated": int(time.time()),
|
||||
}
|
||||
self._save_personality_data(personality_data)
|
||||
|
||||
@@ -269,7 +265,7 @@ class Individuality:
|
||||
2. 尽量简洁,不超过30字
|
||||
3. 直接输出压缩后的内容,不要解释"""
|
||||
|
||||
response, (_, _) = await self.model.generate_response_async(
|
||||
response, _ = await self.model.generate_response_async(
|
||||
prompt=prompt,
|
||||
)
|
||||
|
||||
@@ -308,7 +304,7 @@ class Individuality:
|
||||
2. 尽量简洁,不超过30字
|
||||
3. 直接输出压缩后的内容,不要解释"""
|
||||
|
||||
response, (_, _) = await self.model.generate_response_async(
|
||||
response, _ = await self.model.generate_response_async(
|
||||
prompt=prompt,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,12 +0,0 @@
|
||||
import importlib
|
||||
from typing import Dict
|
||||
|
||||
from src.config.config import model_config
|
||||
from src.common.logger import get_logger
|
||||
|
||||
from .model_client import ModelRequestHandler, BaseClient
|
||||
|
||||
logger = get_logger("模型管理器")
|
||||
|
||||
class ModelManager:
|
||||
|
||||
@@ -1,92 +0,0 @@
|
||||
import importlib
|
||||
from typing import Dict
|
||||
|
||||
from src.config.config import model_config
|
||||
from src.config.api_ada_configs import ModuleConfig, ModelUsageArgConfig
|
||||
from src.common.logger import get_logger
|
||||
|
||||
from .model_client import ModelRequestHandler, BaseClient
|
||||
|
||||
logger = get_logger("模型管理器")
|
||||
|
||||
class ModelManager:
|
||||
# TODO: 添加读写锁,防止异步刷新配置时发生数据竞争
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: ModuleConfig,
|
||||
):
|
||||
self.config: ModuleConfig = config
|
||||
"""配置信息"""
|
||||
|
||||
self.api_client_map: Dict[str, BaseClient] = {}
|
||||
"""API客户端映射表"""
|
||||
|
||||
self._request_handler_cache: Dict[str, ModelRequestHandler] = {}
|
||||
"""ModelRequestHandler缓存,避免重复创建"""
|
||||
|
||||
for provider_name, api_provider in self.config.api_providers.items():
|
||||
# 初始化API客户端
|
||||
try:
|
||||
# 根据配置动态加载实现
|
||||
client_module = importlib.import_module(
|
||||
f".model_client.{api_provider.client_type}_client", __package__
|
||||
)
|
||||
client_class = getattr(
|
||||
client_module, f"{api_provider.client_type.capitalize()}Client"
|
||||
)
|
||||
if not issubclass(client_class, BaseClient):
|
||||
raise TypeError(
|
||||
f"'{client_class.__name__}' is not a subclass of 'BaseClient'"
|
||||
)
|
||||
self.api_client_map[api_provider.name] = client_class(
|
||||
api_provider
|
||||
) # 实例化,放入api_client_map
|
||||
except ImportError as e:
|
||||
logger.error(f"Failed to import client module: {e}")
|
||||
raise ImportError(
|
||||
f"Failed to import client module for '{provider_name}': {e}"
|
||||
) from e
|
||||
|
||||
def __getitem__(self, task_name: str) -> ModelRequestHandler:
|
||||
"""
|
||||
获取任务所需的模型客户端(封装)
|
||||
使用缓存机制避免重复创建ModelRequestHandler
|
||||
:param task_name: 任务名称
|
||||
:return: 模型客户端
|
||||
"""
|
||||
if task_name not in self.config.task_model_arg_map:
|
||||
raise KeyError(f"'{task_name}' not registered in ModelManager")
|
||||
|
||||
# 检查缓存中是否已存在
|
||||
if task_name in self._request_handler_cache:
|
||||
logger.debug(f"🚀 [性能优化] 从缓存获取ModelRequestHandler: {task_name}")
|
||||
return self._request_handler_cache[task_name]
|
||||
|
||||
# 创建新的ModelRequestHandler并缓存
|
||||
logger.debug(f"🔧 [性能优化] 创建并缓存ModelRequestHandler: {task_name}")
|
||||
handler = ModelRequestHandler(
|
||||
task_name=task_name,
|
||||
config=self.config,
|
||||
api_client_map=self.api_client_map,
|
||||
)
|
||||
self._request_handler_cache[task_name] = handler
|
||||
return handler
|
||||
|
||||
def __setitem__(self, task_name: str, value: ModelUsageArgConfig):
|
||||
"""
|
||||
注册任务的模型使用配置
|
||||
:param task_name: 任务名称
|
||||
:param value: 模型使用配置
|
||||
"""
|
||||
self.config.task_model_arg_map[task_name] = value
|
||||
|
||||
def __contains__(self, task_name: str):
|
||||
"""
|
||||
判断任务是否已注册
|
||||
:param task_name: 任务名称
|
||||
:return: 是否在模型列表中
|
||||
"""
|
||||
return task_name in self.config.task_model_arg_map
|
||||
|
||||
|
||||
@@ -1,169 +0,0 @@
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Tuple
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.api_ada_configs import ModelInfo
|
||||
from src.common.database.database_model import LLMUsage
|
||||
|
||||
logger = get_logger("模型使用统计")
|
||||
|
||||
|
||||
class ReqType(Enum):
|
||||
"""
|
||||
请求类型
|
||||
"""
|
||||
|
||||
CHAT = "chat" # 对话请求
|
||||
EMBEDDING = "embedding" # 嵌入请求
|
||||
|
||||
|
||||
class UsageCallStatus(Enum):
|
||||
"""
|
||||
任务调用状态
|
||||
"""
|
||||
|
||||
PROCESSING = "processing" # 处理中
|
||||
SUCCESS = "success" # 成功
|
||||
FAILURE = "failure" # 失败
|
||||
CANCELED = "canceled" # 取消
|
||||
|
||||
|
||||
class ModelUsageStatistic:
|
||||
"""
|
||||
模型使用统计类 - 使用SQLite+Peewee
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
初始化统计类
|
||||
由于使用Peewee ORM,不需要传入数据库实例
|
||||
"""
|
||||
# 确保表已经创建
|
||||
try:
|
||||
from src.common.database.database import db
|
||||
|
||||
db.create_tables([LLMUsage], safe=True)
|
||||
except Exception as e:
|
||||
logger.error(f"创建LLMUsage表失败: {e}")
|
||||
|
||||
@staticmethod
|
||||
def _calculate_cost(prompt_tokens: int, completion_tokens: int, model_info: ModelInfo) -> float:
|
||||
"""计算API调用成本
|
||||
使用模型的pri_in和pri_out价格计算输入和输出的成本
|
||||
|
||||
Args:
|
||||
prompt_tokens: 输入token数量
|
||||
completion_tokens: 输出token数量
|
||||
model_info: 模型信息
|
||||
|
||||
Returns:
|
||||
float: 总成本(元)
|
||||
"""
|
||||
# 使用模型的pri_in和pri_out计算成本
|
||||
input_cost = (prompt_tokens / 1000000) * model_info.price_in
|
||||
output_cost = (completion_tokens / 1000000) * model_info.price_out
|
||||
return round(input_cost + output_cost, 6)
|
||||
|
||||
def create_usage(
|
||||
self,
|
||||
model_name: str,
|
||||
task_name: str = "N/A",
|
||||
request_type: ReqType = ReqType.CHAT,
|
||||
user_id: str = "system",
|
||||
endpoint: str = "/chat/completions",
|
||||
) -> int | None:
|
||||
"""
|
||||
创建模型使用情况记录
|
||||
|
||||
Args:
|
||||
model_name: 模型名
|
||||
task_name: 任务名称
|
||||
request_type: 请求类型,默认为Chat
|
||||
user_id: 用户ID,默认为system
|
||||
endpoint: API端点
|
||||
|
||||
Returns:
|
||||
int | None: 返回记录ID,失败返回None
|
||||
"""
|
||||
try:
|
||||
usage_record = LLMUsage.create(
|
||||
model_name=model_name,
|
||||
user_id=user_id,
|
||||
request_type=request_type.value,
|
||||
endpoint=endpoint,
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
total_tokens=0,
|
||||
cost=0.0,
|
||||
status=UsageCallStatus.PROCESSING.value,
|
||||
timestamp=datetime.now(),
|
||||
)
|
||||
|
||||
# logger.trace(
|
||||
# f"创建了一条模型使用情况记录 - 模型: {model_name}, "
|
||||
# f"子任务: {task_name}, 类型: {request_type.value}, "
|
||||
# f"用户: {user_id}, 记录ID: {usage_record.id}"
|
||||
# )
|
||||
|
||||
return usage_record.id
|
||||
except Exception as e:
|
||||
logger.error(f"创建模型使用情况记录失败: {str(e)}")
|
||||
return None
|
||||
|
||||
def update_usage(
|
||||
self,
|
||||
record_id: int | None,
|
||||
model_info: ModelInfo,
|
||||
usage_data: Tuple[int, int, int] | None = None,
|
||||
stat: UsageCallStatus = UsageCallStatus.SUCCESS,
|
||||
ext_msg: str | None = None,
|
||||
):
|
||||
"""
|
||||
更新模型使用情况
|
||||
|
||||
Args:
|
||||
record_id: 记录ID
|
||||
model_info: 模型信息
|
||||
usage_data: 使用情况数据(输入token数量, 输出token数量, 总token数量)
|
||||
stat: 任务调用状态
|
||||
ext_msg: 额外信息
|
||||
"""
|
||||
if not record_id:
|
||||
logger.error("更新模型使用情况失败: record_id不能为空")
|
||||
return
|
||||
|
||||
if usage_data and len(usage_data) != 3:
|
||||
logger.error("更新模型使用情况失败: usage_data的长度不正确,应该为3个元素")
|
||||
return
|
||||
|
||||
# 提取使用情况数据
|
||||
prompt_tokens = usage_data[0] if usage_data else 0
|
||||
completion_tokens = usage_data[1] if usage_data else 0
|
||||
total_tokens = usage_data[2] if usage_data else 0
|
||||
|
||||
try:
|
||||
# 使用Peewee更新记录
|
||||
update_query = LLMUsage.update(
|
||||
status=stat.value,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
cost=self._calculate_cost(prompt_tokens, completion_tokens, model_info) if usage_data else 0.0,
|
||||
).where(LLMUsage.id == record_id) # type: ignore
|
||||
|
||||
updated_count = update_query.execute()
|
||||
|
||||
if updated_count == 0:
|
||||
logger.warning(f"记录ID {record_id} 不存在,无法更新")
|
||||
return
|
||||
|
||||
logger.debug(
|
||||
f"Token使用情况 - 模型: {model_info.name}, "
|
||||
f"记录ID: {record_id}, "
|
||||
f"任务状态: {stat.value}, 额外信息: {ext_msg or 'N/A'}, "
|
||||
f"提示词: {prompt_tokens}, 完成: {completion_tokens}, "
|
||||
f"总计: {total_tokens}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"记录token使用情况失败: {str(e)}")
|
||||
@@ -2,16 +2,19 @@ import base64
|
||||
import io
|
||||
|
||||
from PIL import Image
|
||||
from datetime import datetime
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database import db # 确保 db 被导入用于 create_tables
|
||||
from src.common.database.database_model import LLMUsage
|
||||
from src.config.api_ada_configs import ModelInfo
|
||||
from .payload_content.message import Message, MessageBuilder
|
||||
from .model_client.base_client import UsageRecord
|
||||
|
||||
logger = get_logger("消息压缩工具")
|
||||
|
||||
|
||||
def compress_messages(
|
||||
messages: list[Message], img_target_size: int = 1 * 1024 * 1024
|
||||
) -> list[Message]:
|
||||
def compress_messages(messages: list[Message], img_target_size: int = 1 * 1024 * 1024) -> list[Message]:
|
||||
"""
|
||||
压缩消息列表中的图片
|
||||
:param messages: 消息列表
|
||||
@@ -28,14 +31,10 @@ def compress_messages(
|
||||
try:
|
||||
image = Image.open(image_data)
|
||||
|
||||
if image.format and (
|
||||
image.format.upper() in ["JPEG", "JPG", "PNG", "WEBP"]
|
||||
):
|
||||
if image.format and (image.format.upper() in ["JPEG", "JPG", "PNG", "WEBP"]):
|
||||
# 静态图像,转换为JPEG格式
|
||||
reformated_image_data = io.BytesIO()
|
||||
image.save(
|
||||
reformated_image_data, format="JPEG", quality=95, optimize=True
|
||||
)
|
||||
image.save(reformated_image_data, format="JPEG", quality=95, optimize=True)
|
||||
image_data = reformated_image_data.getvalue()
|
||||
|
||||
return image_data
|
||||
@@ -43,9 +42,7 @@ def compress_messages(
|
||||
logger.error(f"图片转换格式失败: {str(e)}")
|
||||
return image_data
|
||||
|
||||
def rescale_image(
|
||||
image_data: bytes, scale: float
|
||||
) -> tuple[bytes, tuple[int, int] | None, tuple[int, int] | None]:
|
||||
def rescale_image(image_data: bytes, scale: float) -> tuple[bytes, tuple[int, int] | None, tuple[int, int] | None]:
|
||||
"""
|
||||
缩放图片
|
||||
:param image_data: 图片数据
|
||||
@@ -86,9 +83,7 @@ def compress_messages(
|
||||
else:
|
||||
# 静态图片,直接缩放保存
|
||||
resized_image = image.resize(new_size, Image.Resampling.LANCZOS)
|
||||
resized_image.save(
|
||||
output_buffer, format="JPEG", quality=95, optimize=True
|
||||
)
|
||||
resized_image.save(output_buffer, format="JPEG", quality=95, optimize=True)
|
||||
|
||||
return output_buffer.getvalue(), original_size, new_size
|
||||
|
||||
@@ -99,9 +94,7 @@ def compress_messages(
|
||||
logger.error(traceback.format_exc())
|
||||
return image_data, None, None
|
||||
|
||||
def compress_base64_image(
|
||||
base64_data: str, target_size: int = 1 * 1024 * 1024
|
||||
) -> str:
|
||||
def compress_base64_image(base64_data: str, target_size: int = 1 * 1024 * 1024) -> str:
|
||||
original_b64_data_size = len(base64_data) # 计算原始数据大小
|
||||
|
||||
image_data = base64.b64decode(base64_data)
|
||||
@@ -111,9 +104,7 @@ def compress_messages(
|
||||
base64_data = base64.b64encode(image_data).decode("utf-8")
|
||||
if len(base64_data) <= target_size:
|
||||
# 如果转换后小于目标大小,直接返回
|
||||
logger.info(
|
||||
f"成功将图片转为JPEG格式,编码后大小: {len(base64_data) / 1024:.1f}KB"
|
||||
)
|
||||
logger.info(f"成功将图片转为JPEG格式,编码后大小: {len(base64_data) / 1024:.1f}KB")
|
||||
return base64_data
|
||||
|
||||
# 如果转换后仍然大于目标大小,进行尺寸压缩
|
||||
@@ -139,9 +130,7 @@ def compress_messages(
|
||||
# 图片,进行压缩
|
||||
message_builder.add_image_content(
|
||||
content_item[0],
|
||||
compress_base64_image(
|
||||
content_item[1], target_size=img_target_size
|
||||
),
|
||||
compress_base64_image(content_item[1], target_size=img_target_size),
|
||||
)
|
||||
else:
|
||||
message_builder.add_text_content(content_item)
|
||||
@@ -150,3 +139,48 @@ def compress_messages(
|
||||
compressed_messages.append(message)
|
||||
|
||||
return compressed_messages
|
||||
|
||||
|
||||
class LLMUsageRecorder:
|
||||
"""
|
||||
LLM使用情况记录器
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
try:
|
||||
# 使用 Peewee 创建表,safe=True 表示如果表已存在则不会抛出错误
|
||||
db.create_tables([LLMUsage], safe=True)
|
||||
# logger.debug("LLMUsage 表已初始化/确保存在。")
|
||||
except Exception as e:
|
||||
logger.error(f"创建 LLMUsage 表失败: {str(e)}")
|
||||
|
||||
def record_usage_to_database(
|
||||
self, model_info: ModelInfo, model_usage: UsageRecord, user_id: str, request_type: str, endpoint: str
|
||||
):
|
||||
input_cost = (model_usage.prompt_tokens / 1000000) * model_info.price_in
|
||||
output_cost = (model_usage.completion_tokens / 1000000) * model_info.price_out
|
||||
total_cost = round(input_cost + output_cost, 6)
|
||||
try:
|
||||
# 使用 Peewee 模型创建记录
|
||||
LLMUsage.create(
|
||||
model_name=model_info.model_identifier,
|
||||
user_id=user_id,
|
||||
request_type=request_type,
|
||||
endpoint=endpoint,
|
||||
prompt_tokens=model_usage.prompt_tokens or 0,
|
||||
completion_tokens=model_usage.completion_tokens or 0,
|
||||
total_tokens=model_usage.total_tokens or 0,
|
||||
cost=total_cost or 0.0,
|
||||
status="success",
|
||||
timestamp=datetime.now(), # Peewee 会处理 DateTimeField
|
||||
)
|
||||
logger.debug(
|
||||
f"Token使用情况 - 模型: {model_usage.model_name}, "
|
||||
f"用户: {user_id}, 类型: {request_type}, "
|
||||
f"提示词: {model_usage.prompt_tokens}, 完成: {model_usage.completion_tokens}, "
|
||||
f"总计: {model_usage.total_tokens}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"记录token使用情况失败: {str(e)}")
|
||||
|
||||
llm_usage_recorder = LLMUsageRecorder()
|
||||
@@ -1,34 +1,20 @@
|
||||
import re
|
||||
import copy
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from typing import Tuple, Union, List, Dict, Optional, Callable, Any
|
||||
from src.common.logger import get_logger
|
||||
import base64
|
||||
from PIL import Image
|
||||
from enum import Enum
|
||||
import io
|
||||
from src.common.database.database import db # 确保 db 被导入用于 create_tables
|
||||
from src.common.database.database_model import LLMUsage # 导入 LLMUsage 模型
|
||||
from src.config.config import global_config, model_config
|
||||
from src.config.api_ada_configs import APIProvider, ModelInfo
|
||||
from rich.traceback import install
|
||||
|
||||
from enum import Enum
|
||||
from rich.traceback import install
|
||||
from typing import Tuple, List, Dict, Optional, Callable, Any
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import model_config
|
||||
from src.config.api_ada_configs import APIProvider, ModelInfo, TaskConfig
|
||||
from .payload_content.message import MessageBuilder, Message
|
||||
from .payload_content.resp_format import RespFormat
|
||||
from .payload_content.tool_option import ToolOption, ToolCall
|
||||
from .model_client.base_client import BaseClient, APIResponse, UsageRecord, client_registry
|
||||
from .utils import compress_messages
|
||||
|
||||
from .exceptions import (
|
||||
NetworkConnectionError,
|
||||
ReqAbortException,
|
||||
RespNotOkException,
|
||||
RespParseException,
|
||||
PayLoadTooLargeError,
|
||||
RequestAbortException,
|
||||
PermissionDeniedException,
|
||||
)
|
||||
from .model_client.base_client import BaseClient, APIResponse, client_registry
|
||||
from .utils import compress_messages, llm_usage_recorder
|
||||
from .exceptions import NetworkConnectionError, ReqAbortException, RespNotOkException, RespParseException
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
@@ -57,27 +43,9 @@ class RequestType(Enum):
|
||||
class LLMRequest:
|
||||
"""LLM请求类"""
|
||||
|
||||
# 定义需要转换的模型列表,作为类变量避免重复
|
||||
MODELS_NEEDING_TRANSFORMATION = [
|
||||
"o1",
|
||||
"o1-2024-12-17",
|
||||
"o1-mini",
|
||||
"o1-mini-2024-09-12",
|
||||
"o1-preview",
|
||||
"o1-preview-2024-09-12",
|
||||
"o1-pro",
|
||||
"o1-pro-2025-03-19",
|
||||
"o3",
|
||||
"o3-2025-04-16",
|
||||
"o3-mini",
|
||||
"o3-mini-2025-01-31",
|
||||
"o4-mini",
|
||||
"o4-mini-2025-04-16",
|
||||
]
|
||||
|
||||
def __init__(self, task_name: str, request_type: str = "") -> None:
|
||||
self.task_name = task_name
|
||||
self.model_for_task = model_config.model_task_config.get_task(task_name)
|
||||
def __init__(self, model_set: TaskConfig, request_type: str = "") -> None:
|
||||
self.task_name = request_type
|
||||
self.model_for_task = model_set
|
||||
self.request_type = request_type
|
||||
self.model_usage: Dict[str, Tuple[int, int]] = {model: (0, 0) for model in self.model_for_task.model_list}
|
||||
"""模型使用量记录,用于进行负载均衡,对应为(total_tokens, penalty),惩罚值是为了能在某个模型请求不给力的时候进行调整"""
|
||||
@@ -85,18 +53,6 @@ class LLMRequest:
|
||||
self.pri_in = 0
|
||||
self.pri_out = 0
|
||||
|
||||
self._init_database()
|
||||
|
||||
@staticmethod
|
||||
def _init_database():
|
||||
"""初始化数据库集合"""
|
||||
try:
|
||||
# 使用 Peewee 创建表,safe=True 表示如果表已存在则不会抛出错误
|
||||
db.create_tables([LLMUsage], safe=True)
|
||||
# logger.debug("LLMUsage 表已初始化/确保存在。")
|
||||
except Exception as e:
|
||||
logger.error(f"创建 LLMUsage 表失败: {str(e)}")
|
||||
|
||||
async def generate_response_for_image(
|
||||
self,
|
||||
prompt: str,
|
||||
@@ -104,7 +60,7 @@ class LLMRequest:
|
||||
image_format: str,
|
||||
temperature: Optional[float] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
) -> Tuple[str, str, Optional[List[Dict[str, Any]]]]:
|
||||
) -> Tuple[str, Tuple[str, str, Optional[List[Dict[str, Any]]]]]:
|
||||
"""
|
||||
为图像生成响应
|
||||
Args:
|
||||
@@ -112,7 +68,7 @@ class LLMRequest:
|
||||
image_base64 (str): 图像的Base64编码字符串
|
||||
image_format (str): 图像格式(如 'png', 'jpeg' 等)
|
||||
Returns:
|
||||
|
||||
(Tuple[str, str, str, Optional[List[Dict[str, Any]]]]): 响应内容、推理内容、模型名称、工具调用列表
|
||||
"""
|
||||
# 请求体构建
|
||||
message_builder = MessageBuilder()
|
||||
@@ -141,25 +97,25 @@ class LLMRequest:
|
||||
content, extracted_reasoning = self._extract_reasoning(content)
|
||||
reasoning_content = extracted_reasoning
|
||||
if usage := response.usage:
|
||||
self.pri_in = model_info.price_in
|
||||
self.pri_out = model_info.price_out
|
||||
self._record_usage(
|
||||
model_name=model_info.name,
|
||||
prompt_tokens=usage.prompt_tokens or 0,
|
||||
completion_tokens=usage.completion_tokens,
|
||||
total_tokens=usage.total_tokens or 0,
|
||||
llm_usage_recorder.record_usage_to_database(
|
||||
model_info=model_info,
|
||||
model_usage=usage,
|
||||
user_id="system",
|
||||
request_type=self.request_type,
|
||||
endpoint="/chat/completions",
|
||||
)
|
||||
return content, reasoning_content, self._convert_tool_calls(tool_calls) if tool_calls else None
|
||||
return content, (
|
||||
reasoning_content,
|
||||
model_info.name,
|
||||
self._convert_tool_calls(tool_calls) if tool_calls else None,
|
||||
)
|
||||
|
||||
async def generate_response_for_voice(self):
|
||||
pass
|
||||
|
||||
async def generate_response_async(
|
||||
self, prompt: str, temperature: Optional[float] = None, max_tokens: Optional[int] = None
|
||||
) -> Tuple[str, str, Optional[List[Dict[str, Any]]]]:
|
||||
) -> Tuple[str, Tuple[str, str, Optional[List[Dict[str, Any]]]]]:
|
||||
"""
|
||||
异步生成响应
|
||||
Args:
|
||||
@@ -167,7 +123,7 @@ class LLMRequest:
|
||||
temperature (float, optional): 温度参数
|
||||
max_tokens (int, optional): 最大token数
|
||||
Returns:
|
||||
Tuple[str, str, Optional[List[Dict[str, Any]]]]: 响应内容、推理内容和工具调用列表
|
||||
(Tuple[str, str, str, Optional[List[Dict[str, Any]]]]): 响应内容、推理内容、模型名称、工具调用列表
|
||||
"""
|
||||
# 请求体构建
|
||||
message_builder = MessageBuilder()
|
||||
@@ -195,13 +151,9 @@ class LLMRequest:
|
||||
content, extracted_reasoning = self._extract_reasoning(content)
|
||||
reasoning_content = extracted_reasoning
|
||||
if usage := response.usage:
|
||||
self.pri_in = model_info.price_in
|
||||
self.pri_out = model_info.price_out
|
||||
self._record_usage(
|
||||
model_name=model_info.name,
|
||||
prompt_tokens=usage.prompt_tokens or 0,
|
||||
completion_tokens=usage.completion_tokens,
|
||||
total_tokens=usage.total_tokens or 0,
|
||||
llm_usage_recorder.record_usage_to_database(
|
||||
model_info=model_info,
|
||||
model_usage=usage,
|
||||
user_id="system",
|
||||
request_type=self.request_type,
|
||||
endpoint="/chat/completions",
|
||||
@@ -209,10 +161,19 @@ class LLMRequest:
|
||||
if not content:
|
||||
raise RuntimeError("获取LLM生成内容失败")
|
||||
|
||||
return content, reasoning_content, self._convert_tool_calls(tool_calls) if tool_calls else None
|
||||
return content, (
|
||||
reasoning_content,
|
||||
model_info.name,
|
||||
self._convert_tool_calls(tool_calls) if tool_calls else None,
|
||||
)
|
||||
|
||||
async def get_embedding(self, embedding_input: str) -> List[float]:
|
||||
"""获取嵌入向量"""
|
||||
async def get_embedding(self, embedding_input: str) -> Tuple[List[float], str]:
|
||||
"""获取嵌入向量
|
||||
Args:
|
||||
embedding_input (str): 获取嵌入的目标
|
||||
Returns:
|
||||
(Tuple[List[float], str]): (嵌入向量,使用的模型名称)
|
||||
"""
|
||||
# 无需构建消息体,直接使用输入文本
|
||||
model_info, api_provider, client = self._select_model()
|
||||
|
||||
@@ -227,14 +188,10 @@ class LLMRequest:
|
||||
|
||||
embedding = response.embedding
|
||||
|
||||
if response.usage:
|
||||
self.pri_in = model_info.price_in
|
||||
self.pri_out = model_info.price_out
|
||||
self._record_usage(
|
||||
model_name=model_info.name,
|
||||
prompt_tokens=response.usage.prompt_tokens or 0,
|
||||
completion_tokens=response.usage.completion_tokens,
|
||||
total_tokens=response.usage.total_tokens or 0,
|
||||
if usage := response.usage:
|
||||
llm_usage_recorder.record_usage_to_database(
|
||||
model_info=model_info,
|
||||
model_usage=usage,
|
||||
user_id="system",
|
||||
request_type=self.request_type,
|
||||
endpoint="/embeddings",
|
||||
@@ -243,7 +200,7 @@ class LLMRequest:
|
||||
if not embedding:
|
||||
raise RuntimeError("获取embedding失败")
|
||||
|
||||
return embedding
|
||||
return embedding, model_info.name
|
||||
|
||||
def _select_model(self) -> Tuple[ModelInfo, APIProvider, BaseClient]:
|
||||
"""
|
||||
@@ -305,12 +262,13 @@ class LLMRequest:
|
||||
# 处理异常
|
||||
total_tokens, penalty = self.model_usage[model_info.name]
|
||||
self.model_usage[model_info.name] = (total_tokens, penalty + 1)
|
||||
|
||||
wait_interval, compressed_messages = self._default_exception_handler(
|
||||
e,
|
||||
self.task_name,
|
||||
model_name=model_info.name,
|
||||
remain_try=retry_remain,
|
||||
messages=(message_list, compressed_messages is not None),
|
||||
messages=(message_list, compressed_messages is not None) if message_list else None,
|
||||
)
|
||||
|
||||
if wait_interval == -1:
|
||||
@@ -321,9 +279,7 @@ class LLMRequest:
|
||||
finally:
|
||||
# 放在finally防止死循环
|
||||
retry_remain -= 1
|
||||
logger.error(
|
||||
f"任务 '{self.task_name}' 模型 '{model_info.name}' 请求失败,达到最大重试次数 {api_provider.max_retry} 次"
|
||||
)
|
||||
logger.error(f"模型 '{model_info.name}' 请求失败,达到最大重试次数 {api_provider.max_retry} 次")
|
||||
raise RuntimeError("请求失败,已达到最大重试次数")
|
||||
|
||||
def _default_exception_handler(
|
||||
@@ -481,65 +437,3 @@ class LLMRequest:
|
||||
content = re.sub(r"(?:<think>)?.*?</think>", "", content, flags=re.DOTALL, count=1).strip()
|
||||
reasoning = match[1].strip() if match else ""
|
||||
return content, reasoning
|
||||
|
||||
def _record_usage(
|
||||
self,
|
||||
model_name: str,
|
||||
prompt_tokens: int,
|
||||
completion_tokens: int,
|
||||
total_tokens: int,
|
||||
user_id: str = "system",
|
||||
request_type: str | None = None,
|
||||
endpoint: str = "/chat/completions",
|
||||
):
|
||||
"""记录模型使用情况到数据库
|
||||
Args:
|
||||
prompt_tokens: 输入token数
|
||||
completion_tokens: 输出token数
|
||||
total_tokens: 总token数
|
||||
user_id: 用户ID,默认为system
|
||||
request_type: 请求类型
|
||||
endpoint: API端点
|
||||
"""
|
||||
# 如果 request_type 为 None,则使用实例变量中的值
|
||||
if request_type is None:
|
||||
request_type = self.request_type
|
||||
|
||||
try:
|
||||
# 使用 Peewee 模型创建记录
|
||||
LLMUsage.create(
|
||||
model_name=model_name,
|
||||
user_id=user_id,
|
||||
request_type=request_type,
|
||||
endpoint=endpoint,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
cost=self._calculate_cost(prompt_tokens, completion_tokens),
|
||||
status="success",
|
||||
timestamp=datetime.now(), # Peewee 会处理 DateTimeField
|
||||
)
|
||||
logger.debug(
|
||||
f"Token使用情况 - 模型: {model_name}, "
|
||||
f"用户: {user_id}, 类型: {request_type}, "
|
||||
f"提示词: {prompt_tokens}, 完成: {completion_tokens}, "
|
||||
f"总计: {total_tokens}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"记录token使用情况失败: {str(e)}")
|
||||
|
||||
def _calculate_cost(self, prompt_tokens: int, completion_tokens: int) -> float:
|
||||
"""计算API调用成本
|
||||
使用模型的pri_in和pri_out价格计算输入和输出的成本
|
||||
|
||||
Args:
|
||||
prompt_tokens: 输入token数量
|
||||
completion_tokens: 输出token数量
|
||||
|
||||
Returns:
|
||||
float: 总成本(元)
|
||||
"""
|
||||
# 使用模型的pri_in和pri_out计算成本
|
||||
input_cost = (prompt_tokens / 1000000) * self.pri_in
|
||||
output_cost = (completion_tokens / 1000000) * self.pri_out
|
||||
return round(input_cost + output_cost, 6)
|
||||
|
||||
@@ -2,13 +2,15 @@ from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
import time
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config
|
||||
from src.config.config import model_config
|
||||
from src.chat.message_receive.message import MessageRecvS4U
|
||||
from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor
|
||||
from src.mais4u.mais4u_chat.internal_manager import internal_manager
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def init_prompt():
|
||||
Prompt(
|
||||
"""
|
||||
@@ -32,10 +34,8 @@ def init_prompt():
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
class MaiThinking:
|
||||
def __init__(self,chat_id):
|
||||
def __init__(self, chat_id):
|
||||
self.chat_id = chat_id
|
||||
self.chat_stream = get_chat_manager().get_stream(chat_id)
|
||||
self.platform = self.chat_stream.platform
|
||||
@@ -60,16 +60,12 @@ class MaiThinking:
|
||||
self.sender = ""
|
||||
self.target = ""
|
||||
|
||||
self.thinking_model = LLMRequest(
|
||||
model=global_config.model.replyer_1,
|
||||
request_type="thinking",
|
||||
)
|
||||
self.thinking_model = LLMRequest(model_set=model_config.model_task_config.replyer_1, request_type="thinking")
|
||||
|
||||
async def do_think_before_response(self):
|
||||
pass
|
||||
|
||||
async def do_think_after_response(self,reponse:str):
|
||||
|
||||
async def do_think_after_response(self, reponse: str):
|
||||
prompt = await global_prompt_manager.format_prompt(
|
||||
"after_response_think_prompt",
|
||||
mind=self.mind,
|
||||
@@ -93,17 +89,14 @@ class MaiThinking:
|
||||
# logger.info(f"[{self.chat_id}] 思考前prompt:{prompt}")
|
||||
logger.info(f"[{self.chat_id}] 思考后想法:{self.mind}")
|
||||
|
||||
|
||||
msg_recv = await self.build_internal_message_recv(self.mind)
|
||||
await self.s4u_message_processor.process_message(msg_recv)
|
||||
internal_manager.set_internal_state(self.mind)
|
||||
|
||||
|
||||
async def do_think_when_receive_message(self):
|
||||
pass
|
||||
|
||||
async def build_internal_message_recv(self,message_text:str):
|
||||
|
||||
async def build_internal_message_recv(self, message_text: str):
|
||||
msg_id = f"internal_{time.time()}"
|
||||
|
||||
message_dict = {
|
||||
@@ -155,13 +148,11 @@ class MaiThinking:
|
||||
return msg_recv
|
||||
|
||||
|
||||
|
||||
|
||||
class MaiThinkingManager:
|
||||
def __init__(self):
|
||||
self.mai_think_list = []
|
||||
|
||||
def get_mai_think(self,chat_id):
|
||||
def get_mai_think(self, chat_id):
|
||||
for mai_think in self.mai_think_list:
|
||||
if mai_think.chat_id == chat_id:
|
||||
return mai_think
|
||||
@@ -169,15 +160,8 @@ class MaiThinkingManager:
|
||||
self.mai_think_list.append(mai_think)
|
||||
return mai_think
|
||||
|
||||
|
||||
mai_thinking_manager = MaiThinkingManager()
|
||||
|
||||
|
||||
init_prompt()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -1,14 +1,16 @@
|
||||
import json
|
||||
import time
|
||||
|
||||
from json_repair import repair_json
|
||||
from src.chat.message_receive.message import MessageRecv
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive
|
||||
from src.config.config import global_config
|
||||
from src.config.config import global_config, model_config
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.manager.async_task_manager import AsyncTask, async_task_manager
|
||||
from src.plugin_system.apis import send_api
|
||||
from json_repair import repair_json
|
||||
|
||||
from src.mais4u.s4u_config import s4u_config
|
||||
|
||||
logger = get_logger("action")
|
||||
@@ -32,7 +34,7 @@ BODY_CODE = {
|
||||
"帅气的姿势": "010_0190",
|
||||
"另一个帅气的姿势": "010_0191",
|
||||
"手掌朝前可爱": "010_0210",
|
||||
"平静,双手后放":"平静,双手后放",
|
||||
"平静,双手后放": "平静,双手后放",
|
||||
"思考": "思考",
|
||||
"优雅,左手放在腰上": "优雅,左手放在腰上",
|
||||
"一般": "一般",
|
||||
@@ -94,15 +96,11 @@ class ChatAction:
|
||||
self.body_action_cooldown: dict[str, int] = {}
|
||||
|
||||
print(s4u_config.models.motion)
|
||||
print(global_config.model.emotion)
|
||||
print(model_config.model_task_config.emotion)
|
||||
|
||||
self.action_model = LLMRequest(
|
||||
model=global_config.model.emotion,
|
||||
temperature=0.7,
|
||||
request_type="motion",
|
||||
)
|
||||
self.action_model = LLMRequest(model_set=model_config.model_task_config.emotion, request_type="motion")
|
||||
|
||||
self.last_change_time = 0
|
||||
self.last_change_time: float = 0
|
||||
|
||||
async def send_action_update(self):
|
||||
"""发送动作更新到前端"""
|
||||
@@ -116,12 +114,10 @@ class ChatAction:
|
||||
show_log=True,
|
||||
)
|
||||
|
||||
|
||||
|
||||
async def update_action_by_message(self, message: MessageRecv):
|
||||
self.regression_count = 0
|
||||
|
||||
message_time = message.message_info.time
|
||||
message_time: float = message.message_info.time # type: ignore
|
||||
message_list_before_now = get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||
chat_id=self.chat_id,
|
||||
timestamp_start=self.last_change_time,
|
||||
@@ -163,18 +159,17 @@ class ChatAction:
|
||||
)
|
||||
|
||||
logger.info(f"prompt: {prompt}")
|
||||
response, (reasoning_content, model_name) = await self.action_model.generate_response_async(prompt=prompt)
|
||||
response, (reasoning_content, _, _) = await self.action_model.generate_response_async(
|
||||
prompt=prompt, temperature=0.7
|
||||
)
|
||||
logger.info(f"response: {response}")
|
||||
logger.info(f"reasoning_content: {reasoning_content}")
|
||||
|
||||
action_data = json.loads(repair_json(response))
|
||||
|
||||
if action_data:
|
||||
if action_data := json.loads(repair_json(response)):
|
||||
# 记录原动作,切换后进入冷却
|
||||
prev_body_action = self.body_action
|
||||
new_body_action = action_data.get("body_action", self.body_action)
|
||||
if new_body_action != prev_body_action:
|
||||
if prev_body_action:
|
||||
if new_body_action != prev_body_action and prev_body_action:
|
||||
self.body_action_cooldown[prev_body_action] = 3
|
||||
self.body_action = new_body_action
|
||||
self.head_action = action_data.get("head_action", self.head_action)
|
||||
@@ -213,7 +208,6 @@ class ChatAction:
|
||||
prompt_personality = global_config.personality.personality_core
|
||||
indentify_block = f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}:"
|
||||
try:
|
||||
|
||||
# 冷却池处理:过滤掉冷却中的动作
|
||||
self._update_body_action_cooldown()
|
||||
available_actions = [k for k in BODY_CODE.keys() if k not in self.body_action_cooldown]
|
||||
@@ -228,16 +222,16 @@ class ChatAction:
|
||||
)
|
||||
|
||||
logger.info(f"prompt: {prompt}")
|
||||
response, (reasoning_content, model_name) = await self.action_model.generate_response_async(prompt=prompt)
|
||||
response, (reasoning_content, _, _) = await self.action_model.generate_response_async(
|
||||
prompt=prompt, temperature=0.7
|
||||
)
|
||||
logger.info(f"response: {response}")
|
||||
logger.info(f"reasoning_content: {reasoning_content}")
|
||||
|
||||
action_data = json.loads(repair_json(response))
|
||||
if action_data:
|
||||
if action_data := json.loads(repair_json(response)):
|
||||
prev_body_action = self.body_action
|
||||
new_body_action = action_data.get("body_action", self.body_action)
|
||||
if new_body_action != prev_body_action:
|
||||
if prev_body_action:
|
||||
if new_body_action != prev_body_action and prev_body_action:
|
||||
self.body_action_cooldown[prev_body_action] = 6
|
||||
self.body_action = new_body_action
|
||||
# 发送动作更新
|
||||
@@ -306,9 +300,6 @@ class ActionManager:
|
||||
return new_action_state
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
init_prompt()
|
||||
|
||||
action_manager = ActionManager()
|
||||
|
||||
@@ -137,7 +137,7 @@ class MessageSenderContainer:
|
||||
await self.storage.store_message(bot_message, self.chat_stream)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.chat_stream.get_stream_name()}] 消息发送或存储时出现错误: {e}", exc_info=True)
|
||||
logger.error(f"[消息流: {self.chat_stream.stream_id}] 消息发送或存储时出现错误: {e}", exc_info=True)
|
||||
|
||||
finally:
|
||||
# CRUCIAL: Always call task_done() for any item that was successfully retrieved.
|
||||
|
||||
@@ -6,7 +6,7 @@ from src.chat.message_receive.message import MessageRecv
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive
|
||||
from src.config.config import global_config
|
||||
from src.config.config import global_config, model_config
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.manager.async_task_manager import AsyncTask, async_task_manager
|
||||
from src.plugin_system.apis import send_api
|
||||
@@ -114,18 +114,12 @@ class ChatMood:
|
||||
|
||||
self.regression_count: int = 0
|
||||
|
||||
self.mood_model = LLMRequest(
|
||||
model=global_config.model.emotion,
|
||||
temperature=0.7,
|
||||
request_type="mood_text",
|
||||
)
|
||||
self.mood_model = LLMRequest(model_set=model_config.model_task_config.emotion, request_type="mood_text")
|
||||
self.mood_model_numerical = LLMRequest(
|
||||
model=global_config.model.emotion,
|
||||
temperature=0.4,
|
||||
request_type="mood_numerical",
|
||||
model_set=model_config.model_task_config.emotion, request_type="mood_numerical"
|
||||
)
|
||||
|
||||
self.last_change_time = 0
|
||||
self.last_change_time: float = 0
|
||||
|
||||
# 发送初始情绪状态到ws端
|
||||
asyncio.create_task(self.send_emotion_update(self.mood_values))
|
||||
@@ -164,7 +158,7 @@ class ChatMood:
|
||||
async def update_mood_by_message(self, message: MessageRecv):
|
||||
self.regression_count = 0
|
||||
|
||||
message_time = message.message_info.time
|
||||
message_time: float = message.message_info.time # type: ignore
|
||||
message_list_before_now = get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||
chat_id=self.chat_id,
|
||||
timestamp_start=self.last_change_time,
|
||||
@@ -199,7 +193,9 @@ class ChatMood:
|
||||
mood_state=self.mood_state,
|
||||
)
|
||||
logger.debug(f"text mood prompt: {prompt}")
|
||||
response, (reasoning_content, model_name) = await self.mood_model.generate_response_async(prompt=prompt)
|
||||
response, (reasoning_content, _, _) = await self.mood_model.generate_response_async(
|
||||
prompt=prompt, temperature=0.7
|
||||
)
|
||||
logger.info(f"text mood response: {response}")
|
||||
logger.debug(f"text mood reasoning_content: {reasoning_content}")
|
||||
return response
|
||||
@@ -216,8 +212,8 @@ class ChatMood:
|
||||
fear=self.mood_values["fear"],
|
||||
)
|
||||
logger.debug(f"numerical mood prompt: {prompt}")
|
||||
response, (reasoning_content, model_name) = await self.mood_model_numerical.generate_response_async(
|
||||
prompt=prompt
|
||||
response, (reasoning_content, _, _) = await self.mood_model_numerical.generate_response_async(
|
||||
prompt=prompt, temperature=0.4
|
||||
)
|
||||
logger.info(f"numerical mood response: {response}")
|
||||
logger.debug(f"numerical mood reasoning_content: {reasoning_content}")
|
||||
@@ -276,7 +272,9 @@ class ChatMood:
|
||||
mood_state=self.mood_state,
|
||||
)
|
||||
logger.debug(f"text regress prompt: {prompt}")
|
||||
response, (reasoning_content, model_name) = await self.mood_model.generate_response_async(prompt=prompt)
|
||||
response, (reasoning_content, _, _) = await self.mood_model.generate_response_async(
|
||||
prompt=prompt, temperature=0.7
|
||||
)
|
||||
logger.info(f"text regress response: {response}")
|
||||
logger.debug(f"text regress reasoning_content: {reasoning_content}")
|
||||
return response
|
||||
@@ -293,8 +291,9 @@ class ChatMood:
|
||||
fear=self.mood_values["fear"],
|
||||
)
|
||||
logger.debug(f"numerical regress prompt: {prompt}")
|
||||
response, (reasoning_content, model_name) = await self.mood_model_numerical.generate_response_async(
|
||||
prompt=prompt
|
||||
response, (reasoning_content, _, _) = await self.mood_model_numerical.generate_response_async(
|
||||
prompt=prompt,
|
||||
temperature=0.4,
|
||||
)
|
||||
logger.info(f"numerical regress response: {response}")
|
||||
logger.debug(f"numerical regress reasoning_content: {reasoning_content}")
|
||||
@@ -447,6 +446,7 @@ class MoodManager:
|
||||
# 发送初始情绪状态到ws端
|
||||
asyncio.create_task(new_mood.send_emotion_update(new_mood.mood_values))
|
||||
|
||||
|
||||
if ENABLE_S4U:
|
||||
init_prompt()
|
||||
mood_manager = MoodManager()
|
||||
|
||||
@@ -161,8 +161,7 @@ class PromptBuilder:
|
||||
relation_info_list = await asyncio.gather(
|
||||
*[relationship_fetcher.build_relation_info(person_id, points_num=3) for person_id in person_ids]
|
||||
)
|
||||
relation_info = "".join(relation_info_list)
|
||||
if relation_info:
|
||||
if relation_info := "".join(relation_info_list):
|
||||
relation_prompt = await global_prompt_manager.format_prompt(
|
||||
"relation_prompt", relation_info=relation_info
|
||||
)
|
||||
@@ -188,7 +187,7 @@ class PromptBuilder:
|
||||
)
|
||||
|
||||
|
||||
talk_type = message.message_info.platform + ":" + str(message.chat_stream.user_info.user_id)
|
||||
talk_type = f"{message.message_info.platform}:{str(message.chat_stream.user_info.user_id)}"
|
||||
|
||||
core_dialogue_list = []
|
||||
background_dialogue_list = []
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import os
|
||||
from typing import AsyncGenerator
|
||||
from src.mais4u.openai_client import AsyncOpenAIClient
|
||||
from src.config.config import global_config
|
||||
from src.config.config import global_config, model_config
|
||||
from src.chat.message_receive.message import MessageRecvS4U
|
||||
from src.mais4u.mais4u_chat.s4u_prompt import prompt_builder
|
||||
from src.common.logger import get_logger
|
||||
@@ -14,24 +14,27 @@ logger = get_logger("s4u_stream_generator")
|
||||
|
||||
class S4UStreamGenerator:
|
||||
def __init__(self):
|
||||
replyer_1_config = global_config.model.replyer_1
|
||||
provider = replyer_1_config.get("provider")
|
||||
if not provider:
|
||||
logger.error("`replyer_1` 在配置文件中缺少 `provider` 字段")
|
||||
raise ValueError("`replyer_1` 在配置文件中缺少 `provider` 字段")
|
||||
replyer_1_config = model_config.model_task_config.replyer_1
|
||||
model_to_use = replyer_1_config.model_list[0]
|
||||
model_info = model_config.get_model_info(model_to_use)
|
||||
if not model_info:
|
||||
logger.error(f"模型 {model_to_use} 在配置中未找到")
|
||||
raise ValueError(f"模型 {model_to_use} 在配置中未找到")
|
||||
provider_name = model_info.api_provider
|
||||
provider_info = model_config.get_provider(provider_name)
|
||||
if not provider_info:
|
||||
logger.error("`replyer_1` 找不到对应的Provider")
|
||||
raise ValueError("`replyer_1` 找不到对应的Provider")
|
||||
|
||||
api_key = os.environ.get(f"{provider.upper()}_KEY")
|
||||
base_url = os.environ.get(f"{provider.upper()}_BASE_URL")
|
||||
api_key = provider_info.api_key
|
||||
base_url = provider_info.base_url
|
||||
|
||||
if not api_key:
|
||||
logger.error(f"环境变量 {provider.upper()}_KEY 未设置")
|
||||
raise ValueError(f"环境变量 {provider.upper()}_KEY 未设置")
|
||||
logger.error(f"{provider_name}没有配置API KEY")
|
||||
raise ValueError(f"{provider_name}没有配置API KEY")
|
||||
|
||||
self.client_1 = AsyncOpenAIClient(api_key=api_key, base_url=base_url)
|
||||
self.model_1_name = replyer_1_config.get("name")
|
||||
if not self.model_1_name:
|
||||
logger.error("`replyer_1` 在配置文件中缺少 `model_name` 字段")
|
||||
raise ValueError("`replyer_1` 在配置文件中缺少 `model_name` 字段")
|
||||
self.model_1_name = model_to_use
|
||||
self.replyer_1_config = replyer_1_config
|
||||
|
||||
self.current_model_name = "unknown model"
|
||||
@@ -45,9 +48,9 @@ class S4UStreamGenerator:
|
||||
re.UNICODE | re.DOTALL,
|
||||
)
|
||||
|
||||
self.chat_stream =None
|
||||
self.chat_stream = None
|
||||
|
||||
async def build_last_internal_message(self,message:MessageRecvS4U,previous_reply_context:str = ""):
|
||||
async def build_last_internal_message(self, message: MessageRecvS4U, previous_reply_context: str = ""):
|
||||
# person_id = PersonInfoManager.get_person_id(
|
||||
# message.chat_stream.user_info.platform, message.chat_stream.user_info.user_id
|
||||
# )
|
||||
@@ -71,14 +74,10 @@ class S4UStreamGenerator:
|
||||
[这是用户发来的新消息, 你需要结合上下文,对此进行回复]:
|
||||
{message.processed_plain_text}
|
||||
"""
|
||||
return True,message_txt
|
||||
return True, message_txt
|
||||
else:
|
||||
message_txt = message.processed_plain_text
|
||||
return False,message_txt
|
||||
|
||||
|
||||
|
||||
|
||||
return False, message_txt
|
||||
|
||||
async def generate_response(
|
||||
self, message: MessageRecvS4U, previous_reply_context: str = ""
|
||||
@@ -88,7 +87,7 @@ class S4UStreamGenerator:
|
||||
self.partial_response = ""
|
||||
message_txt = message.processed_plain_text
|
||||
if not message.is_internal:
|
||||
interupted,message_txt_added = await self.build_last_internal_message(message,previous_reply_context)
|
||||
interupted, message_txt_added = await self.build_last_internal_message(message, previous_reply_context)
|
||||
if interupted:
|
||||
message_txt = message_txt_added
|
||||
|
||||
@@ -105,7 +104,6 @@ class S4UStreamGenerator:
|
||||
current_client = self.client_1
|
||||
self.current_model_name = self.model_1_name
|
||||
|
||||
|
||||
extra_kwargs = {}
|
||||
if self.replyer_1_config.get("enable_thinking") is not None:
|
||||
extra_kwargs["enable_thinking"] = self.replyer_1_config.get("enable_thinking")
|
||||
|
||||
@@ -221,9 +221,7 @@ class SuperChatManager:
|
||||
# 限制显示数量
|
||||
display_superchats = superchats[:max_count]
|
||||
|
||||
lines = []
|
||||
lines.append("📢 当前有效超级弹幕:")
|
||||
|
||||
lines = ["📢 当前有效超级弹幕:"]
|
||||
for i, sc in enumerate(display_superchats, 1):
|
||||
remaining_minutes = int(sc.remaining_time() / 60)
|
||||
remaining_seconds = int(sc.remaining_time() % 60)
|
||||
@@ -232,7 +230,7 @@ class SuperChatManager:
|
||||
|
||||
line = f"{i}. 【{sc.price}元】{sc.user_nickname}: {sc.message_text}"
|
||||
if len(line) > 100: # 限制单行长度
|
||||
line = line[:97] + "..."
|
||||
line = f"{line[:97]}..."
|
||||
line += f" (剩余{time_display})"
|
||||
lines.append(line)
|
||||
|
||||
@@ -251,7 +249,7 @@ class SuperChatManager:
|
||||
for sc in superchats:
|
||||
single_sc_str = f"{sc.user_nickname} - {sc.price}元 - {sc.message_text}"
|
||||
if len(single_sc_str) > 100:
|
||||
single_sc_str = single_sc_str[:97] + "..."
|
||||
single_sc_str = f"{single_sc_str[:97]}..."
|
||||
single_sc_str += f" (剩余{int(sc.remaining_time())}秒)"
|
||||
lines.append(single_sc_str)
|
||||
|
||||
@@ -287,7 +285,7 @@ class SuperChatManager:
|
||||
"lowest_amount": min(amounts)
|
||||
}
|
||||
|
||||
async def shutdown(self):
|
||||
async def shutdown(self): # sourcery skip: use-contextlib-suppress
|
||||
"""关闭管理器,清理资源"""
|
||||
if self._cleanup_task and not self._cleanup_task.done():
|
||||
self._cleanup_task.cancel()
|
||||
@@ -300,6 +298,7 @@ class SuperChatManager:
|
||||
|
||||
|
||||
|
||||
# sourcery skip: assign-if-exp
|
||||
if ENABLE_S4U:
|
||||
super_chat_manager = SuperChatManager()
|
||||
else:
|
||||
|
||||
@@ -1,19 +1,14 @@
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.config.config import model_config
|
||||
from src.plugin_system.apis import send_api
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
head_actions_list = [
|
||||
"不做额外动作",
|
||||
"点头一次",
|
||||
"点头两次",
|
||||
"摇头",
|
||||
"歪脑袋",
|
||||
"低头望向一边"
|
||||
]
|
||||
head_actions_list = ["不做额外动作", "点头一次", "点头两次", "摇头", "歪脑袋", "低头望向一边"]
|
||||
|
||||
async def yes_or_no_head(text: str,emotion: str = "",chat_history: str = "",chat_id: str = ""):
|
||||
|
||||
async def yes_or_no_head(text: str, emotion: str = "", chat_history: str = "", chat_id: str = ""):
|
||||
prompt = f"""
|
||||
{chat_history}
|
||||
以上是对方的发言:
|
||||
@@ -30,22 +25,14 @@ async def yes_or_no_head(text: str,emotion: str = "",chat_history: str = "",chat
|
||||
低头望向一边
|
||||
|
||||
请从上面的动作中选择一个,并输出,请只输出你选择的动作就好,不要输出其他内容。"""
|
||||
model = LLMRequest(
|
||||
model=global_config.model.emotion,
|
||||
temperature=0.7,
|
||||
request_type="motion",
|
||||
)
|
||||
model = LLMRequest(model_set=model_config.model_task_config.emotion, request_type="motion")
|
||||
|
||||
try:
|
||||
# logger.info(f"prompt: {prompt}")
|
||||
response, (reasoning_content, model_name) = await model.generate_response_async(prompt=prompt)
|
||||
response, _ = await model.generate_response_async(prompt=prompt, temperature=0.7)
|
||||
logger.info(f"response: {response}")
|
||||
|
||||
if response in head_actions_list:
|
||||
head_action = response
|
||||
else:
|
||||
head_action = "不做额外动作"
|
||||
|
||||
head_action = response if response in head_actions_list else "不做额外动作"
|
||||
await send_api.custom_to_stream(
|
||||
message_type="head_action",
|
||||
content=head_action,
|
||||
@@ -54,10 +41,6 @@ async def yes_or_no_head(text: str,emotion: str = "",chat_history: str = "",chat
|
||||
show_log=True,
|
||||
)
|
||||
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"yes_or_no_head error: {e}")
|
||||
return "不做额外动作"
|
||||
|
||||
|
||||
|
||||
@@ -3,13 +3,14 @@ import random
|
||||
import time
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.config.config import global_config, model_config
|
||||
from src.chat.message_receive.message import MessageRecv
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.manager.async_task_manager import AsyncTask, async_task_manager
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
|
||||
|
||||
logger = get_logger("mood")
|
||||
|
||||
@@ -59,11 +60,7 @@ class ChatMood:
|
||||
|
||||
self.regression_count: int = 0
|
||||
|
||||
self.mood_model = LLMRequest(
|
||||
model=global_config.model.emotion,
|
||||
temperature=0.7,
|
||||
request_type="mood",
|
||||
)
|
||||
self.mood_model = LLMRequest(model_set=model_config.model_task_config.emotion, request_type="mood")
|
||||
|
||||
self.last_change_time: float = 0
|
||||
|
||||
@@ -83,12 +80,16 @@ class ChatMood:
|
||||
logger.debug(
|
||||
f"base_probability: {base_probability}, time_multiplier: {time_multiplier}, interest_multiplier: {interest_multiplier}"
|
||||
)
|
||||
update_probability = global_config.mood.mood_update_threshold * min(1.0, base_probability * time_multiplier * interest_multiplier)
|
||||
update_probability = global_config.mood.mood_update_threshold * min(
|
||||
1.0, base_probability * time_multiplier * interest_multiplier
|
||||
)
|
||||
|
||||
if random.random() > update_probability:
|
||||
return
|
||||
|
||||
logger.debug(f"{self.log_prefix} 更新情绪状态,感兴趣度: {interested_rate:.2f}, 更新概率: {update_probability:.2f}")
|
||||
logger.debug(
|
||||
f"{self.log_prefix} 更新情绪状态,感兴趣度: {interested_rate:.2f}, 更新概率: {update_probability:.2f}"
|
||||
)
|
||||
|
||||
message_time: float = message.message_info.time # type: ignore
|
||||
message_list_before_now = get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||
@@ -124,7 +125,9 @@ class ChatMood:
|
||||
mood_state=self.mood_state,
|
||||
)
|
||||
|
||||
response, (reasoning_content, model_name) = await self.mood_model.generate_response_async(prompt=prompt)
|
||||
response, (reasoning_content, _, _) = await self.mood_model.generate_response_async(
|
||||
prompt=prompt, temperature=0.7
|
||||
)
|
||||
if global_config.debug.show_prompt:
|
||||
logger.info(f"{self.log_prefix} prompt: {prompt}")
|
||||
logger.info(f"{self.log_prefix} response: {response}")
|
||||
@@ -171,7 +174,9 @@ class ChatMood:
|
||||
mood_state=self.mood_state,
|
||||
)
|
||||
|
||||
response, (reasoning_content, model_name) = await self.mood_model.generate_response_async(prompt=prompt)
|
||||
response, (reasoning_content, _, _) = await self.mood_model.generate_response_async(
|
||||
prompt=prompt, temperature=0.7
|
||||
)
|
||||
|
||||
if global_config.debug.show_prompt:
|
||||
logger.info(f"{self.log_prefix} prompt: {prompt}")
|
||||
|
||||
@@ -11,7 +11,7 @@ from src.common.logger import get_logger
|
||||
from src.common.database.database import db
|
||||
from src.common.database.database_model import PersonInfo
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config
|
||||
from src.config.config import global_config, model_config
|
||||
|
||||
|
||||
"""
|
||||
@@ -54,11 +54,7 @@ person_info_default = {
|
||||
class PersonInfoManager:
|
||||
def __init__(self):
|
||||
self.person_name_list = {}
|
||||
# TODO: API-Adapter修改标记
|
||||
self.qv_name_llm = LLMRequest(
|
||||
model=global_config.model.utils,
|
||||
request_type="relation.qv_name",
|
||||
)
|
||||
self.qv_name_llm = LLMRequest(model_set=model_config.model_task_config.utils, request_type="relation.qv_name")
|
||||
try:
|
||||
db.connect(reuse_if_open=True)
|
||||
# 设置连接池参数
|
||||
@@ -376,7 +372,7 @@ class PersonInfoManager:
|
||||
"nickname": "昵称",
|
||||
"reason": "理由"
|
||||
}"""
|
||||
response, (reasoning_content, model_name) = await self.qv_name_llm.generate_response_async(qv_name_prompt)
|
||||
response, _ = await self.qv_name_llm.generate_response_async(qv_name_prompt)
|
||||
# logger.info(f"取名提示词:{qv_name_prompt}\n取名回复:{response}")
|
||||
result = self._extract_json_from_text(response)
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ from typing import List, Dict, Any
|
||||
from json_repair import repair_json
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
@@ -73,14 +73,12 @@ class RelationshipFetcher:
|
||||
|
||||
# LLM模型配置
|
||||
self.llm_model = LLMRequest(
|
||||
model=global_config.model.utils_small,
|
||||
request_type="relation.fetcher",
|
||||
model_set=model_config.model_task_config.utils_small, request_type="relation.fetcher"
|
||||
)
|
||||
|
||||
# 小模型用于即时信息提取
|
||||
self.instant_llm_model = LLMRequest(
|
||||
model=global_config.model.utils_small,
|
||||
request_type="relation.fetch",
|
||||
model_set=model_config.model_task_config.utils_small, request_type="relation.fetch"
|
||||
)
|
||||
|
||||
name = get_chat_manager().get_stream_name(self.chat_id)
|
||||
@@ -96,7 +94,7 @@ class RelationshipFetcher:
|
||||
if not self.info_fetched_cache[person_id]:
|
||||
del self.info_fetched_cache[person_id]
|
||||
|
||||
async def build_relation_info(self, person_id, points_num = 3):
|
||||
async def build_relation_info(self, person_id, points_num=3):
|
||||
# 清理过期的信息缓存
|
||||
self._cleanup_expired_cache()
|
||||
|
||||
@@ -361,7 +359,6 @@ class RelationshipFetcher:
|
||||
logger.error(f"{self.log_prefix} 执行信息提取时出错: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
|
||||
async def _save_info_to_cache(self, person_id: str, info_type: str, info_content: str):
|
||||
# sourcery skip: use-next
|
||||
"""将提取到的信息保存到 person_info 的 info_list 字段中
|
||||
|
||||
@@ -3,7 +3,7 @@ from .person_info import PersonInfoManager, get_person_info_manager
|
||||
import time
|
||||
import random
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config
|
||||
from src.config.config import global_config, model_config
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages
|
||||
import json
|
||||
from json_repair import repair_json
|
||||
@@ -20,9 +20,8 @@ logger = get_logger("relation")
|
||||
class RelationshipManager:
|
||||
def __init__(self):
|
||||
self.relationship_llm = LLMRequest(
|
||||
model=global_config.model.utils,
|
||||
request_type="relationship", # 用于动作规划
|
||||
)
|
||||
model_set=model_config.model_task_config.utils, request_type="relationship"
|
||||
) # 用于动作规划
|
||||
|
||||
@staticmethod
|
||||
async def is_known_some_one(platform, user_id):
|
||||
@@ -188,10 +187,6 @@ class RelationshipManager:
|
||||
elif isinstance(points_data, str) and points_data.lower() == "none":
|
||||
points_list = []
|
||||
elif isinstance(points_data, list):
|
||||
# 正确格式:数组格式 [{"point": "...", "weight": 10}, ...]
|
||||
if not points_data: # 空数组
|
||||
points_list = []
|
||||
else:
|
||||
points_list = [(item["point"], float(item["weight"]), current_time) for item in points_data]
|
||||
else:
|
||||
# 错误格式,直接跳过不解析
|
||||
|
||||
@@ -12,6 +12,7 @@ import traceback
|
||||
from typing import Tuple, Any, Dict, List, Optional
|
||||
from rich.traceback import install
|
||||
from src.common.logger import get_logger
|
||||
from src.config.api_ada_configs import TaskConfig
|
||||
from src.chat.replyer.default_generator import DefaultReplyer
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
from src.chat.utils.utils import process_llm_response
|
||||
@@ -31,7 +32,7 @@ logger = get_logger("generator_api")
|
||||
def get_replyer(
|
||||
chat_stream: Optional[ChatStream] = None,
|
||||
chat_id: Optional[str] = None,
|
||||
model_configs: Optional[List[Dict[str, Any]]] = None,
|
||||
model_set_with_weight: Optional[List[Tuple[TaskConfig, float]]] = None,
|
||||
request_type: str = "replyer",
|
||||
) -> Optional[DefaultReplyer]:
|
||||
"""获取回复器对象
|
||||
@@ -58,7 +59,7 @@ def get_replyer(
|
||||
return replyer_manager.get_replyer(
|
||||
chat_stream=chat_stream,
|
||||
chat_id=chat_id,
|
||||
model_configs=model_configs,
|
||||
model_set_with_weight=model_set_with_weight,
|
||||
request_type=request_type,
|
||||
)
|
||||
except Exception as e:
|
||||
@@ -83,7 +84,7 @@ async def generate_reply(
|
||||
enable_splitter: bool = True,
|
||||
enable_chinese_typo: bool = True,
|
||||
return_prompt: bool = False,
|
||||
model_configs: Optional[List[Dict[str, Any]]] = None,
|
||||
model_set_with_weight: Optional[List[Tuple[TaskConfig, float]]] = None,
|
||||
request_type: str = "generator_api",
|
||||
) -> Tuple[bool, List[Tuple[str, Any]], Optional[str]]:
|
||||
"""生成回复
|
||||
@@ -106,7 +107,7 @@ async def generate_reply(
|
||||
"""
|
||||
try:
|
||||
# 获取回复器
|
||||
replyer = get_replyer(chat_stream, chat_id, model_configs=model_configs, request_type=request_type)
|
||||
replyer = get_replyer(chat_stream, chat_id, model_set_with_weight=model_set_with_weight, request_type=request_type)
|
||||
if not replyer:
|
||||
logger.error("[GeneratorAPI] 无法获取回复器")
|
||||
return False, [], None
|
||||
@@ -154,7 +155,7 @@ async def rewrite_reply(
|
||||
chat_id: Optional[str] = None,
|
||||
enable_splitter: bool = True,
|
||||
enable_chinese_typo: bool = True,
|
||||
model_configs: Optional[List[Dict[str, Any]]] = None,
|
||||
model_set_with_weight: Optional[List[Tuple[TaskConfig, float]]] = None,
|
||||
raw_reply: str = "",
|
||||
reason: str = "",
|
||||
reply_to: str = "",
|
||||
@@ -179,7 +180,7 @@ async def rewrite_reply(
|
||||
"""
|
||||
try:
|
||||
# 获取回复器
|
||||
replyer = get_replyer(chat_stream, chat_id, model_configs=model_configs)
|
||||
replyer = get_replyer(chat_stream, chat_id, model_set_with_weight=model_set_with_weight)
|
||||
if not replyer:
|
||||
logger.error("[GeneratorAPI] 无法获取回复器")
|
||||
return False, [], None
|
||||
@@ -245,17 +246,17 @@ async def process_human_text(content: str, enable_splitter: bool, enable_chinese
|
||||
async def generate_response_custom(
|
||||
chat_stream: Optional[ChatStream] = None,
|
||||
chat_id: Optional[str] = None,
|
||||
model_configs: Optional[List[Dict[str, Any]]] = None,
|
||||
model_set_with_weight: Optional[List[Tuple[TaskConfig, float]]] = None,
|
||||
prompt: str = "",
|
||||
) -> Optional[str]:
|
||||
replyer = get_replyer(chat_stream, chat_id, model_configs=model_configs)
|
||||
replyer = get_replyer(chat_stream, chat_id, model_set_with_weight=model_set_with_weight)
|
||||
if not replyer:
|
||||
logger.error("[GeneratorAPI] 无法获取回复器")
|
||||
return None
|
||||
|
||||
try:
|
||||
logger.debug("[GeneratorAPI] 开始生成自定义回复")
|
||||
response = await replyer.llm_generate_content(prompt)
|
||||
response, _, _, _ = await replyer.llm_generate_content(prompt)
|
||||
if response:
|
||||
logger.debug("[GeneratorAPI] 自定义回复生成成功")
|
||||
return response
|
||||
|
||||
@@ -7,10 +7,11 @@
|
||||
success, response, reasoning, model_name = await llm_api.generate_with_model(prompt, model_config)
|
||||
"""
|
||||
|
||||
from typing import Tuple, Dict, Any
|
||||
from typing import Tuple, Dict
|
||||
from src.common.logger import get_logger
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config
|
||||
from src.config.config import global_config, model_config
|
||||
from src.config.api_ada_configs import TaskConfig
|
||||
|
||||
logger = get_logger("llm_api")
|
||||
|
||||
@@ -19,9 +20,7 @@ logger = get_logger("llm_api")
|
||||
# =============================================================================
|
||||
|
||||
|
||||
|
||||
|
||||
def get_available_models() -> Dict[str, Any]:
|
||||
def get_available_models() -> Dict[str, TaskConfig]:
|
||||
"""获取所有可用的模型配置
|
||||
|
||||
Returns:
|
||||
@@ -33,14 +32,14 @@ def get_available_models() -> Dict[str, Any]:
|
||||
return {}
|
||||
|
||||
# 自动获取所有属性并转换为字典形式
|
||||
rets = {}
|
||||
models = global_config.model
|
||||
models = model_config.model_task_config
|
||||
attrs = dir(models)
|
||||
rets: Dict[str, TaskConfig] = {}
|
||||
for attr in attrs:
|
||||
if not attr.startswith("__"):
|
||||
try:
|
||||
value = getattr(models, attr)
|
||||
if not callable(value): # 排除方法
|
||||
if not callable(value) and isinstance(value, TaskConfig):
|
||||
rets[attr] = value
|
||||
except Exception as e:
|
||||
logger.debug(f"[LLMAPI] 获取属性 {attr} 失败: {e}")
|
||||
@@ -53,8 +52,8 @@ def get_available_models() -> Dict[str, Any]:
|
||||
|
||||
|
||||
async def generate_with_model(
|
||||
prompt: str, model_config: Dict[str, Any], request_type: str = "plugin.generate", **kwargs
|
||||
) -> Tuple[bool, str]:
|
||||
prompt: str, model_config: TaskConfig, request_type: str = "plugin.generate", **kwargs
|
||||
) -> Tuple[bool, str, str, str]:
|
||||
"""使用指定模型生成内容
|
||||
|
||||
Args:
|
||||
@@ -67,17 +66,16 @@ async def generate_with_model(
|
||||
Tuple[bool, str, str, str]: (是否成功, 生成的内容, 推理过程, 模型名称)
|
||||
"""
|
||||
try:
|
||||
model_name = model_config.get("name")
|
||||
logger.info(f"[LLMAPI] 使用模型 {model_name} 生成内容")
|
||||
model_name_list = model_config.model_list
|
||||
logger.info(f"[LLMAPI] 使用模型集合 {model_name_list} 生成内容")
|
||||
logger.debug(f"[LLMAPI] 完整提示词: {prompt}")
|
||||
|
||||
llm_request = LLMRequest(model=model_config, request_type=request_type, **kwargs)
|
||||
llm_request = LLMRequest(model_set=model_config, request_type=request_type, **kwargs)
|
||||
|
||||
# TODO: 复活这个_
|
||||
response, _ = await llm_request.generate_response_async(prompt)
|
||||
return True, response
|
||||
response, (reasoning_content, model_name, _) = await llm_request.generate_response_async(prompt)
|
||||
return True, response, reasoning_content, model_name
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"生成内容时出错: {str(e)}"
|
||||
logger.error(f"[LLMAPI] {error_msg}")
|
||||
return False, error_msg
|
||||
return False, error_msg, "", ""
|
||||
|
||||
@@ -335,7 +335,7 @@ async def command_to_stream(
|
||||
|
||||
async def custom_to_stream(
|
||||
message_type: str,
|
||||
content: str,
|
||||
content: str | dict,
|
||||
stream_id: str,
|
||||
display_message: str = "",
|
||||
typing: bool = False,
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import List, Dict, Tuple, Optional, Any
|
||||
from src.plugin_system.apis.tool_api import get_llm_available_tool_definitions, get_tool_instance
|
||||
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config
|
||||
from src.config.config import global_config, model_config
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.chat.utils.json_utils import process_llm_tool_calls
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
@@ -52,10 +52,7 @@ class ToolExecutor:
|
||||
self.chat_stream = get_chat_manager().get_stream(self.chat_id)
|
||||
self.log_prefix = f"[{get_chat_manager().get_stream_name(self.chat_id) or self.chat_id}]"
|
||||
|
||||
self.llm_model = LLMRequest(
|
||||
model=global_config.model.tool_use,
|
||||
request_type="tool_executor",
|
||||
)
|
||||
self.llm_model = LLMRequest(model_set=model_config.model_task_config.tool_use, request_type="tool_executor")
|
||||
|
||||
# 缓存配置
|
||||
self.enable_cache = enable_cache
|
||||
|
||||
@@ -58,6 +58,7 @@ class EmojiAction(BaseAction):
|
||||
associated_types = ["emoji"]
|
||||
|
||||
async def execute(self) -> Tuple[bool, str]:
|
||||
# sourcery skip: assign-if-exp, introduce-default-else, swap-if-else-branches, use-named-expression
|
||||
"""执行表情动作"""
|
||||
logger.info(f"{self.log_prefix} 决定发送表情")
|
||||
|
||||
@@ -120,7 +121,7 @@ class EmojiAction(BaseAction):
|
||||
logger.error(f"{self.log_prefix} 未找到'utils_small'模型配置,无法调用LLM")
|
||||
return False, "未找到'utils_small'模型配置"
|
||||
|
||||
success, chosen_emotion = await llm_api.generate_with_model(
|
||||
success, chosen_emotion, _, _ = await llm_api.generate_with_model(
|
||||
prompt, model_config=chat_model_config, request_type="emoji"
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user