diff --git a/src/chat/emoji_system/emoji_manager.py b/src/chat/emoji_system/emoji_manager.py index 918b83969..6d50d890e 100644 --- a/src/chat/emoji_system/emoji_manager.py +++ b/src/chat/emoji_system/emoji_manager.py @@ -8,15 +8,15 @@ import traceback import io import re import binascii + from typing import Optional, Tuple, List, Any from PIL import Image from rich.traceback import install - from src.common.database.database_model import Emoji from src.common.database.database import db as peewee_db from src.common.logger import get_logger -from src.config.config import global_config +from src.config.config import global_config, model_config from src.chat.utils.utils_image import image_path_to_base64, get_image_manager from src.llm_models.utils_model import LLMRequest @@ -379,9 +379,9 @@ class EmojiManager: self._scan_task = None - self.vlm = LLMRequest(model=global_config.model.vlm, temperature=0.3, max_tokens=1000, request_type="emoji") + self.vlm = LLMRequest(model_set=model_config.model_task_config.vlm, request_type="emoji") self.llm_emotion_judge = LLMRequest( - model=global_config.model.utils, max_tokens=600, request_type="emoji" + model_set=model_config.model_task_config.utils, request_type="emoji" ) # 更高的温度,更少的token(后续可以根据情绪来调整温度) self.emoji_num = 0 @@ -492,6 +492,7 @@ class EmojiManager: return None def _levenshtein_distance(self, s1: str, s2: str) -> int: + # sourcery skip: simplify-empty-collection-comparison, simplify-len-comparison, simplify-str-len-comparison """计算两个字符串的编辑距离 Args: @@ -629,11 +630,11 @@ class EmojiManager: if success: # 注册成功则跳出循环 break - else: - # 注册失败则删除对应文件 - file_path = os.path.join(EMOJI_DIR, filename) - os.remove(file_path) - logger.warning(f"[清理] 删除注册失败的表情包文件: {filename}") + + # 注册失败则删除对应文件 + file_path = os.path.join(EMOJI_DIR, filename) + os.remove(file_path) + logger.warning(f"[清理] 删除注册失败的表情包文件: {filename}") except Exception as e: logger.error(f"[错误] 扫描表情包目录失败: {str(e)}") @@ -694,6 +695,7 @@ class EmojiManager: return [] async def get_emoji_from_manager(self, emoji_hash: str) -> Optional["MaiEmoji"]: + # sourcery skip: use-next """从内存中的 emoji_objects 列表获取表情包 参数: @@ -709,10 +711,10 @@ class EmojiManager: async def get_emoji_description_by_hash(self, emoji_hash: str) -> Optional[str]: """根据哈希值获取已注册表情包的描述 - + Args: emoji_hash: 表情包的哈希值 - + Returns: Optional[str]: 表情包描述,如果未找到则返回None """ @@ -722,7 +724,7 @@ class EmojiManager: if emoji and emoji.description: logger.info(f"[缓存命中] 从内存获取表情包描述: {emoji.description[:50]}...") return emoji.description - + # 如果内存中没有,从数据库查找 self._ensure_db() try: @@ -732,9 +734,9 @@ class EmojiManager: return emoji_record.description except Exception as e: logger.error(f"从数据库查询表情包描述时出错: {e}") - + return None - + except Exception as e: logger.error(f"获取表情包描述失败 (Hash: {emoji_hash}): {str(e)}") return None @@ -779,6 +781,7 @@ class EmojiManager: return False async def replace_a_emoji(self, new_emoji: "MaiEmoji") -> bool: + # sourcery skip: use-getitem-for-re-match-groups """替换一个表情包 Args: @@ -820,7 +823,7 @@ class EmojiManager: ) # 调用大模型进行决策 - decision, _ = await self.llm_emotion_judge.generate_response_async(prompt, temperature=0.8) + decision, _ = await self.llm_emotion_judge.generate_response_async(prompt, temperature=0.8, max_tokens=600) logger.info(f"[决策] 结果: {decision}") # 解析决策结果 @@ -828,9 +831,7 @@ class EmojiManager: logger.info("[决策] 不删除任何表情包") return False - # 尝试从决策中提取表情包编号 - match = re.search(r"删除编号(\d+)", decision) - if match: + if match := re.search(r"删除编号(\d+)", decision): emoji_index = int(match.group(1)) - 1 # 转换为0-based索引 # 检查索引是否有效 @@ -889,6 +890,7 @@ class EmojiManager: existing_description = None try: from src.common.database.database_model import Images + existing_image = Images.get_or_none((Images.emoji_hash == image_hash) & (Images.type == "emoji")) if existing_image and existing_image.description: existing_description = existing_image.description @@ -902,15 +904,21 @@ class EmojiManager: logger.info("[优化] 复用已有的详细描述,跳过VLM调用") else: logger.info("[VLM分析] 生成新的详细描述") - if image_format == "gif" or image_format == "GIF": + if image_format in ["gif", "GIF"]: image_base64 = get_image_manager().transform_gif(image_base64) # type: ignore if not image_base64: raise RuntimeError("GIF表情包转换失败") prompt = "这是一个动态图表情包,每一张图代表了动态图的某一帧,黑色背景代表透明,描述一下表情包表达的情感和内容,描述细节,从互联网梗,meme的角度去分析" - description, _ = await self.vlm.generate_response_for_image(prompt, image_base64, "jpg") + description, _ = await self.vlm.generate_response_for_image( + prompt, image_base64, "jpg", temperature=0.3, max_tokens=1000 + ) else: - prompt = "这是一个表情包,请详细描述一下表情包所表达的情感和内容,描述细节,从互联网梗,meme的角度去分析" - description, _ = await self.vlm.generate_response_for_image(prompt, image_base64, image_format) + prompt = ( + "这是一个表情包,请详细描述一下表情包所表达的情感和内容,描述细节,从互联网梗,meme的角度去分析" + ) + description, _ = await self.vlm.generate_response_for_image( + prompt, image_base64, image_format, temperature=0.3, max_tokens=1000 + ) # 审核表情包 if global_config.emoji.content_filtration: @@ -922,7 +930,9 @@ class EmojiManager: 4. 不要出现5个以上文字 请回答这个表情包是否满足上述要求,是则回答是,否则回答否,不要出现任何其他内容 ''' - content, _ = await self.vlm.generate_response_for_image(prompt, image_base64, image_format) + content, _ = await self.vlm.generate_response_for_image( + prompt, image_base64, image_format, temperature=0.3, max_tokens=1000 + ) if content == "否": return "", [] @@ -933,7 +943,9 @@ class EmojiManager: 你可以关注其幽默和讽刺意味,动用贴吧,微博,小红书的知识,必须从互联网梗,meme的角度去分析 请直接输出描述,不要出现任何其他内容,如果有多个描述,可以用逗号分隔 """ - emotions_text, _ = await self.llm_emotion_judge.generate_response_async(emotion_prompt, temperature=0.7) + emotions_text, _ = await self.llm_emotion_judge.generate_response_async( + emotion_prompt, temperature=0.7, max_tokens=600 + ) # 处理情感列表 emotions = [e.strip() for e in emotions_text.split(",") if e.strip()] diff --git a/src/chat/express/expression_learner.py b/src/chat/express/expression_learner.py index 1870c470a..a98085038 100644 --- a/src/chat/express/expression_learner.py +++ b/src/chat/express/expression_learner.py @@ -7,12 +7,12 @@ from datetime import datetime from typing import List, Dict, Optional, Any, Tuple from src.common.logger import get_logger +from src.common.database.database_model import Expression from src.llm_models.utils_model import LLMRequest -from src.config.config import global_config +from src.config.config import model_config from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_random, build_anonymous_messages from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.chat.message_receive.chat_stream import get_chat_manager -from src.common.database.database_model import Expression MAX_EXPRESSION_COUNT = 300 @@ -80,11 +80,8 @@ def init_prompt() -> None: class ExpressionLearner: def __init__(self) -> None: - # TODO: API-Adapter修改标记 self.express_learn_model: LLMRequest = LLMRequest( - model=global_config.model.replyer_1, - temperature=0.3, - request_type="expressor.learner", + model_set=model_config.model_task_config.replyer_1, request_type="expressor.learner" ) self.llm_model = None self._ensure_expression_directories() @@ -101,7 +98,7 @@ class ExpressionLearner: os.path.join(base_dir, "learnt_style"), os.path.join(base_dir, "learnt_grammar"), ] - + for directory in directories_to_create: try: os.makedirs(directory, exist_ok=True) @@ -116,7 +113,7 @@ class ExpressionLearner: """ base_dir = os.path.join("data", "expression") done_flag = os.path.join(base_dir, "done.done") - + # 确保基础目录存在 try: os.makedirs(base_dir, exist_ok=True) @@ -124,28 +121,28 @@ class ExpressionLearner: except Exception as e: logger.error(f"创建表达方式目录失败: {e}") return - + if os.path.exists(done_flag): logger.info("表达方式JSON已迁移,无需重复迁移。") return - + logger.info("开始迁移表达方式JSON到数据库...") migrated_count = 0 - + for type in ["learnt_style", "learnt_grammar"]: type_str = "style" if type == "learnt_style" else "grammar" type_dir = os.path.join(base_dir, type) if not os.path.exists(type_dir): logger.debug(f"目录不存在,跳过: {type_dir}") continue - + try: chat_ids = os.listdir(type_dir) logger.debug(f"在 {type_dir} 中找到 {len(chat_ids)} 个聊天ID目录") except Exception as e: logger.error(f"读取目录失败 {type_dir}: {e}") continue - + for chat_id in chat_ids: expr_file = os.path.join(type_dir, chat_id, "expressions.json") if not os.path.exists(expr_file): @@ -153,24 +150,24 @@ class ExpressionLearner: try: with open(expr_file, "r", encoding="utf-8") as f: expressions = json.load(f) - + if not isinstance(expressions, list): logger.warning(f"表达方式文件格式错误,跳过: {expr_file}") continue - + for expr in expressions: if not isinstance(expr, dict): continue - + situation = expr.get("situation") style_val = expr.get("style") count = expr.get("count", 1) last_active_time = expr.get("last_active_time", time.time()) - + if not situation or not style_val: logger.warning(f"表达方式缺少必要字段,跳过: {expr}") continue - + # 查重:同chat_id+type+situation+style from src.common.database.database_model import Expression @@ -201,7 +198,7 @@ class ExpressionLearner: logger.error(f"JSON解析失败 {expr_file}: {e}") except Exception as e: logger.error(f"迁移表达方式 {expr_file} 失败: {e}") - + # 标记迁移完成 try: # 确保done.done文件的父目录存在 @@ -209,7 +206,7 @@ class ExpressionLearner: if not os.path.exists(done_parent_dir): os.makedirs(done_parent_dir, exist_ok=True) logger.debug(f"为done.done创建父目录: {done_parent_dir}") - + with open(done_flag, "w", encoding="utf-8") as f: f.write("done\n") logger.info(f"表达方式JSON迁移已完成,共迁移 {migrated_count} 个表达方式,已写入done.done标记文件") @@ -229,13 +226,13 @@ class ExpressionLearner: # 查找所有create_date为空的表达方式 old_expressions = Expression.select().where(Expression.create_date.is_null()) updated_count = 0 - + for expr in old_expressions: # 使用last_active_time作为create_date expr.create_date = expr.last_active_time expr.save() updated_count += 1 - + if updated_count > 0: logger.info(f"已为 {updated_count} 个老的表达方式设置创建日期") except Exception as e: @@ -287,25 +284,29 @@ class ExpressionLearner: 获取指定chat_id的表达方式创建信息,按创建日期排序 """ try: - expressions = (Expression.select() - .where(Expression.chat_id == chat_id) - .order_by(Expression.create_date.desc()) - .limit(limit)) - + expressions = ( + Expression.select() + .where(Expression.chat_id == chat_id) + .order_by(Expression.create_date.desc()) + .limit(limit) + ) + result = [] for expr in expressions: create_date = expr.create_date if expr.create_date is not None else expr.last_active_time - result.append({ - "situation": expr.situation, - "style": expr.style, - "type": expr.type, - "count": expr.count, - "create_date": create_date, - "create_date_formatted": format_create_date(create_date), - "last_active_time": expr.last_active_time, - "last_active_formatted": format_create_date(expr.last_active_time), - }) - + result.append( + { + "situation": expr.situation, + "style": expr.style, + "type": expr.type, + "count": expr.count, + "create_date": create_date, + "create_date_formatted": format_create_date(create_date), + "last_active_time": expr.last_active_time, + "last_active_formatted": format_create_date(expr.last_active_time), + } + ) + return result except Exception as e: logger.error(f"获取表达方式创建信息失败: {e}") @@ -355,19 +356,19 @@ class ExpressionLearner: try: # 获取所有表达方式 all_expressions = Expression.select() - + updated_count = 0 deleted_count = 0 - + for expr in all_expressions: # 计算时间差 last_active = expr.last_active_time time_diff_days = (current_time - last_active) / (24 * 3600) # 转换为天 - + # 计算衰减值 decay_value = self.calculate_decay_factor(time_diff_days) new_count = max(0.01, expr.count - decay_value) - + if new_count <= 0.01: # 如果count太小,删除这个表达方式 expr.delete_instance() @@ -377,10 +378,10 @@ class ExpressionLearner: expr.count = new_count expr.save() updated_count += 1 - + if updated_count > 0 or deleted_count > 0: logger.info(f"全局衰减完成:更新了 {updated_count} 个表达方式,删除了 {deleted_count} 个表达方式") - + except Exception as e: logger.error(f"数据库全局衰减失败: {e}") @@ -527,7 +528,7 @@ class ExpressionLearner: logger.debug(f"学习{type_str}的prompt: {prompt}") try: - response, _ = await self.express_learn_model.generate_response_async(prompt) + response, _ = await self.express_learn_model.generate_response_async(prompt, temperature=0.3) except Exception as e: logger.error(f"学习{type_str}失败: {e}") return None diff --git a/src/chat/express/expression_selector.py b/src/chat/express/expression_selector.py index 910b43c24..111225c83 100644 --- a/src/chat/express/expression_selector.py +++ b/src/chat/express/expression_selector.py @@ -1,16 +1,17 @@ import json import time import random +import hashlib from typing import List, Dict, Tuple, Optional, Any from json_repair import repair_json from src.llm_models.utils_model import LLMRequest -from src.config.config import global_config +from src.config.config import global_config, model_config from src.common.logger import get_logger +from src.common.database.database_model import Expression from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from .expression_learner import get_expression_learner -from src.common.database.database_model import Expression logger = get_logger("expression_selector") @@ -75,10 +76,8 @@ def weighted_sample(population: List[Dict], weights: List[float], k: int) -> Lis class ExpressionSelector: def __init__(self): self.expression_learner = get_expression_learner() - # TODO: API-Adapter修改标记 self.llm_model = LLMRequest( - model=global_config.model.utils_small, - request_type="expression.selector", + model_set=model_config.model_task_config.utils_small, request_type="expression.selector" ) @staticmethod @@ -92,7 +91,6 @@ class ExpressionSelector: id_str = parts[1] stream_type = parts[2] is_group = stream_type == "group" - import hashlib if is_group: components = [platform, str(id_str)] else: @@ -108,8 +106,7 @@ class ExpressionSelector: for group in groups: group_chat_ids = [] for stream_config_str in group: - chat_id_candidate = self._parse_stream_config_to_chat_id(stream_config_str) - if chat_id_candidate: + if chat_id_candidate := self._parse_stream_config_to_chat_id(stream_config_str): group_chat_ids.append(chat_id_candidate) if chat_id in group_chat_ids: return group_chat_ids @@ -118,9 +115,10 @@ class ExpressionSelector: def get_random_expressions( self, chat_id: str, total_num: int, style_percentage: float, grammar_percentage: float ) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]: + # sourcery skip: extract-duplicate-method, move-assign # 支持多chat_id合并抽选 related_chat_ids = self.get_related_chat_ids(chat_id) - + # 优化:一次性查询所有相关chat_id的表达方式 style_query = Expression.select().where( (Expression.chat_id.in_(related_chat_ids)) & (Expression.type == "style") @@ -128,7 +126,7 @@ class ExpressionSelector: grammar_query = Expression.select().where( (Expression.chat_id.in_(related_chat_ids)) & (Expression.type == "grammar") ) - + style_exprs = [ { "situation": expr.situation, @@ -138,9 +136,10 @@ class ExpressionSelector: "source_id": expr.chat_id, "type": "style", "create_date": expr.create_date if expr.create_date is not None else expr.last_active_time, - } for expr in style_query + } + for expr in style_query ] - + grammar_exprs = [ { "situation": expr.situation, @@ -150,9 +149,10 @@ class ExpressionSelector: "source_id": expr.chat_id, "type": "grammar", "create_date": expr.create_date if expr.create_date is not None else expr.last_active_time, - } for expr in grammar_query + } + for expr in grammar_query ] - + style_num = int(total_num * style_percentage) grammar_num = int(total_num * grammar_percentage) # 按权重抽样(使用count作为权重) @@ -174,22 +174,22 @@ class ExpressionSelector: return updates_by_key = {} for expr in expressions_to_update: - source_id = expr.get("source_id") - expr_type = expr.get("type", "style") - situation = expr.get("situation") - style = expr.get("style") + source_id: str = expr.get("source_id") # type: ignore + expr_type: str = expr.get("type", "style") + situation: str = expr.get("situation") # type: ignore + style: str = expr.get("style") # type: ignore if not source_id or not situation or not style: logger.warning(f"表达方式缺少必要字段,无法更新: {expr}") continue key = (source_id, expr_type, situation, style) if key not in updates_by_key: updates_by_key[key] = expr - for (chat_id, expr_type, situation, style), _expr in updates_by_key.items(): + for chat_id, expr_type, situation, style in updates_by_key: query = Expression.select().where( - (Expression.chat_id == chat_id) & - (Expression.type == expr_type) & - (Expression.situation == situation) & - (Expression.style == style) + (Expression.chat_id == chat_id) + & (Expression.type == expr_type) + & (Expression.situation == situation) + & (Expression.style == style) ) if query.exists(): expr_obj = query.get() @@ -264,7 +264,7 @@ class ExpressionSelector: # 4. 调用LLM try: - content, (_, _) = await self.llm_model.generate_response_async(prompt=prompt) + content, _ = await self.llm_model.generate_response_async(prompt=prompt) # logger.info(f"{self.log_prefix} LLM返回结果: {content}") diff --git a/src/chat/memory_system/Hippocampus.py b/src/chat/memory_system/Hippocampus.py index 26660e5c3..af1723047 100644 --- a/src/chat/memory_system/Hippocampus.py +++ b/src/chat/memory_system/Hippocampus.py @@ -5,25 +5,27 @@ import random import time import re import json -from itertools import combinations - import jieba import networkx as nx import numpy as np + +from itertools import combinations +from typing import List, Tuple, Coroutine, Any, Dict, Set from collections import Counter -from ...llm_models.utils_model import LLMRequest +from rich.traceback import install + +from src.llm_models.utils_model import LLMRequest +from src.config.config import global_config, model_config +from src.common.database.database_model import Messages, GraphNodes, GraphEdges # Peewee Models导入 from src.common.logger import get_logger from src.chat.memory_system.sample_distribution import MemoryBuildScheduler # 分布生成器 -from ..utils.chat_message_builder import ( +from src.chat.utils.chat_message_builder import ( get_raw_msg_by_timestamp, build_readable_messages, get_raw_msg_by_timestamp_with_chat, ) # 导入 build_readable_messages -from ..utils.utils import translate_timestamp_to_human_readable -from rich.traceback import install +from src.chat.utils.utils import translate_timestamp_to_human_readable -from ...config.config import global_config -from src.common.database.database_model import Messages, GraphNodes, GraphEdges # Peewee Models导入 install(extra_lines=3) @@ -198,8 +200,7 @@ class Hippocampus: self.parahippocampal_gyrus = ParahippocampalGyrus(self) # 从数据库加载记忆图 self.entorhinal_cortex.sync_memory_from_db() - # TODO: API-Adapter修改标记 - self.model_summary = LLMRequest(global_config.model.memory, request_type="memory.builder") + self.model_summary = LLMRequest(model_set=model_config.model_task_config.memory, request_type="memory.builder") def get_all_node_names(self) -> list: """获取记忆图中所有节点的名字列表""" @@ -339,9 +340,7 @@ class Hippocampus: else: topic_num = 5 # 51+字符: 5个关键词 (其余长文本) - topics_response, (reasoning_content, model_name) = await self.model_summary.generate_response_async( - self.find_topic_llm(text, topic_num) - ) + topics_response, _ = await self.model_summary.generate_response_async(self.find_topic_llm(text, topic_num)) # 提取关键词 keywords = re.findall(r"<([^>]+)>", topics_response) @@ -353,12 +352,11 @@ class Hippocampus: for keyword in ",".join(keywords).replace(",", ",").replace("、", ",").replace(" ", ",").split(",") if keyword.strip() ] - + if keywords: logger.info(f"提取关键词: {keywords}") - - return keywords - + + return keywords async def get_memory_from_text( self, @@ -1245,7 +1243,7 @@ class ParahippocampalGyrus: # 2. 使用LLM提取关键主题 topic_num = self.hippocampus.calculate_topic_num(input_text, compress_rate) - topics_response, (reasoning_content, model_name) = await self.hippocampus.model_summary.generate_response_async( + topics_response, _ = await self.hippocampus.model_summary.generate_response_async( self.hippocampus.find_topic_llm(input_text, topic_num) ) @@ -1269,7 +1267,7 @@ class ParahippocampalGyrus: logger.debug(f"过滤后话题: {filtered_topics}") # 4. 创建所有话题的摘要生成任务 - tasks = [] + tasks: List[Tuple[str, Coroutine[Any, Any, Tuple[str, Tuple[str, str, List[Dict[str, Any]] | None]]]]] = [] for topic in filtered_topics: # 调用修改后的 topic_what,不再需要 time_info topic_what_prompt = self.hippocampus.topic_what(input_text, topic) @@ -1281,7 +1279,7 @@ class ParahippocampalGyrus: continue # 等待所有任务完成 - compressed_memory = set() + compressed_memory: Set[Tuple[str, str]] = set() similar_topics_dict = {} for topic, task in tasks: diff --git a/src/chat/memory_system/instant_memory.py b/src/chat/memory_system/instant_memory.py index f7e54f8e9..a702a87ed 100644 --- a/src/chat/memory_system/instant_memory.py +++ b/src/chat/memory_system/instant_memory.py @@ -3,13 +3,16 @@ import time import re import json import ast -from json_repair import repair_json -from src.llm_models.utils_model import LLMRequest -from src.common.logger import get_logger import traceback -from src.config.config import global_config +from json_repair import repair_json +from datetime import datetime, timedelta + +from src.llm_models.utils_model import LLMRequest +from src.common.logger import get_logger from src.common.database.database_model import Memory # Peewee Models导入 +from src.config.config import model_config + logger = get_logger(__name__) @@ -35,8 +38,7 @@ class InstantMemory: self.chat_id = chat_id self.last_view_time = time.time() self.summary_model = LLMRequest( - model=global_config.model.memory, - temperature=0.5, + model_set=model_config.model_task_config.memory, request_type="memory.summary", ) @@ -48,14 +50,11 @@ class InstantMemory: """ try: - response, _ = await self.summary_model.generate_response_async(prompt) + response, _ = await self.summary_model.generate_response_async(prompt, temperature=0.5) print(prompt) print(response) - if "1" in response: - return True - else: - return False + return "1" in response except Exception as e: logger.error(f"判断是否需要记忆出现错误:{str(e)} {traceback.format_exc()}") return False @@ -71,9 +70,9 @@ class InstantMemory: }} """ try: - response, _ = await self.summary_model.generate_response_async(prompt) - print(prompt) - print(response) + response, _ = await self.summary_model.generate_response_async(prompt, temperature=0.5) + # print(prompt) + # print(response) if not response: return None try: @@ -142,7 +141,7 @@ class InstantMemory: 请只输出json格式,不要输出其他多余内容 """ try: - response, _ = await self.summary_model.generate_response_async(prompt) + response, _ = await self.summary_model.generate_response_async(prompt, temperature=0.5) print(prompt) print(response) if not response: @@ -177,7 +176,7 @@ class InstantMemory: for mem in query: # 对每条记忆 - mem_keywords = mem.keywords or [] + mem_keywords = mem.keywords or "" parsed = ast.literal_eval(mem_keywords) if isinstance(parsed, list): mem_keywords = [str(k).strip() for k in parsed if str(k).strip()] @@ -201,6 +200,7 @@ class InstantMemory: return None def _parse_time_range(self, time_str): + # sourcery skip: extract-duplicate-method, use-contextlib-suppress """ 支持解析如下格式: - 具体日期时间:YYYY-MM-DD HH:MM:SS @@ -208,8 +208,6 @@ class InstantMemory: - 相对时间:今天,昨天,前天,N天前,N个月前 - 空字符串:返回(None, None) """ - from datetime import datetime, timedelta - now = datetime.now() if not time_str: return 0, now @@ -239,14 +237,12 @@ class InstantMemory: start = (now - timedelta(days=2)).replace(hour=0, minute=0, second=0, microsecond=0) end = start + timedelta(days=1) return start, end - m = re.match(r"(\d+)天前", time_str) - if m: + if m := re.match(r"(\d+)天前", time_str): days = int(m.group(1)) start = (now - timedelta(days=days)).replace(hour=0, minute=0, second=0, microsecond=0) end = start + timedelta(days=1) return start, end - m = re.match(r"(\d+)个月前", time_str) - if m: + if m := re.match(r"(\d+)个月前", time_str): months = int(m.group(1)) # 近似每月30天 start = (now - timedelta(days=months * 30)).replace(hour=0, minute=0, second=0, microsecond=0) diff --git a/src/chat/memory_system/memory_activator.py b/src/chat/memory_system/memory_activator.py index 715d9c067..d3cbb5d75 100644 --- a/src/chat/memory_system/memory_activator.py +++ b/src/chat/memory_system/memory_activator.py @@ -1,13 +1,15 @@ -from src.llm_models.utils_model import LLMRequest -from src.config.config import global_config -from src.common.logger import get_logger -from src.chat.utils.prompt_builder import Prompt, global_prompt_manager -from datetime import datetime -from src.chat.memory_system.Hippocampus import hippocampus_manager -from typing import List, Dict import difflib import json + from json_repair import repair_json +from typing import List, Dict +from datetime import datetime + +from src.llm_models.utils_model import LLMRequest +from src.config.config import global_config, model_config +from src.common.logger import get_logger +from src.chat.utils.prompt_builder import Prompt, global_prompt_manager +from src.chat.memory_system.Hippocampus import hippocampus_manager logger = get_logger("memory_activator") @@ -61,11 +63,8 @@ def init_prompt(): class MemoryActivator: def __init__(self): - # TODO: API-Adapter修改标记 - self.key_words_model = LLMRequest( - model=global_config.model.utils_small, - temperature=0.5, + model_set=model_config.model_task_config.utils_small, request_type="memory.activator", ) @@ -92,7 +91,9 @@ class MemoryActivator: # logger.debug(f"prompt: {prompt}") - response, (reasoning_content, model_name) = await self.key_words_model.generate_response_async(prompt) + response, (reasoning_content, model_name, _) = await self.key_words_model.generate_response_async( + prompt, temperature=0.5 + ) keywords = list(get_keywords_from_json(response)) diff --git a/src/chat/message_receive/message.py b/src/chat/message_receive/message.py index 56ccd33d0..58dd6d689 100644 --- a/src/chat/message_receive/message.py +++ b/src/chat/message_receive/message.py @@ -203,7 +203,7 @@ class MessageRecvS4U(MessageRecv): self.is_superchat = False self.gift_info = None self.gift_name = None - self.gift_count = None + self.gift_count: Optional[str] = None self.superchat_info = None self.superchat_price = None self.superchat_message_text = None diff --git a/src/chat/planner_actions/action_manager.py b/src/chat/planner_actions/action_manager.py index 21d47c75d..267b7a8ff 100644 --- a/src/chat/planner_actions/action_manager.py +++ b/src/chat/planner_actions/action_manager.py @@ -1,9 +1,10 @@ from typing import Dict, Optional, Type -from src.plugin_system.base.base_action import BaseAction + from src.chat.message_receive.chat_stream import ChatStream from src.common.logger import get_logger from src.plugin_system.core.component_registry import component_registry from src.plugin_system.base.component_types import ComponentType, ActionInfo +from src.plugin_system.base.base_action import BaseAction logger = get_logger("action_manager") diff --git a/src/chat/planner_actions/action_modifier.py b/src/chat/planner_actions/action_modifier.py index da11c54f6..dfa4c79c1 100644 --- a/src/chat/planner_actions/action_modifier.py +++ b/src/chat/planner_actions/action_modifier.py @@ -5,7 +5,7 @@ import time from typing import List, Any, Dict, TYPE_CHECKING, Tuple from src.common.logger import get_logger -from src.config.config import global_config +from src.config.config import global_config, model_config from src.llm_models.utils_model import LLMRequest from src.chat.message_receive.chat_stream import get_chat_manager, ChatMessageContext from src.chat.planner_actions.action_manager import ActionManager @@ -36,10 +36,7 @@ class ActionModifier: self.action_manager = action_manager # 用于LLM判定的小模型 - self.llm_judge = LLMRequest( - model=global_config.model.utils_small, - request_type="action.judge", - ) + self.llm_judge = LLMRequest(model_set=model_config.model_task_config.utils_small, request_type="action.judge") # 缓存相关属性 self._llm_judge_cache = {} # 缓存LLM判定结果 @@ -438,4 +435,4 @@ class ActionModifier: return True else: logger.debug(f"{self.log_prefix}动作 {action_name} 未匹配到任何关键词: {activation_keywords}") - return False \ No newline at end of file + return False diff --git a/src/chat/planner_actions/planner.py b/src/chat/planner_actions/planner.py index 0b26a97d0..04e17ad6e 100644 --- a/src/chat/planner_actions/planner.py +++ b/src/chat/planner_actions/planner.py @@ -7,7 +7,7 @@ from datetime import datetime from json_repair import repair_json from src.llm_models.utils_model import LLMRequest -from src.config.config import global_config +from src.config.config import global_config, model_config from src.common.logger import get_logger from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.chat.utils.chat_message_builder import ( @@ -73,10 +73,7 @@ class ActionPlanner: self.log_prefix = f"[{get_chat_manager().get_stream_name(chat_id) or chat_id}]" self.action_manager = action_manager # LLM规划器配置 - self.planner_llm = LLMRequest( - model=global_config.model.planner, - request_type="planner", # 用于动作规划 - ) + self.planner_llm = LLMRequest(model_set=model_config.model_task_config.planner, request_type="planner") # 用于动作规划 self.last_obs_time_mark = 0.0 @@ -140,7 +137,7 @@ class ActionPlanner: # --- 调用 LLM (普通文本生成) --- llm_content = None try: - llm_content, (reasoning_content, _) = await self.planner_llm.generate_response_async(prompt=prompt) + llm_content, (reasoning_content, _, _) = await self.planner_llm.generate_response_async(prompt=prompt) if global_config.debug.show_prompt: logger.info(f"{self.log_prefix}规划器原始提示词: {prompt}") diff --git a/src/chat/replyer/default_generator.py b/src/chat/replyer/default_generator.py index dd691e484..9aacb1ae1 100644 --- a/src/chat/replyer/default_generator.py +++ b/src/chat/replyer/default_generator.py @@ -8,7 +8,8 @@ from typing import List, Optional, Dict, Any, Tuple from datetime import datetime from src.mais4u.mai_think import mai_thinking_manager from src.common.logger import get_logger -from src.config.config import global_config +from src.config.config import global_config, model_config +from src.config.api_ada_configs import TaskConfig from src.individuality.individuality import get_individuality from src.llm_models.utils_model import LLMRequest from src.chat.message_receive.message import UserInfo, Seg, MessageRecv, MessageSending @@ -106,31 +107,36 @@ class DefaultReplyer: def __init__( self, chat_stream: ChatStream, - model_configs: Optional[List[Dict[str, Any]]] = None, + model_set_with_weight: Optional[List[Tuple[TaskConfig, float]]] = None, request_type: str = "focus.replyer", ): self.request_type = request_type - if model_configs: - self.express_model_configs = model_configs + if model_set_with_weight: + # self.express_model_configs = model_configs + self.model_set: List[Tuple[TaskConfig, float]] = model_set_with_weight else: # 当未提供配置时,使用默认配置并赋予默认权重 - model_config_1 = global_config.model.replyer_1.copy() - model_config_2 = global_config.model.replyer_2.copy() + # model_config_1 = global_config.model.replyer_1.copy() + # model_config_2 = global_config.model.replyer_2.copy() prob_first = global_config.chat.replyer_random_probability - model_config_1["weight"] = prob_first - model_config_2["weight"] = 1.0 - prob_first + # model_config_1["weight"] = prob_first + # model_config_2["weight"] = 1.0 - prob_first - self.express_model_configs = [model_config_1, model_config_2] + # self.express_model_configs = [model_config_1, model_config_2] + self.model_set = [ + (model_config.model_task_config.replyer_1, prob_first), + (model_config.model_task_config.replyer_2, 1.0 - prob_first), + ] - if not self.express_model_configs: - logger.warning("未找到有效的模型配置,回复生成可能会失败。") - # 提供一个最终的回退,以防止在空列表上调用 random.choice - fallback_config = global_config.model.replyer_1.copy() - fallback_config.setdefault("weight", 1.0) - self.express_model_configs = [fallback_config] + # if not self.express_model_configs: + # logger.warning("未找到有效的模型配置,回复生成可能会失败。") + # # 提供一个最终的回退,以防止在空列表上调用 random.choice + # fallback_config = global_config.model.replyer_1.copy() + # fallback_config.setdefault("weight", 1.0) + # self.express_model_configs = [fallback_config] self.chat_stream = chat_stream self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_stream.stream_id) @@ -139,14 +145,15 @@ class DefaultReplyer: self.memory_activator = MemoryActivator() self.instant_memory = InstantMemory(chat_id=self.chat_stream.stream_id) - from src.plugin_system.core.tool_use import ToolExecutor # 延迟导入ToolExecutor,不然会循环依赖 + from src.plugin_system.core.tool_use import ToolExecutor # 延迟导入ToolExecutor,不然会循环依赖 + self.tool_executor = ToolExecutor(chat_id=self.chat_stream.stream_id, enable_cache=True, cache_ttl=3) - def _select_weighted_model_config(self) -> Dict[str, Any]: + def _select_weighted_models_config(self) -> Tuple[TaskConfig, float]: """使用加权随机选择来挑选一个模型配置""" - configs = self.express_model_configs + configs = self.model_set # 提取权重,如果模型配置中没有'weight'键,则默认为1.0 - weights = [config.get("weight", 1.0) for config in configs] + weights = [weight for _, weight in configs] return random.choices(population=configs, weights=weights, k=1)[0] @@ -188,12 +195,11 @@ class DefaultReplyer: # 4. 调用 LLM 生成回复 content = None - # TODO: 复活这里 - # reasoning_content = None - # model_name = "unknown_model" + reasoning_content = None + model_name = "unknown_model" try: - content = await self.llm_generate_content(prompt) + content, reasoning_content, model_name, _ = await self.llm_generate_content(prompt) logger.debug(f"replyer生成内容: {content}") except Exception as llm_e: @@ -236,15 +242,14 @@ class DefaultReplyer: ) content = None - # TODO: 复活这里 - # reasoning_content = None - # model_name = "unknown_model" + reasoning_content = None + model_name = "unknown_model" if not prompt: logger.error("Prompt 构建失败,无法生成回复。") return False, None, None try: - content = await self.llm_generate_content(prompt) + content, reasoning_content, model_name, _ = await self.llm_generate_content(prompt) logger.info(f"想要表达:{raw_reply}||理由:{reason}||生成回复: {content}\n") except Exception as llm_e: @@ -843,7 +848,7 @@ class DefaultReplyer: raw_reply: str, reason: str, reply_to: str, - ) -> str: + ) -> str: # sourcery skip: remove-redundant-if chat_stream = self.chat_stream chat_id = chat_stream.stream_id is_group_chat = bool(chat_stream.group_info) @@ -977,30 +982,23 @@ class DefaultReplyer: display_message=display_message, ) - async def llm_generate_content(self, prompt: str) -> str: + async def llm_generate_content(self, prompt: str): with Timer("LLM生成", {}): # 内部计时器,可选保留 # 加权随机选择一个模型配置 - selected_model_config = self._select_weighted_model_config() - model_display_name = selected_model_config.get('model_name') or selected_model_config.get('name', 'N/A') - logger.info( - f"使用模型生成回复: {model_display_name} (选中概率: {selected_model_config.get('weight', 1.0)})" - ) + selected_model_config, weight = self._select_weighted_models_config() + logger.info(f"使用模型集生成回复: {selected_model_config} (选中概率: {weight})") - express_model = LLMRequest( - model=selected_model_config, - request_type=self.request_type, - ) + express_model = LLMRequest(model_set=selected_model_config, request_type=self.request_type) if global_config.debug.show_prompt: logger.info(f"\n{prompt}\n") else: logger.debug(f"\n{prompt}\n") - # TODO: 这里的_应该做出替换 - content, _ = await express_model.generate_response_async(prompt) + content, (reasoning_content, model_name, tool_calls) = await express_model.generate_response_async(prompt) logger.debug(f"replyer生成内容: {content}") - return content + return content, reasoning_content, model_name, tool_calls def weighted_sample_no_replacement(items, weights, k) -> list: diff --git a/src/chat/replyer/replyer_manager.py b/src/chat/replyer/replyer_manager.py index 3f1c731b4..bb3a313b7 100644 --- a/src/chat/replyer/replyer_manager.py +++ b/src/chat/replyer/replyer_manager.py @@ -1,6 +1,7 @@ -from typing import Dict, Any, Optional, List +from typing import Dict, Optional, List, Tuple from src.common.logger import get_logger +from src.config.api_ada_configs import TaskConfig from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager from src.chat.replyer.default_generator import DefaultReplyer @@ -15,7 +16,7 @@ class ReplyerManager: self, chat_stream: Optional[ChatStream] = None, chat_id: Optional[str] = None, - model_configs: Optional[List[Dict[str, Any]]] = None, + model_set_with_weight: Optional[List[Tuple[TaskConfig, float]]] = None, request_type: str = "replyer", ) -> Optional[DefaultReplyer]: """ @@ -49,7 +50,7 @@ class ReplyerManager: # model_configs 只在此时(初始化时)生效 replyer = DefaultReplyer( chat_stream=target_stream, - model_configs=model_configs, # 可以是None,此时使用默认模型 + model_set_with_weight=model_set_with_weight, # 可以是None,此时使用默认模型 request_type=request_type, ) self._repliers[stream_id] = replyer diff --git a/src/chat/utils/utils.py b/src/chat/utils/utils.py index 3ee4ae7b1..0b9ec7798 100644 --- a/src/chat/utils/utils.py +++ b/src/chat/utils/utils.py @@ -11,7 +11,7 @@ from typing import Optional, Tuple, Dict, List, Any from src.common.logger import get_logger from src.common.message_repository import find_messages, count_messages -from src.config.config import global_config +from src.config.config import global_config, model_config from src.chat.message_receive.message import MessageRecv from src.chat.message_receive.chat_stream import get_chat_manager from src.llm_models.utils_model import LLMRequest @@ -109,13 +109,11 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]: return is_mentioned, reply_probability -async def get_embedding(text, request_type="embedding"): +async def get_embedding(text, request_type="embedding") -> Optional[List[float]]: """获取文本的embedding向量""" - # TODO: API-Adapter修改标记 - llm = LLMRequest(model=global_config.model.embedding, request_type=request_type) - # return llm.get_embedding_sync(text) + llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type=request_type) try: - embedding = await llm.get_embedding(text) + embedding, _ = await llm.get_embedding(text) except Exception as e: logger.error(f"获取embedding失败: {str(e)}") embedding = None diff --git a/src/chat/utils/utils_image.py b/src/chat/utils/utils_image.py index 7f14aa6d4..fcf1c717c 100644 --- a/src/chat/utils/utils_image.py +++ b/src/chat/utils/utils_image.py @@ -14,7 +14,7 @@ from rich.traceback import install from src.common.logger import get_logger from src.common.database.database import db from src.common.database.database_model import Images, ImageDescriptions -from src.config.config import global_config +from src.config.config import global_config, model_config from src.llm_models.utils_model import LLMRequest install(extra_lines=3) @@ -37,7 +37,7 @@ class ImageManager: self._ensure_image_dir() self._initialized = True - self.vlm = LLMRequest(model=global_config.model.vlm, temperature=0.4, max_tokens=300, request_type="image") + self.vlm = LLMRequest(model_set=model_config.model_task_config.vlm, request_type="image") try: db.connect(reuse_if_open=True) @@ -107,6 +107,7 @@ class ImageManager: # 优先使用EmojiManager查询已注册表情包的描述 try: from src.chat.emoji_system.emoji_manager import get_emoji_manager + emoji_manager = get_emoji_manager() cached_emoji_description = await emoji_manager.get_emoji_description_by_hash(image_hash) if cached_emoji_description: @@ -116,13 +117,12 @@ class ImageManager: logger.debug(f"查询EmojiManager时出错: {e}") # 查询ImageDescriptions表的缓存描述 - cached_description = self._get_description_from_db(image_hash, "emoji") - if cached_description: + if cached_description := self._get_description_from_db(image_hash, "emoji"): logger.info(f"[缓存命中] 使用ImageDescriptions表中的描述: {cached_description[:50]}...") return f"[表情包:{cached_description}]" # === 二步走识别流程 === - + # 第一步:VLM视觉分析 - 生成详细描述 if image_format in ["gif", "GIF"]: image_base64_processed = self.transform_gif(image_base64) @@ -130,10 +130,16 @@ class ImageManager: logger.warning("GIF转换失败,无法获取描述") return "[表情包(GIF处理失败)]" vlm_prompt = "这是一个动态图表情包,每一张图代表了动态图的某一帧,黑色背景代表透明,描述一下表情包表达的情感和内容,描述细节,从互联网梗,meme的角度去分析" - detailed_description, _ = await self.vlm.generate_response_for_image(vlm_prompt, image_base64_processed, "jpg") + detailed_description, _ = await self.vlm.generate_response_for_image( + vlm_prompt, image_base64_processed, "jpg", temperature=0.4, max_tokens=300 + ) else: - vlm_prompt = "这是一个表情包,请详细描述一下表情包所表达的情感和内容,描述细节,从互联网梗,meme的角度去分析" - detailed_description, _ = await self.vlm.generate_response_for_image(vlm_prompt, image_base64, image_format) + vlm_prompt = ( + "这是一个表情包,请详细描述一下表情包所表达的情感和内容,描述细节,从互联网梗,meme的角度去分析" + ) + detailed_description, _ = await self.vlm.generate_response_for_image( + vlm_prompt, image_base64, image_format, temperature=0.4, max_tokens=300 + ) if detailed_description is None: logger.warning("VLM未能生成表情包详细描述") @@ -150,31 +156,32 @@ class ImageManager: 3. 输出简短精准,不要解释 4. 如果有多个词用逗号分隔 """ - + # 使用较低温度确保输出稳定 - emotion_llm = LLMRequest(model=global_config.model.utils, temperature=0.3, max_tokens=50, request_type="emoji") - emotion_result, _ = await emotion_llm.generate_response_async(emotion_prompt) + emotion_llm = LLMRequest(model_set=model_config.model_task_config.utils, request_type="emoji") + emotion_result, _ = await emotion_llm.generate_response_async( + emotion_prompt, temperature=0.3, max_tokens=50 + ) if emotion_result is None: logger.warning("LLM未能生成情感标签,使用详细描述的前几个词") # 降级处理:从详细描述中提取关键词 import jieba + words = list(jieba.cut(detailed_description)) emotion_result = ",".join(words[:2]) if len(words) >= 2 else (words[0] if words else "表情") # 处理情感结果,取前1-2个最重要的标签 emotions = [e.strip() for e in emotion_result.replace(",", ",").split(",") if e.strip()] final_emotion = emotions[0] if emotions else "表情" - + # 如果有第二个情感且不重复,也包含进来 if len(emotions) > 1 and emotions[1] != emotions[0]: final_emotion = f"{emotions[0]},{emotions[1]}" logger.info(f"[emoji识别] 详细描述: {detailed_description[:50]}... -> 情感标签: {final_emotion}") - # 再次检查缓存,防止并发写入时重复生成 - cached_description = self._get_description_from_db(image_hash, "emoji") - if cached_description: + if cached_description := self._get_description_from_db(image_hash, "emoji"): logger.warning(f"虽然生成了描述,但是找到缓存表情包描述: {cached_description}") return f"[表情包:{cached_description}]" @@ -242,9 +249,7 @@ class ImageManager: logger.debug(f"[缓存命中] 使用Images表中的图片描述: {existing_image.description[:50]}...") return f"[图片:{existing_image.description}]" - # 查询ImageDescriptions表的缓存描述 - cached_description = self._get_description_from_db(image_hash, "image") - if cached_description: + if cached_description := self._get_description_from_db(image_hash, "image"): logger.debug(f"[缓存命中] 使用ImageDescriptions表中的描述: {cached_description[:50]}...") return f"[图片:{cached_description}]" @@ -252,7 +257,9 @@ class ImageManager: image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore prompt = global_config.custom_prompt.image_prompt logger.info(f"[VLM调用] 为图片生成新描述 (Hash: {image_hash[:8]}...)") - description, _ = await self.vlm.generate_response_for_image(prompt, image_base64, image_format) + description, _ = await self.vlm.generate_response_for_image( + prompt, image_base64, image_format, temperature=0.4, max_tokens=300 + ) if description is None: logger.warning("AI未能生成图片描述") @@ -445,10 +452,7 @@ class ImageManager: image_bytes = base64.b64decode(image_base64) image_hash = hashlib.md5(image_bytes).hexdigest() - # 检查图片是否已存在 - existing_image = Images.get_or_none(Images.emoji_hash == image_hash) - - if existing_image: + if existing_image := Images.get_or_none(Images.emoji_hash == image_hash): # 检查是否缺少必要字段,如果缺少则创建新记录 if ( not hasattr(existing_image, "image_id") @@ -524,9 +528,7 @@ class ImageManager: # 优先检查是否已有其他相同哈希的图片记录包含描述 existing_with_description = Images.get_or_none( - (Images.emoji_hash == image_hash) & - (Images.description.is_null(False)) & - (Images.description != "") + (Images.emoji_hash == image_hash) & (Images.description.is_null(False)) & (Images.description != "") ) if existing_with_description and existing_with_description.id != image.id: logger.debug(f"[缓存复用] 从其他相同图片记录复用描述: {existing_with_description.description[:50]}...") @@ -538,8 +540,7 @@ class ImageManager: return # 检查ImageDescriptions表的缓存描述 - cached_description = self._get_description_from_db(image_hash, "image") - if cached_description: + if cached_description := self._get_description_from_db(image_hash, "image"): logger.debug(f"[缓存复用] 从ImageDescriptions表复用描述: {cached_description[:50]}...") image.description = cached_description image.vlm_processed = True @@ -554,15 +555,15 @@ class ImageManager: # 获取VLM描述 logger.info(f"[VLM异步调用] 为图片生成描述 (ID: {image_id}, Hash: {image_hash[:8]}...)") - description, _ = await self.vlm.generate_response_for_image(prompt, image_base64, image_format) + description, _ = await self.vlm.generate_response_for_image( + prompt, image_base64, image_format, temperature=0.4, max_tokens=300 + ) if description is None: logger.warning("VLM未能生成图片描述") description = "无法生成描述" - # 再次检查缓存,防止并发写入时重复生成 - cached_description = self._get_description_from_db(image_hash, "image") - if cached_description: + if cached_description := self._get_description_from_db(image_hash, "image"): logger.warning(f"虽然生成了描述,但是找到缓存图片描述: {cached_description}") description = cached_description @@ -606,7 +607,7 @@ def image_path_to_base64(image_path: str) -> str: raise FileNotFoundError(f"图片文件不存在: {image_path}") with open(image_path, "rb") as f: - image_data = f.read() - if not image_data: + if image_data := f.read(): + return base64.b64encode(image_data).decode("utf-8") + else: raise IOError(f"读取图片文件失败: {image_path}") - return base64.b64encode(image_data).decode("utf-8") diff --git a/src/chat/utils/utils_voice.py b/src/chat/utils/utils_voice.py index cf71dc56f..baff40916 100644 --- a/src/chat/utils/utils_voice.py +++ b/src/chat/utils/utils_voice.py @@ -1,6 +1,6 @@ import base64 -from src.config.config import global_config +from src.config.config import global_config, model_config from src.llm_models.utils_model import LLMRequest from src.common.logger import get_logger @@ -20,7 +20,7 @@ async def get_voice_text(voice_base64: str) -> str: if isinstance(voice_base64, str): voice_base64 = voice_base64.encode("ascii", errors="ignore").decode("ascii") voice_bytes = base64.b64decode(voice_base64) - _llm = LLMRequest(model=global_config.model.voice, request_type="voice") + _llm = LLMRequest(model_set=model_config.model_task_config.voice, request_type="voice") text = await _llm.generate_response_for_voice(voice_bytes) if text is None: logger.warning("未能生成语音文本") diff --git a/src/chat/willing/mode_mxp.py b/src/chat/willing/mode_mxp.py index 5a13a628a..a249cb6f1 100644 --- a/src/chat/willing/mode_mxp.py +++ b/src/chat/willing/mode_mxp.py @@ -19,13 +19,13 @@ Mxp 模式:梦溪畔独家赞助 下下策是询问一个菜鸟(@梦溪畔) """ -from .willing_manager import BaseWillingManager from typing import Dict import asyncio import time import math from src.chat.message_receive.chat_stream import ChatStream +from .willing_manager import BaseWillingManager class MxpWillingManager(BaseWillingManager): diff --git a/src/config/config.py b/src/config/config.py index 298163b07..645a9f179 100644 --- a/src/config/config.py +++ b/src/config/config.py @@ -60,268 +60,6 @@ TEMPLATE_DIR = os.path.join(PROJECT_ROOT, "template") MMC_VERSION = "0.10.0-snapshot.2" -# def _get_config_version(toml: Dict) -> Version: -# """提取配置文件的 SpecifierSet 版本数据 -# Args: -# toml[dict]: 输入的配置文件字典 -# Returns: -# Version -# """ - -# if "inner" in toml and "version" in toml["inner"]: -# config_version: str = toml["inner"]["version"] -# else: -# raise InvalidVersion("配置文件缺少版本信息,请检查配置文件。") - -# try: -# return version.parse(config_version) -# except InvalidVersion as e: -# logger.error( -# "配置文件中 inner段 的 version 键是错误的版本描述\n" -# f"请检查配置文件,当前 version 键: {config_version}\n" -# f"错误信息: {e}" -# ) -# raise e - - -# def _request_conf(parent: Dict, config: ModuleConfig): -# request_conf_config = parent.get("request_conf") -# config.req_conf.max_retry = request_conf_config.get( -# "max_retry", config.req_conf.max_retry -# ) -# config.req_conf.timeout = request_conf_config.get( -# "timeout", config.req_conf.timeout -# ) -# config.req_conf.retry_interval = request_conf_config.get( -# "retry_interval", config.req_conf.retry_interval -# ) -# config.req_conf.default_temperature = request_conf_config.get( -# "default_temperature", config.req_conf.default_temperature -# ) -# config.req_conf.default_max_tokens = request_conf_config.get( -# "default_max_tokens", config.req_conf.default_max_tokens -# ) - - -# def _api_providers(parent: Dict, config: ModuleConfig): -# api_providers_config = parent.get("api_providers") -# for provider in api_providers_config: -# name = provider.get("name", None) -# base_url = provider.get("base_url", None) -# api_key = provider.get("api_key", None) -# api_keys = provider.get("api_keys", []) # 新增:支持多个API Key -# client_type = provider.get("client_type", "openai") - -# if name in config.api_providers: # 查重 -# logger.error(f"重复的API提供商名称: {name},请检查配置文件。") -# raise KeyError(f"重复的API提供商名称: {name},请检查配置文件。") - -# if name and base_url: -# # 处理API Key配置:支持单个api_key或多个api_keys -# if api_keys: -# # 使用新格式:api_keys列表 -# logger.debug(f"API提供商 '{name}' 配置了 {len(api_keys)} 个API Key") -# elif api_key: -# # 向后兼容:使用单个api_key -# api_keys = [api_key] -# logger.debug(f"API提供商 '{name}' 使用单个API Key(向后兼容模式)") -# else: -# logger.warning(f"API提供商 '{name}' 没有配置API Key,某些功能可能不可用") - -# config.api_providers[name] = APIProvider( -# name=name, -# base_url=base_url, -# api_key=api_key, # 保留向后兼容 -# api_keys=api_keys, # 新格式 -# client_type=client_type, -# ) -# else: -# logger.error(f"API提供商 '{name}' 的配置不完整,请检查配置文件。") -# raise ValueError(f"API提供商 '{name}' 的配置不完整,请检查配置文件。") - - -# def _models(parent: Dict, config: ModuleConfig): -# models_config = parent.get("models") -# for model in models_config: -# model_identifier = model.get("model_identifier", None) -# name = model.get("name", model_identifier) -# api_provider = model.get("api_provider", None) -# price_in = model.get("price_in", 0.0) -# price_out = model.get("price_out", 0.0) -# force_stream_mode = model.get("force_stream_mode", False) -# task_type = model.get("task_type", "") -# capabilities = model.get("capabilities", []) - -# if name in config.models: # 查重 -# logger.error(f"重复的模型名称: {name},请检查配置文件。") -# raise KeyError(f"重复的模型名称: {name},请检查配置文件。") - -# if model_identifier and api_provider: -# # 检查API提供商是否存在 -# if api_provider not in config.api_providers: -# logger.error(f"未声明的API提供商 '{api_provider}' ,请检查配置文件。") -# raise ValueError( -# f"未声明的API提供商 '{api_provider}' ,请检查配置文件。" -# ) -# config.models[name] = ModelInfo( -# name=name, -# model_identifier=model_identifier, -# api_provider=api_provider, -# price_in=price_in, -# price_out=price_out, -# force_stream_mode=force_stream_mode, -# task_type=task_type, -# capabilities=capabilities, -# ) -# else: -# logger.error(f"模型 '{name}' 的配置不完整,请检查配置文件。") -# raise ValueError(f"模型 '{name}' 的配置不完整,请检查配置文件。") - - -# def _task_model_usage(parent: Dict, config: ModuleConfig): -# model_usage_configs = parent.get("task_model_usage") -# config.task_model_arg_map = {} -# for task_name, item in model_usage_configs.items(): -# if task_name in config.task_model_arg_map: -# logger.error(f"子任务 {task_name} 已存在,请检查配置文件。") -# raise KeyError(f"子任务 {task_name} 已存在,请检查配置文件。") - -# usage = [] -# if isinstance(item, Dict): -# if "model" in item: -# usage.append( -# ModelUsageArgConfigItem( -# name=item["model"], -# temperature=item.get("temperature", None), -# max_tokens=item.get("max_tokens", None), -# max_retry=item.get("max_retry", None), -# ) -# ) -# else: -# logger.error(f"子任务 {task_name} 的模型配置不合法,请检查配置文件。") -# raise ValueError( -# f"子任务 {task_name} 的模型配置不合法,请检查配置文件。" -# ) -# elif isinstance(item, List): -# for model in item: -# if isinstance(model, Dict): -# usage.append( -# ModelUsageArgConfigItem( -# name=model["model"], -# temperature=model.get("temperature", None), -# max_tokens=model.get("max_tokens", None), -# max_retry=model.get("max_retry", None), -# ) -# ) -# elif isinstance(model, str): -# usage.append( -# ModelUsageArgConfigItem( -# name=model, -# temperature=None, -# max_tokens=None, -# max_retry=None, -# ) -# ) -# else: -# logger.error( -# f"子任务 {task_name} 的模型配置不合法,请检查配置文件。" -# ) -# raise ValueError( -# f"子任务 {task_name} 的模型配置不合法,请检查配置文件。" -# ) -# elif isinstance(item, str): -# usage.append( -# ModelUsageArgConfigItem( -# name=item, -# temperature=None, -# max_tokens=None, -# max_retry=None, -# ) -# ) - -# config.task_model_arg_map[task_name] = ModelUsageArgConfig( -# name=task_name, -# usage=usage, -# ) - - -# def api_ada_load_config(config_path: str) -> ModuleConfig: -# """从TOML配置文件加载配置""" -# config = ModuleConfig() - -# include_configs: Dict[str, Dict[str, Any]] = { -# "request_conf": { -# "func": _request_conf, -# "support": ">=0.0.0", -# "necessary": False, -# }, -# "api_providers": {"func": _api_providers, "support": ">=0.0.0"}, -# "models": {"func": _models, "support": ">=0.0.0"}, -# "task_model_usage": {"func": _task_model_usage, "support": ">=0.0.0"}, -# } - -# if os.path.exists(config_path): -# with open(config_path, "rb") as f: -# try: -# toml_dict = tomlkit.load(f) -# except tomlkit.TOMLDecodeError as e: -# logger.critical( -# f"配置文件model_list.toml填写有误,请检查第{e.lineno}行第{e.colno}处:{e.msg}" -# ) -# exit(1) - -# # 获取配置文件版本 -# config.INNER_VERSION = _get_config_version(toml_dict) - -# # 检查版本 -# if config.INNER_VERSION > Version(NEWEST_VER): -# logger.warning( -# f"当前配置文件版本 {config.INNER_VERSION} 高于支持的最新版本 {NEWEST_VER},可能导致异常,建议更新依赖。" -# ) - -# # 解析配置文件 -# # 如果在配置中找到了需要的项,调用对应项的闭包函数处理 -# for key in include_configs: -# if key in toml_dict: -# group_specifier_set: SpecifierSet = SpecifierSet( -# include_configs[key]["support"] -# ) - -# # 检查配置文件版本是否在支持范围内 -# if config.INNER_VERSION in group_specifier_set: -# # 如果版本在支持范围内,检查是否存在通知 -# if "notice" in include_configs[key]: -# logger.warning(include_configs[key]["notice"]) -# # 调用闭包函数处理配置 -# (include_configs[key]["func"])(toml_dict, config) -# else: -# # 如果版本不在支持范围内,崩溃并提示用户 -# logger.error( -# f"配置文件中的 '{key}' 字段的版本 ({config.INNER_VERSION}) 不在支持范围内。\n" -# f"当前程序仅支持以下版本范围: {group_specifier_set}" -# ) -# raise InvalidVersion( -# f"当前程序仅支持以下版本范围: {group_specifier_set}" -# ) - -# # 如果 necessary 项目存在,而且显式声明是 False,进入特殊处理 -# elif ( -# "necessary" in include_configs[key] -# and include_configs[key].get("necessary") is False -# ): -# # 通过 pass 处理的项虽然直接忽略也是可以的,但是为了不增加理解困难,依然需要在这里显式处理 -# if key == "keywords_reaction": -# pass -# else: -# # 如果用户根本没有需要的配置项,提示缺少配置 -# logger.error(f"配置文件中缺少必需的字段: '{key}'") -# raise KeyError(f"配置文件中缺少必需的字段: '{key}'") - -# logger.info(f"成功加载配置文件: {config_path}") - -# return config - - def get_key_comment(toml_table, key): # 获取key的注释(如果有) if hasattr(toml_table, "trivia") and hasattr(toml_table.trivia, "comment"): @@ -626,9 +364,19 @@ class APIAdapterConfig(ConfigBase): """API提供商列表""" def __post_init__(self): + # 检查API提供商名称是否重复 + provider_names = [provider.name for provider in self.api_providers] + if len(provider_names) != len(set(provider_names)): + raise ValueError("API提供商名称存在重复,请检查配置文件。") + + # 检查模型名称是否重复 + model_names = [model.name for model in self.models] + if len(model_names) != len(set(model_names)): + raise ValueError("模型名称存在重复,请检查配置文件。") + self.api_providers_dict = {provider.name: provider for provider in self.api_providers} self.models_dict = {model.name: model for model in self.models} - + def get_model_info(self, model_name: str) -> ModelInfo: """根据模型名称获取模型信息""" if not model_name: @@ -636,7 +384,7 @@ class APIAdapterConfig(ConfigBase): if model_name not in self.models_dict: raise KeyError(f"模型 '{model_name}' 不存在") return self.models_dict[model_name] - + def get_provider(self, provider_name: str) -> APIProvider: """根据提供商名称获取API提供商信息""" if not provider_name: diff --git a/src/individuality/individuality.py b/src/individuality/individuality.py index 4c8fcac50..c2655fba7 100644 --- a/src/individuality/individuality.py +++ b/src/individuality/individuality.py @@ -4,7 +4,7 @@ import hashlib import time from src.common.logger import get_logger -from src.config.config import global_config +from src.config.config import global_config, model_config from src.llm_models.utils_model import LLMRequest from src.person_info.person_info import get_person_info_manager from rich.traceback import install @@ -23,10 +23,7 @@ class Individuality: self.meta_info_file_path = "data/personality/meta.json" self.personality_data_file_path = "data/personality/personality_data.json" - self.model = LLMRequest( - model=global_config.model.utils, - request_type="individuality.compress", - ) + self.model = LLMRequest(model_set=model_config.model_task_config.utils, request_type="individuality.compress") async def initialize(self) -> None: """初始化个体特征""" @@ -35,7 +32,6 @@ class Individuality: personality_side = global_config.personality.personality_side identity = global_config.personality.identity - person_info_manager = get_person_info_manager() self.bot_person_id = person_info_manager.get_person_id("system", "bot_id") self.name = bot_nickname @@ -85,16 +81,16 @@ class Individuality: bot_nickname = f",也有人叫你{','.join(global_config.bot.alias_names)}" else: bot_nickname = "" - + # 从文件获取 short_impression personality, identity = self._get_personality_from_file() - + # 确保short_impression是列表格式且有足够的元素 if not personality or not identity: logger.warning(f"personality或identity为空: {personality}, {identity}, 使用默认值") personality = "友好活泼" identity = "人类" - + prompt_personality = f"{personality}\n{identity}" return f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}" @@ -215,7 +211,7 @@ class Individuality: def _get_personality_from_file(self) -> tuple[str, str]: """从文件获取personality数据 - + Returns: tuple: (personality, identity) """ @@ -226,7 +222,7 @@ class Individuality: def _save_personality_to_file(self, personality: str, identity: str): """保存personality数据到文件 - + Args: personality: 压缩后的人格描述 identity: 压缩后的身份描述 @@ -235,7 +231,7 @@ class Individuality: "personality": personality, "identity": identity, "bot_nickname": self.name, - "last_updated": int(time.time()) + "last_updated": int(time.time()), } self._save_personality_data(personality_data) @@ -269,7 +265,7 @@ class Individuality: 2. 尽量简洁,不超过30字 3. 直接输出压缩后的内容,不要解释""" - response, (_, _) = await self.model.generate_response_async( + response, _ = await self.model.generate_response_async( prompt=prompt, ) @@ -281,7 +277,7 @@ class Individuality: # 压缩失败时使用原始内容 if personality_side: personality_parts.append(personality_side) - + if personality_parts: personality_result = "。".join(personality_parts) else: @@ -308,7 +304,7 @@ class Individuality: 2. 尽量简洁,不超过30字 3. 直接输出压缩后的内容,不要解释""" - response, (_, _) = await self.model.generate_response_async( + response, _ = await self.model.generate_response_async( prompt=prompt, ) diff --git a/src/llm_models/model_manager.py b/src/llm_models/model_manager.py deleted file mode 100644 index 2db3a6d25..000000000 --- a/src/llm_models/model_manager.py +++ /dev/null @@ -1,12 +0,0 @@ -import importlib -from typing import Dict - -from src.config.config import model_config -from src.common.logger import get_logger - -from .model_client import ModelRequestHandler, BaseClient - -logger = get_logger("模型管理器") - -class ModelManager: - \ No newline at end of file diff --git a/src/llm_models/model_manager_bak.py b/src/llm_models/model_manager_bak.py deleted file mode 100644 index 36d63c72e..000000000 --- a/src/llm_models/model_manager_bak.py +++ /dev/null @@ -1,92 +0,0 @@ -import importlib -from typing import Dict - -from src.config.config import model_config -from src.config.api_ada_configs import ModuleConfig, ModelUsageArgConfig -from src.common.logger import get_logger - -from .model_client import ModelRequestHandler, BaseClient - -logger = get_logger("模型管理器") - -class ModelManager: - # TODO: 添加读写锁,防止异步刷新配置时发生数据竞争 - - def __init__( - self, - config: ModuleConfig, - ): - self.config: ModuleConfig = config - """配置信息""" - - self.api_client_map: Dict[str, BaseClient] = {} - """API客户端映射表""" - - self._request_handler_cache: Dict[str, ModelRequestHandler] = {} - """ModelRequestHandler缓存,避免重复创建""" - - for provider_name, api_provider in self.config.api_providers.items(): - # 初始化API客户端 - try: - # 根据配置动态加载实现 - client_module = importlib.import_module( - f".model_client.{api_provider.client_type}_client", __package__ - ) - client_class = getattr( - client_module, f"{api_provider.client_type.capitalize()}Client" - ) - if not issubclass(client_class, BaseClient): - raise TypeError( - f"'{client_class.__name__}' is not a subclass of 'BaseClient'" - ) - self.api_client_map[api_provider.name] = client_class( - api_provider - ) # 实例化,放入api_client_map - except ImportError as e: - logger.error(f"Failed to import client module: {e}") - raise ImportError( - f"Failed to import client module for '{provider_name}': {e}" - ) from e - - def __getitem__(self, task_name: str) -> ModelRequestHandler: - """ - 获取任务所需的模型客户端(封装) - 使用缓存机制避免重复创建ModelRequestHandler - :param task_name: 任务名称 - :return: 模型客户端 - """ - if task_name not in self.config.task_model_arg_map: - raise KeyError(f"'{task_name}' not registered in ModelManager") - - # 检查缓存中是否已存在 - if task_name in self._request_handler_cache: - logger.debug(f"🚀 [性能优化] 从缓存获取ModelRequestHandler: {task_name}") - return self._request_handler_cache[task_name] - - # 创建新的ModelRequestHandler并缓存 - logger.debug(f"🔧 [性能优化] 创建并缓存ModelRequestHandler: {task_name}") - handler = ModelRequestHandler( - task_name=task_name, - config=self.config, - api_client_map=self.api_client_map, - ) - self._request_handler_cache[task_name] = handler - return handler - - def __setitem__(self, task_name: str, value: ModelUsageArgConfig): - """ - 注册任务的模型使用配置 - :param task_name: 任务名称 - :param value: 模型使用配置 - """ - self.config.task_model_arg_map[task_name] = value - - def __contains__(self, task_name: str): - """ - 判断任务是否已注册 - :param task_name: 任务名称 - :return: 是否在模型列表中 - """ - return task_name in self.config.task_model_arg_map - - diff --git a/src/llm_models/usage_statistic.py b/src/llm_models/usage_statistic.py deleted file mode 100644 index 0ed1bd3ad..000000000 --- a/src/llm_models/usage_statistic.py +++ /dev/null @@ -1,169 +0,0 @@ -from datetime import datetime -from enum import Enum -from typing import Tuple - -from src.common.logger import get_logger -from src.config.api_ada_configs import ModelInfo -from src.common.database.database_model import LLMUsage - -logger = get_logger("模型使用统计") - - -class ReqType(Enum): - """ - 请求类型 - """ - - CHAT = "chat" # 对话请求 - EMBEDDING = "embedding" # 嵌入请求 - - -class UsageCallStatus(Enum): - """ - 任务调用状态 - """ - - PROCESSING = "processing" # 处理中 - SUCCESS = "success" # 成功 - FAILURE = "failure" # 失败 - CANCELED = "canceled" # 取消 - - -class ModelUsageStatistic: - """ - 模型使用统计类 - 使用SQLite+Peewee - """ - - def __init__(self): - """ - 初始化统计类 - 由于使用Peewee ORM,不需要传入数据库实例 - """ - # 确保表已经创建 - try: - from src.common.database.database import db - - db.create_tables([LLMUsage], safe=True) - except Exception as e: - logger.error(f"创建LLMUsage表失败: {e}") - - @staticmethod - def _calculate_cost(prompt_tokens: int, completion_tokens: int, model_info: ModelInfo) -> float: - """计算API调用成本 - 使用模型的pri_in和pri_out价格计算输入和输出的成本 - - Args: - prompt_tokens: 输入token数量 - completion_tokens: 输出token数量 - model_info: 模型信息 - - Returns: - float: 总成本(元) - """ - # 使用模型的pri_in和pri_out计算成本 - input_cost = (prompt_tokens / 1000000) * model_info.price_in - output_cost = (completion_tokens / 1000000) * model_info.price_out - return round(input_cost + output_cost, 6) - - def create_usage( - self, - model_name: str, - task_name: str = "N/A", - request_type: ReqType = ReqType.CHAT, - user_id: str = "system", - endpoint: str = "/chat/completions", - ) -> int | None: - """ - 创建模型使用情况记录 - - Args: - model_name: 模型名 - task_name: 任务名称 - request_type: 请求类型,默认为Chat - user_id: 用户ID,默认为system - endpoint: API端点 - - Returns: - int | None: 返回记录ID,失败返回None - """ - try: - usage_record = LLMUsage.create( - model_name=model_name, - user_id=user_id, - request_type=request_type.value, - endpoint=endpoint, - prompt_tokens=0, - completion_tokens=0, - total_tokens=0, - cost=0.0, - status=UsageCallStatus.PROCESSING.value, - timestamp=datetime.now(), - ) - - # logger.trace( - # f"创建了一条模型使用情况记录 - 模型: {model_name}, " - # f"子任务: {task_name}, 类型: {request_type.value}, " - # f"用户: {user_id}, 记录ID: {usage_record.id}" - # ) - - return usage_record.id - except Exception as e: - logger.error(f"创建模型使用情况记录失败: {str(e)}") - return None - - def update_usage( - self, - record_id: int | None, - model_info: ModelInfo, - usage_data: Tuple[int, int, int] | None = None, - stat: UsageCallStatus = UsageCallStatus.SUCCESS, - ext_msg: str | None = None, - ): - """ - 更新模型使用情况 - - Args: - record_id: 记录ID - model_info: 模型信息 - usage_data: 使用情况数据(输入token数量, 输出token数量, 总token数量) - stat: 任务调用状态 - ext_msg: 额外信息 - """ - if not record_id: - logger.error("更新模型使用情况失败: record_id不能为空") - return - - if usage_data and len(usage_data) != 3: - logger.error("更新模型使用情况失败: usage_data的长度不正确,应该为3个元素") - return - - # 提取使用情况数据 - prompt_tokens = usage_data[0] if usage_data else 0 - completion_tokens = usage_data[1] if usage_data else 0 - total_tokens = usage_data[2] if usage_data else 0 - - try: - # 使用Peewee更新记录 - update_query = LLMUsage.update( - status=stat.value, - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=total_tokens, - cost=self._calculate_cost(prompt_tokens, completion_tokens, model_info) if usage_data else 0.0, - ).where(LLMUsage.id == record_id) # type: ignore - - updated_count = update_query.execute() - - if updated_count == 0: - logger.warning(f"记录ID {record_id} 不存在,无法更新") - return - - logger.debug( - f"Token使用情况 - 模型: {model_info.name}, " - f"记录ID: {record_id}, " - f"任务状态: {stat.value}, 额外信息: {ext_msg or 'N/A'}, " - f"提示词: {prompt_tokens}, 完成: {completion_tokens}, " - f"总计: {total_tokens}" - ) - except Exception as e: - logger.error(f"记录token使用情况失败: {str(e)}") diff --git a/src/llm_models/utils.py b/src/llm_models/utils.py index 352df5a43..52a6120c2 100644 --- a/src/llm_models/utils.py +++ b/src/llm_models/utils.py @@ -2,16 +2,19 @@ import base64 import io from PIL import Image +from datetime import datetime from src.common.logger import get_logger +from src.common.database.database import db # 确保 db 被导入用于 create_tables +from src.common.database.database_model import LLMUsage +from src.config.api_ada_configs import ModelInfo from .payload_content.message import Message, MessageBuilder +from .model_client.base_client import UsageRecord logger = get_logger("消息压缩工具") -def compress_messages( - messages: list[Message], img_target_size: int = 1 * 1024 * 1024 -) -> list[Message]: +def compress_messages(messages: list[Message], img_target_size: int = 1 * 1024 * 1024) -> list[Message]: """ 压缩消息列表中的图片 :param messages: 消息列表 @@ -28,14 +31,10 @@ def compress_messages( try: image = Image.open(image_data) - if image.format and ( - image.format.upper() in ["JPEG", "JPG", "PNG", "WEBP"] - ): + if image.format and (image.format.upper() in ["JPEG", "JPG", "PNG", "WEBP"]): # 静态图像,转换为JPEG格式 reformated_image_data = io.BytesIO() - image.save( - reformated_image_data, format="JPEG", quality=95, optimize=True - ) + image.save(reformated_image_data, format="JPEG", quality=95, optimize=True) image_data = reformated_image_data.getvalue() return image_data @@ -43,9 +42,7 @@ def compress_messages( logger.error(f"图片转换格式失败: {str(e)}") return image_data - def rescale_image( - image_data: bytes, scale: float - ) -> tuple[bytes, tuple[int, int] | None, tuple[int, int] | None]: + def rescale_image(image_data: bytes, scale: float) -> tuple[bytes, tuple[int, int] | None, tuple[int, int] | None]: """ 缩放图片 :param image_data: 图片数据 @@ -86,9 +83,7 @@ def compress_messages( else: # 静态图片,直接缩放保存 resized_image = image.resize(new_size, Image.Resampling.LANCZOS) - resized_image.save( - output_buffer, format="JPEG", quality=95, optimize=True - ) + resized_image.save(output_buffer, format="JPEG", quality=95, optimize=True) return output_buffer.getvalue(), original_size, new_size @@ -99,9 +94,7 @@ def compress_messages( logger.error(traceback.format_exc()) return image_data, None, None - def compress_base64_image( - base64_data: str, target_size: int = 1 * 1024 * 1024 - ) -> str: + def compress_base64_image(base64_data: str, target_size: int = 1 * 1024 * 1024) -> str: original_b64_data_size = len(base64_data) # 计算原始数据大小 image_data = base64.b64decode(base64_data) @@ -111,9 +104,7 @@ def compress_messages( base64_data = base64.b64encode(image_data).decode("utf-8") if len(base64_data) <= target_size: # 如果转换后小于目标大小,直接返回 - logger.info( - f"成功将图片转为JPEG格式,编码后大小: {len(base64_data) / 1024:.1f}KB" - ) + logger.info(f"成功将图片转为JPEG格式,编码后大小: {len(base64_data) / 1024:.1f}KB") return base64_data # 如果转换后仍然大于目标大小,进行尺寸压缩 @@ -139,9 +130,7 @@ def compress_messages( # 图片,进行压缩 message_builder.add_image_content( content_item[0], - compress_base64_image( - content_item[1], target_size=img_target_size - ), + compress_base64_image(content_item[1], target_size=img_target_size), ) else: message_builder.add_text_content(content_item) @@ -150,3 +139,48 @@ def compress_messages( compressed_messages.append(message) return compressed_messages + + +class LLMUsageRecorder: + """ + LLM使用情况记录器 + """ + + def __init__(self): + try: + # 使用 Peewee 创建表,safe=True 表示如果表已存在则不会抛出错误 + db.create_tables([LLMUsage], safe=True) + # logger.debug("LLMUsage 表已初始化/确保存在。") + except Exception as e: + logger.error(f"创建 LLMUsage 表失败: {str(e)}") + + def record_usage_to_database( + self, model_info: ModelInfo, model_usage: UsageRecord, user_id: str, request_type: str, endpoint: str + ): + input_cost = (model_usage.prompt_tokens / 1000000) * model_info.price_in + output_cost = (model_usage.completion_tokens / 1000000) * model_info.price_out + total_cost = round(input_cost + output_cost, 6) + try: + # 使用 Peewee 模型创建记录 + LLMUsage.create( + model_name=model_info.model_identifier, + user_id=user_id, + request_type=request_type, + endpoint=endpoint, + prompt_tokens=model_usage.prompt_tokens or 0, + completion_tokens=model_usage.completion_tokens or 0, + total_tokens=model_usage.total_tokens or 0, + cost=total_cost or 0.0, + status="success", + timestamp=datetime.now(), # Peewee 会处理 DateTimeField + ) + logger.debug( + f"Token使用情况 - 模型: {model_usage.model_name}, " + f"用户: {user_id}, 类型: {request_type}, " + f"提示词: {model_usage.prompt_tokens}, 完成: {model_usage.completion_tokens}, " + f"总计: {model_usage.total_tokens}" + ) + except Exception as e: + logger.error(f"记录token使用情况失败: {str(e)}") + +llm_usage_recorder = LLMUsageRecorder() \ No newline at end of file diff --git a/src/llm_models/utils_model.py b/src/llm_models/utils_model.py index 4602fb751..1c2c5afde 100644 --- a/src/llm_models/utils_model.py +++ b/src/llm_models/utils_model.py @@ -1,34 +1,20 @@ import re import copy import asyncio -from datetime import datetime -from typing import Tuple, Union, List, Dict, Optional, Callable, Any -from src.common.logger import get_logger -import base64 -from PIL import Image -from enum import Enum -import io -from src.common.database.database import db # 确保 db 被导入用于 create_tables -from src.common.database.database_model import LLMUsage # 导入 LLMUsage 模型 -from src.config.config import global_config, model_config -from src.config.api_ada_configs import APIProvider, ModelInfo -from rich.traceback import install +from enum import Enum +from rich.traceback import install +from typing import Tuple, List, Dict, Optional, Callable, Any + +from src.common.logger import get_logger +from src.config.config import model_config +from src.config.api_ada_configs import APIProvider, ModelInfo, TaskConfig from .payload_content.message import MessageBuilder, Message from .payload_content.resp_format import RespFormat from .payload_content.tool_option import ToolOption, ToolCall -from .model_client.base_client import BaseClient, APIResponse, UsageRecord, client_registry -from .utils import compress_messages - -from .exceptions import ( - NetworkConnectionError, - ReqAbortException, - RespNotOkException, - RespParseException, - PayLoadTooLargeError, - RequestAbortException, - PermissionDeniedException, -) +from .model_client.base_client import BaseClient, APIResponse, client_registry +from .utils import compress_messages, llm_usage_recorder +from .exceptions import NetworkConnectionError, ReqAbortException, RespNotOkException, RespParseException install(extra_lines=3) @@ -57,45 +43,15 @@ class RequestType(Enum): class LLMRequest: """LLM请求类""" - # 定义需要转换的模型列表,作为类变量避免重复 - MODELS_NEEDING_TRANSFORMATION = [ - "o1", - "o1-2024-12-17", - "o1-mini", - "o1-mini-2024-09-12", - "o1-preview", - "o1-preview-2024-09-12", - "o1-pro", - "o1-pro-2025-03-19", - "o3", - "o3-2025-04-16", - "o3-mini", - "o3-mini-2025-01-31", - "o4-mini", - "o4-mini-2025-04-16", - ] - - def __init__(self, task_name: str, request_type: str = "") -> None: - self.task_name = task_name - self.model_for_task = model_config.model_task_config.get_task(task_name) + def __init__(self, model_set: TaskConfig, request_type: str = "") -> None: + self.task_name = request_type + self.model_for_task = model_set self.request_type = request_type self.model_usage: Dict[str, Tuple[int, int]] = {model: (0, 0) for model in self.model_for_task.model_list} """模型使用量记录,用于进行负载均衡,对应为(total_tokens, penalty),惩罚值是为了能在某个模型请求不给力的时候进行调整""" self.pri_in = 0 self.pri_out = 0 - - self._init_database() - - @staticmethod - def _init_database(): - """初始化数据库集合""" - try: - # 使用 Peewee 创建表,safe=True 表示如果表已存在则不会抛出错误 - db.create_tables([LLMUsage], safe=True) - # logger.debug("LLMUsage 表已初始化/确保存在。") - except Exception as e: - logger.error(f"创建 LLMUsage 表失败: {str(e)}") async def generate_response_for_image( self, @@ -104,7 +60,7 @@ class LLMRequest: image_format: str, temperature: Optional[float] = None, max_tokens: Optional[int] = None, - ) -> Tuple[str, str, Optional[List[Dict[str, Any]]]]: + ) -> Tuple[str, Tuple[str, str, Optional[List[Dict[str, Any]]]]]: """ 为图像生成响应 Args: @@ -112,7 +68,7 @@ class LLMRequest: image_base64 (str): 图像的Base64编码字符串 image_format (str): 图像格式(如 'png', 'jpeg' 等) Returns: - + (Tuple[str, str, str, Optional[List[Dict[str, Any]]]]): 响应内容、推理内容、模型名称、工具调用列表 """ # 请求体构建 message_builder = MessageBuilder() @@ -141,25 +97,25 @@ class LLMRequest: content, extracted_reasoning = self._extract_reasoning(content) reasoning_content = extracted_reasoning if usage := response.usage: - self.pri_in = model_info.price_in - self.pri_out = model_info.price_out - self._record_usage( - model_name=model_info.name, - prompt_tokens=usage.prompt_tokens or 0, - completion_tokens=usage.completion_tokens, - total_tokens=usage.total_tokens or 0, + llm_usage_recorder.record_usage_to_database( + model_info=model_info, + model_usage=usage, user_id="system", request_type=self.request_type, endpoint="/chat/completions", ) - return content, reasoning_content, self._convert_tool_calls(tool_calls) if tool_calls else None + return content, ( + reasoning_content, + model_info.name, + self._convert_tool_calls(tool_calls) if tool_calls else None, + ) async def generate_response_for_voice(self): pass async def generate_response_async( self, prompt: str, temperature: Optional[float] = None, max_tokens: Optional[int] = None - ) -> Tuple[str, str, Optional[List[Dict[str, Any]]]]: + ) -> Tuple[str, Tuple[str, str, Optional[List[Dict[str, Any]]]]]: """ 异步生成响应 Args: @@ -167,7 +123,7 @@ class LLMRequest: temperature (float, optional): 温度参数 max_tokens (int, optional): 最大token数 Returns: - Tuple[str, str, Optional[List[Dict[str, Any]]]]: 响应内容、推理内容和工具调用列表 + (Tuple[str, str, str, Optional[List[Dict[str, Any]]]]): 响应内容、推理内容、模型名称、工具调用列表 """ # 请求体构建 message_builder = MessageBuilder() @@ -195,13 +151,9 @@ class LLMRequest: content, extracted_reasoning = self._extract_reasoning(content) reasoning_content = extracted_reasoning if usage := response.usage: - self.pri_in = model_info.price_in - self.pri_out = model_info.price_out - self._record_usage( - model_name=model_info.name, - prompt_tokens=usage.prompt_tokens or 0, - completion_tokens=usage.completion_tokens, - total_tokens=usage.total_tokens or 0, + llm_usage_recorder.record_usage_to_database( + model_info=model_info, + model_usage=usage, user_id="system", request_type=self.request_type, endpoint="/chat/completions", @@ -209,10 +161,19 @@ class LLMRequest: if not content: raise RuntimeError("获取LLM生成内容失败") - return content, reasoning_content, self._convert_tool_calls(tool_calls) if tool_calls else None + return content, ( + reasoning_content, + model_info.name, + self._convert_tool_calls(tool_calls) if tool_calls else None, + ) - async def get_embedding(self, embedding_input: str) -> List[float]: - """获取嵌入向量""" + async def get_embedding(self, embedding_input: str) -> Tuple[List[float], str]: + """获取嵌入向量 + Args: + embedding_input (str): 获取嵌入的目标 + Returns: + (Tuple[List[float], str]): (嵌入向量,使用的模型名称) + """ # 无需构建消息体,直接使用输入文本 model_info, api_provider, client = self._select_model() @@ -227,14 +188,10 @@ class LLMRequest: embedding = response.embedding - if response.usage: - self.pri_in = model_info.price_in - self.pri_out = model_info.price_out - self._record_usage( - model_name=model_info.name, - prompt_tokens=response.usage.prompt_tokens or 0, - completion_tokens=response.usage.completion_tokens, - total_tokens=response.usage.total_tokens or 0, + if usage := response.usage: + llm_usage_recorder.record_usage_to_database( + model_info=model_info, + model_usage=usage, user_id="system", request_type=self.request_type, endpoint="/embeddings", @@ -243,7 +200,7 @@ class LLMRequest: if not embedding: raise RuntimeError("获取embedding失败") - return embedding + return embedding, model_info.name def _select_model(self) -> Tuple[ModelInfo, APIProvider, BaseClient]: """ @@ -305,12 +262,13 @@ class LLMRequest: # 处理异常 total_tokens, penalty = self.model_usage[model_info.name] self.model_usage[model_info.name] = (total_tokens, penalty + 1) + wait_interval, compressed_messages = self._default_exception_handler( e, self.task_name, model_name=model_info.name, remain_try=retry_remain, - messages=(message_list, compressed_messages is not None), + messages=(message_list, compressed_messages is not None) if message_list else None, ) if wait_interval == -1: @@ -321,9 +279,7 @@ class LLMRequest: finally: # 放在finally防止死循环 retry_remain -= 1 - logger.error( - f"任务 '{self.task_name}' 模型 '{model_info.name}' 请求失败,达到最大重试次数 {api_provider.max_retry} 次" - ) + logger.error(f"模型 '{model_info.name}' 请求失败,达到最大重试次数 {api_provider.max_retry} 次") raise RuntimeError("请求失败,已达到最大重试次数") def _default_exception_handler( @@ -481,65 +437,3 @@ class LLMRequest: content = re.sub(r"(?:)?.*?", "", content, flags=re.DOTALL, count=1).strip() reasoning = match[1].strip() if match else "" return content, reasoning - - def _record_usage( - self, - model_name: str, - prompt_tokens: int, - completion_tokens: int, - total_tokens: int, - user_id: str = "system", - request_type: str | None = None, - endpoint: str = "/chat/completions", - ): - """记录模型使用情况到数据库 - Args: - prompt_tokens: 输入token数 - completion_tokens: 输出token数 - total_tokens: 总token数 - user_id: 用户ID,默认为system - request_type: 请求类型 - endpoint: API端点 - """ - # 如果 request_type 为 None,则使用实例变量中的值 - if request_type is None: - request_type = self.request_type - - try: - # 使用 Peewee 模型创建记录 - LLMUsage.create( - model_name=model_name, - user_id=user_id, - request_type=request_type, - endpoint=endpoint, - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=total_tokens, - cost=self._calculate_cost(prompt_tokens, completion_tokens), - status="success", - timestamp=datetime.now(), # Peewee 会处理 DateTimeField - ) - logger.debug( - f"Token使用情况 - 模型: {model_name}, " - f"用户: {user_id}, 类型: {request_type}, " - f"提示词: {prompt_tokens}, 完成: {completion_tokens}, " - f"总计: {total_tokens}" - ) - except Exception as e: - logger.error(f"记录token使用情况失败: {str(e)}") - - def _calculate_cost(self, prompt_tokens: int, completion_tokens: int) -> float: - """计算API调用成本 - 使用模型的pri_in和pri_out价格计算输入和输出的成本 - - Args: - prompt_tokens: 输入token数量 - completion_tokens: 输出token数量 - - Returns: - float: 总成本(元) - """ - # 使用模型的pri_in和pri_out计算成本 - input_cost = (prompt_tokens / 1000000) * self.pri_in - output_cost = (completion_tokens / 1000000) * self.pri_out - return round(input_cost + output_cost, 6) diff --git a/src/mais4u/mai_think.py b/src/mais4u/mai_think.py index 867ba8bef..5a1f58082 100644 --- a/src/mais4u/mai_think.py +++ b/src/mais4u/mai_think.py @@ -2,13 +2,15 @@ from src.chat.message_receive.chat_stream import get_chat_manager import time from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.llm_models.utils_model import LLMRequest -from src.config.config import global_config +from src.config.config import model_config from src.chat.message_receive.message import MessageRecvS4U from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor from src.mais4u.mais4u_chat.internal_manager import internal_manager from src.common.logger import get_logger + logger = get_logger(__name__) + def init_prompt(): Prompt( """ @@ -32,10 +34,8 @@ def init_prompt(): ) - - class MaiThinking: - def __init__(self,chat_id): + def __init__(self, chat_id): self.chat_id = chat_id self.chat_stream = get_chat_manager().get_stream(chat_id) self.platform = self.chat_stream.platform @@ -44,11 +44,11 @@ class MaiThinking: self.is_group = True else: self.is_group = False - + self.s4u_message_processor = S4UMessageProcessor() - + self.mind = "" - + self.memory_block = "" self.relation_info_block = "" self.time_block = "" @@ -59,17 +59,13 @@ class MaiThinking: self.identity = "" self.sender = "" self.target = "" - - self.thinking_model = LLMRequest( - model=global_config.model.replyer_1, - request_type="thinking", - ) + + self.thinking_model = LLMRequest(model_set=model_config.model_task_config.replyer_1, request_type="thinking") async def do_think_before_response(self): pass - async def do_think_after_response(self,reponse:str): - + async def do_think_after_response(self, reponse: str): prompt = await global_prompt_manager.format_prompt( "after_response_think_prompt", mind=self.mind, @@ -85,47 +81,44 @@ class MaiThinking: sender=self.sender, target=self.target, ) - + result, _ = await self.thinking_model.generate_response_async(prompt) self.mind = result - + logger.info(f"[{self.chat_id}] 思考前想法:{self.mind}") # logger.info(f"[{self.chat_id}] 思考前prompt:{prompt}") logger.info(f"[{self.chat_id}] 思考后想法:{self.mind}") - - + msg_recv = await self.build_internal_message_recv(self.mind) await self.s4u_message_processor.process_message(msg_recv) internal_manager.set_internal_state(self.mind) - - + async def do_think_when_receive_message(self): pass - - async def build_internal_message_recv(self,message_text:str): - + + async def build_internal_message_recv(self, message_text: str): msg_id = f"internal_{time.time()}" - + message_dict = { "message_info": { "message_id": msg_id, "time": time.time(), "user_info": { - "user_id": "internal", # 内部用户ID - "user_nickname": "内心", # 内部昵称 - "platform": self.platform, # 平台标记为 internal + "user_id": "internal", # 内部用户ID + "user_nickname": "内心", # 内部昵称 + "platform": self.platform, # 平台标记为 internal # 其他 user_info 字段按需补充 }, - "platform": self.platform, # 平台 + "platform": self.platform, # 平台 # 其他 message_info 字段按需补充 }, "message_segment": { - "type": "text", # 消息类型 - "data": message_text, # 消息内容 + "type": "text", # 消息类型 + "data": message_text, # 消息内容 # 其他 segment 字段按需补充 }, - "raw_message": message_text, # 原始消息内容 - "processed_plain_text": message_text, # 处理后的纯文本 + "raw_message": message_text, # 原始消息内容 + "processed_plain_text": message_text, # 处理后的纯文本 # 下面这些字段可选,根据 MessageRecv 需要 "is_emoji": False, "has_emoji": False, @@ -139,45 +132,36 @@ class MaiThinking: "priority_info": {"message_priority": 10.0}, # 内部消息可设高优先级 "interest_value": 1.0, } - + if self.is_group: message_dict["message_info"]["group_info"] = { "platform": self.platform, "group_id": self.chat_stream.group_info.group_id, "group_name": self.chat_stream.group_info.group_name, } - + msg_recv = MessageRecvS4U(message_dict) msg_recv.chat_info = self.chat_info msg_recv.chat_stream = self.chat_stream msg_recv.is_internal = True - + return msg_recv - - - + class MaiThinkingManager: def __init__(self): self.mai_think_list = [] - - def get_mai_think(self,chat_id): + + def get_mai_think(self, chat_id): for mai_think in self.mai_think_list: if mai_think.chat_id == chat_id: return mai_think mai_think = MaiThinking(chat_id) self.mai_think_list.append(mai_think) return mai_think - + + mai_thinking_manager = MaiThinkingManager() - + init_prompt() - - - - - - - - diff --git a/src/mais4u/mais4u_chat/body_emotion_action_manager.py b/src/mais4u/mais4u_chat/body_emotion_action_manager.py index e7380822d..8e05a025e 100644 --- a/src/mais4u/mais4u_chat/body_emotion_action_manager.py +++ b/src/mais4u/mais4u_chat/body_emotion_action_manager.py @@ -1,14 +1,16 @@ import json import time + +from json_repair import repair_json from src.chat.message_receive.message import MessageRecv from src.llm_models.utils_model import LLMRequest from src.common.logger import get_logger from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive -from src.config.config import global_config +from src.config.config import global_config, model_config from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.manager.async_task_manager import AsyncTask, async_task_manager from src.plugin_system.apis import send_api -from json_repair import repair_json + from src.mais4u.s4u_config import s4u_config logger = get_logger("action") @@ -32,7 +34,7 @@ BODY_CODE = { "帅气的姿势": "010_0190", "另一个帅气的姿势": "010_0191", "手掌朝前可爱": "010_0210", - "平静,双手后放":"平静,双手后放", + "平静,双手后放": "平静,双手后放", "思考": "思考", "优雅,左手放在腰上": "优雅,左手放在腰上", "一般": "一般", @@ -94,19 +96,15 @@ class ChatAction: self.body_action_cooldown: dict[str, int] = {} print(s4u_config.models.motion) - print(global_config.model.emotion) - - self.action_model = LLMRequest( - model=global_config.model.emotion, - temperature=0.7, - request_type="motion", - ) + print(model_config.model_task_config.emotion) - self.last_change_time = 0 + self.action_model = LLMRequest(model_set=model_config.model_task_config.emotion, request_type="motion") + + self.last_change_time: float = 0 async def send_action_update(self): """发送动作更新到前端""" - + body_code = BODY_CODE.get(self.body_action, "") await send_api.custom_to_stream( message_type="body_action", @@ -115,13 +113,11 @@ class ChatAction: storage_message=False, show_log=True, ) - - async def update_action_by_message(self, message: MessageRecv): self.regression_count = 0 - message_time = message.message_info.time + message_time: float = message.message_info.time # type: ignore message_list_before_now = get_raw_msg_by_timestamp_with_chat_inclusive( chat_id=self.chat_id, timestamp_start=self.last_change_time, @@ -147,13 +143,13 @@ class ChatAction: prompt_personality = global_config.personality.personality_core indentify_block = f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}:" - + try: # 冷却池处理:过滤掉冷却中的动作 self._update_body_action_cooldown() available_actions = [k for k in BODY_CODE.keys() if k not in self.body_action_cooldown] all_actions = "\n".join(available_actions) - + prompt = await global_prompt_manager.format_prompt( "change_action_prompt", chat_talking_prompt=chat_talking_prompt, @@ -163,19 +159,18 @@ class ChatAction: ) logger.info(f"prompt: {prompt}") - response, (reasoning_content, model_name) = await self.action_model.generate_response_async(prompt=prompt) + response, (reasoning_content, _, _) = await self.action_model.generate_response_async( + prompt=prompt, temperature=0.7 + ) logger.info(f"response: {response}") logger.info(f"reasoning_content: {reasoning_content}") - action_data = json.loads(repair_json(response)) - - if action_data: + if action_data := json.loads(repair_json(response)): # 记录原动作,切换后进入冷却 prev_body_action = self.body_action new_body_action = action_data.get("body_action", self.body_action) - if new_body_action != prev_body_action: - if prev_body_action: - self.body_action_cooldown[prev_body_action] = 3 + if new_body_action != prev_body_action and prev_body_action: + self.body_action_cooldown[prev_body_action] = 3 self.body_action = new_body_action self.head_action = action_data.get("head_action", self.head_action) # 发送动作更新 @@ -213,7 +208,6 @@ class ChatAction: prompt_personality = global_config.personality.personality_core indentify_block = f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}:" try: - # 冷却池处理:过滤掉冷却中的动作 self._update_body_action_cooldown() available_actions = [k for k in BODY_CODE.keys() if k not in self.body_action_cooldown] @@ -228,17 +222,17 @@ class ChatAction: ) logger.info(f"prompt: {prompt}") - response, (reasoning_content, model_name) = await self.action_model.generate_response_async(prompt=prompt) + response, (reasoning_content, _, _) = await self.action_model.generate_response_async( + prompt=prompt, temperature=0.7 + ) logger.info(f"response: {response}") logger.info(f"reasoning_content: {reasoning_content}") - action_data = json.loads(repair_json(response)) - if action_data: + if action_data := json.loads(repair_json(response)): prev_body_action = self.body_action new_body_action = action_data.get("body_action", self.body_action) - if new_body_action != prev_body_action: - if prev_body_action: - self.body_action_cooldown[prev_body_action] = 6 + if new_body_action != prev_body_action and prev_body_action: + self.body_action_cooldown[prev_body_action] = 6 self.body_action = new_body_action # 发送动作更新 await self.send_action_update() @@ -306,9 +300,6 @@ class ActionManager: return new_action_state - - - init_prompt() action_manager = ActionManager() diff --git a/src/mais4u/mais4u_chat/s4u_chat.py b/src/mais4u/mais4u_chat/s4u_chat.py index e447ae193..78df5e98a 100644 --- a/src/mais4u/mais4u_chat/s4u_chat.py +++ b/src/mais4u/mais4u_chat/s4u_chat.py @@ -137,7 +137,7 @@ class MessageSenderContainer: await self.storage.store_message(bot_message, self.chat_stream) except Exception as e: - logger.error(f"[{self.chat_stream.get_stream_name()}] 消息发送或存储时出现错误: {e}", exc_info=True) + logger.error(f"[消息流: {self.chat_stream.stream_id}] 消息发送或存储时出现错误: {e}", exc_info=True) finally: # CRUCIAL: Always call task_done() for any item that was successfully retrieved. diff --git a/src/mais4u/mais4u_chat/s4u_mood_manager.py b/src/mais4u/mais4u_chat/s4u_mood_manager.py index c936cea17..11d8c7ca5 100644 --- a/src/mais4u/mais4u_chat/s4u_mood_manager.py +++ b/src/mais4u/mais4u_chat/s4u_mood_manager.py @@ -6,7 +6,7 @@ from src.chat.message_receive.message import MessageRecv from src.llm_models.utils_model import LLMRequest from src.common.logger import get_logger from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive -from src.config.config import global_config +from src.config.config import global_config, model_config from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.manager.async_task_manager import AsyncTask, async_task_manager from src.plugin_system.apis import send_api @@ -114,18 +114,12 @@ class ChatMood: self.regression_count: int = 0 - self.mood_model = LLMRequest( - model=global_config.model.emotion, - temperature=0.7, - request_type="mood_text", - ) + self.mood_model = LLMRequest(model_set=model_config.model_task_config.emotion, request_type="mood_text") self.mood_model_numerical = LLMRequest( - model=global_config.model.emotion, - temperature=0.4, - request_type="mood_numerical", + model_set=model_config.model_task_config.emotion, request_type="mood_numerical" ) - self.last_change_time = 0 + self.last_change_time: float = 0 # 发送初始情绪状态到ws端 asyncio.create_task(self.send_emotion_update(self.mood_values)) @@ -164,7 +158,7 @@ class ChatMood: async def update_mood_by_message(self, message: MessageRecv): self.regression_count = 0 - message_time = message.message_info.time + message_time: float = message.message_info.time # type: ignore message_list_before_now = get_raw_msg_by_timestamp_with_chat_inclusive( chat_id=self.chat_id, timestamp_start=self.last_change_time, @@ -199,7 +193,9 @@ class ChatMood: mood_state=self.mood_state, ) logger.debug(f"text mood prompt: {prompt}") - response, (reasoning_content, model_name) = await self.mood_model.generate_response_async(prompt=prompt) + response, (reasoning_content, _, _) = await self.mood_model.generate_response_async( + prompt=prompt, temperature=0.7 + ) logger.info(f"text mood response: {response}") logger.debug(f"text mood reasoning_content: {reasoning_content}") return response @@ -216,8 +212,8 @@ class ChatMood: fear=self.mood_values["fear"], ) logger.debug(f"numerical mood prompt: {prompt}") - response, (reasoning_content, model_name) = await self.mood_model_numerical.generate_response_async( - prompt=prompt + response, (reasoning_content, _, _) = await self.mood_model_numerical.generate_response_async( + prompt=prompt, temperature=0.4 ) logger.info(f"numerical mood response: {response}") logger.debug(f"numerical mood reasoning_content: {reasoning_content}") @@ -276,7 +272,9 @@ class ChatMood: mood_state=self.mood_state, ) logger.debug(f"text regress prompt: {prompt}") - response, (reasoning_content, model_name) = await self.mood_model.generate_response_async(prompt=prompt) + response, (reasoning_content, _, _) = await self.mood_model.generate_response_async( + prompt=prompt, temperature=0.7 + ) logger.info(f"text regress response: {response}") logger.debug(f"text regress reasoning_content: {reasoning_content}") return response @@ -293,8 +291,9 @@ class ChatMood: fear=self.mood_values["fear"], ) logger.debug(f"numerical regress prompt: {prompt}") - response, (reasoning_content, model_name) = await self.mood_model_numerical.generate_response_async( - prompt=prompt + response, (reasoning_content, _, _) = await self.mood_model_numerical.generate_response_async( + prompt=prompt, + temperature=0.4, ) logger.info(f"numerical regress response: {response}") logger.debug(f"numerical regress reasoning_content: {reasoning_content}") @@ -447,6 +446,7 @@ class MoodManager: # 发送初始情绪状态到ws端 asyncio.create_task(new_mood.send_emotion_update(new_mood.mood_values)) + if ENABLE_S4U: init_prompt() mood_manager = MoodManager() diff --git a/src/mais4u/mais4u_chat/s4u_prompt.py b/src/mais4u/mais4u_chat/s4u_prompt.py index d748c25e5..72324d744 100644 --- a/src/mais4u/mais4u_chat/s4u_prompt.py +++ b/src/mais4u/mais4u_chat/s4u_prompt.py @@ -150,19 +150,18 @@ class PromptBuilder: relation_prompt = "" if global_config.relationship.enable_relationship and who_chat_in_group: relationship_fetcher = relationship_fetcher_manager.get_fetcher(chat_stream.stream_id) - + # 将 (platform, user_id, nickname) 转换为 person_id person_ids = [] for person in who_chat_in_group: person_id = PersonInfoManager.get_person_id(person[0], person[1]) person_ids.append(person_id) - + # 使用 RelationshipFetcher 的 build_relation_info 方法,设置 points_num=3 保持与原来相同的行为 relation_info_list = await asyncio.gather( *[relationship_fetcher.build_relation_info(person_id, points_num=3) for person_id in person_ids] ) - relation_info = "".join(relation_info_list) - if relation_info: + if relation_info := "".join(relation_info_list): relation_prompt = await global_prompt_manager.format_prompt( "relation_prompt", relation_info=relation_info ) @@ -186,9 +185,9 @@ class PromptBuilder: timestamp=time.time(), limit=300, ) - - talk_type = message.message_info.platform + ":" + str(message.chat_stream.user_info.user_id) + + talk_type = f"{message.message_info.platform}:{str(message.chat_stream.user_info.user_id)}" core_dialogue_list = [] background_dialogue_list = [] @@ -258,19 +257,19 @@ class PromptBuilder: all_msg_seg_list.append(msg_seg_str) for msg in all_msg_seg_list: core_msg_str += msg - - + + all_dialogue_prompt = get_raw_msg_before_timestamp_with_chat( chat_id=chat_stream.stream_id, timestamp=time.time(), limit=20, - ) + ) all_dialogue_prompt_str = build_readable_messages( all_dialogue_prompt, timestamp_mode="normal_no_YMD", show_pic=False, ) - + return core_msg_str, background_dialogue_prompt,all_dialogue_prompt_str diff --git a/src/mais4u/mais4u_chat/s4u_stream_generator.py b/src/mais4u/mais4u_chat/s4u_stream_generator.py index 339b46c33..c0ca26581 100644 --- a/src/mais4u/mais4u_chat/s4u_stream_generator.py +++ b/src/mais4u/mais4u_chat/s4u_stream_generator.py @@ -1,7 +1,7 @@ import os from typing import AsyncGenerator from src.mais4u.openai_client import AsyncOpenAIClient -from src.config.config import global_config +from src.config.config import global_config, model_config from src.chat.message_receive.message import MessageRecvS4U from src.mais4u.mais4u_chat.s4u_prompt import prompt_builder from src.common.logger import get_logger @@ -14,24 +14,27 @@ logger = get_logger("s4u_stream_generator") class S4UStreamGenerator: def __init__(self): - replyer_1_config = global_config.model.replyer_1 - provider = replyer_1_config.get("provider") - if not provider: - logger.error("`replyer_1` 在配置文件中缺少 `provider` 字段") - raise ValueError("`replyer_1` 在配置文件中缺少 `provider` 字段") + replyer_1_config = model_config.model_task_config.replyer_1 + model_to_use = replyer_1_config.model_list[0] + model_info = model_config.get_model_info(model_to_use) + if not model_info: + logger.error(f"模型 {model_to_use} 在配置中未找到") + raise ValueError(f"模型 {model_to_use} 在配置中未找到") + provider_name = model_info.api_provider + provider_info = model_config.get_provider(provider_name) + if not provider_info: + logger.error("`replyer_1` 找不到对应的Provider") + raise ValueError("`replyer_1` 找不到对应的Provider") - api_key = os.environ.get(f"{provider.upper()}_KEY") - base_url = os.environ.get(f"{provider.upper()}_BASE_URL") + api_key = provider_info.api_key + base_url = provider_info.base_url if not api_key: - logger.error(f"环境变量 {provider.upper()}_KEY 未设置") - raise ValueError(f"环境变量 {provider.upper()}_KEY 未设置") + logger.error(f"{provider_name}没有配置API KEY") + raise ValueError(f"{provider_name}没有配置API KEY") self.client_1 = AsyncOpenAIClient(api_key=api_key, base_url=base_url) - self.model_1_name = replyer_1_config.get("name") - if not self.model_1_name: - logger.error("`replyer_1` 在配置文件中缺少 `model_name` 字段") - raise ValueError("`replyer_1` 在配置文件中缺少 `model_name` 字段") + self.model_1_name = model_to_use self.replyer_1_config = replyer_1_config self.current_model_name = "unknown model" @@ -44,10 +47,10 @@ class S4UStreamGenerator: r'[^.。!??!\n\r]+(?:[.。!??!\n\r](?![\'"])|$))', # 匹配直到句子结束符 re.UNICODE | re.DOTALL, ) - - self.chat_stream =None - - async def build_last_internal_message(self,message:MessageRecvS4U,previous_reply_context:str = ""): + + self.chat_stream = None + + async def build_last_internal_message(self, message: MessageRecvS4U, previous_reply_context: str = ""): # person_id = PersonInfoManager.get_person_id( # message.chat_stream.user_info.platform, message.chat_stream.user_info.user_id # ) @@ -71,14 +74,10 @@ class S4UStreamGenerator: [这是用户发来的新消息, 你需要结合上下文,对此进行回复]: {message.processed_plain_text} """ - return True,message_txt + return True, message_txt else: message_txt = message.processed_plain_text - return False,message_txt - - - - + return False, message_txt async def generate_response( self, message: MessageRecvS4U, previous_reply_context: str = "" @@ -88,7 +87,7 @@ class S4UStreamGenerator: self.partial_response = "" message_txt = message.processed_plain_text if not message.is_internal: - interupted,message_txt_added = await self.build_last_internal_message(message,previous_reply_context) + interupted, message_txt_added = await self.build_last_internal_message(message, previous_reply_context) if interupted: message_txt = message_txt_added @@ -105,7 +104,6 @@ class S4UStreamGenerator: current_client = self.client_1 self.current_model_name = self.model_1_name - extra_kwargs = {} if self.replyer_1_config.get("enable_thinking") is not None: extra_kwargs["enable_thinking"] = self.replyer_1_config.get("enable_thinking") diff --git a/src/mais4u/mais4u_chat/super_chat_manager.py b/src/mais4u/mais4u_chat/super_chat_manager.py index 528eaecca..a08d18cd0 100644 --- a/src/mais4u/mais4u_chat/super_chat_manager.py +++ b/src/mais4u/mais4u_chat/super_chat_manager.py @@ -214,51 +214,49 @@ class SuperChatManager: def build_superchat_display_string(self, chat_id: str, max_count: int = 10) -> str: """构建SuperChat显示字符串""" superchats = self.get_superchats_by_chat(chat_id) - + if not superchats: return "" - + # 限制显示数量 display_superchats = superchats[:max_count] - - lines = [] - lines.append("📢 当前有效超级弹幕:") - + + lines = ["📢 当前有效超级弹幕:"] for i, sc in enumerate(display_superchats, 1): remaining_minutes = int(sc.remaining_time() / 60) remaining_seconds = int(sc.remaining_time() % 60) - + time_display = f"{remaining_minutes}分{remaining_seconds}秒" if remaining_minutes > 0 else f"{remaining_seconds}秒" - + line = f"{i}. 【{sc.price}元】{sc.user_nickname}: {sc.message_text}" if len(line) > 100: # 限制单行长度 - line = line[:97] + "..." + line = f"{line[:97]}..." line += f" (剩余{time_display})" lines.append(line) - + if len(superchats) > max_count: lines.append(f"... 还有{len(superchats) - max_count}条SuperChat") - + return "\n".join(lines) def build_superchat_summary_string(self, chat_id: str) -> str: """构建SuperChat摘要字符串""" superchats = self.get_superchats_by_chat(chat_id) - + if not superchats: return "当前没有有效的超级弹幕" lines = [] for sc in superchats: single_sc_str = f"{sc.user_nickname} - {sc.price}元 - {sc.message_text}" if len(single_sc_str) > 100: - single_sc_str = single_sc_str[:97] + "..." + single_sc_str = f"{single_sc_str[:97]}..." single_sc_str += f" (剩余{int(sc.remaining_time())}秒)" lines.append(single_sc_str) - + total_amount = sum(sc.price for sc in superchats) count = len(superchats) highest_amount = max(sc.price for sc in superchats) - + final_str = f"当前有{count}条超级弹幕,总金额{total_amount}元,最高单笔{highest_amount}元" if lines: final_str += "\n" + "\n".join(lines) @@ -287,7 +285,7 @@ class SuperChatManager: "lowest_amount": min(amounts) } - async def shutdown(self): + async def shutdown(self): # sourcery skip: use-contextlib-suppress """关闭管理器,清理资源""" if self._cleanup_task and not self._cleanup_task.done(): self._cleanup_task.cancel() @@ -300,6 +298,7 @@ class SuperChatManager: +# sourcery skip: assign-if-exp if ENABLE_S4U: super_chat_manager = SuperChatManager() else: diff --git a/src/mais4u/mais4u_chat/yes_or_no.py b/src/mais4u/mais4u_chat/yes_or_no.py index edc200f65..c71c160d3 100644 --- a/src/mais4u/mais4u_chat/yes_or_no.py +++ b/src/mais4u/mais4u_chat/yes_or_no.py @@ -1,19 +1,14 @@ from src.llm_models.utils_model import LLMRequest from src.common.logger import get_logger -from src.config.config import global_config +from src.config.config import model_config from src.plugin_system.apis import send_api + logger = get_logger(__name__) -head_actions_list = [ - "不做额外动作", - "点头一次", - "点头两次", - "摇头", - "歪脑袋", - "低头望向一边" -] +head_actions_list = ["不做额外动作", "点头一次", "点头两次", "摇头", "歪脑袋", "低头望向一边"] -async def yes_or_no_head(text: str,emotion: str = "",chat_history: str = "",chat_id: str = ""): + +async def yes_or_no_head(text: str, emotion: str = "", chat_history: str = "", chat_id: str = ""): prompt = f""" {chat_history} 以上是对方的发言: @@ -30,22 +25,14 @@ async def yes_or_no_head(text: str,emotion: str = "",chat_history: str = "",chat 低头望向一边 请从上面的动作中选择一个,并输出,请只输出你选择的动作就好,不要输出其他内容。""" - model = LLMRequest( - model=global_config.model.emotion, - temperature=0.7, - request_type="motion", - ) - + model = LLMRequest(model_set=model_config.model_task_config.emotion, request_type="motion") + try: # logger.info(f"prompt: {prompt}") - response, (reasoning_content, model_name) = await model.generate_response_async(prompt=prompt) + response, _ = await model.generate_response_async(prompt=prompt, temperature=0.7) logger.info(f"response: {response}") - - if response in head_actions_list: - head_action = response - else: - head_action = "不做额外动作" - + + head_action = response if response in head_actions_list else "不做额外动作" await send_api.custom_to_stream( message_type="head_action", content=head_action, @@ -53,11 +40,7 @@ async def yes_or_no_head(text: str,emotion: str = "",chat_history: str = "",chat storage_message=False, show_log=True, ) - - - + except Exception as e: logger.error(f"yes_or_no_head error: {e}") return "不做额外动作" - - diff --git a/src/mood/mood_manager.py b/src/mood/mood_manager.py index eae0ea713..8daf38e65 100644 --- a/src/mood/mood_manager.py +++ b/src/mood/mood_manager.py @@ -3,13 +3,14 @@ import random import time from src.common.logger import get_logger -from src.config.config import global_config +from src.config.config import global_config, model_config from src.chat.message_receive.message import MessageRecv +from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive from src.llm_models.utils_model import LLMRequest from src.manager.async_task_manager import AsyncTask, async_task_manager -from src.chat.message_receive.chat_stream import get_chat_manager + logger = get_logger("mood") @@ -49,7 +50,7 @@ class ChatMood: chat_manager = get_chat_manager() self.chat_stream = chat_manager.get_stream(self.chat_id) - + if not self.chat_stream: raise ValueError(f"Chat stream for chat_id {chat_id} not found") @@ -59,11 +60,7 @@ class ChatMood: self.regression_count: int = 0 - self.mood_model = LLMRequest( - model=global_config.model.emotion, - temperature=0.7, - request_type="mood", - ) + self.mood_model = LLMRequest(model_set=model_config.model_task_config.emotion, request_type="mood") self.last_change_time: float = 0 @@ -83,12 +80,16 @@ class ChatMood: logger.debug( f"base_probability: {base_probability}, time_multiplier: {time_multiplier}, interest_multiplier: {interest_multiplier}" ) - update_probability = global_config.mood.mood_update_threshold * min(1.0, base_probability * time_multiplier * interest_multiplier) + update_probability = global_config.mood.mood_update_threshold * min( + 1.0, base_probability * time_multiplier * interest_multiplier + ) if random.random() > update_probability: return - logger.debug(f"{self.log_prefix} 更新情绪状态,感兴趣度: {interested_rate:.2f}, 更新概率: {update_probability:.2f}") + logger.debug( + f"{self.log_prefix} 更新情绪状态,感兴趣度: {interested_rate:.2f}, 更新概率: {update_probability:.2f}" + ) message_time: float = message.message_info.time # type: ignore message_list_before_now = get_raw_msg_by_timestamp_with_chat_inclusive( @@ -124,7 +125,9 @@ class ChatMood: mood_state=self.mood_state, ) - response, (reasoning_content, model_name) = await self.mood_model.generate_response_async(prompt=prompt) + response, (reasoning_content, _, _) = await self.mood_model.generate_response_async( + prompt=prompt, temperature=0.7 + ) if global_config.debug.show_prompt: logger.info(f"{self.log_prefix} prompt: {prompt}") logger.info(f"{self.log_prefix} response: {response}") @@ -171,7 +174,9 @@ class ChatMood: mood_state=self.mood_state, ) - response, (reasoning_content, model_name) = await self.mood_model.generate_response_async(prompt=prompt) + response, (reasoning_content, _, _) = await self.mood_model.generate_response_async( + prompt=prompt, temperature=0.7 + ) if global_config.debug.show_prompt: logger.info(f"{self.log_prefix} prompt: {prompt}") diff --git a/src/person_info/person_info.py b/src/person_info/person_info.py index 6be0ad277..4d5fe709c 100644 --- a/src/person_info/person_info.py +++ b/src/person_info/person_info.py @@ -11,7 +11,7 @@ from src.common.logger import get_logger from src.common.database.database import db from src.common.database.database_model import PersonInfo from src.llm_models.utils_model import LLMRequest -from src.config.config import global_config +from src.config.config import global_config, model_config """ @@ -54,11 +54,7 @@ person_info_default = { class PersonInfoManager: def __init__(self): self.person_name_list = {} - # TODO: API-Adapter修改标记 - self.qv_name_llm = LLMRequest( - model=global_config.model.utils, - request_type="relation.qv_name", - ) + self.qv_name_llm = LLMRequest(model_set=model_config.model_task_config.utils, request_type="relation.qv_name") try: db.connect(reuse_if_open=True) # 设置连接池参数 @@ -199,7 +195,7 @@ class PersonInfoManager: if existing: logger.debug(f"用户 {p_data['person_id']} 已存在,跳过创建") return True - + # 尝试创建 PersonInfo.create(**p_data) return True @@ -376,7 +372,7 @@ class PersonInfoManager: "nickname": "昵称", "reason": "理由" }""" - response, (reasoning_content, model_name) = await self.qv_name_llm.generate_response_async(qv_name_prompt) + response, _ = await self.qv_name_llm.generate_response_async(qv_name_prompt) # logger.info(f"取名提示词:{qv_name_prompt}\n取名回复:{response}") result = self._extract_json_from_text(response) @@ -592,7 +588,7 @@ class PersonInfoManager: record = PersonInfo.get_or_none(PersonInfo.person_id == p_id) if record: return record, False # 记录存在,未创建 - + # 记录不存在,尝试创建 try: PersonInfo.create(**init_data) @@ -622,7 +618,7 @@ class PersonInfoManager: "points": [], "forgotten_points": [], } - + # 序列化JSON字段 for key in JSON_SERIALIZED_FIELDS: if key in initial_data: @@ -630,12 +626,12 @@ class PersonInfoManager: initial_data[key] = json.dumps(initial_data[key], ensure_ascii=False) elif initial_data[key] is None: initial_data[key] = json.dumps([], ensure_ascii=False) - + model_fields = PersonInfo._meta.fields.keys() # type: ignore filtered_initial_data = {k: v for k, v in initial_data.items() if v is not None and k in model_fields} record, was_created = await asyncio.to_thread(_db_get_or_create_sync, person_id, filtered_initial_data) - + if was_created: logger.info(f"用户 {platform}:{user_id} (person_id: {person_id}) 不存在,将创建新记录 (Peewee)。") logger.info(f"已为 {person_id} 创建新记录,初始数据 (filtered for model): {filtered_initial_data}") diff --git a/src/person_info/relationship_fetcher.py b/src/person_info/relationship_fetcher.py index 99f3be303..267ed96f9 100644 --- a/src/person_info/relationship_fetcher.py +++ b/src/person_info/relationship_fetcher.py @@ -7,7 +7,7 @@ from typing import List, Dict, Any from json_repair import repair_json from src.common.logger import get_logger -from src.config.config import global_config +from src.config.config import global_config, model_config from src.llm_models.utils_model import LLMRequest from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.chat.message_receive.chat_stream import get_chat_manager @@ -73,14 +73,12 @@ class RelationshipFetcher: # LLM模型配置 self.llm_model = LLMRequest( - model=global_config.model.utils_small, - request_type="relation.fetcher", + model_set=model_config.model_task_config.utils_small, request_type="relation.fetcher" ) # 小模型用于即时信息提取 self.instant_llm_model = LLMRequest( - model=global_config.model.utils_small, - request_type="relation.fetch", + model_set=model_config.model_task_config.utils_small, request_type="relation.fetch" ) name = get_chat_manager().get_stream_name(self.chat_id) @@ -96,7 +94,7 @@ class RelationshipFetcher: if not self.info_fetched_cache[person_id]: del self.info_fetched_cache[person_id] - async def build_relation_info(self, person_id, points_num = 3): + async def build_relation_info(self, person_id, points_num=3): # 清理过期的信息缓存 self._cleanup_expired_cache() @@ -361,7 +359,6 @@ class RelationshipFetcher: logger.error(f"{self.log_prefix} 执行信息提取时出错: {e}") logger.error(traceback.format_exc()) - async def _save_info_to_cache(self, person_id: str, info_type: str, info_content: str): # sourcery skip: use-next """将提取到的信息保存到 person_info 的 info_list 字段中 diff --git a/src/person_info/relationship_manager.py b/src/person_info/relationship_manager.py index 6c2693572..9d7a48b97 100644 --- a/src/person_info/relationship_manager.py +++ b/src/person_info/relationship_manager.py @@ -3,7 +3,7 @@ from .person_info import PersonInfoManager, get_person_info_manager import time import random from src.llm_models.utils_model import LLMRequest -from src.config.config import global_config +from src.config.config import global_config, model_config from src.chat.utils.chat_message_builder import build_readable_messages import json from json_repair import repair_json @@ -20,9 +20,8 @@ logger = get_logger("relation") class RelationshipManager: def __init__(self): self.relationship_llm = LLMRequest( - model=global_config.model.utils, - request_type="relationship", # 用于动作规划 - ) + model_set=model_config.model_task_config.utils, request_type="relationship" + ) # 用于动作规划 @staticmethod async def is_known_some_one(platform, user_id): @@ -181,18 +180,14 @@ class RelationshipManager: try: points = repair_json(points) points_data = json.loads(points) - + # 只处理正确的格式,错误格式直接跳过 if points_data == "none" or not points_data: points_list = [] elif isinstance(points_data, str) and points_data.lower() == "none": points_list = [] elif isinstance(points_data, list): - # 正确格式:数组格式 [{"point": "...", "weight": 10}, ...] - if not points_data: # 空数组 - points_list = [] - else: - points_list = [(item["point"], float(item["weight"]), current_time) for item in points_data] + points_list = [(item["point"], float(item["weight"]), current_time) for item in points_data] else: # 错误格式,直接跳过不解析 logger.warning(f"LLM返回了错误的JSON格式,跳过解析: {type(points_data)}, 内容: {points_data}") diff --git a/src/plugin_system/apis/generator_api.py b/src/plugin_system/apis/generator_api.py index f8752ac4e..2b7732f08 100644 --- a/src/plugin_system/apis/generator_api.py +++ b/src/plugin_system/apis/generator_api.py @@ -12,6 +12,7 @@ import traceback from typing import Tuple, Any, Dict, List, Optional from rich.traceback import install from src.common.logger import get_logger +from src.config.api_ada_configs import TaskConfig from src.chat.replyer.default_generator import DefaultReplyer from src.chat.message_receive.chat_stream import ChatStream from src.chat.utils.utils import process_llm_response @@ -31,7 +32,7 @@ logger = get_logger("generator_api") def get_replyer( chat_stream: Optional[ChatStream] = None, chat_id: Optional[str] = None, - model_configs: Optional[List[Dict[str, Any]]] = None, + model_set_with_weight: Optional[List[Tuple[TaskConfig, float]]] = None, request_type: str = "replyer", ) -> Optional[DefaultReplyer]: """获取回复器对象 @@ -58,7 +59,7 @@ def get_replyer( return replyer_manager.get_replyer( chat_stream=chat_stream, chat_id=chat_id, - model_configs=model_configs, + model_set_with_weight=model_set_with_weight, request_type=request_type, ) except Exception as e: @@ -83,7 +84,7 @@ async def generate_reply( enable_splitter: bool = True, enable_chinese_typo: bool = True, return_prompt: bool = False, - model_configs: Optional[List[Dict[str, Any]]] = None, + model_set_with_weight: Optional[List[Tuple[TaskConfig, float]]] = None, request_type: str = "generator_api", ) -> Tuple[bool, List[Tuple[str, Any]], Optional[str]]: """生成回复 @@ -106,7 +107,7 @@ async def generate_reply( """ try: # 获取回复器 - replyer = get_replyer(chat_stream, chat_id, model_configs=model_configs, request_type=request_type) + replyer = get_replyer(chat_stream, chat_id, model_set_with_weight=model_set_with_weight, request_type=request_type) if not replyer: logger.error("[GeneratorAPI] 无法获取回复器") return False, [], None @@ -154,7 +155,7 @@ async def rewrite_reply( chat_id: Optional[str] = None, enable_splitter: bool = True, enable_chinese_typo: bool = True, - model_configs: Optional[List[Dict[str, Any]]] = None, + model_set_with_weight: Optional[List[Tuple[TaskConfig, float]]] = None, raw_reply: str = "", reason: str = "", reply_to: str = "", @@ -179,7 +180,7 @@ async def rewrite_reply( """ try: # 获取回复器 - replyer = get_replyer(chat_stream, chat_id, model_configs=model_configs) + replyer = get_replyer(chat_stream, chat_id, model_set_with_weight=model_set_with_weight) if not replyer: logger.error("[GeneratorAPI] 无法获取回复器") return False, [], None @@ -245,17 +246,17 @@ async def process_human_text(content: str, enable_splitter: bool, enable_chinese async def generate_response_custom( chat_stream: Optional[ChatStream] = None, chat_id: Optional[str] = None, - model_configs: Optional[List[Dict[str, Any]]] = None, + model_set_with_weight: Optional[List[Tuple[TaskConfig, float]]] = None, prompt: str = "", ) -> Optional[str]: - replyer = get_replyer(chat_stream, chat_id, model_configs=model_configs) + replyer = get_replyer(chat_stream, chat_id, model_set_with_weight=model_set_with_weight) if not replyer: logger.error("[GeneratorAPI] 无法获取回复器") return None try: logger.debug("[GeneratorAPI] 开始生成自定义回复") - response = await replyer.llm_generate_content(prompt) + response, _, _, _ = await replyer.llm_generate_content(prompt) if response: logger.debug("[GeneratorAPI] 自定义回复生成成功") return response diff --git a/src/plugin_system/apis/llm_api.py b/src/plugin_system/apis/llm_api.py index 4e9d884fa..eaf48556b 100644 --- a/src/plugin_system/apis/llm_api.py +++ b/src/plugin_system/apis/llm_api.py @@ -7,10 +7,11 @@ success, response, reasoning, model_name = await llm_api.generate_with_model(prompt, model_config) """ -from typing import Tuple, Dict, Any +from typing import Tuple, Dict from src.common.logger import get_logger from src.llm_models.utils_model import LLMRequest -from src.config.config import global_config +from src.config.config import global_config, model_config +from src.config.api_ada_configs import TaskConfig logger = get_logger("llm_api") @@ -19,9 +20,7 @@ logger = get_logger("llm_api") # ============================================================================= - - -def get_available_models() -> Dict[str, Any]: +def get_available_models() -> Dict[str, TaskConfig]: """获取所有可用的模型配置 Returns: @@ -33,14 +32,14 @@ def get_available_models() -> Dict[str, Any]: return {} # 自动获取所有属性并转换为字典形式 - rets = {} - models = global_config.model + models = model_config.model_task_config attrs = dir(models) + rets: Dict[str, TaskConfig] = {} for attr in attrs: if not attr.startswith("__"): try: value = getattr(models, attr) - if not callable(value): # 排除方法 + if not callable(value) and isinstance(value, TaskConfig): rets[attr] = value except Exception as e: logger.debug(f"[LLMAPI] 获取属性 {attr} 失败: {e}") @@ -53,8 +52,8 @@ def get_available_models() -> Dict[str, Any]: async def generate_with_model( - prompt: str, model_config: Dict[str, Any], request_type: str = "plugin.generate", **kwargs -) -> Tuple[bool, str]: + prompt: str, model_config: TaskConfig, request_type: str = "plugin.generate", **kwargs +) -> Tuple[bool, str, str, str]: """使用指定模型生成内容 Args: @@ -67,17 +66,16 @@ async def generate_with_model( Tuple[bool, str, str, str]: (是否成功, 生成的内容, 推理过程, 模型名称) """ try: - model_name = model_config.get("name") - logger.info(f"[LLMAPI] 使用模型 {model_name} 生成内容") + model_name_list = model_config.model_list + logger.info(f"[LLMAPI] 使用模型集合 {model_name_list} 生成内容") logger.debug(f"[LLMAPI] 完整提示词: {prompt}") - llm_request = LLMRequest(model=model_config, request_type=request_type, **kwargs) + llm_request = LLMRequest(model_set=model_config, request_type=request_type, **kwargs) - # TODO: 复活这个_ - response, _ = await llm_request.generate_response_async(prompt) - return True, response + response, (reasoning_content, model_name, _) = await llm_request.generate_response_async(prompt) + return True, response, reasoning_content, model_name except Exception as e: error_msg = f"生成内容时出错: {str(e)}" logger.error(f"[LLMAPI] {error_msg}") - return False, error_msg + return False, error_msg, "", "" diff --git a/src/plugin_system/apis/send_api.py b/src/plugin_system/apis/send_api.py index 46b3bddd7..10fbd804e 100644 --- a/src/plugin_system/apis/send_api.py +++ b/src/plugin_system/apis/send_api.py @@ -335,7 +335,7 @@ async def command_to_stream( async def custom_to_stream( message_type: str, - content: str, + content: str | dict, stream_id: str, display_message: str = "", typing: bool = False, diff --git a/src/plugin_system/core/tool_use.py b/src/plugin_system/core/tool_use.py index d7b86b8d6..a220161db 100644 --- a/src/plugin_system/core/tool_use.py +++ b/src/plugin_system/core/tool_use.py @@ -4,7 +4,7 @@ from typing import List, Dict, Tuple, Optional, Any from src.plugin_system.apis.tool_api import get_llm_available_tool_definitions, get_tool_instance from src.plugin_system.core.global_announcement_manager import global_announcement_manager from src.llm_models.utils_model import LLMRequest -from src.config.config import global_config +from src.config.config import global_config, model_config from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.chat.utils.json_utils import process_llm_tool_calls from src.chat.message_receive.chat_stream import get_chat_manager @@ -52,10 +52,7 @@ class ToolExecutor: self.chat_stream = get_chat_manager().get_stream(self.chat_id) self.log_prefix = f"[{get_chat_manager().get_stream_name(self.chat_id) or self.chat_id}]" - self.llm_model = LLMRequest( - model=global_config.model.tool_use, - request_type="tool_executor", - ) + self.llm_model = LLMRequest(model_set=model_config.model_task_config.tool_use, request_type="tool_executor") # 缓存配置 self.enable_cache = enable_cache @@ -137,7 +134,7 @@ class ToolExecutor: return tool_results, used_tools, prompt else: return tool_results, [], "" - + def _get_tool_definitions(self) -> List[Dict[str, Any]]: all_tools = get_llm_available_tool_definitions() user_disabled_tools = global_announcement_manager.get_disabled_chat_tools(self.chat_id) diff --git a/src/plugins/built_in/core_actions/emoji.py b/src/plugins/built_in/core_actions/emoji.py index 257686b18..790f2096e 100644 --- a/src/plugins/built_in/core_actions/emoji.py +++ b/src/plugins/built_in/core_actions/emoji.py @@ -58,6 +58,7 @@ class EmojiAction(BaseAction): associated_types = ["emoji"] async def execute(self) -> Tuple[bool, str]: + # sourcery skip: assign-if-exp, introduce-default-else, swap-if-else-branches, use-named-expression """执行表情动作""" logger.info(f"{self.log_prefix} 决定发送表情") @@ -120,7 +121,7 @@ class EmojiAction(BaseAction): logger.error(f"{self.log_prefix} 未找到'utils_small'模型配置,无法调用LLM") return False, "未找到'utils_small'模型配置" - success, chosen_emotion = await llm_api.generate_with_model( + success, chosen_emotion, _, _ = await llm_api.generate_with_model( prompt, model_config=chat_model_config, request_type="emoji" )