调整对应的调用

This commit is contained in:
UnCLAS-Prommer
2025-07-30 17:07:55 +08:00
parent 3c40ceda4c
commit 6c0edd0ad7
40 changed files with 580 additions and 1236 deletions

View File

@@ -8,15 +8,15 @@ import traceback
import io
import re
import binascii
from typing import Optional, Tuple, List, Any
from PIL import Image
from rich.traceback import install
from src.common.database.database_model import Emoji
from src.common.database.database import db as peewee_db
from src.common.logger import get_logger
from src.config.config import global_config
from src.config.config import global_config, model_config
from src.chat.utils.utils_image import image_path_to_base64, get_image_manager
from src.llm_models.utils_model import LLMRequest
@@ -379,9 +379,9 @@ class EmojiManager:
self._scan_task = None
self.vlm = LLMRequest(model=global_config.model.vlm, temperature=0.3, max_tokens=1000, request_type="emoji")
self.vlm = LLMRequest(model_set=model_config.model_task_config.vlm, request_type="emoji")
self.llm_emotion_judge = LLMRequest(
model=global_config.model.utils, max_tokens=600, request_type="emoji"
model_set=model_config.model_task_config.utils, request_type="emoji"
) # 更高的温度更少的token后续可以根据情绪来调整温度
self.emoji_num = 0
@@ -492,6 +492,7 @@ class EmojiManager:
return None
def _levenshtein_distance(self, s1: str, s2: str) -> int:
# sourcery skip: simplify-empty-collection-comparison, simplify-len-comparison, simplify-str-len-comparison
"""计算两个字符串的编辑距离
Args:
@@ -629,11 +630,11 @@ class EmojiManager:
if success:
# 注册成功则跳出循环
break
else:
# 注册失败则删除对应文件
file_path = os.path.join(EMOJI_DIR, filename)
os.remove(file_path)
logger.warning(f"[清理] 删除注册失败的表情包文件: {filename}")
# 注册失败则删除对应文件
file_path = os.path.join(EMOJI_DIR, filename)
os.remove(file_path)
logger.warning(f"[清理] 删除注册失败的表情包文件: {filename}")
except Exception as e:
logger.error(f"[错误] 扫描表情包目录失败: {str(e)}")
@@ -694,6 +695,7 @@ class EmojiManager:
return []
async def get_emoji_from_manager(self, emoji_hash: str) -> Optional["MaiEmoji"]:
# sourcery skip: use-next
"""从内存中的 emoji_objects 列表获取表情包
参数:
@@ -709,10 +711,10 @@ class EmojiManager:
async def get_emoji_description_by_hash(self, emoji_hash: str) -> Optional[str]:
"""根据哈希值获取已注册表情包的描述
Args:
emoji_hash: 表情包的哈希值
Returns:
Optional[str]: 表情包描述如果未找到则返回None
"""
@@ -722,7 +724,7 @@ class EmojiManager:
if emoji and emoji.description:
logger.info(f"[缓存命中] 从内存获取表情包描述: {emoji.description[:50]}...")
return emoji.description
# 如果内存中没有,从数据库查找
self._ensure_db()
try:
@@ -732,9 +734,9 @@ class EmojiManager:
return emoji_record.description
except Exception as e:
logger.error(f"从数据库查询表情包描述时出错: {e}")
return None
except Exception as e:
logger.error(f"获取表情包描述失败 (Hash: {emoji_hash}): {str(e)}")
return None
@@ -779,6 +781,7 @@ class EmojiManager:
return False
async def replace_a_emoji(self, new_emoji: "MaiEmoji") -> bool:
# sourcery skip: use-getitem-for-re-match-groups
"""替换一个表情包
Args:
@@ -820,7 +823,7 @@ class EmojiManager:
)
# 调用大模型进行决策
decision, _ = await self.llm_emotion_judge.generate_response_async(prompt, temperature=0.8)
decision, _ = await self.llm_emotion_judge.generate_response_async(prompt, temperature=0.8, max_tokens=600)
logger.info(f"[决策] 结果: {decision}")
# 解析决策结果
@@ -828,9 +831,7 @@ class EmojiManager:
logger.info("[决策] 不删除任何表情包")
return False
# 尝试从决策中提取表情包编号
match = re.search(r"删除编号(\d+)", decision)
if match:
if match := re.search(r"删除编号(\d+)", decision):
emoji_index = int(match.group(1)) - 1 # 转换为0-based索引
# 检查索引是否有效
@@ -889,6 +890,7 @@ class EmojiManager:
existing_description = None
try:
from src.common.database.database_model import Images
existing_image = Images.get_or_none((Images.emoji_hash == image_hash) & (Images.type == "emoji"))
if existing_image and existing_image.description:
existing_description = existing_image.description
@@ -902,15 +904,21 @@ class EmojiManager:
logger.info("[优化] 复用已有的详细描述跳过VLM调用")
else:
logger.info("[VLM分析] 生成新的详细描述")
if image_format == "gif" or image_format == "GIF":
if image_format in ["gif", "GIF"]:
image_base64 = get_image_manager().transform_gif(image_base64) # type: ignore
if not image_base64:
raise RuntimeError("GIF表情包转换失败")
prompt = "这是一个动态图表情包,每一张图代表了动态图的某一帧,黑色背景代表透明,描述一下表情包表达的情感和内容,描述细节,从互联网梗,meme的角度去分析"
description, _ = await self.vlm.generate_response_for_image(prompt, image_base64, "jpg")
description, _ = await self.vlm.generate_response_for_image(
prompt, image_base64, "jpg", temperature=0.3, max_tokens=1000
)
else:
prompt = "这是一个表情包,请详细描述一下表情包所表达的情感和内容,描述细节,从互联网梗,meme的角度去分析"
description, _ = await self.vlm.generate_response_for_image(prompt, image_base64, image_format)
prompt = (
"这是一个表情包,请详细描述一下表情包所表达的情感和内容,描述细节,从互联网梗,meme的角度去分析"
)
description, _ = await self.vlm.generate_response_for_image(
prompt, image_base64, image_format, temperature=0.3, max_tokens=1000
)
# 审核表情包
if global_config.emoji.content_filtration:
@@ -922,7 +930,9 @@ class EmojiManager:
4. 不要出现5个以上文字
请回答这个表情包是否满足上述要求,是则回答是,否则回答否,不要出现任何其他内容
'''
content, _ = await self.vlm.generate_response_for_image(prompt, image_base64, image_format)
content, _ = await self.vlm.generate_response_for_image(
prompt, image_base64, image_format, temperature=0.3, max_tokens=1000
)
if content == "":
return "", []
@@ -933,7 +943,9 @@ class EmojiManager:
你可以关注其幽默和讽刺意味,动用贴吧,微博,小红书的知识,必须从互联网梗,meme的角度去分析
请直接输出描述,不要出现任何其他内容,如果有多个描述,可以用逗号分隔
"""
emotions_text, _ = await self.llm_emotion_judge.generate_response_async(emotion_prompt, temperature=0.7)
emotions_text, _ = await self.llm_emotion_judge.generate_response_async(
emotion_prompt, temperature=0.7, max_tokens=600
)
# 处理情感列表
emotions = [e.strip() for e in emotions_text.split(",") if e.strip()]

View File

@@ -7,12 +7,12 @@ from datetime import datetime
from typing import List, Dict, Optional, Any, Tuple
from src.common.logger import get_logger
from src.common.database.database_model import Expression
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config
from src.config.config import model_config
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_random, build_anonymous_messages
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.message_receive.chat_stream import get_chat_manager
from src.common.database.database_model import Expression
MAX_EXPRESSION_COUNT = 300
@@ -80,11 +80,8 @@ def init_prompt() -> None:
class ExpressionLearner:
def __init__(self) -> None:
# TODO: API-Adapter修改标记
self.express_learn_model: LLMRequest = LLMRequest(
model=global_config.model.replyer_1,
temperature=0.3,
request_type="expressor.learner",
model_set=model_config.model_task_config.replyer_1, request_type="expressor.learner"
)
self.llm_model = None
self._ensure_expression_directories()
@@ -101,7 +98,7 @@ class ExpressionLearner:
os.path.join(base_dir, "learnt_style"),
os.path.join(base_dir, "learnt_grammar"),
]
for directory in directories_to_create:
try:
os.makedirs(directory, exist_ok=True)
@@ -116,7 +113,7 @@ class ExpressionLearner:
"""
base_dir = os.path.join("data", "expression")
done_flag = os.path.join(base_dir, "done.done")
# 确保基础目录存在
try:
os.makedirs(base_dir, exist_ok=True)
@@ -124,28 +121,28 @@ class ExpressionLearner:
except Exception as e:
logger.error(f"创建表达方式目录失败: {e}")
return
if os.path.exists(done_flag):
logger.info("表达方式JSON已迁移无需重复迁移。")
return
logger.info("开始迁移表达方式JSON到数据库...")
migrated_count = 0
for type in ["learnt_style", "learnt_grammar"]:
type_str = "style" if type == "learnt_style" else "grammar"
type_dir = os.path.join(base_dir, type)
if not os.path.exists(type_dir):
logger.debug(f"目录不存在,跳过: {type_dir}")
continue
try:
chat_ids = os.listdir(type_dir)
logger.debug(f"{type_dir} 中找到 {len(chat_ids)} 个聊天ID目录")
except Exception as e:
logger.error(f"读取目录失败 {type_dir}: {e}")
continue
for chat_id in chat_ids:
expr_file = os.path.join(type_dir, chat_id, "expressions.json")
if not os.path.exists(expr_file):
@@ -153,24 +150,24 @@ class ExpressionLearner:
try:
with open(expr_file, "r", encoding="utf-8") as f:
expressions = json.load(f)
if not isinstance(expressions, list):
logger.warning(f"表达方式文件格式错误,跳过: {expr_file}")
continue
for expr in expressions:
if not isinstance(expr, dict):
continue
situation = expr.get("situation")
style_val = expr.get("style")
count = expr.get("count", 1)
last_active_time = expr.get("last_active_time", time.time())
if not situation or not style_val:
logger.warning(f"表达方式缺少必要字段,跳过: {expr}")
continue
# 查重同chat_id+type+situation+style
from src.common.database.database_model import Expression
@@ -201,7 +198,7 @@ class ExpressionLearner:
logger.error(f"JSON解析失败 {expr_file}: {e}")
except Exception as e:
logger.error(f"迁移表达方式 {expr_file} 失败: {e}")
# 标记迁移完成
try:
# 确保done.done文件的父目录存在
@@ -209,7 +206,7 @@ class ExpressionLearner:
if not os.path.exists(done_parent_dir):
os.makedirs(done_parent_dir, exist_ok=True)
logger.debug(f"为done.done创建父目录: {done_parent_dir}")
with open(done_flag, "w", encoding="utf-8") as f:
f.write("done\n")
logger.info(f"表达方式JSON迁移已完成共迁移 {migrated_count} 个表达方式已写入done.done标记文件")
@@ -229,13 +226,13 @@ class ExpressionLearner:
# 查找所有create_date为空的表达方式
old_expressions = Expression.select().where(Expression.create_date.is_null())
updated_count = 0
for expr in old_expressions:
# 使用last_active_time作为create_date
expr.create_date = expr.last_active_time
expr.save()
updated_count += 1
if updated_count > 0:
logger.info(f"已为 {updated_count} 个老的表达方式设置创建日期")
except Exception as e:
@@ -287,25 +284,29 @@ class ExpressionLearner:
获取指定chat_id的表达方式创建信息按创建日期排序
"""
try:
expressions = (Expression.select()
.where(Expression.chat_id == chat_id)
.order_by(Expression.create_date.desc())
.limit(limit))
expressions = (
Expression.select()
.where(Expression.chat_id == chat_id)
.order_by(Expression.create_date.desc())
.limit(limit)
)
result = []
for expr in expressions:
create_date = expr.create_date if expr.create_date is not None else expr.last_active_time
result.append({
"situation": expr.situation,
"style": expr.style,
"type": expr.type,
"count": expr.count,
"create_date": create_date,
"create_date_formatted": format_create_date(create_date),
"last_active_time": expr.last_active_time,
"last_active_formatted": format_create_date(expr.last_active_time),
})
result.append(
{
"situation": expr.situation,
"style": expr.style,
"type": expr.type,
"count": expr.count,
"create_date": create_date,
"create_date_formatted": format_create_date(create_date),
"last_active_time": expr.last_active_time,
"last_active_formatted": format_create_date(expr.last_active_time),
}
)
return result
except Exception as e:
logger.error(f"获取表达方式创建信息失败: {e}")
@@ -355,19 +356,19 @@ class ExpressionLearner:
try:
# 获取所有表达方式
all_expressions = Expression.select()
updated_count = 0
deleted_count = 0
for expr in all_expressions:
# 计算时间差
last_active = expr.last_active_time
time_diff_days = (current_time - last_active) / (24 * 3600) # 转换为天
# 计算衰减值
decay_value = self.calculate_decay_factor(time_diff_days)
new_count = max(0.01, expr.count - decay_value)
if new_count <= 0.01:
# 如果count太小删除这个表达方式
expr.delete_instance()
@@ -377,10 +378,10 @@ class ExpressionLearner:
expr.count = new_count
expr.save()
updated_count += 1
if updated_count > 0 or deleted_count > 0:
logger.info(f"全局衰减完成:更新了 {updated_count} 个表达方式,删除了 {deleted_count} 个表达方式")
except Exception as e:
logger.error(f"数据库全局衰减失败: {e}")
@@ -527,7 +528,7 @@ class ExpressionLearner:
logger.debug(f"学习{type_str}的prompt: {prompt}")
try:
response, _ = await self.express_learn_model.generate_response_async(prompt)
response, _ = await self.express_learn_model.generate_response_async(prompt, temperature=0.3)
except Exception as e:
logger.error(f"学习{type_str}失败: {e}")
return None

View File

