better:优化关系prompt,回退utils的修改
This commit is contained in:
@@ -124,20 +124,14 @@ class LLMRequest:
|
|||||||
self.model_name: str = model["name"]
|
self.model_name: str = model["name"]
|
||||||
self.params = kwargs
|
self.params = kwargs
|
||||||
|
|
||||||
self.enable_thinking = model.get("enable_thinking", None)
|
self.enable_thinking = model.get("enable_thinking", False)
|
||||||
self.temp = model.get("temp", 0.7)
|
self.temp = model.get("temp", 0.7)
|
||||||
self.thinking_budget = model.get("thinking_budget", None)
|
self.thinking_budget = model.get("thinking_budget", 4096)
|
||||||
self.stream = model.get("stream", False)
|
self.stream = model.get("stream", False)
|
||||||
self.pri_in = model.get("pri_in", 0)
|
self.pri_in = model.get("pri_in", 0)
|
||||||
self.pri_out = model.get("pri_out", 0)
|
self.pri_out = model.get("pri_out", 0)
|
||||||
self.max_tokens = model.get("max_tokens", global_config.model.model_max_output_length)
|
self.max_tokens = model.get("max_tokens", global_config.model.model_max_output_length)
|
||||||
# print(f"max_tokens: {self.max_tokens}")
|
# print(f"max_tokens: {self.max_tokens}")
|
||||||
custom_params_str = model.get("custom_params", "{}")
|
|
||||||
try:
|
|
||||||
self.custom_params = json.loads(custom_params_str)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
logger.error(f"Invalid JSON in custom_params for model '{self.model_name}': {custom_params_str}")
|
|
||||||
self.custom_params = {}
|
|
||||||
|
|
||||||
# 获取数据库实例
|
# 获取数据库实例
|
||||||
self._init_database()
|
self._init_database()
|
||||||
@@ -255,6 +249,28 @@ class LLMRequest:
|
|||||||
elif payload is None:
|
elif payload is None:
|
||||||
payload = await self._build_payload(prompt)
|
payload = await self._build_payload(prompt)
|
||||||
|
|
||||||
|
if stream_mode:
|
||||||
|
payload["stream"] = stream_mode
|
||||||
|
|
||||||
|
if self.temp != 0.7:
|
||||||
|
payload["temperature"] = self.temp
|
||||||
|
|
||||||
|
# 添加enable_thinking参数(如果不是默认值False)
|
||||||
|
if not self.enable_thinking:
|
||||||
|
payload["enable_thinking"] = False
|
||||||
|
|
||||||
|
if self.thinking_budget != 4096:
|
||||||
|
payload["thinking_budget"] = self.thinking_budget
|
||||||
|
|
||||||
|
if self.max_tokens:
|
||||||
|
payload["max_tokens"] = self.max_tokens
|
||||||
|
|
||||||
|
# if "max_tokens" not in payload and "max_completion_tokens" not in payload:
|
||||||
|
# payload["max_tokens"] = global_config.model.model_max_output_length
|
||||||
|
# 如果 payload 中依然存在 max_tokens 且需要转换,在这里进行再次检查
|
||||||
|
if self.model_name.lower() in self.MODELS_NEEDING_TRANSFORMATION and "max_tokens" in payload:
|
||||||
|
payload["max_completion_tokens"] = payload.pop("max_tokens")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"policy": policy,
|
"policy": policy,
|
||||||
"payload": payload,
|
"payload": payload,
|
||||||
@@ -654,16 +670,18 @@ class LLMRequest:
|
|||||||
if self.temp != 0.7:
|
if self.temp != 0.7:
|
||||||
payload["temperature"] = self.temp
|
payload["temperature"] = self.temp
|
||||||
|
|
||||||
# 仅当配置文件中存在参数时,添加对应参数
|
# 添加enable_thinking参数(如果不是默认值False)
|
||||||
if self.enable_thinking is not None:
|
if not self.enable_thinking:
|
||||||
payload["enable_thinking"] = self.enable_thinking
|
payload["enable_thinking"] = False
|
||||||
|
|
||||||
if self.thinking_budget is not None:
|
if self.thinking_budget != 4096:
|
||||||
payload["thinking_budget"] = self.thinking_budget
|
payload["thinking_budget"] = self.thinking_budget
|
||||||
|
|
||||||
if self.max_tokens:
|
if self.max_tokens:
|
||||||
payload["max_tokens"] = self.max_tokens
|
payload["max_tokens"] = self.max_tokens
|
||||||
|
|
||||||
|
# if "max_tokens" not in payload and "max_completion_tokens" not in payload:
|
||||||
|
# payload["max_tokens"] = global_config.model.model_max_output_length
|
||||||
# 如果 payload 中依然存在 max_tokens 且需要转换,在这里进行再次检查
|
# 如果 payload 中依然存在 max_tokens 且需要转换,在这里进行再次检查
|
||||||
if self.model_name.lower() in self.MODELS_NEEDING_TRANSFORMATION and "max_tokens" in payload:
|
if self.model_name.lower() in self.MODELS_NEEDING_TRANSFORMATION and "max_tokens" in payload:
|
||||||
payload["max_completion_tokens"] = payload.pop("max_tokens")
|
payload["max_completion_tokens"] = payload.pop("max_tokens")
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ from typing import List, Dict
|
|||||||
from json_repair import repair_json
|
from json_repair import repair_json
|
||||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||||
import json
|
import json
|
||||||
|
import random
|
||||||
|
|
||||||
logger = get_logger("relationship_fetcher")
|
logger = get_logger("relationship_fetcher")
|
||||||
|
|
||||||
@@ -101,22 +101,69 @@ class RelationshipFetcher:
|
|||||||
person_name = await person_info_manager.get_value(person_id, "person_name")
|
person_name = await person_info_manager.get_value(person_id, "person_name")
|
||||||
short_impression = await person_info_manager.get_value(person_id, "short_impression")
|
short_impression = await person_info_manager.get_value(person_id, "short_impression")
|
||||||
|
|
||||||
|
nickname_str = await person_info_manager.get_value(person_id, "nickname")
|
||||||
|
platform = await person_info_manager.get_value(person_id, "platform")
|
||||||
|
|
||||||
|
if person_name == nickname_str and not short_impression:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
current_points = await person_info_manager.get_value(person_id, "points") or []
|
||||||
|
|
||||||
|
if isinstance(current_points, str):
|
||||||
|
try:
|
||||||
|
current_points = json.loads(current_points)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.error(f"解析points JSON失败: {current_points}")
|
||||||
|
current_points = []
|
||||||
|
elif not isinstance(current_points, list):
|
||||||
|
current_points = []
|
||||||
|
|
||||||
|
# 按时间排序forgotten_points
|
||||||
|
current_points.sort(key=lambda x: x[2])
|
||||||
|
# 按权重加权随机抽取3个points,point[1]的值在1-10之间,权重越高被抽到概率越大
|
||||||
|
if len(current_points) > 3:
|
||||||
|
# point[1] 取值范围1-10,直接作为权重
|
||||||
|
weights = [max(1, min(10, int(point[1]))) for point in current_points]
|
||||||
|
points = random.choices(current_points, weights=weights, k=3)
|
||||||
|
else:
|
||||||
|
points = current_points
|
||||||
|
|
||||||
|
# 构建points文本
|
||||||
|
points_text = "\n".join([f"{point[2]}:{point[0]}" for point in points])
|
||||||
|
|
||||||
info_type = await self._build_fetch_query(person_id, target_message, chat_history)
|
info_type = await self._build_fetch_query(person_id, target_message, chat_history)
|
||||||
if info_type:
|
if info_type:
|
||||||
await self._extract_single_info(person_id, info_type, person_name)
|
await self._extract_single_info(person_id, info_type, person_name)
|
||||||
|
|
||||||
relation_info = self._organize_known_info()
|
relation_info = self._organize_known_info()
|
||||||
|
|
||||||
|
nickname_str = ""
|
||||||
|
if person_name != nickname_str:
|
||||||
|
nickname_str = f"(ta在{platform}上的昵称是{nickname_str})"
|
||||||
|
|
||||||
if short_impression and relation_info:
|
if short_impression and relation_info:
|
||||||
relation_info = f"你对{person_name}的印象是:{short_impression}。具体来说:{relation_info}"
|
if points_text:
|
||||||
|
relation_info = f"你对{person_name}的印象是{nickname_str}:{short_impression}。具体来说:{relation_info}。你还记得ta最近做的事:{points_text}"
|
||||||
|
else:
|
||||||
|
relation_info = f"你对{person_name}的印象是{nickname_str}:{short_impression}。具体来说:{relation_info}"
|
||||||
elif short_impression:
|
elif short_impression:
|
||||||
relation_info = f"你对{person_name}的印象是:{short_impression}"
|
if points_text:
|
||||||
|
relation_info = f"你对{person_name}的印象是{nickname_str}:{short_impression}。你还记得ta最近做的事:{points_text}"
|
||||||
|
else:
|
||||||
|
relation_info = f"你对{person_name}的印象是{nickname_str}:{short_impression}"
|
||||||
elif relation_info:
|
elif relation_info:
|
||||||
relation_info = f"你对{person_name}的了解:{relation_info}"
|
if points_text:
|
||||||
|
relation_info = f"你对{person_name}的了解{nickname_str}:{relation_info}。你还记得ta最近做的事:{points_text}"
|
||||||
|
else:
|
||||||
|
relation_info = f"你对{person_name}的了解{nickname_str}:{relation_info}"
|
||||||
|
elif points_text:
|
||||||
|
relation_info = f"你记得{person_name}{nickname_str}最近做的事:{points_text}"
|
||||||
else:
|
else:
|
||||||
relation_info = ""
|
relation_info = ""
|
||||||
|
|
||||||
return relation_info
|
return relation_info
|
||||||
|
|
||||||
|
|
||||||
async def _build_fetch_query(self, person_id, target_message, chat_history):
|
async def _build_fetch_query(self, person_id, target_message, chat_history):
|
||||||
nickname_str = ",".join(global_config.bot.alias_names)
|
nickname_str = ",".join(global_config.bot.alias_names)
|
||||||
name_block = f"你的名字是{global_config.bot.nickname},你的昵称有{nickname_str},有人也会用这些昵称称呼你。"
|
name_block = f"你的名字是{global_config.bot.nickname},你的昵称有{nickname_str},有人也会用这些昵称称呼你。"
|
||||||
|
|||||||
@@ -126,6 +126,7 @@ class RelationshipManager:
|
|||||||
short_impression = await person_info_manager.get_value(person_id, "short_impression")
|
short_impression = await person_info_manager.get_value(person_id, "short_impression")
|
||||||
|
|
||||||
current_points = await person_info_manager.get_value(person_id, "points") or []
|
current_points = await person_info_manager.get_value(person_id, "points") or []
|
||||||
|
print(f"current_points: {current_points}")
|
||||||
if isinstance(current_points, str):
|
if isinstance(current_points, str):
|
||||||
try:
|
try:
|
||||||
current_points = json.loads(current_points)
|
current_points = json.loads(current_points)
|
||||||
|
|||||||
Reference in New Issue
Block a user