@@ -1,16 +1,17 @@
import json
import time
import random
import hashlib
from typing import List, Dict, Tuple, Optional, Any
from json_repair import repair_json
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config
from src.config.config import global_config, model_config
from src.common.logger import get_logger
from src.common.database.database_model import Expression
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from .expression_learner import get_expression_learner
from src.common.database.database_model import Expression
logger = get_logger("expression_selector")
@@ -75,10 +76,8 @@ def weighted_sample(population: List[Dict], weights: List[float], k: int) -> Lis
class ExpressionSelector:
def __init__(self):
self.expression_learner = get_expression_learner()
# TODO: API-Adapter修改标记
self.llm_model = LLMRequest(
model=global_config.model.utils_small,
request_type="expression.selector",
model_set=model_config.model_task_config.utils_small, request_type="expression.selector"
)
@staticmethod
@@ -92,7 +91,6 @@ class ExpressionSelector:
id_str = parts[1]
stream_type = parts[2]
is_group = stream_type == "group"
import hashlib
if is_group:
components = [platform, str(id_str)]
else:
@@ -108,8 +106,7 @@ class ExpressionSelector:
for group in groups:
group_chat_ids = []
for stream_config_str in group:
chat_id_candidate = self._parse_stream_config_to_chat_id(stream_config_str)
if chat_id_candidate:
if chat_id_candidate := self._parse_stream_config_to_chat_id(stream_config_str):
group_chat_ids.append(chat_id_candidate)
if chat_id in group_chat_ids:
return group_chat_ids
@@ -118,9 +115,10 @@ class ExpressionSelector:
def get_random_expressions(
self, chat_id: str, total_num: int, style_percentage: float, grammar_percentage: float
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
# sourcery skip: extract-duplicate-method, move-assign
# 支持多chat_id合并抽选
related_chat_ids = self.get_related_chat_ids(chat_id)
# 优化一次性查询所有相关chat_id的表达方式
style_query = Expression.select().where(
(Expression.chat_id.in_(related_chat_ids)) & (Expression.type == "style")
@@ -128,7 +126,7 @@ class ExpressionSelector:
grammar_query = Expression.select().where(
(Expression.chat_id.in_(related_chat_ids)) & (Expression.type == "grammar")
)
style_exprs = [
{
"situation": expr.situation,
@@ -138,9 +136,10 @@ class ExpressionSelector:
"source_id": expr.chat_id,
"type": "style",
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
} for expr in style_query
}
for expr in style_query
]
grammar_exprs = [
{
"situation": expr.situation,
@@ -150,9 +149,10 @@ class ExpressionSelector:
"source_id": expr.chat_id,
"type": "grammar",
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
} for expr in grammar_query
}
for expr in grammar_query
]
style_num = int(total_num * style_percentage)
grammar_num = int(total_num * grammar_percentage)
# 按权重抽样使用count作为权重
@@ -174,22 +174,22 @@ class ExpressionSelector:
return
updates_by_key = {}
for expr in expressions_to_update:
source_id = expr.get("source_id")
expr_type = expr.get("type", "style")
situation = expr.get("situation")
style = expr.get("style")
source_id: str = expr.get("source_id") # type: ignore
expr_type: str = expr.get("type", "style")
situation: str = expr.get("situation") # type: ignore
style: str = expr.get("style") # type: ignore
if not source_id or not situation or not style:
logger.warning(f"表达方式缺少必要字段,无法更新: {expr}")
continue
key = (source_id, expr_type, situation, style)
if key not in updates_by_key:
updates_by_key[key] = expr
for (chat_id, expr_type, situation, style), _expr in updates_by_key.items():
for chat_id, expr_type, situation, style in updates_by_key:
query = Expression.select().where(
(Expression.chat_id == chat_id) &
(Expression.type == expr_type) &
(Expression.situation == situation) &
(Expression.style == style)
(Expression.chat_id == chat_id)
& (Expression.type == expr_type)
& (Expression.situation == situation)
& (Expression.style == style)
)
if query.exists():
expr_obj = query.get()
@@ -264,7 +264,7 @@ class ExpressionSelector:
# 4. 调用LLM
try:
content, (_, _) = await self.llm_model.generate_response_async(prompt=prompt)
content, _ = await self.llm_model.generate_response_async(prompt=prompt)
# logger.info(f"{self.log_prefix} LLM返回结果: {content}")

View File

@@ -5,25 +5,27 @@ import random
import time
import re
import json
from itertools import combinations
import jieba
import networkx as nx
import numpy as np
from itertools import combinations
from typing import List, Tuple, Coroutine, Any, Dict, Set
from collections import Counter
from ...llm_models.utils_model import LLMRequest
from rich.traceback import install
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config, model_config
from src.common.database.database_model import Messages, GraphNodes, GraphEdges # Peewee Models导入
from src.common.logger import get_logger
from src.chat.memory_system.sample_distribution import MemoryBuildScheduler # 分布生成器
from ..utils.chat_message_builder import (
from src.chat.utils.chat_message_builder import (
get_raw_msg_by_timestamp,
build_readable_messages,
get_raw_msg_by_timestamp_with_chat,
) # 导入 build_readable_messages
from ..utils.utils import translate_timestamp_to_human_readable
from rich.traceback import install
from src.chat.utils.utils import translate_timestamp_to_human_readable
from ...config.config import global_config
from src.common.database.database_model import Messages, GraphNodes, GraphEdges # Peewee Models导入
install(extra_lines=3)
@@ -198,8 +200,7 @@ class Hippocampus:
self.parahippocampal_gyrus = ParahippocampalGyrus(self)
# 从数据库加载记忆图
self.entorhinal_cortex.sync_memory_from_db()
# TODO: API-Adapter修改标记
self.model_summary = LLMRequest(global_config.model.memory, request_type="memory.builder")
self.model_summary = LLMRequest(model_set=model_config.model_task_config.memory, request_type="memory.builder")
def get_all_node_names(self) -> list:
"""获取记忆图中所有节点的名字列表"""
@@ -339,9 +340,7 @@ class Hippocampus:
else:
topic_num = 5 # 51+字符: 5个关键词 (其余长文本)
topics_response, (reasoning_content, model_name) = await self.model_summary.generate_response_async(
self.find_topic_llm(text, topic_num)
)
topics_response, _ = await self.model_summary.generate_response_async(self.find_topic_llm(text, topic_num))
# 提取关键词
keywords = re.findall(r"<([^>]+)>", topics_response)
@@ -353,12 +352,11 @@ class Hippocampus:
for keyword in ",".join(keywords).replace("", ",").replace("", ",").replace(" ", ",").split(",")
if keyword.strip()
]
if keywords:
logger.info(f"提取关键词: {keywords}")
return keywords
return keywords
async def get_memory_from_text(
self,
@@ -1245,7 +1243,7 @@ class ParahippocampalGyrus:
# 2. 使用LLM提取关键主题
topic_num = self.hippocampus.calculate_topic_num(input_text, compress_rate)
topics_response, (reasoning_content, model_name) = await self.hippocampus.model_summary.generate_response_async(
topics_response, _ = await self.hippocampus.model_summary.generate_response_async(
self.hippocampus.find_topic_llm(input_text, topic_num)
)
@@ -1269,7 +1267,7 @@ class ParahippocampalGyrus:
logger.debug(f"过滤后话题: {filtered_topics}")
# 4. 创建所有话题的摘要生成任务
tasks = []
tasks: List[Tuple[str, Coroutine[Any, Any, Tuple[str, Tuple[str, str, List[Dict[str, Any]] | None]]]]] = []
for topic in filtered_topics:
# 调用修改后的 topic_what不再需要 time_info
topic_what_prompt = self.hippocampus.topic_what(input_text, topic)
@@ -1281,7 +1279,7 @@ class ParahippocampalGyrus:
continue
# 等待所有任务完成
compressed_memory = set()
compressed_memory: Set[Tuple[str, str]] = set()
similar_topics_dict = {}
for topic, task in tasks:

View File

@@ -3,13 +3,16 @@ import time
import re
import json
import ast
from json_repair import repair_json
from src.llm_models.utils_model import LLMRequest
from src.common.logger import get_logger
import traceback
from src.config.config import global_config
from json_repair import repair_json
from datetime import datetime, timedelta
from src.llm_models.utils_model import LLMRequest
from src.common.logger import get_logger
from src.common.database.database_model import Memory # Peewee Models导入
from src.config.config import model_config
logger = get_logger(__name__)
@@ -35,8 +38,7 @@ class InstantMemory:
self.chat_id = chat_id
self.last_view_time = time.time()
self.summary_model = LLMRequest(
model=global_config.model.memory,
temperature=0.5,
model_set=model_config.model_task_config.memory,
request_type="memory.summary",
)
@@ -48,14 +50,11 @@ class InstantMemory:
"""
try:
response, _ = await self.summary_model.generate_response_async(prompt)
response, _ = await self.summary_model.generate_response_async(prompt, temperature=0.5)
print(prompt)
print(response)
if "1" in response:
return True
else:
return False
return "1" in response
except Exception as e:
logger.error(f"判断是否需要记忆出现错误:{str(e)} {traceback.format_exc()}")
return False
@@ -71,9 +70,9 @@ class InstantMemory:
}}
"""
try:
response, _ = await self.summary_model.generate_response_async(prompt)
print(prompt)
print(response)
response, _ = await self.summary_model.generate_response_async(prompt, temperature=0.5)
# print(prompt)
# print(response)
if not response:
return None
try:
@@ -142,7 +141,7 @@ class InstantMemory:
请只输出json格式不要输出其他多余内容
"""
try:
response, _ = await self.summary_model.generate_response_async(prompt)
response, _ = await self.summary_model.generate_response_async(prompt, temperature=0.5)
print(prompt)
print(response)
if not response:
@@ -177,7 +176,7 @@ class InstantMemory:
for mem in query:
# 对每条记忆
mem_keywords = mem.keywords or []
mem_keywords = mem.keywords or ""
parsed = ast.literal_eval(mem_keywords)
if isinstance(parsed, list):
mem_keywords = [str(k).strip() for k in parsed if str(k).strip()]
@@ -201,6 +200,7 @@ class InstantMemory:
return None
def _parse_time_range(self, time_str):
# sourcery skip: extract-duplicate-method, use-contextlib-suppress
"""
支持解析如下格式:
- 具体日期时间YYYY-MM-DD HH:MM:SS
@@ -208,8 +208,6 @@ class InstantMemory:
- 相对时间今天昨天前天N天前N个月前
- 空字符串:返回(None, None)
"""
from datetime import datetime, timedelta
now = datetime.now()
if not time_str:
return 0, now
@@ -239,14 +237,12 @@ class InstantMemory:
start = (now - timedelta(days=2)).replace(hour=0, minute=0, second=0, microsecond=0)
end = start + timedelta(days=1)
return start, end
m = re.match(r"(\d+)天前", time_str)
if m:
if m := re.match(r"(\d+)天前", time_str):
days = int(m.group(1))
start = (now - timedelta(days=days)).replace(hour=0, minute=0, second=0, microsecond=0)
end = start + timedelta(days=1)
return start, end
m = re.match(r"(\d+)个月前", time_str)
if m:
if m := re.match(r"(\d+)个月前", time_str):
months = int(m.group(1))
# 近似每月30天
start = (now - timedelta(days=months * 30)).replace(hour=0, minute=0, second=0, microsecond=0)

View File

@@ -1,13 +1,15 @@
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config
from src.common.logger import get_logger
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from datetime import datetime
from src.chat.memory_system.Hippocampus import hippocampus_manager
from typing import List, Dict
import difflib
import json
from json_repair import repair_json
from typing import List, Dict
from datetime import datetime
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config, model_config
from src.common.logger import get_logger
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.memory_system.Hippocampus import hippocampus_manager
logger = get_logger("memory_activator")
@@ -61,11 +63,8 @@ def init_prompt():
class MemoryActivator:
def __init__(self):
# TODO: API-Adapter修改标记
self.key_words_model = LLMRequest(
model=global_config.model.utils_small,
temperature=0.5,
model_set=model_config.model_task_config.utils_small,
request_type="memory.activator",
)
@@ -92,7 +91,9 @@ class MemoryActivator:
# logger.debug(f"prompt: {prompt}")
response, (reasoning_content, model_name) = await self.key_words_model.generate_response_async(prompt)
response, (reasoning_content, model_name, _) = await self.key_words_model.generate_response_async(
prompt, temperature=0.5
)
keywords = list(get_keywords_from_json(response))

View File

@@ -203,7 +203,7 @@ class MessageRecvS4U(MessageRecv):
self.is_superchat = False
self.gift_info = None
self.gift_name = None
self.gift_count = None
self.gift_count: Optional[str] = None
self.superchat_info = None
self.superchat_price = None
self.superchat_message_text = None

View File

@@ -1,9 +1,10 @@
from typing import Dict, Optional, Type
from src.plugin_system.base.base_action import BaseAction
from src.chat.message_receive.chat_stream import ChatStream
from src.common.logger import get_logger
from src.plugin_system.core.component_registry import component_registry
from src.plugin_system.base.component_types import ComponentType, ActionInfo
from src.plugin_system.base.base_action import BaseAction
logger = get_logger("action_manager")

View File

@@ -5,7 +5,7 @@ import time
from typing import List, Any, Dict, TYPE_CHECKING, Tuple
from src.common.logger import get_logger
from src.config.config import global_config
from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
from src.chat.message_receive.chat_stream import get_chat_manager, ChatMessageContext
from src.chat.planner_actions.action_manager import ActionManager
@@ -36,10 +36,7 @@ class ActionModifier:
self.action_manager = action_manager
# 用于LLM判定的小模型
self.llm_judge = LLMRequest(
model=global_config.model.utils_small,
request_type="action.judge",
)
self.llm_judge = LLMRequest(model_set=model_config.model_task_config.utils_small, request_type="action.judge")
# 缓存相关属性
self._llm_judge_cache = {} # 缓存LLM判定结果
@@ -438,4 +435,4 @@ class ActionModifier:
return True
else:
logger.debug(f"{self.log_prefix}动作 {action_name} 未匹配到任何关键词: {activation_keywords}")
return False
return False

View File

@@ -7,7 +7,7 @@ from datetime import datetime
from json_repair import repair_json
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config
from src.config.config import global_config, model_config
from src.common.logger import get_logger
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.utils.chat_message_builder import (
@@ -73,10 +73,7 @@ class ActionPlanner:
self.log_prefix = f"[{get_chat_manager().get_stream_name(chat_id) or chat_id}]"
self.action_manager = action_manager
# LLM规划器配置
self.planner_llm = LLMRequest(
model=global_config.model.planner,
request_type="planner", # 用于动作规划
)
self.planner_llm = LLMRequest(model_set=model_config.model_task_config.planner, request_type="planner") # 用于动作规划
self.last_obs_time_mark = 0.0
@@ -140,7 +137,7 @@ class ActionPlanner:
# --- 调用 LLM (普通文本生成) ---
llm_content = None
try:
llm_content, (reasoning_content, _) = await self.planner_llm.generate_response_async(prompt=prompt)
llm_content, (reasoning_content, _, _) = await self.planner_llm.generate_response_async(prompt=prompt)
if global_config.debug.show_prompt:
logger.info(f"{self.log_prefix}规划器原始提示词: {prompt}")

View File

@@ -8,7 +8,8 @@ from typing import List, Optional, Dict, Any, Tuple
from datetime import datetime
from src.mais4u.mai_think import mai_thinking_manager
from src.common.logger import get_logger
from src.config.config import global_config
from src.config.config import global_config, model_config
from src.config.api_ada_configs import TaskConfig
from src.individuality.individuality import get_individuality
from src.llm_models.utils_model import LLMRequest
from src.chat.message_receive.message import UserInfo, Seg, MessageRecv, MessageSending
@@ -106,31 +107,36 @@ class DefaultReplyer:
def __init__(
self,
chat_stream: ChatStream,
model_configs: Optional[List[Dict[str, Any]]] = None,
model_set_with_weight: Optional[List[Tuple[TaskConfig, float]]] = None,
request_type: str = "focus.replyer",
):
self.request_type = request_type
if model_configs:
self.express_model_configs = model_configs
if model_set_with_weight:
# self.express_model_configs = model_configs
self.model_set: List[Tuple[TaskConfig, float]] = model_set_with_weight
else:
# 当未提供配置时,使用默认配置并赋予默认权重
model_config_1 = global_config.model.replyer_1.copy()
model_config_2 = global_config.model.replyer_2.copy()
# model_config_1 = global_config.model.replyer_1.copy()
# model_config_2 = global_config.model.replyer_2.copy()
prob_first = global_config.chat.replyer_random_probability
model_config_1["weight"] = prob_first
model_config_2["weight"] = 1.0 - prob_first
# model_config_1["weight"] = prob_first
# model_config_2["weight"] = 1.0 - prob_first
self.express_model_configs = [model_config_1, model_config_2]
# self.express_model_configs = [model_config_1, model_config_2]
self.model_set = [
(model_config.model_task_config.replyer_1, prob_first),
(model_config.model_task_config.replyer_2, 1.0 - prob_first),
]
if not self.express_model_configs:
logger.warning("未找到有效的模型配置,回复生成可能会失败。")
# 提供一个最终的回退,以防止在空列表上调用 random.choice
fallback_config = global_config.model.replyer_1.copy()
fallback_config.setdefault("weight", 1.0)
self.express_model_configs = [fallback_config]
# if not self.express_model_configs:
# logger.warning("未找到有效的模型配置,回复生成可能会失败。")
# # 提供一个最终的回退,以防止在空列表上调用 random.choice
# fallback_config = global_config.model.replyer_1.copy()
# fallback_config.setdefault("weight", 1.0)
# self.express_model_configs = [fallback_config]
self.chat_stream = chat_stream
self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_stream.stream_id)
@@ -139,14 +145,15 @@ class DefaultReplyer:
self.memory_activator = MemoryActivator()
self.instant_memory = InstantMemory(chat_id=self.chat_stream.stream_id)
from src.plugin_system.core.tool_use import ToolExecutor # 延迟导入ToolExecutor不然会循环依赖
from src.plugin_system.core.tool_use import ToolExecutor # 延迟导入ToolExecutor不然会循环依赖
self.tool_executor = ToolExecutor(chat_id=self.chat_stream.stream_id, enable_cache=True, cache_ttl=3)
def _select_weighted_model_config(self) -> Dict[str, Any]:
def _select_weighted_models_config(self) -> Tuple[TaskConfig, float]:
"""使用加权随机选择来挑选一个模型配置"""
configs = self.express_model_configs
configs = self.model_set
# 提取权重,如果模型配置中没有'weight'键则默认为1.0
weights = [config.get("weight", 1.0) for config in configs]
weights = [weight for _, weight in configs]
return random.choices(population=configs, weights=weights, k=1)[0]
@@ -188,12 +195,11 @@ class DefaultReplyer:
# 4. 调用 LLM 生成回复
content = None
# TODO: 复活这里
# reasoning_content = None
# model_name = "unknown_model"
reasoning_content = None
model_name = "unknown_model"
try:
content = await self.llm_generate_content(prompt)
content, reasoning_content, model_name, _ = await self.llm_generate_content(prompt)
logger.debug(f"replyer生成内容: {content}")
except Exception as llm_e:
@@ -236,15 +242,14 @@ class DefaultReplyer:
)
content = None
# TODO: 复活这里
# reasoning_content = None
# model_name = "unknown_model"
reasoning_content = None
model_name = "unknown_model"
if not prompt:
logger.error("Prompt 构建失败,无法生成回复。")
return False, None, None
try:
content = await self.llm_generate_content(prompt)
content, reasoning_content, model_name, _ = await self.llm_generate_content(prompt)
logger.info(f"想要表达:{raw_reply}||理由:{reason}||生成回复: {content}\n")
except Exception as llm_e:
@@ -843,7 +848,7 @@ class DefaultReplyer:
raw_reply: str,
reason: str,
reply_to: str,
) -> str:
) -> str: # sourcery skip: remove-redundant-if
chat_stream = self.chat_stream
chat_id = chat_stream.stream_id
is_group_chat = bool(chat_stream.group_info)
@@ -977,30 +982,23 @@ class DefaultReplyer:
display_message=display_message,
)
async def llm_generate_content(self, prompt: str) -> str:
async def llm_generate_content(self, prompt: str):
with Timer("LLM生成", {}): # 内部计时器,可选保留
# 加权随机选择一个模型配置
selected_model_config = self._select_weighted_model_config()
model_display_name = selected_model_config.get('model_name') or selected_model_config.get('name', 'N/A')
logger.info(
f"使用模型生成回复: {model_display_name} (选中概率: {selected_model_config.get('weight', 1.0)})"
)
selected_model_config, weight = self._select_weighted_models_config()
logger.info(f"使用模型集生成回复: {selected_model_config} (选中概率: {weight})")
express_model = LLMRequest(
model=selected_model_config,
request_type=self.request_type,
)
express_model = LLMRequest(model_set=selected_model_config, request_type=self.request_type)
if global_config.debug.show_prompt:
logger.info(f"\n{prompt}\n")
else:
logger.debug(f"\n{prompt}\n")
# TODO: 这里的_应该做出替换
content, _ = await express_model.generate_response_async(prompt)
content, (reasoning_content, model_name, tool_calls) = await express_model.generate_response_async(prompt)
logger.debug(f"replyer生成内容: {content}")
return content
return content, reasoning_content, model_name, tool_calls
def weighted_sample_no_replacement(items, weights, k) -> list:

View File

@@ -1,6 +1,7 @@
from typing import Dict, Any, Optional, List
from typing import Dict, Optional, List, Tuple
from src.common.logger import get_logger
from src.config.api_ada_configs import TaskConfig
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
from src.chat.replyer.default_generator import DefaultReplyer
@@ -15,7 +16,7 @@ class ReplyerManager:
self,
chat_stream: Optional[ChatStream] = None,
chat_id: Optional[str] = None,
model_configs: Optional[List[Dict[str, Any]]] = None,
model_set_with_weight: Optional[List[Tuple[TaskConfig, float]]] = None,
request_type: str = "replyer",
) -> Optional[DefaultReplyer]:
"""
@@ -49,7 +50,7 @@ class ReplyerManager:
# model_configs 只在此时(初始化时)生效
replyer = DefaultReplyer(
chat_stream=target_stream,
model_configs=model_configs, # 可以是None此时使用默认模型
model_set_with_weight=model_set_with_weight, # 可以是None此时使用默认模型
request_type=request_type,
)
self._repliers[stream_id] = replyer

View File

@@ -11,7 +11,7 @@ from typing import Optional, Tuple, Dict, List, Any
from src.common.logger import get_logger
from src.common.message_repository import find_messages, count_messages
from src.config.config import global_config
from src.config.config import global_config, model_config
from src.chat.message_receive.message import MessageRecv
from src.chat.message_receive.chat_stream import get_chat_manager
from src.llm_models.utils_model import LLMRequest
@@ -109,13 +109,11 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]:
return is_mentioned, reply_probability
async def get_embedding(text, request_type="embedding"):
async def get_embedding(text, request_type="embedding") -> Optional[List[float]]:
"""获取文本的embedding向量"""
# TODO: API-Adapter修改标记
llm = LLMRequest(model=global_config.model.embedding, request_type=request_type)
# return llm.get_embedding_sync(text)
llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type=request_type)
try:
embedding = await llm.get_embedding(text)
embedding, _ = await llm.get_embedding(text)
except Exception as e:
logger.error(f"获取embedding失败: {str(e)}")
embedding = None

View File

@@ -14,7 +14,7 @@ from rich.traceback import install
from src.common.logger import get_logger
from src.common.database.database import db
from src.common.database.database_model import Images, ImageDescriptions
from src.config.config import global_config
from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
install(extra_lines=3)
@@ -37,7 +37,7 @@ class ImageManager:
self._ensure_image_dir()
self._initialized = True
self.vlm = LLMRequest(model=global_config.model.vlm, temperature=0.4, max_tokens=300, request_type="image")
self.vlm = LLMRequest(model_set=model_config.model_task_config.vlm, request_type="image")
try:
db.connect(reuse_if_open=True)
@@ -107,6 +107,7 @@ class ImageManager:
# 优先使用EmojiManager查询已注册表情包的描述
try:
from src.chat.emoji_system.emoji_manager import get_emoji_manager
emoji_manager = get_emoji_manager()
cached_emoji_description = await emoji_manager.get_emoji_description_by_hash(image_hash)
if cached_emoji_description:
@@ -116,13 +117,12 @@ class ImageManager:
logger.debug(f"查询EmojiManager时出错: {e}")
# 查询ImageDescriptions表的缓存描述
cached_description = self._get_description_from_db(image_hash, "emoji")
if cached_description:
if cached_description := self._get_description_from_db(image_hash, "emoji"):
logger.info(f"[缓存命中] 使用ImageDescriptions表中的描述: {cached_description[:50]}...")
return f"[表情包:{cached_description}]"
# === 二步走识别流程 ===
# 第一步VLM视觉分析 - 生成详细描述
if image_format in ["gif", "GIF"]:
image_base64_processed = self.transform_gif(image_base64)
@@ -130,10 +130,16 @@ class ImageManager:
logger.warning("GIF转换失败无法获取描述")
return "[表情包(GIF处理失败)]"
vlm_prompt = "这是一个动态图表情包,每一张图代表了动态图的某一帧,黑色背景代表透明,描述一下表情包表达的情感和内容,描述细节,从互联网梗,meme的角度去分析"
detailed_description, _ = await self.vlm.generate_response_for_image(vlm_prompt, image_base64_processed, "jpg")
detailed_description, _ = await self.vlm.generate_response_for_image(
vlm_prompt, image_base64_processed, "jpg", temperature=0.4, max_tokens=300
)
else:
vlm_prompt = "这是一个表情包,请详细描述一下表情包所表达的情感和内容,描述细节,从互联网梗,meme的角度去分析"
detailed_description, _ = await self.vlm.generate_response_for_image(vlm_prompt, image_base64, image_format)
vlm_prompt = (
"这是一个表情包,请详细描述一下表情包所表达的情感和内容,描述细节,从互联网梗,meme的角度去分析"
)
detailed_description, _ = await self.vlm.generate_response_for_image(
vlm_prompt, image_base64, image_format, temperature=0.4, max_tokens=300
)
if detailed_description is None:
logger.warning("VLM未能生成表情包详细描述")
@@ -150,31 +156,32 @@ class ImageManager:
3. 输出简短精准,不要解释
4. 如果有多个词用逗号分隔
"""
# 使用较低温度确保输出稳定
emotion_llm = LLMRequest(model=global_config.model.utils, temperature=0.3, max_tokens=50, request_type="emoji")
emotion_result, _ = await emotion_llm.generate_response_async(emotion_prompt)
emotion_llm = LLMRequest(model_set=model_config.model_task_config.utils, request_type="emoji")
emotion_result, _ = await emotion_llm.generate_response_async(
emotion_prompt, temperature=0.3, max_tokens=50
)
if emotion_result is None:
logger.warning("LLM未能生成情感标签使用详细描述的前几个词")
# 降级处理:从详细描述中提取关键词
import jieba
words = list(jieba.cut(detailed_description))
emotion_result = "".join(words[:2]) if len(words) >= 2 else (words[0] if words else "表情")
# 处理情感结果取前1-2个最重要的标签
emotions = [e.strip() for e in emotion_result.replace("", ",").split(",") if e.strip()]
final_emotion = emotions[0] if emotions else "表情"
# 如果有第二个情感且不重复,也包含进来
if len(emotions) > 1 and emotions[1] != emotions[0]:
final_emotion = f"{emotions[0]}{emotions[1]}"
logger.info(f"[emoji识别] 详细描述: {detailed_description[:50]}... -> 情感标签: {final_emotion}")
# 再次检查缓存,防止并发写入时重复生成
cached_description = self._get_description_from_db(image_hash, "emoji")
if cached_description:
if cached_description := self._get_description_from_db(image_hash, "emoji"):
logger.warning(f"虽然生成了描述,但是找到缓存表情包描述: {cached_description}")
return f"[表情包:{cached_description}]"
@@ -242,9 +249,7 @@ class ImageManager:
logger.debug(f"[缓存命中] 使用Images表中的图片描述: {existing_image.description[:50]}...")
return f"[图片:{existing_image.description}]"
# 查询ImageDescriptions表的缓存描述
cached_description = self._get_description_from_db(image_hash, "image")
if cached_description:
if cached_description := self._get_description_from_db(image_hash, "image"):
logger.debug(f"[缓存命中] 使用ImageDescriptions表中的描述: {cached_description[:50]}...")
return f"[图片:{cached_description}]"
@@ -252,7 +257,9 @@ class ImageManager:
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore
prompt = global_config.custom_prompt.image_prompt
logger.info(f"[VLM调用] 为图片生成新描述 (Hash: {image_hash[:8]}...)")
description, _ = await self.vlm.generate_response_for_image(prompt, image_base64, image_format)
description, _ = await self.vlm.generate_response_for_image(
prompt, image_base64, image_format, temperature=0.4, max_tokens=300
)
if description is None:
logger.warning("AI未能生成图片描述")
@@ -445,10 +452,7 @@ class ImageManager:
image_bytes = base64.b64decode(image_base64)
image_hash = hashlib.md5(image_bytes).hexdigest()
# 检查图片是否已存在
existing_image = Images.get_or_none(Images.emoji_hash == image_hash)
if existing_image:
if existing_image := Images.get_or_none(Images.emoji_hash == image_hash):
# 检查是否缺少必要字段,如果缺少则创建新记录
if (
not hasattr(existing_image, "image_id")
@@ -524,9 +528,7 @@ class ImageManager:
# 优先检查是否已有其他相同哈希的图片记录包含描述
existing_with_description = Images.get_or_none(
(Images.emoji_hash == image_hash) &
(Images.description.is_null(False)) &
(Images.description != "")
(Images.emoji_hash == image_hash) & (Images.description.is_null(False)) & (Images.description != "")
)
if existing_with_description and existing_with_description.id != image.id:
logger.debug(f"[缓存复用] 从其他相同图片记录复用描述: {existing_with_description.description[:50]}...")
@@ -538,8 +540,7 @@ class ImageManager:
return
# 检查ImageDescriptions表的缓存描述
cached_description = self._get_description_from_db(image_hash, "image")
if cached_description:
if cached_description := self._get_description_from_db(image_hash, "image"):
logger.debug(f"[缓存复用] 从ImageDescriptions表复用描述: {cached_description[:50]}...")
image.description = cached_description
image.vlm_processed = True
@@ -554,15 +555,15 @@ class ImageManager:
# 获取VLM描述
logger.info(f"[VLM异步调用] 为图片生成描述 (ID: {image_id}, Hash: {image_hash[:8]}...)")
description, _ = await self.vlm.generate_response_for_image(prompt, image_base64, image_format)
description, _ = await self.vlm.generate_response_for_image(
prompt, image_base64, image_format, temperature=0.4, max_tokens=300
)
if description is None:
logger.warning("VLM未能生成图片描述")
description = "无法生成描述"
# 再次检查缓存,防止并发写入时重复生成
cached_description = self._get_description_from_db(image_hash, "image")
if cached_description:
if cached_description := self._get_description_from_db(image_hash, "image"):
logger.warning(f"虽然生成了描述,但是找到缓存图片描述: {cached_description}")
description = cached_description
@@ -606,7 +607,7 @@ def image_path_to_base64(image_path: str) -> str:
raise FileNotFoundError(f"图片文件不存在: {image_path}")
with open(image_path, "rb") as f:
image_data = f.read()
if not image_data:
if image_data := f.read():
return base64.b64encode(image_data).decode("utf-8")
else:
raise IOError(f"读取图片文件失败: {image_path}")
return base64.b64encode(image_data).decode("utf-8")

View File

@@ -1,6 +1,6 @@
import base64
from src.config.config import global_config
from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
from src.common.logger import get_logger
@@ -20,7 +20,7 @@ async def get_voice_text(voice_base64: str) -> str:
if isinstance(voice_base64, str):
voice_base64 = voice_base64.encode("ascii", errors="ignore").decode("ascii")
voice_bytes = base64.b64decode(voice_base64)
_llm = LLMRequest(model=global_config.model.voice, request_type="voice")
_llm = LLMRequest(model_set=model_config.model_task_config.voice, request_type="voice")
text = await _llm.generate_response_for_voice(voice_bytes)
if text is None:
logger.warning("未能生成语音文本")

View File

@@ -19,13 +19,13 @@ Mxp 模式:梦溪畔独家赞助
下下策是询问一个菜鸟(@梦溪畔)
"""
from .willing_manager import BaseWillingManager
from typing import Dict
import asyncio
import time
import math
from src.chat.message_receive.chat_stream import ChatStream
from .willing_manager import BaseWillingManager
class MxpWillingManager(BaseWillingManager):

View File

@@ -60,268 +60,6 @@ TEMPLATE_DIR = os.path.join(PROJECT_ROOT, "template")
MMC_VERSION = "0.10.0-snapshot.2"
# def _get_config_version(toml: Dict) -> Version:
# """提取配置文件的 SpecifierSet 版本数据
# Args:
# toml[dict]: 输入的配置文件字典
# Returns:
# Version
# """
# if "inner" in toml and "version" in toml["inner"]:
# config_version: str = toml["inner"]["version"]
# else:
# raise InvalidVersion("配置文件缺少版本信息,请检查配置文件。")
# try:
# return version.parse(config_version)
# except InvalidVersion as e:
# logger.error(
# "配置文件中 inner段 的 version 键是错误的版本描述\n"
# f"请检查配置文件,当前 version 键: {config_version}\n"
# f"错误信息: {e}"
# )
# raise e
# def _request_conf(parent: Dict, config: ModuleConfig):
# request_conf_config = parent.get("request_conf")
# config.req_conf.max_retry = request_conf_config.get(
# "max_retry", config.req_conf.max_retry
# )
# config.req_conf.timeout = request_conf_config.get(
# "timeout", config.req_conf.timeout
# )
# config.req_conf.retry_interval = request_conf_config.get(
# "retry_interval", config.req_conf.retry_interval
# )
# config.req_conf.default_temperature = request_conf_config.get(
# "default_temperature", config.req_conf.default_temperature
# )
# config.req_conf.default_max_tokens = request_conf_config.get(
# "default_max_tokens", config.req_conf.default_max_tokens
# )
# def _api_providers(parent: Dict, config: ModuleConfig):
# api_providers_config = parent.get("api_providers")
# for provider in api_providers_config:
# name = provider.get("name", None)
# base_url = provider.get("base_url", None)
# api_key = provider.get("api_key", None)
# api_keys = provider.get("api_keys", []) # 新增支持多个API Key
# client_type = provider.get("client_type", "openai")
# if name in config.api_providers: # 查重
# logger.error(f"重复的API提供商名称: {name},请检查配置文件。")
# raise KeyError(f"重复的API提供商名称: {name},请检查配置文件。")
# if name and base_url:
# # 处理API Key配置支持单个api_key或多个api_keys
# if api_keys:
# # 使用新格式api_keys列表
# logger.debug(f"API提供商 '{name}' 配置了 {len(api_keys)} 个API Key")
# elif api_key:
# # 向后兼容使用单个api_key
# api_keys = [api_key]
# logger.debug(f"API提供商 '{name}' 使用单个API Key向后兼容模式")
# else:
# logger.warning(f"API提供商 '{name}' 没有配置API Key某些功能可能不可用")
# config.api_providers[name] = APIProvider(
# name=name,
# base_url=base_url,
# api_key=api_key, # 保留向后兼容
# api_keys=api_keys, # 新格式
# client_type=client_type,
# )
# else:
# logger.error(f"API提供商 '{name}' 的配置不完整,请检查配置文件。")
# raise ValueError(f"API提供商 '{name}' 的配置不完整,请检查配置文件。")
# def _models(parent: Dict, config: ModuleConfig):
# models_config = parent.get("models")
# for model in models_config:
# model_identifier = model.get("model_identifier", None)
# name = model.get("name", model_identifier)
# api_provider = model.get("api_provider", None)
# price_in = model.get("price_in", 0.0)
# price_out = model.get("price_out", 0.0)
# force_stream_mode = model.get("force_stream_mode", False)
# task_type = model.get("task_type", "")
# capabilities = model.get("capabilities", [])
# if name in config.models: # 查重
# logger.error(f"重复的模型名称: {name},请检查配置文件。")
# raise KeyError(f"重复的模型名称: {name},请检查配置文件。")
# if model_identifier and api_provider:
# # 检查API提供商是否存在
# if api_provider not in config.api_providers:
# logger.error(f"未声明的API提供商 '{api_provider}' ,请检查配置文件。")
# raise ValueError(
# f"未声明的API提供商 '{api_provider}' ,请检查配置文件。"
# )
# config.models[name] = ModelInfo(
# name=name,
# model_identifier=model_identifier,
# api_provider=api_provider,
# price_in=price_in,
# price_out=price_out,
# force_stream_mode=force_stream_mode,
# task_type=task_type,
# capabilities=capabilities,
# )
# else:
# logger.error(f"模型 '{name}' 的配置不完整,请检查配置文件。")
# raise ValueError(f"模型 '{name}' 的配置不完整,请检查配置文件。")
# def _task_model_usage(parent: Dict, config: ModuleConfig):
# model_usage_configs = parent.get("task_model_usage")
# config.task_model_arg_map = {}
# for task_name, item in model_usage_configs.items():
# if task_name in config.task_model_arg_map:
# logger.error(f"子任务 {task_name} 已存在,请检查配置文件。")
# raise KeyError(f"子任务 {task_name} 已存在,请检查配置文件。")
# usage = []
# if isinstance(item, Dict):
# if "model" in item:
# usage.append(
# ModelUsageArgConfigItem(
# name=item["model"],
# temperature=item.get("temperature", None),
# max_tokens=item.get("max_tokens", None),
# max_retry=item.get("max_retry", None),
# )
# )
# else:
# logger.error(f"子任务 {task_name} 的模型配置不合法,请检查配置文件。")
# raise ValueError(
# f"子任务 {task_name} 的模型配置不合法,请检查配置文件。"
# )
# elif isinstance(item, List):
# for model in item:
# if isinstance(model, Dict):
# usage.append(
# ModelUsageArgConfigItem(
# name=model["model"],
# temperature=model.get("temperature", None),
# max_tokens=model.get("max_tokens", None),
# max_retry=model.get("max_retry", None),
# )
# )
# elif isinstance(model, str):
# usage.append(
# ModelUsageArgConfigItem(
# name=model,
# temperature=None,
# max_tokens=None,
# max_retry=None,
# )
# )
# else:
# logger.error(
# f"子任务 {task_name} 的模型配置不合法,请检查配置文件。"
# )
# raise ValueError(
# f"子任务 {task_name} 的模型配置不合法,请检查配置文件。"
# )
# elif isinstance(item, str):
# usage.append(
# ModelUsageArgConfigItem(
# name=item,
# temperature=None,
# max_tokens=None,
# max_retry=None,
# )
# )
# config.task_model_arg_map[task_name] = ModelUsageArgConfig(
# name=task_name,
# usage=usage,
# )
# def api_ada_load_config(config_path: str) -> ModuleConfig:
# """从TOML配置文件加载配置"""
# config = ModuleConfig()
# include_configs: Dict[str, Dict[str, Any]] = {
# "request_conf": {
# "func": _request_conf,
# "support": ">=0.0.0",
# "necessary": False,
# },
# "api_providers": {"func": _api_providers, "support": ">=0.0.0"},
# "models": {"func": _models, "support": ">=0.0.0"},
# "task_model_usage": {"func": _task_model_usage, "support": ">=0.0.0"},
# }
# if os.path.exists(config_path):
# with open(config_path, "rb") as f:
# try:
# toml_dict = tomlkit.load(f)
# except tomlkit.TOMLDecodeError as e:
# logger.critical(
# f"配置文件model_list.toml填写有误请检查第{e.lineno}行第{e.colno}处:{e.msg}"
# )
# exit(1)
# # 获取配置文件版本
# config.INNER_VERSION = _get_config_version(toml_dict)
# # 检查版本
# if config.INNER_VERSION > Version(NEWEST_VER):
# logger.warning(
# f"当前配置文件版本 {config.INNER_VERSION} 高于支持的最新版本 {NEWEST_VER},可能导致异常,建议更新依赖。"
# )
# # 解析配置文件
# # 如果在配置中找到了需要的项,调用对应项的闭包函数处理
# for key in include_configs:
# if key in toml_dict:
# group_specifier_set: SpecifierSet = SpecifierSet(
# include_configs[key]["support"]
# )
# # 检查配置文件版本是否在支持范围内
# if config.INNER_VERSION in group_specifier_set:
# # 如果版本在支持范围内,检查是否存在通知
# if "notice" in include_configs[key]:
# logger.warning(include_configs[key]["notice"])
# # 调用闭包函数处理配置
# (include_configs[key]["func"])(toml_dict, config)
# else:
# # 如果版本不在支持范围内,崩溃并提示用户
# logger.error(
# f"配置文件中的 '{key}' 字段的版本 ({config.INNER_VERSION}) 不在支持范围内。\n"
# f"当前程序仅支持以下版本范围: {group_specifier_set}"
# )
# raise InvalidVersion(
# f"当前程序仅支持以下版本范围: {group_specifier_set}"
# )
# # 如果 necessary 项目存在,而且显式声明是 False进入特殊处理
# elif (
# "necessary" in include_configs[key]
# and include_configs[key].get("necessary") is False
# ):
# # 通过 pass 处理的项虽然直接忽略也是可以的,但是为了不增加理解困难,依然需要在这里显式处理
# if key == "keywords_reaction":
# pass
# else:
# # 如果用户根本没有需要的配置项,提示缺少配置
# logger.error(f"配置文件中缺少必需的字段: '{key}'")
# raise KeyError(f"配置文件中缺少必需的字段: '{key}'")
# logger.info(f"成功加载配置文件: {config_path}")
# return config
def get_key_comment(toml_table, key):
# 获取key的注释如果有
if hasattr(toml_table, "trivia") and hasattr(toml_table.trivia, "comment"):
@@ -626,9 +364,19 @@ class APIAdapterConfig(ConfigBase):
"""API提供商列表"""
def __post_init__(self):
# 检查API提供商名称是否重复
provider_names = [provider.name for provider in self.api_providers]
if len(provider_names) != len(set(provider_names)):
raise ValueError("API提供商名称存在重复请检查配置文件。")
# 检查模型名称是否重复
model_names = [model.name for model in self.models]
if len(model_names) != len(set(model_names)):
raise ValueError("模型名称存在重复,请检查配置文件。")
self.api_providers_dict = {provider.name: provider for provider in self.api_providers}
self.models_dict = {model.name: model for model in self.models}
def get_model_info(self, model_name: str) -> ModelInfo:
"""根据模型名称获取模型信息"""
if not model_name:
@@ -636,7 +384,7 @@ class APIAdapterConfig(ConfigBase):
if model_name not in self.models_dict:
raise KeyError(f"模型 '{model_name}' 不存在")
return self.models_dict[model_name]
def get_provider(self, provider_name: str) -> APIProvider:
"""根据提供商名称获取API提供商信息"""
if not provider_name:

View File

@@ -4,7 +4,7 @@ import hashlib
import time
from src.common.logger import get_logger
from src.config.config import global_config
from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
from src.person_info.person_info import get_person_info_manager
from rich.traceback import install
@@ -23,10 +23,7 @@ class Individuality:
self.meta_info_file_path = "data/personality/meta.json"
self.personality_data_file_path = "data/personality/personality_data.json"
self.model = LLMRequest(
model=global_config.model.utils,
request_type="individuality.compress",
)
self.model = LLMRequest(model_set=model_config.model_task_config.utils, request_type="individuality.compress")
async def initialize(self) -> None:
"""初始化个体特征"""
@@ -35,7 +32,6 @@ class Individuality:
personality_side = global_config.personality.personality_side
identity = global_config.personality.identity
person_info_manager = get_person_info_manager()
self.bot_person_id = person_info_manager.get_person_id("system", "bot_id")
self.name = bot_nickname
@@ -85,16 +81,16 @@ class Individuality:
bot_nickname = f",也有人叫你{','.join(global_config.bot.alias_names)}"
else:
bot_nickname = ""
# 从文件获取 short_impression
personality, identity = self._get_personality_from_file()
# 确保short_impression是列表格式且有足够的元素
if not personality or not identity:
logger.warning(f"personality或identity为空: {personality}, {identity}, 使用默认值")
personality = "友好活泼"
identity = "人类"
prompt_personality = f"{personality}\n{identity}"
return f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}"
@@ -215,7 +211,7 @@ class Individuality:
def _get_personality_from_file(self) -> tuple[str, str]:
"""从文件获取personality数据
Returns:
tuple: (personality, identity)
"""
@@ -226,7 +222,7 @@ class Individuality:
def _save_personality_to_file(self, personality: str, identity: str):
"""保存personality数据到文件
Args:
personality: 压缩后的人格描述
identity: 压缩后的身份描述
@@ -235,7 +231,7 @@ class Individuality:
"personality": personality,
"identity": identity,
"bot_nickname": self.name,
"last_updated": int(time.time())
"last_updated": int(time.time()),
}
self._save_personality_data(personality_data)
@@ -269,7 +265,7 @@ class Individuality:
2. 尽量简洁不超过30字
3. 直接输出压缩后的内容,不要解释"""
response, (_, _) = await self.model.generate_response_async(
response, _ = await self.model.generate_response_async(
prompt=prompt,
)
@@ -281,7 +277,7 @@ class Individuality:
# 压缩失败时使用原始内容
if personality_side:
personality_parts.append(personality_side)
if personality_parts:
personality_result = "".join(personality_parts)
else:
@@ -308,7 +304,7 @@ class Individuality:
2. 尽量简洁不超过30字
3. 直接输出压缩后的内容,不要解释"""
response, (_, _) = await self.model.generate_response_async(
response, _ = await self.model.generate_response_async(
prompt=prompt,
)

View File

@@ -1,12 +0,0 @@
import importlib
from typing import Dict
from src.config.config import model_config
from src.common.logger import get_logger
from .model_client import ModelRequestHandler, BaseClient
logger = get_logger("模型管理器")
class ModelManager:

View File

@@ -1,92 +0,0 @@
import importlib
from typing import Dict
from src.config.config import model_config
from src.config.api_ada_configs import ModuleConfig, ModelUsageArgConfig
from src.common.logger import get_logger
from .model_client import ModelRequestHandler, BaseClient
logger = get_logger("模型管理器")
class ModelManager:
# TODO: 添加读写锁,防止异步刷新配置时发生数据竞争
def __init__(
self,
config: ModuleConfig,
):
self.config: ModuleConfig = config
"""配置信息"""
self.api_client_map: Dict[str, BaseClient] = {}
"""API客户端映射表"""
self._request_handler_cache: Dict[str, ModelRequestHandler] = {}
"""ModelRequestHandler缓存避免重复创建"""
for provider_name, api_provider in self.config.api_providers.items():
# 初始化API客户端
try:
# 根据配置动态加载实现
client_module = importlib.import_module(
f".model_client.{api_provider.client_type}_client", __package__
)
client_class = getattr(
client_module, f"{api_provider.client_type.capitalize()}Client"
)
if not issubclass(client_class, BaseClient):
raise TypeError(
f"'{client_class.__name__}' is not a subclass of 'BaseClient'"
)
self.api_client_map[api_provider.name] = client_class(
api_provider
) # 实例化放入api_client_map
except ImportError as e:
logger.error(f"Failed to import client module: {e}")
raise ImportError(
f"Failed to import client module for '{provider_name}': {e}"
) from e
def __getitem__(self, task_name: str) -> ModelRequestHandler:
"""
获取任务所需的模型客户端(封装)
使用缓存机制避免重复创建ModelRequestHandler
:param task_name: 任务名称
:return: 模型客户端
"""
if task_name not in self.config.task_model_arg_map:
raise KeyError(f"'{task_name}' not registered in ModelManager")
# 检查缓存中是否已存在
if task_name in self._request_handler_cache:
logger.debug(f"🚀 [性能优化] 从缓存获取ModelRequestHandler: {task_name}")
return self._request_handler_cache[task_name]
# 创建新的ModelRequestHandler并缓存
logger.debug(f"🔧 [性能优化] 创建并缓存ModelRequestHandler: {task_name}")
handler = ModelRequestHandler(
task_name=task_name,
config=self.config,
api_client_map=self.api_client_map,
)
self._request_handler_cache[task_name] = handler
return handler
def __setitem__(self, task_name: str, value: ModelUsageArgConfig):
"""
注册任务的模型使用配置
:param task_name: 任务名称
:param value: 模型使用配置
"""
self.config.task_model_arg_map[task_name] = value
def __contains__(self, task_name: str):
"""
判断任务是否已注册
:param task_name: 任务名称
:return: 是否在模型列表中
"""
return task_name in self.config.task_model_arg_map

View File

@@ -1,169 +0,0 @@
from datetime import datetime
from enum import Enum
from typing import Tuple
from src.common.logger import get_logger
from src.config.api_ada_configs import ModelInfo
from src.common.database.database_model import LLMUsage
logger = get_logger("模型使用统计")
class ReqType(Enum):
"""
请求类型
"""
CHAT = "chat" # 对话请求
EMBEDDING = "embedding" # 嵌入请求
class UsageCallStatus(Enum):
"""
任务调用状态
"""
PROCESSING = "processing" # 处理中
SUCCESS = "success" # 成功
FAILURE = "failure" # 失败
CANCELED = "canceled" # 取消
class ModelUsageStatistic:
"""
模型使用统计类 - 使用SQLite+Peewee
"""
def __init__(self):
"""
初始化统计类
由于使用Peewee ORM不需要传入数据库实例
"""
# 确保表已经创建
try:
from src.common.database.database import db
db.create_tables([LLMUsage], safe=True)
except Exception as e:
logger.error(f"创建LLMUsage表失败: {e}")
@staticmethod
def _calculate_cost(prompt_tokens: int, completion_tokens: int, model_info: ModelInfo) -> float:
"""计算API调用成本
使用模型的pri_in和pri_out价格计算输入和输出的成本
Args:
prompt_tokens: 输入token数量
completion_tokens: 输出token数量
model_info: 模型信息
Returns:
float: 总成本(元)
"""
# 使用模型的pri_in和pri_out计算成本
input_cost = (prompt_tokens / 1000000) * model_info.price_in
output_cost = (completion_tokens / 1000000) * model_info.price_out
return round(input_cost + output_cost, 6)
def create_usage(
self,
model_name: str,
task_name: str = "N/A",
request_type: ReqType = ReqType.CHAT,
user_id: str = "system",
endpoint: str = "/chat/completions",
) -> int | None:
"""
创建模型使用情况记录
Args:
model_name: 模型名
task_name: 任务名称
request_type: 请求类型默认为Chat
user_id: 用户ID默认为system
endpoint: API端点
Returns:
int | None: 返回记录ID失败返回None
"""
try:
usage_record = LLMUsage.create(
model_name=model_name,
user_id=user_id,
request_type=request_type.value,
endpoint=endpoint,
prompt_tokens=0,
completion_tokens=0,
total_tokens=0,
cost=0.0,
status=UsageCallStatus.PROCESSING.value,
timestamp=datetime.now(),
)
# logger.trace(
# f"创建了一条模型使用情况记录 - 模型: {model_name}, "
# f"子任务: {task_name}, 类型: {request_type.value}, "
# f"用户: {user_id}, 记录ID: {usage_record.id}"
# )
return usage_record.id
except Exception as e:
logger.error(f"创建模型使用情况记录失败: {str(e)}")
return None
def update_usage(
self,
record_id: int | None,
model_info: ModelInfo,
usage_data: Tuple[int, int, int] | None = None,
stat: UsageCallStatus = UsageCallStatus.SUCCESS,
ext_msg: str | None = None,
):
"""
更新模型使用情况
Args:
record_id: 记录ID
model_info: 模型信息
usage_data: 使用情况数据(输入token数量, 输出token数量, 总token数量)
stat: 任务调用状态
ext_msg: 额外信息
"""
if not record_id:
logger.error("更新模型使用情况失败: record_id不能为空")
return
if usage_data and len(usage_data) != 3:
logger.error("更新模型使用情况失败: usage_data的长度不正确应该为3个元素")
return
# 提取使用情况数据
prompt_tokens = usage_data[0] if usage_data else 0
completion_tokens = usage_data[1] if usage_data else 0
total_tokens = usage_data[2] if usage_data else 0
try:
# 使用Peewee更新记录
update_query = LLMUsage.update(
status=stat.value,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
cost=self._calculate_cost(prompt_tokens, completion_tokens, model_info) if usage_data else 0.0,
).where(LLMUsage.id == record_id) # type: ignore
updated_count = update_query.execute()
if updated_count == 0:
logger.warning(f"记录ID {record_id} 不存在,无法更新")
return
logger.debug(
f"Token使用情况 - 模型: {model_info.name}, "
f"记录ID: {record_id}, "
f"任务状态: {stat.value}, 额外信息: {ext_msg or 'N/A'}, "
f"提示词: {prompt_tokens}, 完成: {completion_tokens}, "
f"总计: {total_tokens}"
)
except Exception as e:
logger.error(f"记录token使用情况失败: {str(e)}")

View File

@@ -2,16 +2,19 @@ import base64
import io
from PIL import Image
from datetime import datetime
from src.common.logger import get_logger
from src.common.database.database import db # 确保 db 被导入用于 create_tables
from src.common.database.database_model import LLMUsage
from src.config.api_ada_configs import ModelInfo
from .payload_content.message import Message, MessageBuilder
from .model_client.base_client import UsageRecord
logger = get_logger("消息压缩工具")
def compress_messages(
messages: list[Message], img_target_size: int = 1 * 1024 * 1024
) -> list[Message]:
def compress_messages(messages: list[Message], img_target_size: int = 1 * 1024 * 1024) -> list[Message]:
"""
压缩消息列表中的图片
:param messages: 消息列表
@@ -28,14 +31,10 @@ def compress_messages(
try:
image = Image.open(image_data)
if image.format and (
image.format.upper() in ["JPEG", "JPG", "PNG", "WEBP"]
):
if image.format and (image.format.upper() in ["JPEG", "JPG", "PNG", "WEBP"]):
# 静态图像转换为JPEG格式
reformated_image_data = io.BytesIO()
image.save(
reformated_image_data, format="JPEG", quality=95, optimize=True
)
image.save(reformated_image_data, format="JPEG", quality=95, optimize=True)
image_data = reformated_image_data.getvalue()
return image_data
@@ -43,9 +42,7 @@ def compress_messages(
logger.error(f"图片转换格式失败: {str(e)}")
return image_data
def rescale_image(
image_data: bytes, scale: float
) -> tuple[bytes, tuple[int, int] | None, tuple[int, int] | None]:
def rescale_image(image_data: bytes, scale: float) -> tuple[bytes, tuple[int, int] | None, tuple[int, int] | None]:
"""
缩放图片
:param image_data: 图片数据
@@ -86,9 +83,7 @@ def compress_messages(
else:
# 静态图片,直接缩放保存
resized_image = image.resize(new_size, Image.Resampling.LANCZOS)
resized_image.save(
output_buffer, format="JPEG", quality=95, optimize=True
)
resized_image.save(output_buffer, format="JPEG", quality=95, optimize=True)
return output_buffer.getvalue(), original_size, new_size
@@ -99,9 +94,7 @@ def compress_messages(
logger.error(traceback.format_exc())
return image_data, None, None
def compress_base64_image(
base64_data: str, target_size: int = 1 * 1024 * 1024
) -> str:
def compress_base64_image(base64_data: str, target_size: int = 1 * 1024 * 1024) -> str:
original_b64_data_size = len(base64_data) # 计算原始数据大小
image_data = base64.b64decode(base64_data)
@@ -111,9 +104,7 @@ def compress_messages(
base64_data = base64.b64encode(image_data).decode("utf-8")
if len(base64_data) <= target_size:
# 如果转换后小于目标大小,直接返回
logger.info(
f"成功将图片转为JPEG格式编码后大小: {len(base64_data) / 1024:.1f}KB"
)
logger.info(f"成功将图片转为JPEG格式编码后大小: {len(base64_data) / 1024:.1f}KB")
return base64_data
# 如果转换后仍然大于目标大小,进行尺寸压缩
@@ -139,9 +130,7 @@ def compress_messages(
# 图片,进行压缩
message_builder.add_image_content(
content_item[0],
compress_base64_image(
content_item[1], target_size=img_target_size
),
compress_base64_image(content_item[1], target_size=img_target_size),
)
else:
message_builder.add_text_content(content_item)
@@ -150,3 +139,48 @@ def compress_messages(
compressed_messages.append(message)
return compressed_messages
class LLMUsageRecorder:
"""
LLM使用情况记录器
"""
def __init__(self):
try:
# 使用 Peewee 创建表safe=True 表示如果表已存在则不会抛出错误
db.create_tables([LLMUsage], safe=True)
# logger.debug("LLMUsage 表已初始化/确保存在。")
except Exception as e:
logger.error(f"创建 LLMUsage 表失败: {str(e)}")
def record_usage_to_database(
self, model_info: ModelInfo, model_usage: UsageRecord, user_id: str, request_type: str, endpoint: str
):
input_cost = (model_usage.prompt_tokens / 1000000) * model_info.price_in
output_cost = (model_usage.completion_tokens / 1000000) * model_info.price_out
total_cost = round(input_cost + output_cost, 6)
try:
# 使用 Peewee 模型创建记录
LLMUsage.create(
model_name=model_info.model_identifier,
user_id=user_id,
request_type=request_type,
endpoint=endpoint,
prompt_tokens=model_usage.prompt_tokens or 0,
completion_tokens=model_usage.completion_tokens or 0,
total_tokens=model_usage.total_tokens or 0,
cost=total_cost or 0.0,
status="success",
timestamp=datetime.now(), # Peewee 会处理 DateTimeField
)
logger.debug(
f"Token使用情况 - 模型: {model_usage.model_name}, "
f"用户: {user_id}, 类型: {request_type}, "
f"提示词: {model_usage.prompt_tokens}, 完成: {model_usage.completion_tokens}, "
f"总计: {model_usage.total_tokens}"
)
except Exception as e:
logger.error(f"记录token使用情况失败: {str(e)}")
llm_usage_recorder = LLMUsageRecorder()

View File

@@ -1,34 +1,20 @@
import re
import copy
import asyncio
from datetime import datetime
from typing import Tuple, Union, List, Dict, Optional, Callable, Any
from src.common.logger import get_logger
import base64
from PIL import Image
from enum import Enum
import io
from src.common.database.database import db # 确保 db 被导入用于 create_tables
from src.common.database.database_model import LLMUsage # 导入 LLMUsage 模型
from src.config.config import global_config, model_config
from src.config.api_ada_configs import APIProvider, ModelInfo
from rich.traceback import install
from enum import Enum
from rich.traceback import install
from typing import Tuple, List, Dict, Optional, Callable, Any
from src.common.logger import get_logger
from src.config.config import model_config
from src.config.api_ada_configs import APIProvider, ModelInfo, TaskConfig
from .payload_content.message import MessageBuilder, Message
from .payload_content.resp_format import RespFormat
from .payload_content.tool_option import ToolOption, ToolCall
from .model_client.base_client import BaseClient, APIResponse, UsageRecord, client_registry
from .utils import compress_messages
from .exceptions import (
NetworkConnectionError,
ReqAbortException,
RespNotOkException,
RespParseException,
PayLoadTooLargeError,
RequestAbortException,
PermissionDeniedException,
)
from .model_client.base_client import BaseClient, APIResponse, client_registry
from .utils import compress_messages, llm_usage_recorder
from .exceptions import NetworkConnectionError, ReqAbortException, RespNotOkException, RespParseException
install(extra_lines=3)
@@ -57,45 +43,15 @@ class RequestType(Enum):
class LLMRequest:
"""LLM请求类"""
# 定义需要转换的模型列表,作为类变量避免重复
MODELS_NEEDING_TRANSFORMATION = [
"o1",
"o1-2024-12-17",
"o1-mini",
"o1-mini-2024-09-12",
"o1-preview",
"o1-preview-2024-09-12",
"o1-pro",
"o1-pro-2025-03-19",
"o3",
"o3-2025-04-16",
"o3-mini",
"o3-mini-2025-01-31",
"o4-mini",
"o4-mini-2025-04-16",
]
def __init__(self, task_name: str, request_type: str = "") -> None:
self.task_name = task_name
self.model_for_task = model_config.model_task_config.get_task(task_name)
def __init__(self, model_set: TaskConfig, request_type: str = "") -> None:
self.task_name = request_type
self.model_for_task = model_set
self.request_type = request_type
self.model_usage: Dict[str, Tuple[int, int]] = {model: (0, 0) for model in self.model_for_task.model_list}
"""模型使用量记录,用于进行负载均衡,对应为(total_tokens, penalty),惩罚值是为了能在某个模型请求不给力的时候进行调整"""
self.pri_in = 0
self.pri_out = 0
self._init_database()
@staticmethod
def _init_database():
"""初始化数据库集合"""
try:
# 使用 Peewee 创建表safe=True 表示如果表已存在则不会抛出错误
db.create_tables([LLMUsage], safe=True)
# logger.debug("LLMUsage 表已初始化/确保存在。")
except Exception as e:
logger.error(f"创建 LLMUsage 表失败: {str(e)}")
async def generate_response_for_image(
self,
@@ -104,7 +60,7 @@ class LLMRequest:
image_format: str,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
) -> Tuple[str, str, Optional[List[Dict[str, Any]]]]:
) -> Tuple[str, Tuple[str, str, Optional[List[Dict[str, Any]]]]]:
"""
为图像生成响应
Args:
@@ -112,7 +68,7 @@ class LLMRequest:
image_base64 (str): 图像的Base64编码字符串
image_format (str): 图像格式(如 'png', 'jpeg' 等)
Returns:
(Tuple[str, str, str, Optional[List[Dict[str, Any]]]]): 响应内容、推理内容、模型名称、工具调用列表
"""
# 请求体构建
message_builder = MessageBuilder()
@@ -141,25 +97,25 @@ class LLMRequest:
content, extracted_reasoning = self._extract_reasoning(content)
reasoning_content = extracted_reasoning
if usage := response.usage:
self.pri_in = model_info.price_in
self.pri_out = model_info.price_out
self._record_usage(
model_name=model_info.name,
prompt_tokens=usage.prompt_tokens or 0,
completion_tokens=usage.completion_tokens,
total_tokens=usage.total_tokens or 0,
llm_usage_recorder.record_usage_to_database(
model_info=model_info,
model_usage=usage,
user_id="system",
request_type=self.request_type,
endpoint="/chat/completions",
)
return content, reasoning_content, self._convert_tool_calls(tool_calls) if tool_calls else None
return content, (
reasoning_content,
model_info.name,
self._convert_tool_calls(tool_calls) if tool_calls else None,
)
async def generate_response_for_voice(self):
pass
async def generate_response_async(
self, prompt: str, temperature: Optional[float] = None, max_tokens: Optional[int] = None
) -> Tuple[str, str, Optional[List[Dict[str, Any]]]]:
) -> Tuple[str, Tuple[str, str, Optional[List[Dict[str, Any]]]]]:
"""
异步生成响应
Args:
@@ -167,7 +123,7 @@ class LLMRequest:
temperature (float, optional): 温度参数
max_tokens (int, optional): 最大token数
Returns:
Tuple[str, str, Optional[List[Dict[str, Any]]]]: 响应内容、推理内容工具调用列表
(Tuple[str, str, str, Optional[List[Dict[str, Any]]]]): 响应内容、推理内容、模型名称、工具调用列表
"""
# 请求体构建
message_builder = MessageBuilder()
@@ -195,13 +151,9 @@ class LLMRequest:
content, extracted_reasoning = self._extract_reasoning(content)
reasoning_content = extracted_reasoning
if usage := response.usage:
self.pri_in = model_info.price_in
self.pri_out = model_info.price_out
self._record_usage(
model_name=model_info.name,
prompt_tokens=usage.prompt_tokens or 0,
completion_tokens=usage.completion_tokens,
total_tokens=usage.total_tokens or 0,
llm_usage_recorder.record_usage_to_database(
model_info=model_info,
model_usage=usage,
user_id="system",
request_type=self.request_type,
endpoint="/chat/completions",
@@ -209,10 +161,19 @@ class LLMRequest:
if not content:
raise RuntimeError("获取LLM生成内容失败")
return content, reasoning_content, self._convert_tool_calls(tool_calls) if tool_calls else None
return content, (
reasoning_content,
model_info.name,
self._convert_tool_calls(tool_calls) if tool_calls else None,
)
async def get_embedding(self, embedding_input: str) -> List[float]:
"""获取嵌入向量"""
async def get_embedding(self, embedding_input: str) -> Tuple[List[float], str]:
"""获取嵌入向量
Args:
embedding_input (str): 获取嵌入的目标
Returns:
(Tuple[List[float], str]): (嵌入向量,使用的模型名称)
"""
# 无需构建消息体,直接使用输入文本
model_info, api_provider, client = self._select_model()
@@ -227,14 +188,10 @@ class LLMRequest:
embedding = response.embedding
if response.usage:
self.pri_in = model_info.price_in
self.pri_out = model_info.price_out
self._record_usage(
model_name=model_info.name,
prompt_tokens=response.usage.prompt_tokens or 0,
completion_tokens=response.usage.completion_tokens,
total_tokens=response.usage.total_tokens or 0,
if usage := response.usage:
llm_usage_recorder.record_usage_to_database(
model_info=model_info,
model_usage=usage,
user_id="system",
request_type=self.request_type,
endpoint="/embeddings",
@@ -243,7 +200,7 @@ class LLMRequest:
if not embedding:
raise RuntimeError("获取embedding失败")
return embedding
return embedding, model_info.name
def _select_model(self) -> Tuple[ModelInfo, APIProvider, BaseClient]:
"""
@@ -305,12 +262,13 @@ class LLMRequest:
# 处理异常
total_tokens, penalty = self.model_usage[model_info.name]
self.model_usage[model_info.name] = (total_tokens, penalty + 1)
wait_interval, compressed_messages = self._default_exception_handler(
e,
self.task_name,
model_name=model_info.name,
remain_try=retry_remain,
messages=(message_list, compressed_messages is not None),
messages=(message_list, compressed_messages is not None) if message_list else None,
)
if wait_interval == -1:
@@ -321,9 +279,7 @@ class LLMRequest:
finally:
# 放在finally防止死循环
retry_remain -= 1
logger.error(
f"任务 '{self.task_name}' 模型 '{model_info.name}' 请求失败,达到最大重试次数 {api_provider.max_retry}"
)
logger.error(f"模型 '{model_info.name}' 请求失败,达到最大重试次数 {api_provider.max_retry}")
raise RuntimeError("请求失败,已达到最大重试次数")
def _default_exception_handler(
@@ -481,65 +437,3 @@ class LLMRequest:
content = re.sub(r"(?:<think>)?.*?</think>", "", content, flags=re.DOTALL, count=1).strip()
reasoning = match[1].strip() if match else ""
return content, reasoning
def _record_usage(
self,
model_name: str,
prompt_tokens: int,
completion_tokens: int,
total_tokens: int,
user_id: str = "system",
request_type: str | None = None,
endpoint: str = "/chat/completions",
):
"""记录模型使用情况到数据库
Args:
prompt_tokens: 输入token数
completion_tokens: 输出token数
total_tokens: 总token数
user_id: 用户ID默认为system
request_type: 请求类型
endpoint: API端点
"""
# 如果 request_type 为 None则使用实例变量中的值
if request_type is None:
request_type = self.request_type
try:
# 使用 Peewee 模型创建记录
LLMUsage.create(
model_name=model_name,
user_id=user_id,
request_type=request_type,
endpoint=endpoint,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
cost=self._calculate_cost(prompt_tokens, completion_tokens),
status="success",
timestamp=datetime.now(), # Peewee 会处理 DateTimeField
)
logger.debug(
f"Token使用情况 - 模型: {model_name}, "
f"用户: {user_id}, 类型: {request_type}, "
f"提示词: {prompt_tokens}, 完成: {completion_tokens}, "
f"总计: {total_tokens}"
)
except Exception as e:
logger.error(f"记录token使用情况失败: {str(e)}")
def _calculate_cost(self, prompt_tokens: int, completion_tokens: int) -> float:
"""计算API调用成本
使用模型的pri_in和pri_out价格计算输入和输出的成本
Args:
prompt_tokens: 输入token数量
completion_tokens: 输出token数量
Returns:
float: 总成本(元)
"""
# 使用模型的pri_in和pri_out计算成本
input_cost = (prompt_tokens / 1000000) * self.pri_in
output_cost = (completion_tokens / 1000000) * self.pri_out
return round(input_cost + output_cost, 6)

View File

@@ -2,13 +2,15 @@ from src.chat.message_receive.chat_stream import get_chat_manager
import time
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config
from src.config.config import model_config
from src.chat.message_receive.message import MessageRecvS4U
from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor
from src.mais4u.mais4u_chat.internal_manager import internal_manager
from src.common.logger import get_logger
logger = get_logger(__name__)
def init_prompt():
Prompt(
"""
@@ -32,10 +34,8 @@ def init_prompt():
)
class MaiThinking:
def __init__(self,chat_id):
def __init__(self, chat_id):
self.chat_id = chat_id
self.chat_stream = get_chat_manager().get_stream(chat_id)
self.platform = self.chat_stream.platform
@@ -44,11 +44,11 @@ class MaiThinking:
self.is_group = True
else:
self.is_group = False
self.s4u_message_processor = S4UMessageProcessor()
self.mind = ""
self.memory_block = ""
self.relation_info_block = ""
self.time_block = ""
@@ -59,17 +59,13 @@ class MaiThinking:
self.identity = ""
self.sender = ""
self.target = ""
self.thinking_model = LLMRequest(
model=global_config.model.replyer_1,
request_type="thinking",
)
self.thinking_model = LLMRequest(model_set=model_config.model_task_config.replyer_1, request_type="thinking")
async def do_think_before_response(self):
pass
async def do_think_after_response(self,reponse:str):
async def do_think_after_response(self, reponse: str):
prompt = await global_prompt_manager.format_prompt(
"after_response_think_prompt",
mind=self.mind,
@@ -85,47 +81,44 @@ class MaiThinking:
sender=self.sender,
target=self.target,
)
result, _ = await self.thinking_model.generate_response_async(prompt)
self.mind = result
logger.info(f"[{self.chat_id}] 思考前想法:{self.mind}")
# logger.info(f"[{self.chat_id}] 思考前prompt{prompt}")
logger.info(f"[{self.chat_id}] 思考后想法:{self.mind}")
msg_recv = await self.build_internal_message_recv(self.mind)
await self.s4u_message_processor.process_message(msg_recv)
internal_manager.set_internal_state(self.mind)
async def do_think_when_receive_message(self):
pass
async def build_internal_message_recv(self,message_text:str):
async def build_internal_message_recv(self, message_text: str):
msg_id = f"internal_{time.time()}"
message_dict = {
"message_info": {
"message_id": msg_id,
"time": time.time(),
"user_info": {
"user_id": "internal", # 内部用户ID
"user_nickname": "内心", # 内部昵称
"platform": self.platform, # 平台标记为 internal
"user_id": "internal", # 内部用户ID
"user_nickname": "内心", # 内部昵称
"platform": self.platform, # 平台标记为 internal
# 其他 user_info 字段按需补充
},
"platform": self.platform, # 平台
"platform": self.platform, # 平台
# 其他 message_info 字段按需补充
},
"message_segment": {
"type": "text", # 消息类型
"data": message_text, # 消息内容
"type": "text", # 消息类型
"data": message_text, # 消息内容
# 其他 segment 字段按需补充
},
"raw_message": message_text, # 原始消息内容
"processed_plain_text": message_text, # 处理后的纯文本
"raw_message": message_text, # 原始消息内容
"processed_plain_text": message_text, # 处理后的纯文本
# 下面这些字段可选,根据 MessageRecv 需要
"is_emoji": False,
"has_emoji": False,
@@ -139,45 +132,36 @@ class MaiThinking:
"priority_info": {"message_priority": 10.0}, # 内部消息可设高优先级
"interest_value": 1.0,
}
if self.is_group:
message_dict["message_info"]["group_info"] = {
"platform": self.platform,
"group_id": self.chat_stream.group_info.group_id,
"group_name": self.chat_stream.group_info.group_name,
}
msg_recv = MessageRecvS4U(message_dict)
msg_recv.chat_info = self.chat_info
msg_recv.chat_stream = self.chat_stream
msg_recv.is_internal = True
return msg_recv
class MaiThinkingManager:
def __init__(self):
self.mai_think_list = []
def get_mai_think(self,chat_id):
def get_mai_think(self, chat_id):
for mai_think in self.mai_think_list:
if mai_think.chat_id == chat_id:
return mai_think
mai_think = MaiThinking(chat_id)
self.mai_think_list.append(mai_think)
return mai_think
mai_thinking_manager = MaiThinkingManager()
init_prompt()

View File

@@ -1,14 +1,16 @@
import json
import time
from json_repair import repair_json
from src.chat.message_receive.message import MessageRecv
from src.llm_models.utils_model import LLMRequest
from src.common.logger import get_logger
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive
from src.config.config import global_config
from src.config.config import global_config, model_config
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.manager.async_task_manager import AsyncTask, async_task_manager
from src.plugin_system.apis import send_api
from json_repair import repair_json
from src.mais4u.s4u_config import s4u_config
logger = get_logger("action")
@@ -32,7 +34,7 @@ BODY_CODE = {
"帅气的姿势": "010_0190",
"另一个帅气的姿势": "010_0191",
"手掌朝前可爱": "010_0210",
"平静,双手后放":"平静,双手后放",
"平静,双手后放": "平静,双手后放",
"思考": "思考",
"优雅,左手放在腰上": "优雅,左手放在腰上",
"一般": "一般",
@@ -94,19 +96,15 @@ class ChatAction:
self.body_action_cooldown: dict[str, int] = {}
print(s4u_config.models.motion)
print(global_config.model.emotion)
self.action_model = LLMRequest(
model=global_config.model.emotion,
temperature=0.7,
request_type="motion",
)
print(model_config.model_task_config.emotion)
self.last_change_time = 0
self.action_model = LLMRequest(model_set=model_config.model_task_config.emotion, request_type="motion")
self.last_change_time: float = 0
async def send_action_update(self):
"""发送动作更新到前端"""
body_code = BODY_CODE.get(self.body_action, "")
await send_api.custom_to_stream(
message_type="body_action",
@@ -115,13 +113,11 @@ class ChatAction:
storage_message=False,
show_log=True,
)
async def update_action_by_message(self, message: MessageRecv):
self.regression_count = 0
message_time = message.message_info.time
message_time: float = message.message_info.time # type: ignore
message_list_before_now = get_raw_msg_by_timestamp_with_chat_inclusive(
chat_id=self.chat_id,
timestamp_start=self.last_change_time,
@@ -147,13 +143,13 @@ class ChatAction:
prompt_personality = global_config.personality.personality_core
indentify_block = f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}"
try:
# 冷却池处理:过滤掉冷却中的动作
self._update_body_action_cooldown()
available_actions = [k for k in BODY_CODE.keys() if k not in self.body_action_cooldown]
all_actions = "\n".join(available_actions)
prompt = await global_prompt_manager.format_prompt(
"change_action_prompt",
chat_talking_prompt=chat_talking_prompt,
@@ -163,19 +159,18 @@ class ChatAction:
)
logger.info(f"prompt: {prompt}")
response, (reasoning_content, model_name) = await self.action_model.generate_response_async(prompt=prompt)
response, (reasoning_content, _, _) = await self.action_model.generate_response_async(
prompt=prompt, temperature=0.7
)
logger.info(f"response: {response}")
logger.info(f"reasoning_content: {reasoning_content}")
action_data = json.loads(repair_json(response))
if action_data:
if action_data := json.loads(repair_json(response)):
# 记录原动作,切换后进入冷却
prev_body_action = self.body_action
new_body_action = action_data.get("body_action", self.body_action)
if new_body_action != prev_body_action:
if prev_body_action:
self.body_action_cooldown[prev_body_action] = 3
if new_body_action != prev_body_action and prev_body_action:
self.body_action_cooldown[prev_body_action] = 3
self.body_action = new_body_action
self.head_action = action_data.get("head_action", self.head_action)
# 发送动作更新
@@ -213,7 +208,6 @@ class ChatAction:
prompt_personality = global_config.personality.personality_core
indentify_block = f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}"
try:
# 冷却池处理:过滤掉冷却中的动作
self._update_body_action_cooldown()
available_actions = [k for k in BODY_CODE.keys() if k not in self.body_action_cooldown]
@@ -228,17 +222,17 @@ class ChatAction:
)
logger.info(f"prompt: {prompt}")
response, (reasoning_content, model_name) = await self.action_model.generate_response_async(prompt=prompt)
response, (reasoning_content, _, _) = await self.action_model.generate_response_async(
prompt=prompt, temperature=0.7
)
logger.info(f"response: {response}")
logger.info(f"reasoning_content: {reasoning_content}")
action_data = json.loads(repair_json(response))
if action_data:
if action_data := json.loads(repair_json(response)):
prev_body_action = self.body_action
new_body_action = action_data.get("body_action", self.body_action)
if new_body_action != prev_body_action:
if prev_body_action:
self.body_action_cooldown[prev_body_action] = 6
if new_body_action != prev_body_action and prev_body_action:
self.body_action_cooldown[prev_body_action] = 6
self.body_action = new_body_action
# 发送动作更新
await self.send_action_update()
@@ -306,9 +300,6 @@ class ActionManager:
return new_action_state
init_prompt()
action_manager = ActionManager()

View File

@@ -137,7 +137,7 @@ class MessageSenderContainer:
await self.storage.store_message(bot_message, self.chat_stream)
except Exception as e:
logger.error(f"[{self.chat_stream.get_stream_name()}] 消息发送或存储时出现错误: {e}", exc_info=True)
logger.error(f"[消息流: {self.chat_stream.stream_id}] 消息发送或存储时出现错误: {e}", exc_info=True)
finally:
# CRUCIAL: Always call task_done() for any item that was successfully retrieved.

View File

@@ -6,7 +6,7 @@ from src.chat.message_receive.message import MessageRecv
from src.llm_models.utils_model import LLMRequest
from src.common.logger import get_logger
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive
from src.config.config import global_config
from src.config.config import global_config, model_config
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.manager.async_task_manager import AsyncTask, async_task_manager
from src.plugin_system.apis import send_api
@@ -114,18 +114,12 @@ class ChatMood:
self.regression_count: int = 0
self.mood_model = LLMRequest(
model=global_config.model.emotion,
temperature=0.7,
request_type="mood_text",
)
self.mood_model = LLMRequest(model_set=model_config.model_task_config.emotion, request_type="mood_text")
self.mood_model_numerical = LLMRequest(
model=global_config.model.emotion,
temperature=0.4,
request_type="mood_numerical",
model_set=model_config.model_task_config.emotion, request_type="mood_numerical"
)
self.last_change_time = 0
self.last_change_time: float = 0
# 发送初始情绪状态到ws端
asyncio.create_task(self.send_emotion_update(self.mood_values))
@@ -164,7 +158,7 @@ class ChatMood:
async def update_mood_by_message(self, message: MessageRecv):
self.regression_count = 0
message_time = message.message_info.time
message_time: float = message.message_info.time # type: ignore
message_list_before_now = get_raw_msg_by_timestamp_with_chat_inclusive(
chat_id=self.chat_id,
timestamp_start=self.last_change_time,
@@ -199,7 +193,9 @@ class ChatMood:
mood_state=self.mood_state,
)
logger.debug(f"text mood prompt: {prompt}")
response, (reasoning_content, model_name) = await self.mood_model.generate_response_async(prompt=prompt)
response, (reasoning_content, _, _) = await self.mood_model.generate_response_async(
prompt=prompt, temperature=0.7
)
logger.info(f"text mood response: {response}")
logger.debug(f"text mood reasoning_content: {reasoning_content}")
return response
@@ -216,8 +212,8 @@ class ChatMood:
fear=self.mood_values["fear"],
)
logger.debug(f"numerical mood prompt: {prompt}")
response, (reasoning_content, model_name) = await self.mood_model_numerical.generate_response_async(
prompt=prompt
response, (reasoning_content, _, _) = await self.mood_model_numerical.generate_response_async(
prompt=prompt, temperature=0.4
)
logger.info(f"numerical mood response: {response}")
logger.debug(f"numerical mood reasoning_content: {reasoning_content}")
@@ -276,7 +272,9 @@ class ChatMood:
mood_state=self.mood_state,
)
logger.debug(f"text regress prompt: {prompt}")
response, (reasoning_content, model_name) = await self.mood_model.generate_response_async(prompt=prompt)
response, (reasoning_content, _, _) = await self.mood_model.generate_response_async(
prompt=prompt, temperature=0.7
)
logger.info(f"text regress response: {response}")
logger.debug(f"text regress reasoning_content: {reasoning_content}")
return response
@@ -293,8 +291,9 @@ class ChatMood:
fear=self.mood_values["fear"],
)
logger.debug(f"numerical regress prompt: {prompt}")
response, (reasoning_content, model_name) = await self.mood_model_numerical.generate_response_async(
prompt=prompt
response, (reasoning_content, _, _) = await self.mood_model_numerical.generate_response_async(
prompt=prompt,
temperature=0.4,
)
logger.info(f"numerical regress response: {response}")
logger.debug(f"numerical regress reasoning_content: {reasoning_content}")
@@ -447,6 +446,7 @@ class MoodManager:
# 发送初始情绪状态到ws端
asyncio.create_task(new_mood.send_emotion_update(new_mood.mood_values))
if ENABLE_S4U:
init_prompt()
mood_manager = MoodManager()

View File

@@ -150,19 +150,18 @@ class PromptBuilder:
relation_prompt = ""
if global_config.relationship.enable_relationship and who_chat_in_group:
relationship_fetcher = relationship_fetcher_manager.get_fetcher(chat_stream.stream_id)
# 将 (platform, user_id, nickname) 转换为 person_id
person_ids = []
for person in who_chat_in_group:
person_id = PersonInfoManager.get_person_id(person[0], person[1])
person_ids.append(person_id)
# 使用 RelationshipFetcher 的 build_relation_info 方法,设置 points_num=3 保持与原来相同的行为
relation_info_list = await asyncio.gather(
*[relationship_fetcher.build_relation_info(person_id, points_num=3) for person_id in person_ids]
)
relation_info = "".join(relation_info_list)
if relation_info:
if relation_info := "".join(relation_info_list):
relation_prompt = await global_prompt_manager.format_prompt(
"relation_prompt", relation_info=relation_info
)
@@ -186,9 +185,9 @@ class PromptBuilder:
timestamp=time.time(),
limit=300,
)
talk_type = message.message_info.platform + ":" + str(message.chat_stream.user_info.user_id)
talk_type = f"{message.message_info.platform}:{str(message.chat_stream.user_info.user_id)}"
core_dialogue_list = []
background_dialogue_list = []
@@ -258,19 +257,19 @@ class PromptBuilder:
all_msg_seg_list.append(msg_seg_str)
for msg in all_msg_seg_list:
core_msg_str += msg
all_dialogue_prompt = get_raw_msg_before_timestamp_with_chat(
chat_id=chat_stream.stream_id,
timestamp=time.time(),
limit=20,
)
)
all_dialogue_prompt_str = build_readable_messages(
all_dialogue_prompt,
timestamp_mode="normal_no_YMD",
show_pic=False,
)
return core_msg_str, background_dialogue_prompt,all_dialogue_prompt_str

View File

@@ -1,7 +1,7 @@
import os
from typing import AsyncGenerator
from src.mais4u.openai_client import AsyncOpenAIClient
from src.config.config import global_config
from src.config.config import global_config, model_config
from src.chat.message_receive.message import MessageRecvS4U
from src.mais4u.mais4u_chat.s4u_prompt import prompt_builder
from src.common.logger import get_logger
@@ -14,24 +14,27 @@ logger = get_logger("s4u_stream_generator")
class S4UStreamGenerator:
def __init__(self):
replyer_1_config = global_config.model.replyer_1
provider = replyer_1_config.get("provider")
if not provider:
logger.error("`replyer_1` 在配置文件中缺少 `provider` 字段")
raise ValueError("`replyer_1` 在配置文件中缺少 `provider` 字段")
replyer_1_config = model_config.model_task_config.replyer_1
model_to_use = replyer_1_config.model_list[0]
model_info = model_config.get_model_info(model_to_use)
if not model_info:
logger.error(f"模型 {model_to_use} 在配置中未找到")
raise ValueError(f"模型 {model_to_use} 在配置中未找到")
provider_name = model_info.api_provider
provider_info = model_config.get_provider(provider_name)
if not provider_info:
logger.error("`replyer_1` 找不到对应的Provider")
raise ValueError("`replyer_1` 找不到对应的Provider")
api_key = os.environ.get(f"{provider.upper()}_KEY")
base_url = os.environ.get(f"{provider.upper()}_BASE_URL")
api_key = provider_info.api_key
base_url = provider_info.base_url
if not api_key:
logger.error(f"环境变量 {provider.upper()}_KEY 未设置")
raise ValueError(f"环境变量 {provider.upper()}_KEY 未设置")
logger.error(f"{provider_name}没有配置API KEY")
raise ValueError(f"{provider_name}没有配置API KEY")
self.client_1 = AsyncOpenAIClient(api_key=api_key, base_url=base_url)
self.model_1_name = replyer_1_config.get("name")
if not self.model_1_name:
logger.error("`replyer_1` 在配置文件中缺少 `model_name` 字段")
raise ValueError("`replyer_1` 在配置文件中缺少 `model_name` 字段")
self.model_1_name = model_to_use
self.replyer_1_config = replyer_1_config
self.current_model_name = "unknown model"
@@ -44,10 +47,10 @@ class S4UStreamGenerator:
r'[^.。!?\n\r]+(?:[.。!?\n\r](?![\'"])|$))', # 匹配直到句子结束符
re.UNICODE | re.DOTALL,
)
self.chat_stream =None
async def build_last_internal_message(self,message:MessageRecvS4U,previous_reply_context:str = ""):
self.chat_stream = None
async def build_last_internal_message(self, message: MessageRecvS4U, previous_reply_context: str = ""):
# person_id = PersonInfoManager.get_person_id(
# message.chat_stream.user_info.platform, message.chat_stream.user_info.user_id
# )
@@ -71,14 +74,10 @@ class S4UStreamGenerator:
[这是用户发来的新消息, 你需要结合上下文,对此进行回复]:
{message.processed_plain_text}
"""
return True,message_txt
return True, message_txt
else:
message_txt = message.processed_plain_text
return False,message_txt
return False, message_txt
async def generate_response(
self, message: MessageRecvS4U, previous_reply_context: str = ""
@@ -88,7 +87,7 @@ class S4UStreamGenerator:
self.partial_response = ""
message_txt = message.processed_plain_text
if not message.is_internal:
interupted,message_txt_added = await self.build_last_internal_message(message,previous_reply_context)
interupted, message_txt_added = await self.build_last_internal_message(message, previous_reply_context)
if interupted:
message_txt = message_txt_added
@@ -105,7 +104,6 @@ class S4UStreamGenerator:
current_client = self.client_1
self.current_model_name = self.model_1_name
extra_kwargs = {}
if self.replyer_1_config.get("enable_thinking") is not None:
extra_kwargs["enable_thinking"] = self.replyer_1_config.get("enable_thinking")

View File

@@ -214,51 +214,49 @@ class SuperChatManager:
def build_superchat_display_string(self, chat_id: str, max_count: int = 10) -> str:
"""构建SuperChat显示字符串"""
superchats = self.get_superchats_by_chat(chat_id)
if not superchats:
return ""
# 限制显示数量
display_superchats = superchats[:max_count]
lines = []
lines.append("📢 当前有效超级弹幕:")
lines = ["📢 当前有效超级弹幕:"]
for i, sc in enumerate(display_superchats, 1):
remaining_minutes = int(sc.remaining_time() / 60)
remaining_seconds = int(sc.remaining_time() % 60)
time_display = f"{remaining_minutes}{remaining_seconds}" if remaining_minutes > 0 else f"{remaining_seconds}"
line = f"{i}. 【{sc.price}元】{sc.user_nickname}: {sc.message_text}"
if len(line) > 100: # 限制单行长度
line = line[:97] + "..."
line = f"{line[:97]}..."
line += f" (剩余{time_display})"
lines.append(line)
if len(superchats) > max_count:
lines.append(f"... 还有{len(superchats) - max_count}条SuperChat")
return "\n".join(lines)
def build_superchat_summary_string(self, chat_id: str) -> str:
"""构建SuperChat摘要字符串"""
superchats = self.get_superchats_by_chat(chat_id)
if not superchats:
return "当前没有有效的超级弹幕"
lines = []
for sc in superchats:
single_sc_str = f"{sc.user_nickname} - {sc.price}元 - {sc.message_text}"
if len(single_sc_str) > 100:
single_sc_str = single_sc_str[:97] + "..."
single_sc_str = f"{single_sc_str[:97]}..."
single_sc_str += f" (剩余{int(sc.remaining_time())}秒)"
lines.append(single_sc_str)
total_amount = sum(sc.price for sc in superchats)
count = len(superchats)
highest_amount = max(sc.price for sc in superchats)
final_str = f"当前有{count}条超级弹幕,总金额{total_amount}元,最高单笔{highest_amount}"
if lines:
final_str += "\n" + "\n".join(lines)
@@ -287,7 +285,7 @@ class SuperChatManager:
"lowest_amount": min(amounts)
}
async def shutdown(self):
async def shutdown(self): # sourcery skip: use-contextlib-suppress
"""关闭管理器,清理资源"""
if self._cleanup_task and not self._cleanup_task.done():
self._cleanup_task.cancel()
@@ -300,6 +298,7 @@ class SuperChatManager:
# sourcery skip: assign-if-exp
if ENABLE_S4U:
super_chat_manager = SuperChatManager()
else:

View File

@@ -1,19 +1,14 @@
from src.llm_models.utils_model import LLMRequest
from src.common.logger import get_logger
from src.config.config import global_config
from src.config.config import model_config
from src.plugin_system.apis import send_api
logger = get_logger(__name__)
head_actions_list = [
"不做额外动作",
"点头一次",
"点头两次",
"摇头",
"歪脑袋",
"低头望向一边"
]
head_actions_list = ["不做额外动作", "点头一次", "点头两次", "摇头", "歪脑袋", "低头望向一边"]
async def yes_or_no_head(text: str,emotion: str = "",chat_history: str = "",chat_id: str = ""):
async def yes_or_no_head(text: str, emotion: str = "", chat_history: str = "", chat_id: str = ""):
prompt = f"""
{chat_history}
以上是对方的发言:
@@ -30,22 +25,14 @@ async def yes_or_no_head(text: str,emotion: str = "",chat_history: str = "",chat
低头望向一边
请从上面的动作中选择一个,并输出,请只输出你选择的动作就好,不要输出其他内容。"""
model = LLMRequest(
model=global_config.model.emotion,
temperature=0.7,
request_type="motion",
)
model = LLMRequest(model_set=model_config.model_task_config.emotion, request_type="motion")
try:
# logger.info(f"prompt: {prompt}")
response, (reasoning_content, model_name) = await model.generate_response_async(prompt=prompt)
response, _ = await model.generate_response_async(prompt=prompt, temperature=0.7)
logger.info(f"response: {response}")
if response in head_actions_list:
head_action = response
else:
head_action = "不做额外动作"
head_action = response if response in head_actions_list else "不做额外动作"
await send_api.custom_to_stream(
message_type="head_action",
content=head_action,
@@ -53,11 +40,7 @@ async def yes_or_no_head(text: str,emotion: str = "",chat_history: str = "",chat
storage_message=False,
show_log=True,
)
except Exception as e:
logger.error(f"yes_or_no_head error: {e}")
return "不做额外动作"

View File

@@ -3,13 +3,14 @@ import random
import time
from src.common.logger import get_logger
from src.config.config import global_config
from src.config.config import global_config, model_config
from src.chat.message_receive.message import MessageRecv
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive
from src.llm_models.utils_model import LLMRequest
from src.manager.async_task_manager import AsyncTask, async_task_manager
from src.chat.message_receive.chat_stream import get_chat_manager
logger = get_logger("mood")
@@ -49,7 +50,7 @@ class ChatMood:
chat_manager = get_chat_manager()
self.chat_stream = chat_manager.get_stream(self.chat_id)
if not self.chat_stream:
raise ValueError(f"Chat stream for chat_id {chat_id} not found")
@@ -59,11 +60,7 @@ class ChatMood:
self.regression_count: int = 0
self.mood_model = LLMRequest(
model=global_config.model.emotion,
temperature=0.7,
request_type="mood",
)
self.mood_model = LLMRequest(model_set=model_config.model_task_config.emotion, request_type="mood")
self.last_change_time: float = 0
@@ -83,12 +80,16 @@ class ChatMood:
logger.debug(
f"base_probability: {base_probability}, time_multiplier: {time_multiplier}, interest_multiplier: {interest_multiplier}"
)
update_probability = global_config.mood.mood_update_threshold * min(1.0, base_probability * time_multiplier * interest_multiplier)
update_probability = global_config.mood.mood_update_threshold * min(
1.0, base_probability * time_multiplier * interest_multiplier
)
if random.random() > update_probability:
return
logger.debug(f"{self.log_prefix} 更新情绪状态,感兴趣度: {interested_rate:.2f}, 更新概率: {update_probability:.2f}")
logger.debug(
f"{self.log_prefix} 更新情绪状态,感兴趣度: {interested_rate:.2f}, 更新概率: {update_probability:.2f}"
)
message_time: float = message.message_info.time # type: ignore
message_list_before_now = get_raw_msg_by_timestamp_with_chat_inclusive(
@@ -124,7 +125,9 @@ class ChatMood:
mood_state=self.mood_state,
)
response, (reasoning_content, model_name) = await self.mood_model.generate_response_async(prompt=prompt)
response, (reasoning_content, _, _) = await self.mood_model.generate_response_async(
prompt=prompt, temperature=0.7
)
if global_config.debug.show_prompt:
logger.info(f"{self.log_prefix} prompt: {prompt}")
logger.info(f"{self.log_prefix} response: {response}")
@@ -171,7 +174,9 @@ class ChatMood:
mood_state=self.mood_state,
)
response, (reasoning_content, model_name) = await self.mood_model.generate_response_async(prompt=prompt)
response, (reasoning_content, _, _) = await self.mood_model.generate_response_async(
prompt=prompt, temperature=0.7
)
if global_config.debug.show_prompt:
logger.info(f"{self.log_prefix} prompt: {prompt}")

View File

@@ -11,7 +11,7 @@ from src.common.logger import get_logger
from src.common.database.database import db
from src.common.database.database_model import PersonInfo
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config
from src.config.config import global_config, model_config
"""
@@ -54,11 +54,7 @@ person_info_default = {
class PersonInfoManager:
def __init__(self):
self.person_name_list = {}
# TODO: API-Adapter修改标记
self.qv_name_llm = LLMRequest(
model=global_config.model.utils,
request_type="relation.qv_name",
)
self.qv_name_llm = LLMRequest(model_set=model_config.model_task_config.utils, request_type="relation.qv_name")
try:
db.connect(reuse_if_open=True)
# 设置连接池参数
@@ -199,7 +195,7 @@ class PersonInfoManager:
if existing:
logger.debug(f"用户 {p_data['person_id']} 已存在,跳过创建")
return True
# 尝试创建
PersonInfo.create(**p_data)
return True
@@ -376,7 +372,7 @@ class PersonInfoManager:
"nickname": "昵称",
"reason": "理由"
}"""
response, (reasoning_content, model_name) = await self.qv_name_llm.generate_response_async(qv_name_prompt)
response, _ = await self.qv_name_llm.generate_response_async(qv_name_prompt)
# logger.info(f"取名提示词:{qv_name_prompt}\n取名回复{response}")
result = self._extract_json_from_text(response)
@@ -592,7 +588,7 @@ class PersonInfoManager:
record = PersonInfo.get_or_none(PersonInfo.person_id == p_id)
if record:
return record, False # 记录存在,未创建
# 记录不存在,尝试创建
try:
PersonInfo.create(**init_data)
@@ -622,7 +618,7 @@ class PersonInfoManager:
"points": [],
"forgotten_points": [],
}
# 序列化JSON字段
for key in JSON_SERIALIZED_FIELDS:
if key in initial_data:
@@ -630,12 +626,12 @@ class PersonInfoManager:
initial_data[key] = json.dumps(initial_data[key], ensure_ascii=False)
elif initial_data[key] is None:
initial_data[key] = json.dumps([], ensure_ascii=False)
model_fields = PersonInfo._meta.fields.keys() # type: ignore
filtered_initial_data = {k: v for k, v in initial_data.items() if v is not None and k in model_fields}
record, was_created = await asyncio.to_thread(_db_get_or_create_sync, person_id, filtered_initial_data)
if was_created:
logger.info(f"用户 {platform}:{user_id} (person_id: {person_id}) 不存在,将创建新记录 (Peewee)。")
logger.info(f"已为 {person_id} 创建新记录,初始数据 (filtered for model): {filtered_initial_data}")

View File

@@ -7,7 +7,7 @@ from typing import List, Dict, Any
from json_repair import repair_json
from src.common.logger import get_logger
from src.config.config import global_config
from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.message_receive.chat_stream import get_chat_manager
@@ -73,14 +73,12 @@ class RelationshipFetcher:
# LLM模型配置
self.llm_model = LLMRequest(
model=global_config.model.utils_small,
request_type="relation.fetcher",
model_set=model_config.model_task_config.utils_small, request_type="relation.fetcher"
)
# 小模型用于即时信息提取
self.instant_llm_model = LLMRequest(
model=global_config.model.utils_small,
request_type="relation.fetch",
model_set=model_config.model_task_config.utils_small, request_type="relation.fetch"
)
name = get_chat_manager().get_stream_name(self.chat_id)
@@ -96,7 +94,7 @@ class RelationshipFetcher:
if not self.info_fetched_cache[person_id]:
del self.info_fetched_cache[person_id]
async def build_relation_info(self, person_id, points_num = 3):
async def build_relation_info(self, person_id, points_num=3):
# 清理过期的信息缓存
self._cleanup_expired_cache()
@@ -361,7 +359,6 @@ class RelationshipFetcher:
logger.error(f"{self.log_prefix} 执行信息提取时出错: {e}")
logger.error(traceback.format_exc())
async def _save_info_to_cache(self, person_id: str, info_type: str, info_content: str):
# sourcery skip: use-next
"""将提取到的信息保存到 person_info 的 info_list 字段中

View File

@@ -3,7 +3,7 @@ from .person_info import PersonInfoManager, get_person_info_manager
import time
import random
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config
from src.config.config import global_config, model_config
from src.chat.utils.chat_message_builder import build_readable_messages
import json
from json_repair import repair_json
@@ -20,9 +20,8 @@ logger = get_logger("relation")
class RelationshipManager:
def __init__(self):
self.relationship_llm = LLMRequest(
model=global_config.model.utils,
request_type="relationship", # 用于动作规划
)
model_set=model_config.model_task_config.utils, request_type="relationship"
) # 用于动作规划
@staticmethod
async def is_known_some_one(platform, user_id):
@@ -181,18 +180,14 @@ class RelationshipManager:
try:
points = repair_json(points)
points_data = json.loads(points)
# 只处理正确的格式,错误格式直接跳过
if points_data == "none" or not points_data:
points_list = []
elif isinstance(points_data, str) and points_data.lower() == "none":
points_list = []
elif isinstance(points_data, list):
# 正确格式:数组格式 [{"point": "...", "weight": 10}, ...]
if not points_data: # 空数组
points_list = []
else:
points_list = [(item["point"], float(item["weight"]), current_time) for item in points_data]
points_list = [(item["point"], float(item["weight"]), current_time) for item in points_data]
else:
# 错误格式,直接跳过不解析
logger.warning(f"LLM返回了错误的JSON格式跳过解析: {type(points_data)}, 内容: {points_data}")

View File

@@ -12,6 +12,7 @@ import traceback
from typing import Tuple, Any, Dict, List, Optional
from rich.traceback import install
from src.common.logger import get_logger
from src.config.api_ada_configs import TaskConfig
from src.chat.replyer.default_generator import DefaultReplyer
from src.chat.message_receive.chat_stream import ChatStream
from src.chat.utils.utils import process_llm_response
@@ -31,7 +32,7 @@ logger = get_logger("generator_api")
def get_replyer(
chat_stream: Optional[ChatStream] = None,
chat_id: Optional[str] = None,
model_configs: Optional[List[Dict[str, Any]]] = None,
model_set_with_weight: Optional[List[Tuple[TaskConfig, float]]] = None,
request_type: str = "replyer",
) -> Optional[DefaultReplyer]:
"""获取回复器对象
@@ -58,7 +59,7 @@ def get_replyer(
return replyer_manager.get_replyer(
chat_stream=chat_stream,
chat_id=chat_id,
model_configs=model_configs,
model_set_with_weight=model_set_with_weight,
request_type=request_type,
)
except Exception as e:
@@ -83,7 +84,7 @@ async def generate_reply(
enable_splitter: bool = True,
enable_chinese_typo: bool = True,
return_prompt: bool = False,
model_configs: Optional[List[Dict[str, Any]]] = None,
model_set_with_weight: Optional[List[Tuple[TaskConfig, float]]] = None,
request_type: str = "generator_api",
) -> Tuple[bool, List[Tuple[str, Any]], Optional[str]]:
"""生成回复
@@ -106,7 +107,7 @@ async def generate_reply(
"""
try:
# 获取回复器
replyer = get_replyer(chat_stream, chat_id, model_configs=model_configs, request_type=request_type)
replyer = get_replyer(chat_stream, chat_id, model_set_with_weight=model_set_with_weight, request_type=request_type)
if not replyer:
logger.error("[GeneratorAPI] 无法获取回复器")
return False, [], None
@@ -154,7 +155,7 @@ async def rewrite_reply(
chat_id: Optional[str] = None,
enable_splitter: bool = True,
enable_chinese_typo: bool = True,
model_configs: Optional[List[Dict[str, Any]]] = None,
model_set_with_weight: Optional[List[Tuple[TaskConfig, float]]] = None,
raw_reply: str = "",
reason: str = "",
reply_to: str = "",
@@ -179,7 +180,7 @@ async def rewrite_reply(
"""
try:
# 获取回复器
replyer = get_replyer(chat_stream, chat_id, model_configs=model_configs)
replyer = get_replyer(chat_stream, chat_id, model_set_with_weight=model_set_with_weight)
if not replyer:
logger.error("[GeneratorAPI] 无法获取回复器")
return False, [], None
@@ -245,17 +246,17 @@ async def process_human_text(content: str, enable_splitter: bool, enable_chinese
async def generate_response_custom(
chat_stream: Optional[ChatStream] = None,
chat_id: Optional[str] = None,
model_configs: Optional[List[Dict[str, Any]]] = None,
model_set_with_weight: Optional[List[Tuple[TaskConfig, float]]] = None,
prompt: str = "",
) -> Optional[str]:
replyer = get_replyer(chat_stream, chat_id, model_configs=model_configs)
replyer = get_replyer(chat_stream, chat_id, model_set_with_weight=model_set_with_weight)
if not replyer:
logger.error("[GeneratorAPI] 无法获取回复器")
return None
try:
logger.debug("[GeneratorAPI] 开始生成自定义回复")
response = await replyer.llm_generate_content(prompt)
response, _, _, _ = await replyer.llm_generate_content(prompt)
if response:
logger.debug("[GeneratorAPI] 自定义回复生成成功")
return response

View File

@@ -7,10 +7,11 @@
success, response, reasoning, model_name = await llm_api.generate_with_model(prompt, model_config)
"""
from typing import Tuple, Dict, Any
from typing import Tuple, Dict
from src.common.logger import get_logger
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config
from src.config.config import global_config, model_config
from src.config.api_ada_configs import TaskConfig
logger = get_logger("llm_api")
@@ -19,9 +20,7 @@ logger = get_logger("llm_api")
# =============================================================================
def get_available_models() -> Dict[str, Any]:
def get_available_models() -> Dict[str, TaskConfig]:
"""获取所有可用的模型配置
Returns:
@@ -33,14 +32,14 @@ def get_available_models() -> Dict[str, Any]:
return {}
# 自动获取所有属性并转换为字典形式
rets = {}
models = global_config.model
models = model_config.model_task_config
attrs = dir(models)
rets: Dict[str, TaskConfig] = {}
for attr in attrs:
if not attr.startswith("__"):
try:
value = getattr(models, attr)
if not callable(value): # 排除方法
if not callable(value) and isinstance(value, TaskConfig):
rets[attr] = value
except Exception as e:
logger.debug(f"[LLMAPI] 获取属性 {attr} 失败: {e}")
@@ -53,8 +52,8 @@ def get_available_models() -> Dict[str, Any]:
async def generate_with_model(
prompt: str, model_config: Dict[str, Any], request_type: str = "plugin.generate", **kwargs
) -> Tuple[bool, str]:
prompt: str, model_config: TaskConfig, request_type: str = "plugin.generate", **kwargs
) -> Tuple[bool, str, str, str]:
"""使用指定模型生成内容
Args:
@@ -67,17 +66,16 @@ async def generate_with_model(
Tuple[bool, str, str, str]: (是否成功, 生成的内容, 推理过程, 模型名称)
"""
try:
model_name = model_config.get("name")
logger.info(f"[LLMAPI] 使用模型 {model_name} 生成内容")
model_name_list = model_config.model_list
logger.info(f"[LLMAPI] 使用模型集合 {model_name_list} 生成内容")
logger.debug(f"[LLMAPI] 完整提示词: {prompt}")
llm_request = LLMRequest(model=model_config, request_type=request_type, **kwargs)
llm_request = LLMRequest(model_set=model_config, request_type=request_type, **kwargs)
# TODO: 复活这个_
response, _ = await llm_request.generate_response_async(prompt)
return True, response
response, (reasoning_content, model_name, _) = await llm_request.generate_response_async(prompt)
return True, response, reasoning_content, model_name
except Exception as e:
error_msg = f"生成内容时出错: {str(e)}"
logger.error(f"[LLMAPI] {error_msg}")
return False, error_msg
return False, error_msg, "", ""

View File

@@ -335,7 +335,7 @@ async def command_to_stream(
async def custom_to_stream(
message_type: str,
content: str,
content: str | dict,
stream_id: str,
display_message: str = "",
typing: bool = False,

View File

@@ -4,7 +4,7 @@ from typing import List, Dict, Tuple, Optional, Any
from src.plugin_system.apis.tool_api import get_llm_available_tool_definitions, get_tool_instance
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config
from src.config.config import global_config, model_config
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.utils.json_utils import process_llm_tool_calls
from src.chat.message_receive.chat_stream import get_chat_manager
@@ -52,10 +52,7 @@ class ToolExecutor:
self.chat_stream = get_chat_manager().get_stream(self.chat_id)
self.log_prefix = f"[{get_chat_manager().get_stream_name(self.chat_id) or self.chat_id}]"
self.llm_model = LLMRequest(
model=global_config.model.tool_use,
request_type="tool_executor",
)
self.llm_model = LLMRequest(model_set=model_config.model_task_config.tool_use, request_type="tool_executor")
# 缓存配置
self.enable_cache = enable_cache
@@ -137,7 +134,7 @@ class ToolExecutor:
return tool_results, used_tools, prompt
else:
return tool_results, [], ""
def _get_tool_definitions(self) -> List[Dict[str, Any]]:
all_tools = get_llm_available_tool_definitions()
user_disabled_tools = global_announcement_manager.get_disabled_chat_tools(self.chat_id)

View File

@@ -58,6 +58,7 @@ class EmojiAction(BaseAction):
associated_types = ["emoji"]
async def execute(self) -> Tuple[bool, str]:
# sourcery skip: assign-if-exp, introduce-default-else, swap-if-else-branches, use-named-expression
"""执行表情动作"""
logger.info(f"{self.log_prefix} 决定发送表情")
@@ -120,7 +121,7 @@ class EmojiAction(BaseAction):
logger.error(f"{self.log_prefix} 未找到'utils_small'模型配置无法调用LLM")
return False, "未找到'utils_small'模型配置"
success, chosen_emotion = await llm_api.generate_with_model(
success, chosen_emotion, _, _ = await llm_api.generate_with_model(
prompt, model_config=chat_model_config, request_type="emoji"
)