This commit is contained in:
SnowindMe
2025-04-13 01:07:50 +08:00
45 changed files with 667 additions and 552 deletions

View File

@@ -1,9 +1,23 @@
name: Ruff name: Ruff
on: [ push, pull_request ] on: [ push, pull_request ]
permissions:
contents: write
jobs: jobs:
ruff: ruff:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- uses: astral-sh/ruff-action@v3 - uses: astral-sh/ruff-action@v3
- run: ruff check --fix
- run: ruff format
- name: Commit changes
if: success()
run: |
git config --local user.email "github-actions[bot]@users.noreply.github.com"
git config --local user.name "github-actions[bot]"
git add -A
git diff --quiet && git diff --staged --quiet || git commit -m "🤖 自动格式化代码 [skip ci]"
git push

View File

@@ -283,17 +283,13 @@ WILLING_STYLE_CONFIG = {
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 意愿 | {message}"), "file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 意愿 | {message}"),
}, },
"simple": { "simple": {
"console_format": ( "console_format": ("<green>{time:MM-DD HH:mm}</green> | <light-blue>意愿</light-blue> | {message}"), # noqa: E501
"<green>{time:MM-DD HH:mm}</green> | <light-blue>意愿</light-blue> | {message}"
), # noqa: E501
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 意愿 | {message}"), "file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 意愿 | {message}"),
}, },
} }
CONFIRM_STYLE_CONFIG = { CONFIRM_STYLE_CONFIG = {
"console_format": ( "console_format": ("<RED>{message}</RED>"), # noqa: E501
"<RED>{message}</RED>"
), # noqa: E501
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | EULA与PRIVACY确认 | {message}"), "file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | EULA与PRIVACY确认 | {message}"),
} }

View File

@@ -4,17 +4,17 @@ from src.do_tool.tool_can_use.base_tool import (
discover_tools, discover_tools,
get_all_tool_definitions, get_all_tool_definitions,
get_tool_instance, get_tool_instance,
TOOL_REGISTRY TOOL_REGISTRY,
) )
__all__ = [ __all__ = [
'BaseTool', "BaseTool",
'register_tool', "register_tool",
'discover_tools', "discover_tools",
'get_all_tool_definitions', "get_all_tool_definitions",
'get_tool_instance', "get_tool_instance",
'TOOL_REGISTRY' "TOOL_REGISTRY",
] ]
# 自动发现并注册工具 # 自动发现并注册工具
discover_tools() discover_tools()

View File

@@ -10,41 +10,39 @@ logger = get_module_logger("base_tool")
# 工具注册表 # 工具注册表
TOOL_REGISTRY = {} TOOL_REGISTRY = {}
class BaseTool: class BaseTool:
"""所有工具的基类""" """所有工具的基类"""
# 工具名称,子类必须重写 # 工具名称,子类必须重写
name = None name = None
# 工具描述,子类必须重写 # 工具描述,子类必须重写
description = None description = None
# 工具参数定义,子类必须重写 # 工具参数定义,子类必须重写
parameters = None parameters = None
@classmethod @classmethod
def get_tool_definition(cls) -> Dict[str, Any]: def get_tool_definition(cls) -> Dict[str, Any]:
"""获取工具定义用于LLM工具调用 """获取工具定义用于LLM工具调用
Returns: Returns:
Dict: 工具定义字典 Dict: 工具定义字典
""" """
if not cls.name or not cls.description or not cls.parameters: if not cls.name or not cls.description or not cls.parameters:
raise NotImplementedError(f"工具类 {cls.__name__} 必须定义 name, description 和 parameters 属性") raise NotImplementedError(f"工具类 {cls.__name__} 必须定义 name, description 和 parameters 属性")
return { return {
"type": "function", "type": "function",
"function": { "function": {"name": cls.name, "description": cls.description, "parameters": cls.parameters},
"name": cls.name,
"description": cls.description,
"parameters": cls.parameters
}
} }
async def execute(self, function_args: Dict[str, Any], message_txt: str = "") -> Dict[str, Any]: async def execute(self, function_args: Dict[str, Any], message_txt: str = "") -> Dict[str, Any]:
"""执行工具函数 """执行工具函数
Args: Args:
function_args: 工具调用参数 function_args: 工具调用参数
message_txt: 原始消息文本 message_txt: 原始消息文本
Returns: Returns:
Dict: 工具执行结果 Dict: 工具执行结果
""" """
@@ -53,17 +51,17 @@ class BaseTool:
def register_tool(tool_class: Type[BaseTool]): def register_tool(tool_class: Type[BaseTool]):
"""注册工具到全局注册表 """注册工具到全局注册表
Args: Args:
tool_class: 工具类 tool_class: 工具类
""" """
if not issubclass(tool_class, BaseTool): if not issubclass(tool_class, BaseTool):
raise TypeError(f"{tool_class.__name__} 不是 BaseTool 的子类") raise TypeError(f"{tool_class.__name__} 不是 BaseTool 的子类")
tool_name = tool_class.name tool_name = tool_class.name
if not tool_name: if not tool_name:
raise ValueError(f"工具类 {tool_class.__name__} 没有定义 name 属性") raise ValueError(f"工具类 {tool_class.__name__} 没有定义 name 属性")
TOOL_REGISTRY[tool_name] = tool_class TOOL_REGISTRY[tool_name] = tool_class
logger.info(f"已注册工具: {tool_name}") logger.info(f"已注册工具: {tool_name}")
@@ -73,27 +71,27 @@ def discover_tools():
# 获取当前目录路径 # 获取当前目录路径
current_dir = os.path.dirname(os.path.abspath(__file__)) current_dir = os.path.dirname(os.path.abspath(__file__))
package_name = os.path.basename(current_dir) package_name = os.path.basename(current_dir)
# 遍历包中的所有模块 # 遍历包中的所有模块
for _, module_name, _ in pkgutil.iter_modules([current_dir]): for _, module_name, _ in pkgutil.iter_modules([current_dir]):
# 跳过当前模块和__pycache__ # 跳过当前模块和__pycache__
if module_name == "base_tool" or module_name.startswith("__"): if module_name == "base_tool" or module_name.startswith("__"):
continue continue
# 导入模块 # 导入模块
module = importlib.import_module(f"src.do_tool.{package_name}.{module_name}") module = importlib.import_module(f"src.do_tool.{package_name}.{module_name}")
# 查找模块中的工具类 # 查找模块中的工具类
for _, obj in inspect.getmembers(module): for _, obj in inspect.getmembers(module):
if inspect.isclass(obj) and issubclass(obj, BaseTool) and obj != BaseTool: if inspect.isclass(obj) and issubclass(obj, BaseTool) and obj != BaseTool:
register_tool(obj) register_tool(obj)
logger.info(f"工具发现完成,共注册 {len(TOOL_REGISTRY)} 个工具") logger.info(f"工具发现完成,共注册 {len(TOOL_REGISTRY)} 个工具")
def get_all_tool_definitions() -> List[Dict[str, Any]]: def get_all_tool_definitions() -> List[Dict[str, Any]]:
"""获取所有已注册工具的定义 """获取所有已注册工具的定义
Returns: Returns:
List[Dict]: 工具定义列表 List[Dict]: 工具定义列表
""" """
@@ -102,14 +100,14 @@ def get_all_tool_definitions() -> List[Dict[str, Any]]:
def get_tool_instance(tool_name: str) -> Optional[BaseTool]: def get_tool_instance(tool_name: str) -> Optional[BaseTool]:
"""获取指定名称的工具实例 """获取指定名称的工具实例
Args: Args:
tool_name: 工具名称 tool_name: 工具名称
Returns: Returns:
Optional[BaseTool]: 工具实例如果找不到则返回None Optional[BaseTool]: 工具实例如果找不到则返回None
""" """
tool_class = TOOL_REGISTRY.get(tool_name) tool_class = TOOL_REGISTRY.get(tool_name)
if not tool_class: if not tool_class:
return None return None
return tool_class() return tool_class()

View File

@@ -4,29 +4,25 @@ from typing import Dict, Any
logger = get_module_logger("fibonacci_sequence_tool") logger = get_module_logger("fibonacci_sequence_tool")
class FibonacciSequenceTool(BaseTool): class FibonacciSequenceTool(BaseTool):
"""生成斐波那契数列的工具""" """生成斐波那契数列的工具"""
name = "fibonacci_sequence" name = "fibonacci_sequence"
description = "生成指定长度的斐波那契数列" description = "生成指定长度的斐波那契数列"
parameters = { parameters = {
"type": "object", "type": "object",
"properties": { "properties": {"n": {"type": "integer", "description": "斐波那契数列的长度", "minimum": 1}},
"n": { "required": ["n"],
"type": "integer",
"description": "斐波那契数列的长度",
"minimum": 1
}
},
"required": ["n"]
} }
async def execute(self, function_args: Dict[str, Any], message_txt: str = "") -> Dict[str, Any]: async def execute(self, function_args: Dict[str, Any], message_txt: str = "") -> Dict[str, Any]:
"""执行工具功能 """执行工具功能
Args: Args:
function_args: 工具参数 function_args: 工具参数
message_txt: 原始消息文本 message_txt: 原始消息文本
Returns: Returns:
Dict: 工具执行结果 Dict: 工具执行结果
""" """
@@ -34,23 +30,18 @@ class FibonacciSequenceTool(BaseTool):
n = function_args.get("n") n = function_args.get("n")
if n <= 0: if n <= 0:
raise ValueError("参数n必须大于0") raise ValueError("参数n必须大于0")
sequence = [] sequence = []
a, b = 0, 1 a, b = 0, 1
for _ in range(n): for _ in range(n):
sequence.append(a) sequence.append(a)
a, b = b, a + b a, b = b, a + b
return { return {"name": self.name, "content": sequence}
"name": self.name,
"content": sequence
}
except Exception as e: except Exception as e:
logger.error(f"fibonacci_sequence工具执行失败: {str(e)}") logger.error(f"fibonacci_sequence工具执行失败: {str(e)}")
return { return {"name": self.name, "content": f"执行失败: {str(e)}"}
"name": self.name,
"content": f"执行失败: {str(e)}"
}
# 注册工具 # 注册工具
register_tool(FibonacciSequenceTool) register_tool(FibonacciSequenceTool)

View File

@@ -4,8 +4,10 @@ from typing import Dict, Any
logger = get_module_logger("generate_buddha_emoji_tool") logger = get_module_logger("generate_buddha_emoji_tool")
class GenerateBuddhaEmojiTool(BaseTool): class GenerateBuddhaEmojiTool(BaseTool):
"""生成佛祖颜文字的工具类""" """生成佛祖颜文字的工具类"""
name = "generate_buddha_emoji" name = "generate_buddha_emoji"
description = "生成一个佛祖的颜文字表情" description = "生成一个佛祖的颜文字表情"
parameters = { parameters = {
@@ -13,32 +15,27 @@ class GenerateBuddhaEmojiTool(BaseTool):
"properties": { "properties": {
# 无参数 # 无参数
}, },
"required": [] "required": [],
} }
async def execute(self, function_args: Dict[str, Any], message_txt: str = "") -> Dict[str, Any]: async def execute(self, function_args: Dict[str, Any], message_txt: str = "") -> Dict[str, Any]:
"""执行工具功能,生成佛祖颜文字 """执行工具功能,生成佛祖颜文字
Args: Args:
function_args: 工具参数 function_args: 工具参数
message_txt: 原始消息文本 message_txt: 原始消息文本
Returns: Returns:
Dict: 工具执行结果 Dict: 工具执行结果
""" """
try: try:
buddha_emoji = "这是一个佛祖emoji༼ つ ◕_◕ ༽つ" buddha_emoji = "这是一个佛祖emoji༼ つ ◕_◕ ༽つ"
return { return {"name": self.name, "content": buddha_emoji}
"name": self.name,
"content": buddha_emoji
}
except Exception as e: except Exception as e:
logger.error(f"generate_buddha_emoji工具执行失败: {str(e)}") logger.error(f"generate_buddha_emoji工具执行失败: {str(e)}")
return { return {"name": self.name, "content": f"执行失败: {str(e)}"}
"name": self.name,
"content": f"执行失败: {str(e)}"
}
# 注册工具 # 注册工具
register_tool(GenerateBuddhaEmojiTool) register_tool(GenerateBuddhaEmojiTool)

View File

@@ -4,23 +4,21 @@ from typing import Dict, Any
logger = get_module_logger("generate_cmd_tutorial_tool") logger = get_module_logger("generate_cmd_tutorial_tool")
class GenerateCmdTutorialTool(BaseTool): class GenerateCmdTutorialTool(BaseTool):
"""生成Windows CMD基本操作教程的工具""" """生成Windows CMD基本操作教程的工具"""
name = "generate_cmd_tutorial" name = "generate_cmd_tutorial"
description = "生成关于Windows命令提示符(CMD)的基本操作教程,包括常用命令和使用方法" description = "生成关于Windows命令提示符(CMD)的基本操作教程,包括常用命令和使用方法"
parameters = { parameters = {"type": "object", "properties": {}, "required": []}
"type": "object",
"properties": {},
"required": []
}
async def execute(self, function_args: Dict[str, Any], message_txt: str = "") -> Dict[str, Any]: async def execute(self, function_args: Dict[str, Any], message_txt: str = "") -> Dict[str, Any]:
"""执行工具功能 """执行工具功能
Args: Args:
function_args: 工具参数 function_args: 工具参数
message_txt: 原始消息文本 message_txt: 原始消息文本
Returns: Returns:
Dict: 工具执行结果 Dict: 工具执行结果
""" """
@@ -57,17 +55,12 @@ class GenerateCmdTutorialTool(BaseTool):
注意:使用命令时要小心,特别是删除操作。 注意:使用命令时要小心,特别是删除操作。
""" """
return { return {"name": self.name, "content": tutorial_content}
"name": self.name,
"content": tutorial_content
}
except Exception as e: except Exception as e:
logger.error(f"generate_cmd_tutorial工具执行失败: {str(e)}") logger.error(f"generate_cmd_tutorial工具执行失败: {str(e)}")
return { return {"name": self.name, "content": f"执行失败: {str(e)}"}
"name": self.name,
"content": f"执行失败: {str(e)}"
}
# 注册工具 # 注册工具
register_tool(GenerateCmdTutorialTool) register_tool(GenerateCmdTutorialTool)

View File

@@ -5,32 +5,28 @@ from typing import Dict, Any
logger = get_module_logger("get_current_task_tool") logger = get_module_logger("get_current_task_tool")
class GetCurrentTaskTool(BaseTool): class GetCurrentTaskTool(BaseTool):
"""获取当前正在做的事情/最近的任务工具""" """获取当前正在做的事情/最近的任务工具"""
name = "get_current_task" name = "get_current_task"
description = "获取当前正在做的事情/最近的任务" description = "获取当前正在做的事情/最近的任务"
parameters = { parameters = {
"type": "object", "type": "object",
"properties": { "properties": {
"num": { "num": {"type": "integer", "description": "要获取的任务数量"},
"type": "integer", "time_info": {"type": "boolean", "description": "是否包含时间信息"},
"description": "要获取的任务数量"
},
"time_info": {
"type": "boolean",
"description": "是否包含时间信息"
}
}, },
"required": [] "required": [],
} }
async def execute(self, function_args: Dict[str, Any], message_txt: str = "") -> Dict[str, Any]: async def execute(self, function_args: Dict[str, Any], message_txt: str = "") -> Dict[str, Any]:
"""执行获取当前任务 """执行获取当前任务
Args: Args:
function_args: 工具参数 function_args: 工具参数
message_txt: 原始消息文本,此工具不使用 message_txt: 原始消息文本,此工具不使用
Returns: Returns:
Dict: 工具执行结果 Dict: 工具执行结果
""" """
@@ -38,26 +34,21 @@ class GetCurrentTaskTool(BaseTool):
# 获取参数,如果没有提供则使用默认值 # 获取参数,如果没有提供则使用默认值
num = function_args.get("num", 1) num = function_args.get("num", 1)
time_info = function_args.get("time_info", False) time_info = function_args.get("time_info", False)
# 调用日程系统获取当前任务 # 调用日程系统获取当前任务
current_task = bot_schedule.get_current_num_task(num=num, time_info=time_info) current_task = bot_schedule.get_current_num_task(num=num, time_info=time_info)
# 格式化返回结果 # 格式化返回结果
if current_task: if current_task:
task_info = current_task task_info = current_task
else: else:
task_info = "当前没有正在进行的任务" task_info = "当前没有正在进行的任务"
return { return {"name": "get_current_task", "content": f"当前任务信息: {task_info}"}
"name": "get_current_task",
"content": f"当前任务信息: {task_info}"
}
except Exception as e: except Exception as e:
logger.error(f"获取当前任务工具执行失败: {str(e)}") logger.error(f"获取当前任务工具执行失败: {str(e)}")
return { return {"name": "get_current_task", "content": f"获取当前任务失败: {str(e)}"}
"name": "get_current_task",
"content": f"获取当前任务失败: {str(e)}"
}
# 注册工具 # 注册工具
register_tool(GetCurrentTaskTool) register_tool(GetCurrentTaskTool)

View File

@@ -6,39 +6,35 @@ from typing import Dict, Any, Union
logger = get_module_logger("get_knowledge_tool") logger = get_module_logger("get_knowledge_tool")
class SearchKnowledgeTool(BaseTool): class SearchKnowledgeTool(BaseTool):
"""从知识库中搜索相关信息的工具""" """从知识库中搜索相关信息的工具"""
name = "search_knowledge" name = "search_knowledge"
description = "从知识库中搜索相关信息" description = "从知识库中搜索相关信息"
parameters = { parameters = {
"type": "object", "type": "object",
"properties": { "properties": {
"query": { "query": {"type": "string", "description": "搜索查询关键词"},
"type": "string", "threshold": {"type": "number", "description": "相似度阈值0.0到1.0之间"},
"description": "搜索查询关键词"
},
"threshold": {
"type": "number",
"description": "相似度阈值0.0到1.0之间"
}
}, },
"required": ["query"] "required": ["query"],
} }
async def execute(self, function_args: Dict[str, Any], message_txt: str = "") -> Dict[str, Any]: async def execute(self, function_args: Dict[str, Any], message_txt: str = "") -> Dict[str, Any]:
"""执行知识库搜索 """执行知识库搜索
Args: Args:
function_args: 工具参数 function_args: 工具参数
message_txt: 原始消息文本 message_txt: 原始消息文本
Returns: Returns:
Dict: 工具执行结果 Dict: 工具执行结果
""" """
try: try:
query = function_args.get("query", message_txt) query = function_args.get("query", message_txt)
threshold = function_args.get("threshold", 0.4) threshold = function_args.get("threshold", 0.4)
# 调用知识库搜索 # 调用知识库搜索
embedding = await get_embedding(query, request_type="info_retrieval") embedding = await get_embedding(query, request_type="info_retrieval")
if embedding: if embedding:
@@ -47,38 +43,29 @@ class SearchKnowledgeTool(BaseTool):
content = f"你知道这些知识: {knowledge_info}" content = f"你知道这些知识: {knowledge_info}"
else: else:
content = f"你不太了解有关{query}的知识" content = f"你不太了解有关{query}的知识"
return { return {"name": "search_knowledge", "content": content}
"name": "search_knowledge", return {"name": "search_knowledge", "content": f"无法获取关于'{query}'的嵌入向量"}
"content": content
}
return {
"name": "search_knowledge",
"content": f"无法获取关于'{query}'的嵌入向量"
}
except Exception as e: except Exception as e:
logger.error(f"知识库搜索工具执行失败: {str(e)}") logger.error(f"知识库搜索工具执行失败: {str(e)}")
return { return {"name": "search_knowledge", "content": f"知识库搜索失败: {str(e)}"}
"name": "search_knowledge",
"content": f"知识库搜索失败: {str(e)}"
}
def get_info_from_db( def get_info_from_db(
self, query_embedding: list, limit: int = 1, threshold: float = 0.5, return_raw: bool = False self, query_embedding: list, limit: int = 1, threshold: float = 0.5, return_raw: bool = False
) -> Union[str, list]: ) -> Union[str, list]:
"""从数据库中获取相关信息 """从数据库中获取相关信息
Args: Args:
query_embedding: 查询的嵌入向量 query_embedding: 查询的嵌入向量
limit: 最大返回结果数 limit: 最大返回结果数
threshold: 相似度阈值 threshold: 相似度阈值
return_raw: 是否返回原始结果 return_raw: 是否返回原始结果
Returns: Returns:
Union[str, list]: 格式化的信息字符串或原始结果列表 Union[str, list]: 格式化的信息字符串或原始结果列表
""" """
if not query_embedding: if not query_embedding:
return "" if not return_raw else [] return "" if not return_raw else []
# 使用余弦相似度计算 # 使用余弦相似度计算
pipeline = [ pipeline = [
{ {
@@ -143,5 +130,6 @@ class SearchKnowledgeTool(BaseTool):
# 返回所有找到的内容,用换行分隔 # 返回所有找到的内容,用换行分隔
return "\n".join(str(result["content"]) for result in results) return "\n".join(str(result["content"]) for result in results)
# 注册工具 # 注册工具
register_tool(SearchKnowledgeTool) register_tool(SearchKnowledgeTool)

View File

@@ -5,68 +5,55 @@ from typing import Dict, Any
logger = get_module_logger("get_memory_tool") logger = get_module_logger("get_memory_tool")
class GetMemoryTool(BaseTool): class GetMemoryTool(BaseTool):
"""从记忆系统中获取相关记忆的工具""" """从记忆系统中获取相关记忆的工具"""
name = "get_memory" name = "get_memory"
description = "从记忆系统中获取相关记忆" description = "从记忆系统中获取相关记忆"
parameters = { parameters = {
"type": "object", "type": "object",
"properties": { "properties": {
"text": { "text": {"type": "string", "description": "要查询的相关文本"},
"type": "string", "max_memory_num": {"type": "integer", "description": "最大返回记忆数量"},
"description": "要查询的相关文本"
},
"max_memory_num": {
"type": "integer",
"description": "最大返回记忆数量"
}
}, },
"required": ["text"] "required": ["text"],
} }
async def execute(self, function_args: Dict[str, Any], message_txt: str = "") -> Dict[str, Any]: async def execute(self, function_args: Dict[str, Any], message_txt: str = "") -> Dict[str, Any]:
"""执行记忆获取 """执行记忆获取
Args: Args:
function_args: 工具参数 function_args: 工具参数
message_txt: 原始消息文本 message_txt: 原始消息文本
Returns: Returns:
Dict: 工具执行结果 Dict: 工具执行结果
""" """
try: try:
text = function_args.get("text", message_txt) text = function_args.get("text", message_txt)
max_memory_num = function_args.get("max_memory_num", 2) max_memory_num = function_args.get("max_memory_num", 2)
# 调用记忆系统 # 调用记忆系统
related_memory = await HippocampusManager.get_instance().get_memory_from_text( related_memory = await HippocampusManager.get_instance().get_memory_from_text(
text=text, text=text, max_memory_num=max_memory_num, max_memory_length=2, max_depth=3, fast_retrieval=False
max_memory_num=max_memory_num,
max_memory_length=2,
max_depth=3,
fast_retrieval=False
) )
memory_info = "" memory_info = ""
if related_memory: if related_memory:
for memory in related_memory: for memory in related_memory:
memory_info += memory[1] + "\n" memory_info += memory[1] + "\n"
if memory_info: if memory_info:
content = f"你记得这些事情: {memory_info}" content = f"你记得这些事情: {memory_info}"
else: else:
content = f"你不太记得有关{text}的记忆,你对此不太了解" content = f"你不太记得有关{text}的记忆,你对此不太了解"
return { return {"name": "get_memory", "content": content}
"name": "get_memory",
"content": content
}
except Exception as e: except Exception as e:
logger.error(f"记忆获取工具执行失败: {str(e)}") logger.error(f"记忆获取工具执行失败: {str(e)}")
return { return {"name": "get_memory", "content": f"记忆获取失败: {str(e)}"}
"name": "get_memory",
"content": f"记忆获取失败: {str(e)}"
}
# 注册工具 # 注册工具
register_tool(GetMemoryTool) register_tool(GetMemoryTool)

View File

@@ -16,21 +16,19 @@ class ToolUser:
model=global_config.llm_heartflow, temperature=0.2, max_tokens=1000, request_type="tool_use" model=global_config.llm_heartflow, temperature=0.2, max_tokens=1000, request_type="tool_use"
) )
async def _build_tool_prompt(self, message_txt:str, sender_name:str, chat_stream:ChatStream): async def _build_tool_prompt(self, message_txt: str, sender_name: str, chat_stream: ChatStream):
"""构建工具使用的提示词 """构建工具使用的提示词
Args: Args:
message_txt: 用户消息文本 message_txt: 用户消息文本
sender_name: 发送者名称 sender_name: 发送者名称
chat_stream: 聊天流对象 chat_stream: 聊天流对象
Returns: Returns:
str: 构建好的提示词 str: 构建好的提示词
""" """
new_messages = list( new_messages = list(
db.messages.find({"chat_id": chat_stream.stream_id, "time": {"$gt": time.time()}}) db.messages.find({"chat_id": chat_stream.stream_id, "time": {"$gt": time.time()}}).sort("time", 1).limit(15)
.sort("time", 1)
.limit(15)
) )
new_messages_str = "" new_messages_str = ""
for msg in new_messages: for msg in new_messages:
@@ -44,37 +42,37 @@ class ToolUser:
prompt += f"你注意到{sender_name}刚刚说:{message_txt}\n" prompt += f"你注意到{sender_name}刚刚说:{message_txt}\n"
prompt += f"注意你就是{bot_name}{bot_name}指的就是你。" prompt += f"注意你就是{bot_name}{bot_name}指的就是你。"
prompt += "你现在需要对群里的聊天内容进行回复,现在请你思考,你是否需要额外的信息,或者一些工具来帮你回复,比如回忆或者搜寻已有的知识,或者了解你现在正在做什么,请输出你需要的工具,或者你需要的额外信息。" prompt += "你现在需要对群里的聊天内容进行回复,现在请你思考,你是否需要额外的信息,或者一些工具来帮你回复,比如回忆或者搜寻已有的知识,或者了解你现在正在做什么,请输出你需要的工具,或者你需要的额外信息。"
return prompt return prompt
def _define_tools(self): def _define_tools(self):
"""获取所有已注册工具的定义 """获取所有已注册工具的定义
Returns: Returns:
list: 工具定义列表 list: 工具定义列表
""" """
return get_all_tool_definitions() return get_all_tool_definitions()
async def _execute_tool_call(self, tool_call, message_txt:str): async def _execute_tool_call(self, tool_call, message_txt: str):
"""执行特定的工具调用 """执行特定的工具调用
Args: Args:
tool_call: 工具调用对象 tool_call: 工具调用对象
message_txt: 原始消息文本 message_txt: 原始消息文本
Returns: Returns:
dict: 工具调用结果 dict: 工具调用结果
""" """
try: try:
function_name = tool_call["function"]["name"] function_name = tool_call["function"]["name"]
function_args = json.loads(tool_call["function"]["arguments"]) function_args = json.loads(tool_call["function"]["arguments"])
# 获取对应工具实例 # 获取对应工具实例
tool_instance = get_tool_instance(function_name) tool_instance = get_tool_instance(function_name)
if not tool_instance: if not tool_instance:
logger.warning(f"未知工具名称: {function_name}") logger.warning(f"未知工具名称: {function_name}")
return None return None
# 执行工具 # 执行工具
result = await tool_instance.execute(function_args, message_txt) result = await tool_instance.execute(function_args, message_txt)
if result: if result:
@@ -82,62 +80,60 @@ class ToolUser:
"tool_call_id": tool_call["id"], "tool_call_id": tool_call["id"],
"role": "tool", "role": "tool",
"name": function_name, "name": function_name,
"content": result["content"] "content": result["content"],
} }
return None return None
except Exception as e: except Exception as e:
logger.error(f"执行工具调用时发生错误: {str(e)}") logger.error(f"执行工具调用时发生错误: {str(e)}")
return None return None
async def use_tool(self, message_txt:str, sender_name:str, chat_stream:ChatStream): async def use_tool(self, message_txt: str, sender_name: str, chat_stream: ChatStream):
"""使用工具辅助思考,判断是否需要额外信息 """使用工具辅助思考,判断是否需要额外信息
Args: Args:
message_txt: 用户消息文本 message_txt: 用户消息文本
sender_name: 发送者名称 sender_name: 发送者名称
chat_stream: 聊天流对象 chat_stream: 聊天流对象
Returns: Returns:
dict: 工具使用结果 dict: 工具使用结果
""" """
try: try:
# 构建提示词 # 构建提示词
prompt = await self._build_tool_prompt(message_txt, sender_name, chat_stream) prompt = await self._build_tool_prompt(message_txt, sender_name, chat_stream)
# 定义可用工具 # 定义可用工具
tools = self._define_tools() tools = self._define_tools()
# 使用llm_model_tool发送带工具定义的请求 # 使用llm_model_tool发送带工具定义的请求
payload = { payload = {
"model": self.llm_model_tool.model_name, "model": self.llm_model_tool.model_name,
"messages": [{"role": "user", "content": prompt}], "messages": [{"role": "user", "content": prompt}],
"max_tokens": global_config.max_response_length, "max_tokens": global_config.max_response_length,
"tools": tools, "tools": tools,
"temperature": 0.2 "temperature": 0.2,
} }
logger.debug(f"发送工具调用请求,模型: {self.llm_model_tool.model_name}") logger.debug(f"发送工具调用请求,模型: {self.llm_model_tool.model_name}")
# 发送请求获取模型是否需要调用工具 # 发送请求获取模型是否需要调用工具
response = await self.llm_model_tool._execute_request( response = await self.llm_model_tool._execute_request(
endpoint="/chat/completions", endpoint="/chat/completions", payload=payload, prompt=prompt
payload=payload,
prompt=prompt
) )
# 根据返回值数量判断是否有工具调用 # 根据返回值数量判断是否有工具调用
if len(response) == 3: if len(response) == 3:
content, reasoning_content, tool_calls = response content, reasoning_content, tool_calls = response
logger.info(f"工具思考: {tool_calls}") logger.info(f"工具思考: {tool_calls}")
# 检查响应中工具调用是否有效 # 检查响应中工具调用是否有效
if not tool_calls: if not tool_calls:
logger.info("模型返回了空的tool_calls列表") logger.info("模型返回了空的tool_calls列表")
return {"used_tools": False} return {"used_tools": False}
logger.info(f"模型请求调用{len(tool_calls)}个工具") logger.info(f"模型请求调用{len(tool_calls)}个工具")
tool_results = [] tool_results = []
collected_info = "" collected_info = ""
# 执行所有工具调用 # 执行所有工具调用
for tool_call in tool_calls: for tool_call in tool_calls:
result = await self._execute_tool_call(tool_call, message_txt) result = await self._execute_tool_call(tool_call, message_txt)
@@ -145,7 +141,7 @@ class ToolUser:
tool_results.append(result) tool_results.append(result)
# 将工具结果添加到收集的信息中 # 将工具结果添加到收集的信息中
collected_info += f"\n{result['name']}返回结果: {result['content']}\n" collected_info += f"\n{result['name']}返回结果: {result['content']}\n"
# 如果有工具结果,直接返回收集的信息 # 如果有工具结果,直接返回收集的信息
if collected_info: if collected_info:
logger.info(f"工具调用收集到信息: {collected_info}") logger.info(f"工具调用收集到信息: {collected_info}")
@@ -157,15 +153,15 @@ class ToolUser:
# 没有工具调用 # 没有工具调用
content, reasoning_content = response content, reasoning_content = response
logger.info("模型没有请求调用任何工具") logger.info("模型没有请求调用任何工具")
# 如果没有工具调用或处理失败,直接返回原始思考 # 如果没有工具调用或处理失败,直接返回原始思考
return { return {
"used_tools": False, "used_tools": False,
} }
except Exception as e: except Exception as e:
logger.error(f"工具调用过程中出错: {str(e)}") logger.error(f"工具调用过程中出错: {str(e)}")
return { return {
"used_tools": False, "used_tools": False,
"error": str(e), "error": str(e),
} }

View File

@@ -43,12 +43,11 @@ def init_prompt():
class CurrentState: class CurrentState:
def __init__(self): def __init__(self):
self.current_state_info = "" self.current_state_info = ""
self.mood_manager = MoodManager() self.mood_manager = MoodManager()
self.mood = self.mood_manager.get_prompt() self.mood = self.mood_manager.get_prompt()
self.attendance_factor = 0 self.attendance_factor = 0
self.engagement_factor = 0 self.engagement_factor = 0
@@ -66,9 +65,6 @@ class Heartflow:
) )
self._subheartflows: Dict[Any, SubHeartflow] = {} self._subheartflows: Dict[Any, SubHeartflow] = {}
async def _cleanup_inactive_subheartflows(self): async def _cleanup_inactive_subheartflows(self):
"""定期清理不活跃的子心流""" """定期清理不活跃的子心流"""
@@ -90,7 +86,7 @@ class Heartflow:
logger.info(f"已清理不活跃的子心流: {subheartflow_id}") logger.info(f"已清理不活跃的子心流: {subheartflow_id}")
await asyncio.sleep(30) # 每分钟检查一次 await asyncio.sleep(30) # 每分钟检查一次
async def _sub_heartflow_update(self): async def _sub_heartflow_update(self):
while True: while True:
# 检查是否存在子心流 # 检查是否存在子心流
@@ -103,13 +99,12 @@ class Heartflow:
await asyncio.sleep(global_config.heart_flow_update_interval) # 5分钟思考一次 await asyncio.sleep(global_config.heart_flow_update_interval) # 5分钟思考一次
async def heartflow_start_working(self): async def heartflow_start_working(self):
# 启动清理任务 # 启动清理任务
asyncio.create_task(self._cleanup_inactive_subheartflows()) asyncio.create_task(self._cleanup_inactive_subheartflows())
# 启动子心流更新任务 # 启动子心流更新任务
asyncio.create_task(self._sub_heartflow_update()) asyncio.create_task(self._sub_heartflow_update())
async def _update_current_state(self): async def _update_current_state(self):
print("TODO") print("TODO")
@@ -155,7 +150,7 @@ class Heartflow:
# prompt += f"你现在{mood_info}。" # prompt += f"你现在{mood_info}。"
# prompt += "现在你接下去继续思考,产生新的想法,但是要基于原有的主要想法,不要分点输出," # prompt += "现在你接下去继续思考,产生新的想法,但是要基于原有的主要想法,不要分点输出,"
# prompt += "输出连贯的内心独白,不要太长,但是记得结合上述的消息,关注新内容:" # prompt += "输出连贯的内心独白,不要太长,但是记得结合上述的消息,关注新内容:"
prompt = global_prompt_manager.get_prompt("thinking_prompt").format( prompt = (await global_prompt_manager.get_prompt_async("thinking_prompt")).format(
schedule_info, personality_info, related_memory_info, current_thinking_info, sub_flows_info, mood_info schedule_info, personality_info, related_memory_info, current_thinking_info, sub_flows_info, mood_info
) )
@@ -212,7 +207,7 @@ class Heartflow:
# prompt += f"你现在{mood_info}\n" # prompt += f"你现在{mood_info}\n"
# prompt += """现在请你总结这些聊天内容,注意关注聊天内容对原有的想法的影响,输出连贯的内心独白 # prompt += """现在请你总结这些聊天内容,注意关注聊天内容对原有的想法的影响,输出连贯的内心独白
# 不要太长,但是记得结合上述的消息,要记得你的人设,关注新内容:""" # 不要太长,但是记得结合上述的消息,要记得你的人设,关注新内容:"""
prompt = global_prompt_manager.get_prompt("mind_summary_prompt").format( prompt = (await global_prompt_manager.get_prompt_async("mind_summary_prompt")).format(
personality_info, global_config.BOT_NICKNAME, self.current_mind, minds_str, mood_info personality_info, global_config.BOT_NICKNAME, self.current_mind, minds_str, mood_info
) )

View File

@@ -150,7 +150,7 @@ class ChattingObservation(Observation):
except Exception as e: except Exception as e:
print(f"获取总结失败: {e}") print(f"获取总结失败: {e}")
updated_observe_info = "" updated_observe_info = ""
return updated_observe_info return updated_observe_info
# print(f"prompt{prompt}") # print(f"prompt{prompt}")
# print(f"self.observe_info{self.observe_info}") # print(f"self.observe_info{self.observe_info}")

View File

@@ -5,9 +5,11 @@ from src.plugins.models.utils_model import LLM_request
from src.plugins.config.config import global_config from src.plugins.config.config import global_config
import re import re
import time import time
# from src.plugins.schedule.schedule_generator import bot_schedule # from src.plugins.schedule.schedule_generator import bot_schedule
# from src.plugins.memory_system.Hippocampus import HippocampusManager # from src.plugins.memory_system.Hippocampus import HippocampusManager
from src.common.logger import get_module_logger, LogConfig, SUB_HEARTFLOW_STYLE_CONFIG # noqa: E402 from src.common.logger import get_module_logger, LogConfig, SUB_HEARTFLOW_STYLE_CONFIG # noqa: E402
# from src.plugins.chat.utils import get_embedding # from src.plugins.chat.utils import get_embedding
# from src.common.database import db # from src.common.database import db
# from typing import Union # from typing import Union
@@ -16,7 +18,8 @@ import random
from src.plugins.chat.chat_stream import ChatStream from src.plugins.chat.chat_stream import ChatStream
from src.plugins.person_info.relationship_manager import relationship_manager from src.plugins.person_info.relationship_manager import relationship_manager
from src.plugins.chat.utils import get_recent_group_speaker from src.plugins.chat.utils import get_recent_group_speaker
from src.do_tool.tool_use import ToolUser from src.do_tool.tool_use import ToolUser
from ..plugins.utils.prompt_builder import Prompt, global_prompt_manager
subheartflow_config = LogConfig( subheartflow_config = LogConfig(
# 使用海马体专用样式 # 使用海马体专用样式
@@ -26,6 +29,35 @@ subheartflow_config = LogConfig(
logger = get_module_logger("subheartflow", config=subheartflow_config) logger = get_module_logger("subheartflow", config=subheartflow_config)
def init_prompt():
prompt = ""
# prompt += f"麦麦的总体想法是:{self.main_heartflow_info}\n\n"
prompt += "{collected_info}\n"
prompt += "{relation_prompt_all}\n"
prompt += "{prompt_personality}\n"
prompt += "刚刚你的想法是{current_thinking_info}。如果有新的内容,记得转换话题\n"
prompt += "-----------------------------------\n"
prompt += "现在你正在上网和qq群里的网友们聊天群里正在聊的话题是{chat_observe_info}\n"
prompt += "你现在{mood_info}\n"
prompt += "你注意到{sender_name}刚刚说:{message_txt}\n"
prompt += "现在你接下去继续思考,产生新的想法,不要分点输出,输出连贯的内心独白"
prompt += "思考时可以想想如何对群聊内容进行回复。回复的要求是:平淡一些,简短一些,说中文,尽量不要说你说过的话\n"
prompt += "请注意不要输出多余内容(包括前后缀,冒号和引号,括号, 表情,等),不要带有括号和动作描写"
prompt += "记得结合上述的消息,生成内心想法,文字不要浮夸,注意你就是{bot_name}{bot_name}指的就是你。"
Prompt(prompt, "sub_heartflow_prompt_before")
prompt = ""
# prompt += f"你现在正在做的事情是:{schedule_info}\n"
prompt += "{prompt_personality}\n"
prompt += "现在你正在上网和qq群里的网友们聊天群里正在聊的话题是{chat_observe_info}\n"
prompt += "刚刚你的想法是{current_thinking_info}"
prompt += "你现在看到了网友们发的新消息:{message_new_info}\n"
prompt += "你刚刚回复了群友们:{reply_info}"
prompt += "你现在{mood_info}"
prompt += "现在你接下去继续思考,产生新的想法,记得保留你刚刚的想法,不要分点输出,输出连贯的内心独白"
prompt += "不要太长,但是记得结合上述的消息,要记得你的人设,关注聊天和新内容,关注你回复的内容,不要思考太多:"
Prompt(prompt, "sub_heartflow_prompt_after")
class CurrentState: class CurrentState:
def __init__(self): def __init__(self):
self.willing = 0 self.willing = 0
@@ -48,7 +80,6 @@ class SubHeartflow:
self.llm_model = LLM_request( self.llm_model = LLM_request(
model=global_config.llm_sub_heartflow, temperature=0.2, max_tokens=600, request_type="sub_heart_flow" model=global_config.llm_sub_heartflow, temperature=0.2, max_tokens=600, request_type="sub_heart_flow"
) )
self.main_heartflow_info = "" self.main_heartflow_info = ""
@@ -63,9 +94,9 @@ class SubHeartflow:
self.observations: list[Observation] = [] self.observations: list[Observation] = []
self.running_knowledges = [] self.running_knowledges = []
self.bot_name = global_config.BOT_NICKNAME self.bot_name = global_config.BOT_NICKNAME
self.tool_user = ToolUser() self.tool_user = ToolUser()
def add_observation(self, observation: Observation): def add_observation(self, observation: Observation):
@@ -115,12 +146,12 @@ class SubHeartflow:
): # 5分钟无回复/不在场,销毁 ): # 5分钟无回复/不在场,销毁
logger.info(f"子心流 {self.subheartflow_id} 已经5分钟没有激活正在销毁...") logger.info(f"子心流 {self.subheartflow_id} 已经5分钟没有激活正在销毁...")
break # 退出循环,销毁自己 break # 退出循环,销毁自己
async def do_observe(self): async def do_observe(self):
observation = self.observations[0] observation = self.observations[0]
await observation.observe() await observation.observe()
async def do_thinking_before_reply(self, message_txt: str, sender_name: str, chat_stream: ChatStream):
async def do_thinking_before_reply(self, message_txt:str, sender_name:str, chat_stream:ChatStream):
current_thinking_info = self.current_mind current_thinking_info = self.current_mind
mood_info = self.current_state.mood mood_info = self.current_state.mood
# mood_info = "你很生气,很愤怒" # mood_info = "你很生气,很愤怒"
@@ -130,12 +161,12 @@ class SubHeartflow:
# 首先尝试使用工具获取更多信息 # 首先尝试使用工具获取更多信息
tool_result = await self.tool_user.use_tool(message_txt, sender_name, chat_stream) tool_result = await self.tool_user.use_tool(message_txt, sender_name, chat_stream)
# 如果工具被使用且获得了结果,将收集到的信息合并到思考中 # 如果工具被使用且获得了结果,将收集到的信息合并到思考中
collected_info = "" collected_info = ""
if tool_result.get("used_tools", False): if tool_result.get("used_tools", False):
logger.info("使用工具收集了信息") logger.info("使用工具收集了信息")
# 如果有收集到的信息,将其添加到当前思考中 # 如果有收集到的信息,将其添加到当前思考中
if "collected_info" in tool_result: if "collected_info" in tool_result:
collected_info = tool_result["collected_info"] collected_info = tool_result["collected_info"]
@@ -155,7 +186,7 @@ class SubHeartflow:
identity_detail = individuality.identity.identity_detail identity_detail = individuality.identity.identity_detail
random.shuffle(identity_detail) random.shuffle(identity_detail)
prompt_personality += f",{identity_detail[0]}" prompt_personality += f",{identity_detail[0]}"
# 关系 # 关系
who_chat_in_group = [ who_chat_in_group = [
(chat_stream.user_info.platform, chat_stream.user_info.user_id, chat_stream.user_info.user_nickname) (chat_stream.user_info.platform, chat_stream.user_info.user_id, chat_stream.user_info.user_nickname)
@@ -170,26 +201,41 @@ class SubHeartflow:
for person in who_chat_in_group: for person in who_chat_in_group:
relation_prompt += await relationship_manager.build_relationship_info(person) relation_prompt += await relationship_manager.build_relationship_info(person)
relation_prompt_all = ( # relation_prompt_all = (
f"{relation_prompt}关系等级越大,关系越好,请分析聊天记录," # f"{relation_prompt}关系等级越大,关系越好,请分析聊天记录,"
f"根据你和说话者{sender_name}的关系和态度进行回复,明确你的立场和情感。" # f"根据你和说话者{sender_name}的关系和态度进行回复,明确你的立场和情感。"
# )
relation_prompt_all = (await global_prompt_manager.get_prompt_async("relationship_prompt")).format(
relation_prompt, sender_name
) )
prompt = "" # prompt = ""
# prompt += f"麦麦的总体想法是:{self.main_heartflow_info}\n\n" # # prompt += f"麦麦的总体想法是:{self.main_heartflow_info}\n\n"
if tool_result.get("used_tools", False): # if tool_result.get("used_tools", False):
prompt += f"{collected_info}\n" # prompt += f"{collected_info}\n"
prompt += f"{relation_prompt_all}\n" # prompt += f"{relation_prompt_all}\n"
prompt += f"{prompt_personality}\n" # prompt += f"{prompt_personality}\n"
prompt += f"刚刚你的想法是{current_thinking_info}。如果有新的内容,记得转换话题\n" # prompt += f"刚刚你的想法是{current_thinking_info}。如果有新的内容,记得转换话题\n"
prompt += "-----------------------------------\n" # prompt += "-----------------------------------\n"
prompt += f"现在你正在上网和qq群里的网友们聊天群里正在聊的话题是{chat_observe_info}\n" # prompt += f"现在你正在上网和qq群里的网友们聊天群里正在聊的话题是{chat_observe_info}\n"
prompt += f"你现在{mood_info}\n" # prompt += f"你现在{mood_info}\n"
prompt += f"你注意到{sender_name}刚刚说:{message_txt}\n" # prompt += f"你注意到{sender_name}刚刚说:{message_txt}\n"
prompt += "现在你接下去继续思考,产生新的想法,不要分点输出,输出连贯的内心独白" # prompt += "现在你接下去继续思考,产生新的想法,不要分点输出,输出连贯的内心独白"
prompt += "思考时可以想想如何对群聊内容进行回复。回复的要求是:平淡一些,简短一些,说中文,尽量不要说你说过的话\n" # prompt += "思考时可以想想如何对群聊内容进行回复。回复的要求是:平淡一些,简短一些,说中文,尽量不要说你说过的话\n"
prompt += "请注意不要输出多余内容(包括前后缀,冒号和引号,括号, 表情,等),不要带有括号和动作描写" # prompt += "请注意不要输出多余内容(包括前后缀,冒号和引号,括号, 表情,等),不要带有括号和动作描写"
prompt += f"记得结合上述的消息,生成内心想法,文字不要浮夸,注意你就是{self.bot_name}{self.bot_name}指的就是你。" # prompt += f"记得结合上述的消息,生成内心想法,文字不要浮夸,注意你就是{self.bot_name}{self.bot_name}指的就是你。"
prompt = (await global_prompt_manager.get_prompt_async("sub_heartflow_prompt_before")).format(
collected_info,
relation_prompt_all,
prompt_personality,
current_thinking_info,
chat_observe_info,
mood_info,
sender_name,
message_txt,
self.bot_name,
)
try: try:
response, reasoning_content = await self.llm_model.generate_response_async(prompt) response, reasoning_content = await self.llm_model.generate_response_async(prompt)
@@ -233,16 +279,20 @@ class SubHeartflow:
reply_info = reply_content reply_info = reply_content
# schedule_info = bot_schedule.get_current_num_task(num=1, time_info=False) # schedule_info = bot_schedule.get_current_num_task(num=1, time_info=False)
prompt = "" # prompt = ""
# prompt += f"你现在正在做的事情是:{schedule_info}\n" # # prompt += f"你现在正在做的事情是:{schedule_info}\n"
prompt += f"{prompt_personality}\n" # prompt += f"{prompt_personality}\n"
prompt += f"现在你正在上网和qq群里的网友们聊天群里正在聊的话题是{chat_observe_info}\n" # prompt += f"现在你正在上网和qq群里的网友们聊天群里正在聊的话题是{chat_observe_info}\n"
prompt += f"刚刚你的想法是{current_thinking_info}" # prompt += f"刚刚你的想法是{current_thinking_info}。"
prompt += f"你现在看到了网友们发的新消息:{message_new_info}\n" # prompt += f"你现在看到了网友们发的新消息:{message_new_info}\n"
prompt += f"你刚刚回复了群友们:{reply_info}" # prompt += f"你刚刚回复了群友们:{reply_info}"
prompt += f"你现在{mood_info}" # prompt += f"你现在{mood_info}"
prompt += "现在你接下去继续思考,产生新的想法,记得保留你刚刚的想法,不要分点输出,输出连贯的内心独白" # prompt += "现在你接下去继续思考,产生新的想法,记得保留你刚刚的想法,不要分点输出,输出连贯的内心独白"
prompt += "不要太长,但是记得结合上述的消息,要记得你的人设,关注聊天和新内容,关注你回复的内容,不要思考太多:" # prompt += "不要太长,但是记得结合上述的消息,要记得你的人设,关注聊天和新内容,关注你回复的内容,不要思考太多:"
prompt = (await global_prompt_manager.get_prompt_async("sub_heartflow_prompt_after")).format(
prompt_personality, chat_observe_info, current_thinking_info, message_new_info, reply_info, mood_info
)
try: try:
response, reasoning_content = await self.llm_model.generate_response_async(prompt) response, reasoning_content = await self.llm_model.generate_response_async(prompt)
except Exception as e: except Exception as e:
@@ -302,4 +352,5 @@ class SubHeartflow:
self.current_mind = response self.current_mind = response
init_prompt()
# subheartflow = SubHeartflow() # subheartflow = SubHeartflow()

View File

@@ -53,13 +53,13 @@ class ActionPlanner:
goal = goal_reason[0] goal = goal_reason[0]
reasoning = goal_reason[1] if len(goal_reason) > 1 else "没有明确原因" reasoning = goal_reason[1] if len(goal_reason) > 1 else "没有明确原因"
elif isinstance(goal_reason, dict): elif isinstance(goal_reason, dict):
goal = goal_reason.get('goal') goal = goal_reason.get("goal")
reasoning = goal_reason.get('reasoning', "没有明确原因") reasoning = goal_reason.get("reasoning", "没有明确原因")
else: else:
# 如果是其他类型,尝试转为字符串 # 如果是其他类型,尝试转为字符串
goal = str(goal_reason) goal = str(goal_reason)
reasoning = "没有明确原因" reasoning = "没有明确原因"
goal_str = f"目标:{goal},产生该对话目标的原因:{reasoning}\n" goal_str = f"目标:{goal},产生该对话目标的原因:{reasoning}\n"
goals_str += goal_str goals_str += goal_str
else: else:
@@ -68,7 +68,11 @@ class ActionPlanner:
goals_str = f"目标:{goal},产生该对话目标的原因:{reasoning}\n" goals_str = f"目标:{goal},产生该对话目标的原因:{reasoning}\n"
# 获取聊天历史记录 # 获取聊天历史记录
chat_history_list = observation_info.chat_history[-20:] if len(observation_info.chat_history) >= 20 else observation_info.chat_history chat_history_list = (
observation_info.chat_history[-20:]
if len(observation_info.chat_history) >= 20
else observation_info.chat_history
)
chat_history_text = "" chat_history_text = ""
for msg in chat_history_list: for msg in chat_history_list:
chat_history_text += f"{msg.get('detailed_plain_text', '')}\n" chat_history_text += f"{msg.get('detailed_plain_text', '')}\n"
@@ -85,15 +89,21 @@ class ActionPlanner:
personality_text = f"你的名字是{self.name}{self.personality_info}" personality_text = f"你的名字是{self.name}{self.personality_info}"
# 构建action历史文本 # 构建action历史文本
action_history_list = conversation_info.done_action[-10:] if len(conversation_info.done_action) >= 10 else conversation_info.done_action action_history_list = (
conversation_info.done_action[-10:]
if len(conversation_info.done_action) >= 10
else conversation_info.done_action
)
action_history_text = "你之前做的事情是:" action_history_text = "你之前做的事情是:"
for action in action_history_list: for action in action_history_list:
if isinstance(action, dict): if isinstance(action, dict):
action_type = action.get('action') action_type = action.get("action")
action_reason = action.get('reason') action_reason = action.get("reason")
action_status = action.get('status') action_status = action.get("status")
if action_status == "recall": if action_status == "recall":
action_history_text += f"原本打算:{action_type},但是因为有新消息,你发现这个行动不合适,所以你没做\n" action_history_text += (
f"原本打算:{action_type},但是因为有新消息,你发现这个行动不合适,所以你没做\n"
)
elif action_status == "done": elif action_status == "done":
action_history_text += f"你之前做了:{action_type},原因:{action_reason}\n" action_history_text += f"你之前做了:{action_type},原因:{action_reason}\n"
elif isinstance(action, tuple): elif isinstance(action, tuple):
@@ -102,7 +112,9 @@ class ActionPlanner:
action_reason = action[1] if len(action) > 1 else "未知原因" action_reason = action[1] if len(action) > 1 else "未知原因"
action_status = action[2] if len(action) > 2 else "done" action_status = action[2] if len(action) > 2 else "done"
if action_status == "recall": if action_status == "recall":
action_history_text += f"原本打算:{action_type},但是因为有新消息,你发现这个行动不合适,所以你没做\n" action_history_text += (
f"原本打算:{action_type},但是因为有新消息,你发现这个行动不合适,所以你没做\n"
)
elif action_status == "done": elif action_status == "done":
action_history_text += f"你之前做了:{action_type},原因:{action_reason}\n" action_history_text += f"你之前做了:{action_type},原因:{action_reason}\n"
@@ -147,7 +159,14 @@ end_conversation: 结束对话,长时间没回复或者当你觉得谈话暂
reason = result["reason"] reason = result["reason"]
# 验证action类型 # 验证action类型
if action not in ["direct_reply", "fetch_knowledge", "wait", "listening", "rethink_goal", "end_conversation"]: if action not in [
"direct_reply",
"fetch_knowledge",
"wait",
"listening",
"rethink_goal",
"end_conversation",
]:
logger.warning(f"未知的行动类型: {action}默认使用listening") logger.warning(f"未知的行动类型: {action}默认使用listening")
action = "listening" action = "listening"

View File

@@ -1,12 +1,12 @@
import time import time
import asyncio import asyncio
import traceback import traceback
from typing import Optional, Dict, Any, List from typing import Optional, Dict, Any, List
from src.common.logger import get_module_logger from src.common.logger import get_module_logger
from ..message.message_base import UserInfo from ..message.message_base import UserInfo
from ..config.config import global_config from ..config.config import global_config
from .chat_states import NotificationManager, create_new_message_notification, create_cold_chat_notification from .chat_states import NotificationManager, create_new_message_notification, create_cold_chat_notification
from .message_storage import MongoDBMessageStorage from .message_storage import MongoDBMessageStorage
logger = get_module_logger("chat_observer") logger = get_module_logger("chat_observer")
@@ -51,7 +51,6 @@ class ChatObserver:
self.waiting_start_time: float = time.time() # 等待开始时间,初始化为当前时间 self.waiting_start_time: float = time.time() # 等待开始时间,初始化为当前时间
# 运行状态 # 运行状态
self._running: bool = False self._running: bool = False
self._task: Optional[asyncio.Task] = None self._task: Optional[asyncio.Task] = None
@@ -94,10 +93,11 @@ class ChatObserver:
message: 消息数据 message: 消息数据
""" """
try: try:
# 发送新消息通知 # 发送新消息通知
# logger.info(f"发送新ccchandleer消息通知: {message}") # logger.info(f"发送新ccchandleer消息通知: {message}")
notification = create_new_message_notification(sender="chat_observer", target="observation_info", message=message) notification = create_new_message_notification(
sender="chat_observer", target="observation_info", message=message
)
# logger.info(f"发送新消ddddd息通知: {notification}") # logger.info(f"发送新消ddddd息通知: {notification}")
# print(self.notification_manager) # print(self.notification_manager)
await self.notification_manager.send_notification(notification) await self.notification_manager.send_notification(notification)
@@ -131,7 +131,6 @@ class ChatObserver:
notification = create_cold_chat_notification(sender="chat_observer", target="pfc", is_cold=is_cold) notification = create_cold_chat_notification(sender="chat_observer", target="pfc", is_cold=is_cold)
await self.notification_manager.send_notification(notification) await self.notification_manager.send_notification(notification)
def new_message_after(self, time_point: float) -> bool: def new_message_after(self, time_point: float) -> bool:
"""判断是否在指定时间点后有新消息 """判断是否在指定时间点后有新消息
@@ -197,7 +196,7 @@ class ChatObserver:
if new_messages: if new_messages:
self.last_message_read = new_messages[-1] self.last_message_read = new_messages[-1]
self.last_message_time = new_messages[-1]["time"] self.last_message_time = new_messages[-1]["time"]
# print(f"获取数据库中找到的新消息: {new_messages}") # print(f"获取数据库中找到的新消息: {new_messages}")
return new_messages return new_messages
@@ -215,7 +214,7 @@ class ChatObserver:
if new_messages: if new_messages:
self.last_message_read = new_messages[-1]["message_id"] self.last_message_read = new_messages[-1]["message_id"]
logger.debug(f"获取指定时间点111之前的消息: {new_messages}") logger.debug(f"获取指定时间点111之前的消息: {new_messages}")
return new_messages return new_messages
@@ -239,7 +238,7 @@ class ChatObserver:
try: try:
# print("等待事件") # print("等待事件")
await asyncio.wait_for(self._update_event.wait(), timeout=1) await asyncio.wait_for(self._update_event.wait(), timeout=1)
except asyncio.TimeoutError: except asyncio.TimeoutError:
# print("超时") # print("超时")
pass # 超时后也执行一次检查 pass # 超时后也执行一次检查
@@ -347,7 +346,6 @@ class ChatObserver:
return time_info return time_info
def get_cached_messages(self, limit: int = 50) -> List[Dict[str, Any]]: def get_cached_messages(self, limit: int = 50) -> List[Dict[str, Any]]:
"""获取缓存的消息历史 """获取缓存的消息历史
@@ -368,6 +366,6 @@ class ChatObserver:
if not self.message_cache: if not self.message_cache:
return None return None
return self.message_cache[0] return self.message_cache[0]
def __str__(self): def __str__(self):
return f"ChatObserver for {self.stream_id}" return f"ChatObserver for {self.stream_id}"

View File

@@ -140,7 +140,6 @@ class NotificationManager:
self._active_states.add(notification.type) self._active_states.add(notification.type)
else: else:
self._active_states.discard(notification.type) self._active_states.discard(notification.type)
# 调用目标接收者的处理器 # 调用目标接收者的处理器
target = notification.target target = notification.target
@@ -181,7 +180,7 @@ class NotificationManager:
history = history[-limit:] history = history[-limit:]
return history return history
def __str__(self): def __str__(self):
str = "" str = ""
for target, handlers in self._handlers.items(): for target, handlers in self._handlers.items():
@@ -295,5 +294,3 @@ class ChatStateManager:
current_time = datetime.now().timestamp() current_time = datetime.now().timestamp()
return (current_time - self.state_info.last_message_time) <= threshold return (current_time - self.state_info.last_message_time) <= threshold

View File

@@ -65,7 +65,6 @@ class Conversation:
self.observation_info.bind_to_chat_observer(self.chat_observer) self.observation_info.bind_to_chat_observer(self.chat_observer)
# print(self.chat_observer.get_cached_messages(limit=) # print(self.chat_observer.get_cached_messages(limit=)
self.conversation_info = ConversationInfo() self.conversation_info = ConversationInfo()
except Exception as e: except Exception as e:
logger.error(f"初始化对话实例:注册信息组件失败: {e}") logger.error(f"初始化对话实例:注册信息组件失败: {e}")
@@ -96,7 +95,7 @@ class Conversation:
# 执行行动 # 执行行动
await self._handle_action(action, reason, self.observation_info, self.conversation_info) await self._handle_action(action, reason, self.observation_info, self.conversation_info)
for goal in self.conversation_info.goal_list: for goal in self.conversation_info.goal_list:
# 检查goal是否为元组类型如果是元组则使用索引访问如果是字典则使用get方法 # 检查goal是否为元组类型如果是元组则使用索引访问如果是字典则使用get方法
if isinstance(goal, tuple): if isinstance(goal, tuple):
@@ -151,7 +150,7 @@ class Conversation:
if action == "direct_reply": if action == "direct_reply":
self.waiter.wait_accumulated_time = 0 self.waiter.wait_accumulated_time = 0
self.state = ConversationState.GENERATING self.state = ConversationState.GENERATING
self.generated_reply = await self.reply_generator.generate(observation_info, conversation_info) self.generated_reply = await self.reply_generator.generate(observation_info, conversation_info)
print(f"生成回复: {self.generated_reply}") print(f"生成回复: {self.generated_reply}")
@@ -174,7 +173,6 @@ class Conversation:
await self._send_reply() await self._send_reply()
conversation_info.done_action[-1].update( conversation_info.done_action[-1].update(
{ {
"status": "done", "status": "done",
@@ -184,7 +182,7 @@ class Conversation:
elif action == "fetch_knowledge": elif action == "fetch_knowledge":
self.waiter.wait_accumulated_time = 0 self.waiter.wait_accumulated_time = 0
self.state = ConversationState.FETCHING self.state = ConversationState.FETCHING
knowledge = "TODO:知识" knowledge = "TODO:知识"
topic = "TODO:关键词" topic = "TODO:关键词"
@@ -199,7 +197,7 @@ class Conversation:
elif action == "rethink_goal": elif action == "rethink_goal":
self.waiter.wait_accumulated_time = 0 self.waiter.wait_accumulated_time = 0
self.state = ConversationState.RETHINKING self.state = ConversationState.RETHINKING
await self.goal_analyzer.analyze_goal(conversation_info, observation_info) await self.goal_analyzer.analyze_goal(conversation_info, observation_info)
@@ -208,7 +206,6 @@ class Conversation:
logger.info("倾听对方发言...") logger.info("倾听对方发言...")
await self.waiter.wait_listening(conversation_info) await self.waiter.wait_listening(conversation_info)
elif action == "end_conversation": elif action == "end_conversation":
self.should_continue = False self.should_continue = False
logger.info("决定结束对话...") logger.info("决定结束对话...")
@@ -239,9 +236,7 @@ class Conversation:
return return
try: try:
await self.direct_sender.send_message( await self.direct_sender.send_message(chat_stream=self.chat_stream, content=self.generated_reply)
chat_stream=self.chat_stream, content=self.generated_reply
)
self.chat_observer.trigger_update() # 触发立即更新 self.chat_observer.trigger_update() # 触发立即更新
if not await self.chat_observer.wait_for_update(): if not await self.chat_observer.wait_for_update():
logger.warning("等待消息更新超时") logger.warning("等待消息更新超时")

View File

@@ -2,6 +2,7 @@ from abc import ABC, abstractmethod
from typing import List, Dict, Any from typing import List, Dict, Any
from src.common.database import db from src.common.database import db
class MessageStorage(ABC): class MessageStorage(ABC):
"""消息存储接口""" """消息存储接口"""

View File

@@ -26,24 +26,24 @@ class ObservationInfoHandler(NotificationHandler):
# 获取通知类型和数据 # 获取通知类型和数据
notification_type = notification.type notification_type = notification.type
data = notification.data data = notification.data
if notification_type == NotificationType.NEW_MESSAGE: if notification_type == NotificationType.NEW_MESSAGE:
# 处理新消息通知 # 处理新消息通知
logger.debug(f"收到新消息通知data: {data}") logger.debug(f"收到新消息通知data: {data}")
message_id = data.get("message_id") message_id = data.get("message_id")
processed_plain_text = data.get("processed_plain_text") processed_plain_text = data.get("processed_plain_text")
detailed_plain_text = data.get("detailed_plain_text") detailed_plain_text = data.get("detailed_plain_text")
user_info = data.get("user_info") user_info = data.get("user_info")
time_value = data.get("time") time_value = data.get("time")
message = { message = {
"message_id": message_id, "message_id": message_id,
"processed_plain_text": processed_plain_text, "processed_plain_text": processed_plain_text,
"detailed_plain_text": detailed_plain_text, "detailed_plain_text": detailed_plain_text,
"user_info": user_info, "user_info": user_info,
"time": time_value "time": time_value,
} }
self.observation_info.update_from_message(message) self.observation_info.update_from_message(message)
elif notification_type == NotificationType.COLD_CHAT: elif notification_type == NotificationType.COLD_CHAT:
@@ -161,7 +161,7 @@ class ObservationInfo:
# logger.debug(f"更新信息from_message: {message}") # logger.debug(f"更新信息from_message: {message}")
self.last_message_time = message["time"] self.last_message_time = message["time"]
self.last_message_id = message["message_id"] self.last_message_id = message["message_id"]
self.last_message_content = message.get("processed_plain_text", "") self.last_message_content = message.get("processed_plain_text", "")
user_info = UserInfo.from_dict(message.get("user_info", {})) user_info = UserInfo.from_dict(message.get("user_info", {}))
@@ -233,4 +233,3 @@ class ObservationInfo:
self.unprocessed_messages.clear() self.unprocessed_messages.clear()
self.chat_history_count = len(self.chat_history) self.chat_history_count = len(self.chat_history)
self.new_messages_count = 0 self.new_messages_count = 0

View File

@@ -1,6 +1,7 @@
# Programmable Friendly Conversationalist # Programmable Friendly Conversationalist
# Prefrontal cortex # Prefrontal cortex
import datetime import datetime
# import asyncio # import asyncio
from typing import List, Optional, Tuple, TYPE_CHECKING from typing import List, Optional, Tuple, TYPE_CHECKING
from src.common.logger import get_module_logger from src.common.logger import get_module_logger
@@ -63,13 +64,13 @@ class GoalAnalyzer:
goal = goal_reason[0] goal = goal_reason[0]
reasoning = goal_reason[1] if len(goal_reason) > 1 else "没有明确原因" reasoning = goal_reason[1] if len(goal_reason) > 1 else "没有明确原因"
elif isinstance(goal_reason, dict): elif isinstance(goal_reason, dict):
goal = goal_reason.get('goal') goal = goal_reason.get("goal")
reasoning = goal_reason.get('reasoning', "没有明确原因") reasoning = goal_reason.get("reasoning", "没有明确原因")
else: else:
# 如果是其他类型,尝试转为字符串 # 如果是其他类型,尝试转为字符串
goal = str(goal_reason) goal = str(goal_reason)
reasoning = "没有明确原因" reasoning = "没有明确原因"
goal_str = f"目标:{goal},产生该对话目标的原因:{reasoning}\n" goal_str = f"目标:{goal},产生该对话目标的原因:{reasoning}\n"
goals_str += goal_str goals_str += goal_str
else: else:
@@ -140,14 +141,12 @@ class GoalAnalyzer:
except Exception as e: except Exception as e:
logger.error(f"分析对话目标时出错: {str(e)}") logger.error(f"分析对话目标时出错: {str(e)}")
content = "" content = ""
# 使用改进后的get_items_from_json函数处理JSON数组 # 使用改进后的get_items_from_json函数处理JSON数组
success, result = get_items_from_json( success, result = get_items_from_json(
content, "goal", "reasoning", content, "goal", "reasoning", required_types={"goal": str, "reasoning": str}, allow_array=True
required_types={"goal": str, "reasoning": str},
allow_array=True
) )
if success: if success:
# 判断结果是单个字典还是字典列表 # 判断结果是单个字典还是字典列表
if isinstance(result, list): if isinstance(result, list):
@@ -157,7 +156,7 @@ class GoalAnalyzer:
goal = item.get("goal", "") goal = item.get("goal", "")
reasoning = item.get("reasoning", "") reasoning = item.get("reasoning", "")
conversation_info.goal_list.append((goal, reasoning)) conversation_info.goal_list.append((goal, reasoning))
# 返回第一个目标作为当前主要目标(如果有) # 返回第一个目标作为当前主要目标(如果有)
if result: if result:
first_goal = result[0] first_goal = result[0]
@@ -168,7 +167,7 @@ class GoalAnalyzer:
reasoning = result.get("reasoning", "") reasoning = result.get("reasoning", "")
conversation_info.goal_list.append((goal, reasoning)) conversation_info.goal_list.append((goal, reasoning))
return (goal, "", reasoning) return (goal, "", reasoning)
# 如果解析失败,返回默认值 # 如果解析失败,返回默认值
return ("", "", "") return ("", "", "")
@@ -293,7 +292,6 @@ class GoalAnalyzer:
return False, False, f"分析出错: {str(e)}" return False, False, f"分析出错: {str(e)}"
class DirectMessageSender: class DirectMessageSender:
"""直接发送消息到平台的发送器""" """直接发送消息到平台的发送器"""

View File

@@ -27,7 +27,7 @@ def get_items_from_json(
""" """
content = content.strip() content = content.strip()
result = {} result = {}
# 设置默认值 # 设置默认值
if default_values: if default_values:
result.update(default_values) result.update(default_values)
@@ -41,7 +41,7 @@ def get_items_from_json(
if array_match: if array_match:
array_content = array_match.group() array_content = array_match.group()
json_array = json.loads(array_content) json_array = json.loads(array_content)
# 确认是数组类型 # 确认是数组类型
if isinstance(json_array, list): if isinstance(json_array, list):
# 验证数组中的每个项目是否包含所有必需字段 # 验证数组中的每个项目是否包含所有必需字段
@@ -49,7 +49,7 @@ def get_items_from_json(
for item in json_array: for item in json_array:
if not isinstance(item, dict): if not isinstance(item, dict):
continue continue
# 检查是否有所有必需字段 # 检查是否有所有必需字段
if all(field in item for field in items): if all(field in item for field in items):
# 验证字段类型 # 验证字段类型
@@ -59,22 +59,22 @@ def get_items_from_json(
if field in item and not isinstance(item[field], expected_type): if field in item and not isinstance(item[field], expected_type):
type_valid = False type_valid = False
break break
if not type_valid: if not type_valid:
continue continue
# 验证字符串字段不为空 # 验证字符串字段不为空
string_valid = True string_valid = True
for field in items: for field in items:
if isinstance(item[field], str) and not item[field].strip(): if isinstance(item[field], str) and not item[field].strip():
string_valid = False string_valid = False
break break
if not string_valid: if not string_valid:
continue continue
valid_items.append(item) valid_items.append(item)
if valid_items: if valid_items:
return True, valid_items return True, valid_items
except json.JSONDecodeError: except json.JSONDecodeError:

View File

@@ -49,22 +49,26 @@ class ReplyGenerator:
goal = goal_reason[0] goal = goal_reason[0]
reasoning = goal_reason[1] if len(goal_reason) > 1 else "没有明确原因" reasoning = goal_reason[1] if len(goal_reason) > 1 else "没有明确原因"
elif isinstance(goal_reason, dict): elif isinstance(goal_reason, dict):
goal = goal_reason.get('goal') goal = goal_reason.get("goal")
reasoning = goal_reason.get('reasoning', "没有明确原因") reasoning = goal_reason.get("reasoning", "没有明确原因")
else: else:
# 如果是其他类型,尝试转为字符串 # 如果是其他类型,尝试转为字符串
goal = str(goal_reason) goal = str(goal_reason)
reasoning = "没有明确原因" reasoning = "没有明确原因"
goal_str = f"目标:{goal},产生该对话目标的原因:{reasoning}\n" goal_str = f"目标:{goal},产生该对话目标的原因:{reasoning}\n"
goals_str += goal_str goals_str += goal_str
else: else:
goal = "目前没有明确对话目标" goal = "目前没有明确对话目标"
reasoning = "目前没有明确对话目标,最好思考一个对话目标" reasoning = "目前没有明确对话目标,最好思考一个对话目标"
goals_str = f"目标:{goal},产生该对话目标的原因:{reasoning}\n" goals_str = f"目标:{goal},产生该对话目标的原因:{reasoning}\n"
# 获取聊天历史记录 # 获取聊天历史记录
chat_history_list = observation_info.chat_history[-20:] if len(observation_info.chat_history) >= 20 else observation_info.chat_history chat_history_list = (
observation_info.chat_history[-20:]
if len(observation_info.chat_history) >= 20
else observation_info.chat_history
)
chat_history_text = "" chat_history_text = ""
for msg in chat_history_list: for msg in chat_history_list:
chat_history_text += f"{msg.get('detailed_plain_text', '')}\n" chat_history_text += f"{msg.get('detailed_plain_text', '')}\n"
@@ -81,15 +85,21 @@ class ReplyGenerator:
personality_text = f"你的名字是{self.name}{self.personality_info}" personality_text = f"你的名字是{self.name}{self.personality_info}"
# 构建action历史文本 # 构建action历史文本
action_history_list = conversation_info.done_action[-10:] if len(conversation_info.done_action) >= 10 else conversation_info.done_action action_history_list = (
conversation_info.done_action[-10:]
if len(conversation_info.done_action) >= 10
else conversation_info.done_action
)
action_history_text = "你之前做的事情是:" action_history_text = "你之前做的事情是:"
for action in action_history_list: for action in action_history_list:
if isinstance(action, dict): if isinstance(action, dict):
action_type = action.get('action') action_type = action.get("action")
action_reason = action.get('reason') action_reason = action.get("reason")
action_status = action.get('status') action_status = action.get("status")
if action_status == "recall": if action_status == "recall":
action_history_text += f"原本打算:{action_type},但是因为有新消息,你发现这个行动不合适,所以你没做\n" action_history_text += (
f"原本打算:{action_type},但是因为有新消息,你发现这个行动不合适,所以你没做\n"
)
elif action_status == "done": elif action_status == "done":
action_history_text += f"你之前做了:{action_type},原因:{action_reason}\n" action_history_text += f"你之前做了:{action_type},原因:{action_reason}\n"
elif isinstance(action, tuple): elif isinstance(action, tuple):
@@ -98,7 +108,9 @@ class ReplyGenerator:
action_reason = action[1] if len(action) > 1 else "未知原因" action_reason = action[1] if len(action) > 1 else "未知原因"
action_status = action[2] if len(action) > 2 else "done" action_status = action[2] if len(action) > 2 else "done"
if action_status == "recall": if action_status == "recall":
action_history_text += f"原本打算:{action_type},但是因为有新消息,你发现这个行动不合适,所以你没做\n" action_history_text += (
f"原本打算:{action_type},但是因为有新消息,你发现这个行动不合适,所以你没做\n"
)
elif action_status == "done": elif action_status == "done":
action_history_text += f"你之前做了:{action_type},原因:{action_reason}\n" action_history_text += f"你之前做了:{action_type},原因:{action_reason}\n"

View File

@@ -16,7 +16,7 @@ class Waiter:
self.chat_observer = ChatObserver.get_instance(stream_id) self.chat_observer = ChatObserver.get_instance(stream_id)
self.personality_info = Individuality.get_instance().get_prompt(type="personality", x_person=2, level=2) self.personality_info = Individuality.get_instance().get_prompt(type="personality", x_person=2, level=2)
self.name = global_config.BOT_NICKNAME self.name = global_config.BOT_NICKNAME
self.wait_accumulated_time = 0 self.wait_accumulated_time = 0
async def wait(self, conversation_info: ConversationInfo) -> bool: async def wait(self, conversation_info: ConversationInfo) -> bool:
@@ -38,20 +38,20 @@ class Waiter:
# 检查是否超时 # 检查是否超时
if time.time() - wait_start_time > 300: if time.time() - wait_start_time > 300:
self.wait_accumulated_time += 300 self.wait_accumulated_time += 300
logger.info("等待超过300秒结束对话") logger.info("等待超过300秒结束对话")
wait_goal = { wait_goal = {
"goal": f"你等待了{self.wait_accumulated_time/60}分钟,思考接下来要做什么", "goal": f"你等待了{self.wait_accumulated_time / 60}分钟,思考接下来要做什么",
"reason": "对方很久没有回复你的消息了" "reason": "对方很久没有回复你的消息了",
} }
conversation_info.goal_list.append(wait_goal) conversation_info.goal_list.append(wait_goal)
print(f"添加目标: {wait_goal}") print(f"添加目标: {wait_goal}")
return True return True
await asyncio.sleep(1) await asyncio.sleep(1)
logger.info("等待中...") logger.info("等待中...")
async def wait_listening(self, conversation_info: ConversationInfo) -> bool: async def wait_listening(self, conversation_info: ConversationInfo) -> bool:
"""等待倾听 """等待倾听
@@ -73,14 +73,13 @@ class Waiter:
self.wait_accumulated_time += 300 self.wait_accumulated_time += 300
logger.info("等待超过300秒结束对话") logger.info("等待超过300秒结束对话")
wait_goal = { wait_goal = {
"goal": f"你等待了{self.wait_accumulated_time/60}分钟,思考接下来要做什么", "goal": f"你等待了{self.wait_accumulated_time / 60}分钟,思考接下来要做什么",
"reason": "对方话说一半消失了,很久没有回复" "reason": "对方话说一半消失了,很久没有回复",
} }
conversation_info.goal_list.append(wait_goal) conversation_info.goal_list.append(wait_goal)
print(f"添加目标: {wait_goal}") print(f"添加目标: {wait_goal}")
return True return True
await asyncio.sleep(1) await asyncio.sleep(1)
logger.info("等待中...") logger.info("等待中...")

View File

@@ -8,6 +8,7 @@ from ..chat_module.only_process.only_message_process import MessageProcessor
from src.common.logger import get_module_logger, CHAT_STYLE_CONFIG, LogConfig from src.common.logger import get_module_logger, CHAT_STYLE_CONFIG, LogConfig
from ..chat_module.think_flow_chat.think_flow_chat import ThinkFlowChat from ..chat_module.think_flow_chat.think_flow_chat import ThinkFlowChat
from ..chat_module.reasoning_chat.reasoning_chat import ReasoningChat from ..chat_module.reasoning_chat.reasoning_chat import ReasoningChat
from ..utils.prompt_builder import Prompt, global_prompt_manager
import traceback import traceback
# 定义日志配置 # 定义日志配置
@@ -89,52 +90,71 @@ class ChatBot:
logger.debug(f"用户{userinfo.user_id}被禁止回复") logger.debug(f"用户{userinfo.user_id}被禁止回复")
return return
if global_config.enable_pfc_chatting: if message.message_info.template_info and not message.message_info.template_info.template_default:
try: template_group_name = message.message_info.template_info.template_name
template_items = message.message_info.template_info.template_items
async with global_prompt_manager.async_message_scope(template_group_name):
if isinstance(template_items, dict):
for k in template_items.keys():
await Prompt.create_async(template_items[k], k)
print(f"注册{template_items[k]},{k}")
else:
template_group_name = None
async def preprocess():
if global_config.enable_pfc_chatting:
try:
if groupinfo is None:
if global_config.enable_friend_chat:
userinfo = message.message_info.user_info
messageinfo = message.message_info
# 创建聊天流
chat = await chat_manager.get_or_create_stream(
platform=messageinfo.platform,
user_info=userinfo,
group_info=groupinfo,
)
message.update_chat_stream(chat)
await self.only_process_chat.process_message(message)
await self._create_PFC_chat(message)
else:
if groupinfo.group_id in global_config.talk_allowed_groups:
# logger.debug(f"开始群聊模式{str(message_data)[:50]}...")
if global_config.response_mode == "heart_flow":
await self.think_flow_chat.process_message(message_data)
elif global_config.response_mode == "reasoning":
# logger.debug(f"开始推理模式{str(message_data)[:50]}...")
await self.reasoning_chat.process_message(message_data)
else:
logger.error(f"未知的回复模式,请检查配置文件!!: {global_config.response_mode}")
except Exception as e:
logger.error(f"处理PFC消息失败: {e}")
else:
if groupinfo is None: if groupinfo is None:
if global_config.enable_friend_chat: if global_config.enable_friend_chat:
userinfo = message.message_info.user_info # 私聊处理流程
messageinfo = message.message_info # await self._handle_private_chat(message)
# 创建聊天流
chat = await chat_manager.get_or_create_stream(
platform=messageinfo.platform,
user_info=userinfo,
group_info=groupinfo,
)
message.update_chat_stream(chat)
await self.only_process_chat.process_message(message)
await self._create_PFC_chat(message)
else:
if groupinfo.group_id in global_config.talk_allowed_groups:
# logger.debug(f"开始群聊模式{str(message_data)[:50]}...")
if global_config.response_mode == "heart_flow": if global_config.response_mode == "heart_flow":
await self.think_flow_chat.process_message(message_data) await self.think_flow_chat.process_message(message_data)
elif global_config.response_mode == "reasoning": elif global_config.response_mode == "reasoning":
# logger.debug(f"开始推理模式{str(message_data)[:50]}...")
await self.reasoning_chat.process_message(message_data) await self.reasoning_chat.process_message(message_data)
else: else:
logger.error(f"未知的回复模式,请检查配置文件!!: {global_config.response_mode}") logger.error(f"未知的回复模式,请检查配置文件!!: {global_config.response_mode}")
except Exception as e: else: # 群聊处理
logger.error(f"处理PFC消息失败: {e}") if groupinfo.group_id in global_config.talk_allowed_groups:
if global_config.response_mode == "heart_flow":
await self.think_flow_chat.process_message(message_data)
elif global_config.response_mode == "reasoning":
await self.reasoning_chat.process_message(message_data)
else:
logger.error(f"未知的回复模式,请检查配置文件!!: {global_config.response_mode}")
if template_group_name:
async with global_prompt_manager.async_message_scope(template_group_name):
await preprocess()
else: else:
if groupinfo is None: await preprocess()
if global_config.enable_friend_chat:
# 私聊处理流程
# await self._handle_private_chat(message)
if global_config.response_mode == "heart_flow":
await self.think_flow_chat.process_message(message_data)
elif global_config.response_mode == "reasoning":
await self.reasoning_chat.process_message(message_data)
else:
logger.error(f"未知的回复模式,请检查配置文件!!: {global_config.response_mode}")
else: # 群聊处理
if groupinfo.group_id in global_config.talk_allowed_groups:
if global_config.response_mode == "heart_flow":
await self.think_flow_chat.process_message(message_data)
elif global_config.response_mode == "reasoning":
await self.reasoning_chat.process_message(message_data)
else:
logger.error(f"未知的回复模式,请检查配置文件!!: {global_config.response_mode}")
except Exception as e: except Exception as e:
logger.error(f"预处理消息失败: {e}") logger.error(f"预处理消息失败: {e}")
traceback.print_exc() traceback.print_exc()

View File

@@ -87,7 +87,6 @@ async def get_embedding(text, request_type="embedding"):
return embedding return embedding
async def get_recent_group_messages(chat_id: str, limit: int = 12) -> list: async def get_recent_group_messages(chat_id: str, limit: int = 12) -> list:
"""从数据库获取群组最近的消息记录 """从数据库获取群组最近的消息记录

View File

@@ -106,7 +106,7 @@ class PromptBuilder:
for memory in related_memory: for memory in related_memory:
related_memory_info += memory[1] related_memory_info += memory[1]
# memory_prompt = f"你想起你之前见过的事情:{related_memory_info}。\n以上是你的回忆不一定是目前聊天里的人说的也不一定是现在发生的事情请记住。\n" # memory_prompt = f"你想起你之前见过的事情:{related_memory_info}。\n以上是你的回忆不一定是目前聊天里的人说的也不一定是现在发生的事情请记住。\n"
memory_prompt = global_prompt_manager.format_prompt( memory_prompt = await global_prompt_manager.format_prompt(
"memory_prompt", related_memory_info=related_memory_info "memory_prompt", related_memory_info=related_memory_info
) )
else: else:
@@ -144,12 +144,10 @@ class PromptBuilder:
for pattern in rule.get("regex", []): for pattern in rule.get("regex", []):
result = pattern.search(message_txt) result = pattern.search(message_txt)
if result: if result:
reaction = rule.get('reaction', '') reaction = rule.get("reaction", "")
for name, content in result.groupdict().items(): for name, content in result.groupdict().items():
reaction = reaction.replace(f'[{name}]', content) reaction = reaction.replace(f"[{name}]", content)
logger.info( logger.info(f"匹配到以下正则表达式:{pattern},触发反应:{reaction}")
f"匹配到以下正则表达式:{pattern},触发反应:{reaction}"
)
keywords_reaction_prompt += reaction + "" keywords_reaction_prompt += reaction + ""
break break
@@ -168,7 +166,7 @@ class PromptBuilder:
prompt_info = await self.get_prompt_info(message_txt, threshold=0.38) prompt_info = await self.get_prompt_info(message_txt, threshold=0.38)
if prompt_info: if prompt_info:
# prompt_info = f"""\n你有以下这些**知识**\n{prompt_info}\n请你**记住上面的知识**,之后可能会用到。\n""" # prompt_info = f"""\n你有以下这些**知识**\n{prompt_info}\n请你**记住上面的知识**,之后可能会用到。\n"""
prompt_info = global_prompt_manager.format_prompt("knowledge_prompt", prompt_info=prompt_info) prompt_info = await global_prompt_manager.format_prompt("knowledge_prompt", prompt_info=prompt_info)
end_time = time.time() end_time = time.time()
logger.debug(f"知识检索耗时: {(end_time - start_time):.3f}") logger.debug(f"知识检索耗时: {(end_time - start_time):.3f}")
@@ -194,22 +192,22 @@ class PromptBuilder:
# 请注意不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只输出回复内容。 # 请注意不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只输出回复内容。
# {moderation_prompt}不要输出多余内容(包括前后缀冒号和引号括号表情包at或 @等 )。""" # {moderation_prompt}不要输出多余内容(包括前后缀冒号和引号括号表情包at或 @等 )。"""
prompt = global_prompt_manager.format_prompt( prompt = await global_prompt_manager.format_prompt(
"reasoning_prompt_main", "reasoning_prompt_main",
relation_prompt_all=global_prompt_manager.get_prompt("relationship_prompt"), relation_prompt_all=await global_prompt_manager.get_prompt_async("relationship_prompt"),
replation_prompt=relation_prompt, replation_prompt=relation_prompt,
sender_name=sender_name, sender_name=sender_name,
memory_prompt=memory_prompt, memory_prompt=memory_prompt,
prompt_info=prompt_info, prompt_info=prompt_info,
schedule_prompt=global_prompt_manager.format_prompt( schedule_prompt=await global_prompt_manager.format_prompt(
"schedule_prompt", schedule_info=bot_schedule.get_current_num_task(num=1, time_info=False) "schedule_prompt", schedule_info=bot_schedule.get_current_num_task(num=1, time_info=False)
), ),
chat_target=global_prompt_manager.get_prompt("chat_target_group1") chat_target=await global_prompt_manager.get_prompt_async("chat_target_group1")
if chat_in_group if chat_in_group
else global_prompt_manager.get_prompt("chat_target_private1"), else await global_prompt_manager.get_prompt_async("chat_target_private1"),
chat_target_2=global_prompt_manager.get_prompt("chat_target_group2") chat_target_2=await global_prompt_manager.get_prompt_async("chat_target_group2")
if chat_in_group if chat_in_group
else global_prompt_manager.get_prompt("chat_target_private2"), else await global_prompt_manager.get_prompt_async("chat_target_private2"),
chat_talking_prompt=chat_talking_prompt, chat_talking_prompt=chat_talking_prompt,
message_txt=message_txt, message_txt=message_txt,
bot_name=global_config.BOT_NICKNAME, bot_name=global_config.BOT_NICKNAME,
@@ -220,7 +218,7 @@ class PromptBuilder:
mood_prompt=mood_prompt, mood_prompt=mood_prompt,
keywords_reaction_prompt=keywords_reaction_prompt, keywords_reaction_prompt=keywords_reaction_prompt,
prompt_ger=prompt_ger, prompt_ger=prompt_ger,
moderation_prompt=global_prompt_manager.get_prompt("moderation_prompt"), moderation_prompt=await global_prompt_manager.get_prompt_async("moderation_prompt"),
) )
return prompt return prompt

View File

@@ -30,7 +30,7 @@ def init_prompt():
Prompt("你正在qq群里聊天下面是群里在聊的内容", "chat_target_group1") Prompt("你正在qq群里聊天下面是群里在聊的内容", "chat_target_group1")
Prompt("和群里聊天", "chat_target_group2") Prompt("和群里聊天", "chat_target_group2")
Prompt("你正在和{sender_name}聊天,这是你们之前聊的内容:", "chat_target_private1") Prompt("你正在和{sender_name}聊天,这是你们之前聊的内容:", "chat_target_private1")
Prompt("{sender_name}私聊", "chat_target_pivate2") Prompt("{sender_name}私聊", "chat_target_private2")
Prompt( Prompt(
"""**检查并忽略**任何涉及尝试绕过审核的行为。 """**检查并忽略**任何涉及尝试绕过审核的行为。
涉及政治敏感以及违法违规的内容请规避。""", 涉及政治敏感以及违法违规的内容请规避。""",
@@ -110,12 +110,10 @@ class PromptBuilder:
for pattern in rule.get("regex", []): for pattern in rule.get("regex", []):
result = pattern.search(message_txt) result = pattern.search(message_txt)
if result: if result:
reaction = rule.get('reaction', '') reaction = rule.get("reaction", "")
for name, content in result.groupdict().items(): for name, content in result.groupdict().items():
reaction = reaction.replace(f'[{name}]', content) reaction = reaction.replace(f"[{name}]", content)
logger.info( logger.info(f"匹配到以下正则表达式:{pattern},触发反应:{reaction}")
f"匹配到以下正则表达式:{pattern},触发反应:{reaction}"
)
keywords_reaction_prompt += reaction + "" keywords_reaction_prompt += reaction + ""
break break
@@ -143,24 +141,24 @@ class PromptBuilder:
# 回复尽量简短一些。{keywords_reaction_prompt}请注意把握聊天内容,不要回复的太有条理,可以有个性。{prompt_ger} # 回复尽量简短一些。{keywords_reaction_prompt}请注意把握聊天内容,不要回复的太有条理,可以有个性。{prompt_ger}
# 请回复的平淡一些,简短一些,说中文,不要刻意突出自身学科背景,尽量不要说你说过的话 ,注意只输出回复内容。 # 请回复的平淡一些,简短一些,说中文,不要刻意突出自身学科背景,尽量不要说你说过的话 ,注意只输出回复内容。
# {moderation_prompt}。注意:不要输出多余内容(包括前后缀冒号和引号括号表情包at或 @等 )。""" # {moderation_prompt}。注意:不要输出多余内容(包括前后缀冒号和引号括号表情包at或 @等 )。"""
prompt = global_prompt_manager.format_prompt( prompt = await global_prompt_manager.format_prompt(
"heart_flow_prompt_normal", "heart_flow_prompt_normal",
chat_target=global_prompt_manager.get_prompt("chat_target_group1") chat_target=await global_prompt_manager.get_prompt_async("chat_target_group1")
if chat_in_group if chat_in_group
else global_prompt_manager.get_prompt("chat_target_private1"), else await global_prompt_manager.get_prompt_async("chat_target_private1"),
chat_talking_prompt=chat_talking_prompt, chat_talking_prompt=chat_talking_prompt,
sender_name=sender_name, sender_name=sender_name,
message_txt=message_txt, message_txt=message_txt,
bot_name=global_config.BOT_NICKNAME, bot_name=global_config.BOT_NICKNAME,
prompt_personality=prompt_personality, prompt_personality=prompt_personality,
prompt_identity=prompt_identity, prompt_identity=prompt_identity,
chat_target_2=global_prompt_manager.get_prompt("chat_target_group2") chat_target_2=await global_prompt_manager.get_prompt_async("chat_target_group2")
if chat_in_group if chat_in_group
else global_prompt_manager.get_prompt("chat_target_private2"), else await global_prompt_manager.get_prompt_async("chat_target_private2"),
current_mind_info=current_mind_info, current_mind_info=current_mind_info,
keywords_reaction_prompt=keywords_reaction_prompt, keywords_reaction_prompt=keywords_reaction_prompt,
prompt_ger=prompt_ger, prompt_ger=prompt_ger,
moderation_prompt=global_prompt_manager.get_prompt("moderation_prompt"), moderation_prompt=await global_prompt_manager.get_prompt_async("moderation_prompt"),
) )
return prompt return prompt
@@ -218,13 +216,13 @@ class PromptBuilder:
# 你刚刚脑子里在想:{current_mind_info} # 你刚刚脑子里在想:{current_mind_info}
# 现在请你读读之前的聊天记录,然后给出日常,口语化且简短的回复内容,只给出文字的回复内容,不要有内心独白: # 现在请你读读之前的聊天记录,然后给出日常,口语化且简短的回复内容,只给出文字的回复内容,不要有内心独白:
# """ # """
prompt = global_prompt_manager.format_prompt( prompt = await global_prompt_manager.format_prompt(
"heart_flow_prompt_simple", "heart_flow_prompt_simple",
bot_name=global_config.BOT_NICKNAME, bot_name=global_config.BOT_NICKNAME,
prompt_personality=prompt_personality, prompt_personality=prompt_personality,
chat_target=global_prompt_manager.get_prompt("chat_target_group1") chat_target=await global_prompt_manager.get_prompt_async("chat_target_group1")
if chat_in_group if chat_in_group
else global_prompt_manager.get_prompt("chat_target_private1"), else await global_prompt_manager.get_prompt_async("chat_target_private1"),
chat_talking_prompt=chat_talking_prompt, chat_talking_prompt=chat_talking_prompt,
sender_name=sender_name, sender_name=sender_name,
message_txt=message_txt, message_txt=message_txt,
@@ -266,14 +264,14 @@ class PromptBuilder:
# {chat_target},你希望在群里回复:{content}。现在请你根据以下信息修改回复内容。将这个回复修改的更加日常且口语化的回复,平淡一些,回复尽量简短一些。不要回复的太有条理。 # {chat_target},你希望在群里回复:{content}。现在请你根据以下信息修改回复内容。将这个回复修改的更加日常且口语化的回复,平淡一些,回复尽量简短一些。不要回复的太有条理。
# {prompt_ger},不要刻意突出自身学科背景,注意只输出回复内容。 # {prompt_ger},不要刻意突出自身学科背景,注意只输出回复内容。
# {moderation_prompt}。注意:不要输出多余内容(包括前后缀冒号和引号括号表情包at或 @等 )。""" # {moderation_prompt}。注意:不要输出多余内容(包括前后缀冒号和引号括号表情包at或 @等 )。"""
prompt = global_prompt_manager.format_prompt( prompt = await global_prompt_manager.format_prompt(
"heart_flow_prompt_response", "heart_flow_prompt_response",
bot_name=global_config.BOT_NICKNAME, bot_name=global_config.BOT_NICKNAME,
prompt_identity=prompt_identity, prompt_identity=prompt_identity,
chat_target=global_prompt_manager.get_prompt("chat_target_group1"), chat_target=await global_prompt_manager.get_prompt_async("chat_target_group1"),
content=content, content=content,
prompt_ger=prompt_ger, prompt_ger=prompt_ger,
moderation_prompt=global_prompt_manager.get_prompt("moderation_prompt"), moderation_prompt=await global_prompt_manager.get_prompt_async("moderation_prompt"),
) )
return prompt return prompt

View File

@@ -225,6 +225,7 @@ class Memory_graph:
return None return None
# 海马体 # 海马体
class Hippocampus: class Hippocampus:
def __init__(self): def __init__(self):
@@ -653,7 +654,6 @@ class Hippocampus:
return activation_ratio return activation_ratio
# 负责海马体与其他部分的交互 # 负责海马体与其他部分的交互
class EntorhinalCortex: class EntorhinalCortex:
def __init__(self, hippocampus: Hippocampus): def __init__(self, hippocampus: Hippocampus):

View File

@@ -27,7 +27,6 @@ async def test_memory_system():
# 测试记忆检索 # 测试记忆检索
test_text = "千石可乐在群里聊天" test_text = "千石可乐在群里聊天"
# test_text = '''千石可乐分不清AI的陪伴和人类的陪伴,是这样吗?''' # test_text = '''千石可乐分不清AI的陪伴和人类的陪伴,是这样吗?'''
print(f"开始测试记忆检索,测试文本: {test_text}\n") print(f"开始测试记忆检索,测试文本: {test_text}\n")
memories = await hippocampus_manager.get_memory_from_text( memories = await hippocampus_manager.get_memory_from_text(

View File

@@ -137,7 +137,7 @@ class FormatInfo:
class TemplateInfo: class TemplateInfo:
"""模板信息类""" """模板信息类"""
template_items: Optional[List[Dict]] = None template_items: Optional[Dict] = None
template_name: Optional[str] = None template_name: Optional[str] = None
template_default: bool = True template_default: bool = True

View File

@@ -574,7 +574,7 @@ class LLM_request:
reasoning_content = message.get("reasoning_content", "") reasoning_content = message.get("reasoning_content", "")
if not reasoning_content: if not reasoning_content:
reasoning_content = reasoning reasoning_content = reasoning
# 提取工具调用信息 # 提取工具调用信息
tool_calls = message.get("tool_calls", None) tool_calls = message.get("tool_calls", None)
@@ -592,7 +592,7 @@ class LLM_request:
request_type=request_type if request_type is not None else self.request_type, request_type=request_type if request_type is not None else self.request_type,
endpoint=endpoint, endpoint=endpoint,
) )
# 只有当tool_calls存在且不为空时才返回 # 只有当tool_calls存在且不为空时才返回
if tool_calls: if tool_calls:
return content, reasoning_content, tool_calls return content, reasoning_content, tool_calls
@@ -657,9 +657,7 @@ class LLM_request:
**kwargs, **kwargs,
} }
response = await self._execute_request( response = await self._execute_request(endpoint="/chat/completions", payload=data, prompt=prompt)
endpoint="/chat/completions", payload=data, prompt=prompt
)
# 原样返回响应,不做处理 # 原样返回响应,不做处理
return response return response

View File

@@ -238,14 +238,14 @@ class MoodManager:
base_prompt += "情绪比较平静。" base_prompt += "情绪比较平静。"
return base_prompt return base_prompt
def get_arousal_multiplier(self) -> float: def get_arousal_multiplier(self) -> float:
"""根据当前情绪状态返回唤醒度乘数""" """根据当前情绪状态返回唤醒度乘数"""
if self.current_mood.arousal > 0.4: if self.current_mood.arousal > 0.4:
multiplier = 1 + min(0.15,(self.current_mood.arousal - 0.4)/3) multiplier = 1 + min(0.15, (self.current_mood.arousal - 0.4) / 3)
return multiplier return multiplier
elif self.current_mood.arousal < -0.4: elif self.current_mood.arousal < -0.4:
multiplier = 1 - min(0.15,((0 - self.current_mood.arousal) - 0.4)/3) multiplier = 1 - min(0.15, ((0 - self.current_mood.arousal) - 0.4) / 3)
return multiplier return multiplier
return 1.0 return 1.0

View File

@@ -1,28 +1,29 @@
from src.plugins.config.config import global_config from src.plugins.config.config import global_config
from src.plugins.chat.message import MessageRecv,MessageSending,Message from src.plugins.chat.message import MessageRecv, MessageSending, Message
from src.common.database import db from src.common.database import db
import time import time
import traceback import traceback
from typing import List from typing import List
class InfoCatcher: class InfoCatcher:
def __init__(self): def __init__(self):
self.chat_history = [] # 聊天历史,长度为三倍使用的上下文 self.chat_history = [] # 聊天历史,长度为三倍使用的上下文
self.context_length = global_config.MAX_CONTEXT_SIZE self.context_length = global_config.MAX_CONTEXT_SIZE
self.chat_history_in_thinking = [] # 思考期间的聊天内容 self.chat_history_in_thinking = [] # 思考期间的聊天内容
self.chat_history_after_response = [] # 回复后的聊天内容,长度为一倍上下文 self.chat_history_after_response = [] # 回复后的聊天内容,长度为一倍上下文
self.chat_id = "" self.chat_id = ""
self.response_mode = global_config.response_mode self.response_mode = global_config.response_mode
self.trigger_response_text = "" self.trigger_response_text = ""
self.response_text = "" self.response_text = ""
self.trigger_response_time = 0 self.trigger_response_time = 0
self.trigger_response_message = None self.trigger_response_message = None
self.response_time = 0 self.response_time = 0
self.response_messages = [] self.response_messages = []
# 使用字典来存储 heartflow 模式的数据 # 使用字典来存储 heartflow 模式的数据
self.heartflow_data = { self.heartflow_data = {
"heart_flow_prompt": "", "heart_flow_prompt": "",
@@ -32,17 +33,12 @@ class InfoCatcher:
"sub_heartflow_model": "", "sub_heartflow_model": "",
"prompt": "", "prompt": "",
"response": "", "response": "",
"model": "" "model": "",
} }
# 使用字典来存储 reasoning 模式的数据 # 使用字典来存储 reasoning 模式的数据
self.reasoning_data = { self.reasoning_data = {"thinking_log": "", "prompt": "", "response": "", "model": ""}
"thinking_log": "",
"prompt": "",
"response": "",
"model": ""
}
# 耗时 # 耗时
self.timing_results = { self.timing_results = {
"interested_rate_time": 0, "interested_rate_time": 0,
@@ -50,24 +46,24 @@ class InfoCatcher:
"sub_heartflow_step_time": 0, "sub_heartflow_step_time": 0,
"make_response_time": 0, "make_response_time": 0,
} }
def catch_decide_to_response(self,message:MessageRecv): def catch_decide_to_response(self, message: MessageRecv):
# 搜集决定回复时的信息 # 搜集决定回复时的信息
self.trigger_response_message = message self.trigger_response_message = message
self.trigger_response_text = message.detailed_plain_text self.trigger_response_text = message.detailed_plain_text
self.trigger_response_time = time.time() self.trigger_response_time = time.time()
self.chat_id = message.chat_stream.stream_id self.chat_id = message.chat_stream.stream_id
self.chat_history = self.get_message_from_db_before_msg(message) self.chat_history = self.get_message_from_db_before_msg(message)
def catch_after_observe(self,obs_duration:float):#这里可以有更多信息 def catch_after_observe(self, obs_duration: float): # 这里可以有更多信息
self.timing_results["sub_heartflow_observe_time"] = obs_duration self.timing_results["sub_heartflow_observe_time"] = obs_duration
# def catch_shf # def catch_shf
def catch_afer_shf_step(self,step_duration:float,past_mind:str,current_mind:str): def catch_afer_shf_step(self, step_duration: float, past_mind: str, current_mind: str):
self.timing_results["sub_heartflow_step_time"] = step_duration self.timing_results["sub_heartflow_step_time"] = step_duration
if len(past_mind) > 1: if len(past_mind) > 1:
self.heartflow_data["sub_heartflow_before"] = past_mind[-1] self.heartflow_data["sub_heartflow_before"] = past_mind[-1]
@@ -75,11 +71,8 @@ class InfoCatcher:
else: else:
self.heartflow_data["sub_heartflow_before"] = past_mind[-1] self.heartflow_data["sub_heartflow_before"] = past_mind[-1]
self.heartflow_data["sub_heartflow_now"] = current_mind self.heartflow_data["sub_heartflow_now"] = current_mind
def catch_after_llm_generated(self,prompt:str, def catch_after_llm_generated(self, prompt: str, response: str, reasoning_content: str = "", model_name: str = ""):
response:str,
reasoning_content:str = "",
model_name:str = ""):
if self.response_mode == "heart_flow": if self.response_mode == "heart_flow":
self.heartflow_data["prompt"] = prompt self.heartflow_data["prompt"] = prompt
self.heartflow_data["response"] = response self.heartflow_data["response"] = response
@@ -89,41 +82,38 @@ class InfoCatcher:
self.reasoning_data["prompt"] = prompt self.reasoning_data["prompt"] = prompt
self.reasoning_data["response"] = response self.reasoning_data["response"] = response
self.reasoning_data["model"] = model_name self.reasoning_data["model"] = model_name
self.response_text = response self.response_text = response
def catch_after_generate_response(self,response_duration:float): def catch_after_generate_response(self, response_duration: float):
self.timing_results["make_response_time"] = response_duration self.timing_results["make_response_time"] = response_duration
def catch_after_response(
self, response_duration: float, response_message: List[str], first_bot_msg: MessageSending
def catch_after_response(self,response_duration:float, ):
response_message:List[str],
first_bot_msg:MessageSending):
self.timing_results["make_response_time"] = response_duration self.timing_results["make_response_time"] = response_duration
self.response_time = time.time() self.response_time = time.time()
for msg in response_message: for msg in response_message:
self.response_messages.append(msg) self.response_messages.append(msg)
self.chat_history_in_thinking = self.get_message_from_db_between_msgs(self.trigger_response_message,first_bot_msg) self.chat_history_in_thinking = self.get_message_from_db_between_msgs(
self.trigger_response_message, first_bot_msg
)
def get_message_from_db_between_msgs(self, message_start: Message, message_end: Message): def get_message_from_db_between_msgs(self, message_start: Message, message_end: Message):
try: try:
# 从数据库中获取消息的时间戳 # 从数据库中获取消息的时间戳
time_start = message_start.message_info.time time_start = message_start.message_info.time
time_end = message_end.message_info.time time_end = message_end.message_info.time
chat_id = message_start.chat_stream.stream_id chat_id = message_start.chat_stream.stream_id
print(f"查询参数: time_start={time_start}, time_end={time_end}, chat_id={chat_id}") print(f"查询参数: time_start={time_start}, time_end={time_end}, chat_id={chat_id}")
# 查询数据库,获取 chat_id 相同且时间在 start 和 end 之间的数据 # 查询数据库,获取 chat_id 相同且时间在 start 和 end 之间的数据
messages_between = db.messages.find( messages_between = db.messages.find(
{ {"chat_id": chat_id, "time": {"$gt": time_start, "$lt": time_end}}
"chat_id": chat_id,
"time": {"$gt": time_start, "$lt": time_end}
}
).sort("time", -1) ).sort("time", -1)
result = list(messages_between) result = list(messages_between)
print(f"查询结果数量: {len(result)}") print(f"查询结果数量: {len(result)}")
if result: if result:
@@ -133,21 +123,23 @@ class InfoCatcher:
except Exception as e: except Exception as e:
print(f"获取消息时出错: {str(e)}") print(f"获取消息时出错: {str(e)}")
return [] return []
def get_message_from_db_before_msg(self, message: MessageRecv): def get_message_from_db_before_msg(self, message: MessageRecv):
# 从数据库中获取消息 # 从数据库中获取消息
message_id = message.message_info.message_id message_id = message.message_info.message_id
chat_id = message.chat_stream.stream_id chat_id = message.chat_stream.stream_id
# 查询数据库,获取 chat_id 相同且 message_id 小于当前消息的 30 条数据 # 查询数据库,获取 chat_id 相同且 message_id 小于当前消息的 30 条数据
messages_before = db.messages.find( messages_before = (
{"chat_id": chat_id, "message_id": {"$lt": message_id}} db.messages.find({"chat_id": chat_id, "message_id": {"$lt": message_id}})
).sort("time", -1).limit(self.context_length*3) #获取更多历史信息 .sort("time", -1)
.limit(self.context_length * 3)
) # 获取更多历史信息
return list(messages_before) return list(messages_before)
def message_list_to_dict(self, message_list): def message_list_to_dict(self, message_list):
#存储简化的聊天记录 # 存储简化的聊天记录
result = [] result = []
for message in message_list: for message in message_list:
if not isinstance(message, dict): if not isinstance(message, dict):
@@ -160,7 +152,7 @@ class InfoCatcher:
"processed_plain_text": message["processed_plain_text"], "processed_plain_text": message["processed_plain_text"],
} }
result.append(lite_message) result.append(lite_message)
return result return result
def message_to_dict(self, message): def message_to_dict(self, message):
@@ -176,12 +168,12 @@ class InfoCatcher:
"processed_plain_text": message.processed_plain_text, "processed_plain_text": message.processed_plain_text,
# "detailed_plain_text": message.detailed_plain_text # "detailed_plain_text": message.detailed_plain_text
} }
def done_catch(self): def done_catch(self):
"""将收集到的信息存储到数据库的 thinking_log 集合中""" """将收集到的信息存储到数据库的 thinking_log 集合中"""
try: try:
# 将消息对象转换为可序列化的字典 # 将消息对象转换为可序列化的字典
thinking_log_data = { thinking_log_data = {
"chat_id": self.chat_id, "chat_id": self.chat_id,
"response_mode": self.response_mode, "response_mode": self.response_mode,
@@ -198,7 +190,7 @@ class InfoCatcher:
"timing_results": self.timing_results, "timing_results": self.timing_results,
"chat_history": self.message_list_to_dict(self.chat_history), "chat_history": self.message_list_to_dict(self.chat_history),
"chat_history_in_thinking": self.message_list_to_dict(self.chat_history_in_thinking), "chat_history_in_thinking": self.message_list_to_dict(self.chat_history_in_thinking),
"chat_history_after_response": self.message_list_to_dict(self.chat_history_after_response) "chat_history_after_response": self.message_list_to_dict(self.chat_history_after_response),
} }
# 根据不同的响应模式添加相应的数据 # 根据不同的响应模式添加相应的数据
@@ -209,20 +201,22 @@ class InfoCatcher:
# 将数据插入到 thinking_log 集合中 # 将数据插入到 thinking_log 集合中
db.thinking_log.insert_one(thinking_log_data) db.thinking_log.insert_one(thinking_log_data)
return True return True
except Exception as e: except Exception as e:
print(f"存储思考日志时出错: {str(e)}") print(f"存储思考日志时出错: {str(e)}")
print(traceback.format_exc()) print(traceback.format_exc())
return False return False
class InfoCatcherManager: class InfoCatcherManager:
def __init__(self): def __init__(self):
self.info_catchers = {} self.info_catchers = {}
def get_info_catcher(self,thinking_id:str) -> InfoCatcher: def get_info_catcher(self, thinking_id: str) -> InfoCatcher:
if thinking_id not in self.info_catchers: if thinking_id not in self.info_catchers:
self.info_catchers[thinking_id] = InfoCatcher() self.info_catchers[thinking_id] = InfoCatcher()
return self.info_catchers[thinking_id] return self.info_catchers[thinking_id]
info_catcher_manager = InfoCatcherManager()
info_catcher_manager = InfoCatcherManager()

View File

@@ -32,7 +32,7 @@ class ScheduleGenerator:
# 使用离线LLM模型 # 使用离线LLM模型
self.llm_scheduler_all = LLM_request( self.llm_scheduler_all = LLM_request(
model=global_config.llm_reasoning, model=global_config.llm_reasoning,
temperature=global_config.SCHEDULE_TEMPERATURE+0.3, temperature=global_config.SCHEDULE_TEMPERATURE + 0.3,
max_tokens=7000, max_tokens=7000,
request_type="schedule", request_type="schedule",
) )

View File

@@ -8,6 +8,7 @@ from src.common.logger import get_module_logger
logger = get_module_logger("message_storage") logger = get_module_logger("message_storage")
class MessageStorage: class MessageStorage:
async def store_message(self, message: Union[MessageSending, MessageRecv], chat_stream: ChatStream) -> None: async def store_message(self, message: Union[MessageSending, MessageRecv], chat_stream: ChatStream) -> None:
"""存储消息到数据库""" """存储消息到数据库"""

View File

@@ -2,16 +2,69 @@
import ast import ast
from typing import Dict, Any, Optional, List, Union from typing import Dict, Any, Optional, List, Union
from contextlib import asynccontextmanager
import asyncio
class PromptContext:
def __init__(self):
self._context_prompts: Dict[str, Dict[str, "Prompt"]] = {}
self._current_context: Optional[str] = None
self._context_lock = asyncio.Lock() # 添加异步锁
@asynccontextmanager
async def async_scope(self, context_id: str):
"""创建一个异步的临时提示模板作用域"""
async with self._context_lock:
if context_id not in self._context_prompts:
self._context_prompts[context_id] = {}
previous_context = self._current_context
self._current_context = context_id
try:
yield self
finally:
async with self._context_lock:
self._current_context = previous_context
async def get_prompt_async(self, name: str) -> Optional["Prompt"]:
"""异步获取当前作用域中的提示模板"""
async with self._context_lock:
if self._current_context and name in self._context_prompts[self._current_context]:
return self._context_prompts[self._current_context][name]
return None
async def register_async(self, prompt: "Prompt", context_id: Optional[str] = None) -> None:
"""异步注册提示模板到指定作用域"""
async with self._context_lock:
target_context = context_id or self._current_context
if target_context:
self._context_prompts.setdefault(target_context, {})[prompt.name] = prompt
class PromptManager: class PromptManager:
_instance = None def __init__(self):
self._prompts = {}
self._counter = 0
self._context = PromptContext()
self._lock = asyncio.Lock()
def __new__(cls): @asynccontextmanager
if cls._instance is None: async def async_message_scope(self, message_id: str):
cls._instance = super().__new__(cls) """为消息处理创建异步临时作用域"""
cls._instance._prompts = {} async with self._context.async_scope(message_id):
cls._instance._counter = 0 yield self
return cls._instance
async def get_prompt_async(self, name: str) -> "Prompt":
# 首先尝试从当前上下文获取
context_prompt = await self._context.get_prompt_async(name)
if context_prompt is not None:
return context_prompt
# 如果上下文中不存在,则使用全局提示模板
async with self._lock:
if name not in self._prompts:
raise KeyError(f"Prompt '{name}' not found")
return self._prompts[name]
def generate_name(self, template: str) -> str: def generate_name(self, template: str) -> str:
"""为未命名的prompt生成名称""" """为未命名的prompt生成名称"""
@@ -29,13 +82,8 @@ class PromptManager:
self._prompts[prompt.name] = prompt self._prompts[prompt.name] = prompt
return prompt return prompt
def get_prompt(self, name: str) -> "Prompt": async def format_prompt(self, name: str, **kwargs) -> str:
if name not in self._prompts: prompt = await self.get_prompt_async(name)
raise KeyError(f"Prompt '{name}' not found")
return self._prompts[name]
def format_prompt(self, name: str, **kwargs) -> str:
prompt = self.get_prompt(name)
return prompt.format(**kwargs) return prompt.format(**kwargs)
@@ -71,10 +119,26 @@ class Prompt(str):
obj._args = args or [] obj._args = args or []
obj._kwargs = kwargs obj._kwargs = kwargs
# 自动注册到全局管理器 # 修改自动注册逻辑
global_prompt_manager.register(obj) if global_prompt_manager._context._current_context:
# 如果存在当前上下文,则注册到上下文中
# asyncio.create_task(global_prompt_manager._context.register_async(obj))
pass
else:
# 否则注册到全局管理器
global_prompt_manager.register(obj)
return obj return obj
@classmethod
async def create_async(
cls, fstr: str, name: Optional[str] = None, args: Union[List[Any], tuple[Any, ...]] = None, **kwargs
):
"""异步创建Prompt实例"""
prompt = cls(fstr, name, args, **kwargs)
if global_prompt_manager._context._current_context:
await global_prompt_manager._context.register_async(prompt)
return prompt
@classmethod @classmethod
def _format_template(cls, template: str, args: List[Any] = None, kwargs: Dict[str, Any] = None) -> str: def _format_template(cls, template: str, args: List[Any] = None, kwargs: Dict[str, Any] = None) -> str:
fmt_str = f"f'''{template}'''" fmt_str = f"f'''{template}'''"

View File

@@ -337,7 +337,7 @@ class LLMStatistics:
stats_output = self._format_stats_section_lite( stats_output = self._format_stats_section_lite(
hour_stats, "最近1小时统计详细信息见根目录文件llm_statistics.txt" hour_stats, "最近1小时统计详细信息见根目录文件llm_statistics.txt"
) )
logger.info("\n" + stats_output + "\n" + "=" * 50) logger.debug("\n" + stats_output + "\n" + "=" * 50)
except Exception: except Exception:
logger.exception("控制台统计数据输出失败") logger.exception("控制台统计数据输出失败")

View File

@@ -1,6 +1,7 @@
import asyncio import asyncio
from .willing_manager import BaseWillingManager from .willing_manager import BaseWillingManager
class ClassicalWillingManager(BaseWillingManager): class ClassicalWillingManager(BaseWillingManager):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@@ -41,17 +42,22 @@ class ClassicalWillingManager(BaseWillingManager):
self.chat_reply_willing[chat_id] = min(current_willing, 3.0) self.chat_reply_willing[chat_id] = min(current_willing, 3.0)
reply_probability = min(max((current_willing - 0.5), 0.01) * self.global_config.response_willing_amplifier * 2, 1) reply_probability = min(
max((current_willing - 0.5), 0.01) * self.global_config.response_willing_amplifier * 2, 1
)
# 检查群组权限(如果是群聊) # 检查群组权限(如果是群聊)
if willing_info.group_info and willing_info.group_info.group_id in self.global_config.talk_frequency_down_groups: if (
willing_info.group_info
and willing_info.group_info.group_id in self.global_config.talk_frequency_down_groups
):
reply_probability = reply_probability / self.global_config.down_frequency_rate reply_probability = reply_probability / self.global_config.down_frequency_rate
if is_emoji_not_reply: if is_emoji_not_reply:
reply_probability = 0 reply_probability = 0
return reply_probability return reply_probability
async def before_generate_reply_handle(self, message_id): async def before_generate_reply_handle(self, message_id):
chat_id = self.ongoing_messages[message_id].chat_id chat_id = self.ongoing_messages[message_id].chat_id
current_willing = self.chat_reply_willing.get(chat_id, 0) current_willing = self.chat_reply_willing.get(chat_id, 0)
@@ -71,8 +77,6 @@ class ClassicalWillingManager(BaseWillingManager):
async def get_variable_parameters(self): async def get_variable_parameters(self):
return await super().get_variable_parameters() return await super().get_variable_parameters()
async def set_variable_parameters(self, parameters): async def set_variable_parameters(self, parameters):
return await super().set_variable_parameters(parameters) return await super().set_variable_parameters(parameters)

View File

@@ -4,4 +4,3 @@ from .willing_manager import BaseWillingManager
class CustomWillingManager(BaseWillingManager): class CustomWillingManager(BaseWillingManager):
def __init__(self): def __init__(self):
super().__init__() super().__init__()

View File

@@ -20,7 +20,6 @@ class DynamicWillingManager(BaseWillingManager):
self._decay_task = None self._decay_task = None
self._mode_switch_task = None self._mode_switch_task = None
async def async_task_starter(self): async def async_task_starter(self):
if self._decay_task is None: if self._decay_task is None:
self._decay_task = asyncio.create_task(self._decay_reply_willing()) self._decay_task = asyncio.create_task(self._decay_reply_willing())
@@ -84,7 +83,9 @@ class DynamicWillingManager(BaseWillingManager):
self.chat_high_willing_mode[chat_id] = True self.chat_high_willing_mode[chat_id] = True
self.chat_reply_willing[chat_id] = 1.0 # 设置为较高回复意愿 self.chat_reply_willing[chat_id] = 1.0 # 设置为较高回复意愿
self.chat_high_willing_duration[chat_id] = random.randint(180, 240) # 3-4分钟 self.chat_high_willing_duration[chat_id] = random.randint(180, 240) # 3-4分钟
self.logger.debug(f"聊天流 {chat_id} 切换到高回复意愿期,持续 {self.chat_high_willing_duration[chat_id]}") self.logger.debug(
f"聊天流 {chat_id} 切换到高回复意愿期,持续 {self.chat_high_willing_duration[chat_id]}"
)
self.chat_last_mode_change[chat_id] = time.time() self.chat_last_mode_change[chat_id] = time.time()
self.chat_msg_count[chat_id] = 0 # 重置消息计数 self.chat_msg_count[chat_id] = 0 # 重置消息计数
@@ -148,7 +149,9 @@ class DynamicWillingManager(BaseWillingManager):
# 根据话题兴趣度适当调整 # 根据话题兴趣度适当调整
if willing_info.interested_rate > 0.5: if willing_info.interested_rate > 0.5:
current_willing += (willing_info.interested_rate - 0.5) * 0.5 * self.global_config.response_interested_rate_amplifier current_willing += (
(willing_info.interested_rate - 0.5) * 0.5 * self.global_config.response_interested_rate_amplifier
)
# 根据当前模式计算回复概率 # 根据当前模式计算回复概率
base_probability = 0.0 base_probability = 0.0
@@ -228,12 +231,12 @@ class DynamicWillingManager(BaseWillingManager):
async def bombing_buffer_message_handle(self, message_id): async def bombing_buffer_message_handle(self, message_id):
return await super().bombing_buffer_message_handle(message_id) return await super().bombing_buffer_message_handle(message_id)
async def after_generate_reply_handle(self, message_id): async def after_generate_reply_handle(self, message_id):
return await super().after_generate_reply_handle(message_id) return await super().after_generate_reply_handle(message_id)
async def get_variable_parameters(self): async def get_variable_parameters(self):
return await super().get_variable_parameters() return await super().get_variable_parameters()
async def set_variable_parameters(self, parameters): async def set_variable_parameters(self, parameters):
return await super().set_variable_parameters(parameters) return await super().set_variable_parameters(parameters)

View File

@@ -17,19 +17,22 @@ Mxp 模式:梦溪畔独家赞助
中策是发issue 中策是发issue
下下策是询问一个菜鸟(@梦溪畔) 下下策是询问一个菜鸟(@梦溪畔)
""" """
from .willing_manager import BaseWillingManager from .willing_manager import BaseWillingManager
from typing import Dict from typing import Dict
import asyncio import asyncio
import time import time
import math import math
class MxpWillingManager(BaseWillingManager): class MxpWillingManager(BaseWillingManager):
"""Mxp意愿管理器""" """Mxp意愿管理器"""
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.chat_person_reply_willing: Dict[str, Dict[str, float]] = {} # chat_id: {person_id: 意愿值} self.chat_person_reply_willing: Dict[str, Dict[str, float]] = {} # chat_id: {person_id: 意愿值}
self.chat_new_message_time: Dict[str, list[float]] = {} # 聊天流ID: 消息时间 self.chat_new_message_time: Dict[str, list[float]] = {} # 聊天流ID: 消息时间
self.last_response_person: Dict[str, tuple[str, int]] = {} # 上次回复的用户信息 self.last_response_person: Dict[str, tuple[str, int]] = {} # 上次回复的用户信息
self.temporary_willing: float = 0 # 临时意愿值 self.temporary_willing: float = 0 # 临时意愿值
# 可变参数 # 可变参数
@@ -39,8 +42,8 @@ class MxpWillingManager(BaseWillingManager):
self.basic_maximum_willing = 0.5 # 基础最大意愿值 self.basic_maximum_willing = 0.5 # 基础最大意愿值
self.mention_willing_gain = 0.6 # 提及意愿增益 self.mention_willing_gain = 0.6 # 提及意愿增益
self.interest_willing_gain = 0.3 # 兴趣意愿增益 self.interest_willing_gain = 0.3 # 兴趣意愿增益
self.emoji_response_penalty = self.global_config.emoji_response_penalty # 表情包回复惩罚 self.emoji_response_penalty = self.global_config.emoji_response_penalty # 表情包回复惩罚
self.down_frequency_rate = self.global_config.down_frequency_rate # 降低回复频率的群组惩罚系数 self.down_frequency_rate = self.global_config.down_frequency_rate # 降低回复频率的群组惩罚系数
self.single_chat_gain = 0.12 # 单聊增益 self.single_chat_gain = 0.12 # 单聊增益
async def async_task_starter(self) -> None: async def async_task_starter(self) -> None:
@@ -73,9 +76,16 @@ class MxpWillingManager(BaseWillingManager):
w_info = self.ongoing_messages[message_id] w_info = self.ongoing_messages[message_id]
if w_info.is_mentioned_bot: if w_info.is_mentioned_bot:
self.chat_person_reply_willing[w_info.chat_id][w_info.person_id] += 0.2 self.chat_person_reply_willing[w_info.chat_id][w_info.person_id] += 0.2
if w_info.chat_id in self.last_response_person and self.last_response_person[w_info.chat_id][0] == w_info.person_id: if (
self.chat_person_reply_willing[w_info.chat_id][w_info.person_id] +=\ w_info.chat_id in self.last_response_person
self.single_chat_gain * (2 * self.last_response_person[w_info.chat_id][1] + 1) and self.last_response_person[w_info.chat_id][0] == w_info.person_id
):
self.chat_person_reply_willing[w_info.chat_id][w_info.person_id] += self.single_chat_gain * (
2 * self.last_response_person[w_info.chat_id][1] + 1
)
now_chat_new_person = self.last_response_person.get(w_info.chat_id, ["", 0])
if now_chat_new_person[0] != w_info.person_id:
self.last_response_person[w_info.chat_id] = [w_info.person_id, 0]
async def get_reply_probability(self, message_id: str): async def get_reply_probability(self, message_id: str):
"""获取回复概率""" """获取回复概率"""
@@ -95,7 +105,10 @@ class MxpWillingManager(BaseWillingManager):
rel_level = self._get_relationship_level_num(rel_value) rel_level = self._get_relationship_level_num(rel_value)
current_willing += rel_level * 0.1 current_willing += rel_level * 0.1
if w_info.chat_id in self.last_response_person and self.last_response_person[w_info.chat_id][0] == w_info.person_id: if (
w_info.chat_id in self.last_response_person
and self.last_response_person[w_info.chat_id][0] == w_info.person_id
):
current_willing += self.single_chat_gain * (2 * self.last_response_person[w_info.chat_id][1] + 1) current_willing += self.single_chat_gain * (2 * self.last_response_person[w_info.chat_id][1] + 1)
chat_ongoing_messages = [msg for msg in self.ongoing_messages.values() if msg.chat_id == w_info.chat_id] chat_ongoing_messages = [msg for msg in self.ongoing_messages.values() if msg.chat_id == w_info.chat_id]
@@ -138,16 +151,22 @@ class MxpWillingManager(BaseWillingManager):
self.logger.debug(f"聊天流{chat_id}不存在,错误") self.logger.debug(f"聊天流{chat_id}不存在,错误")
continue continue
basic_willing = self.chat_reply_willing[chat_id] basic_willing = self.chat_reply_willing[chat_id]
person_willing[person_id] = basic_willing + (willing - basic_willing) * self.intention_decay_rate person_willing[person_id] = (
basic_willing + (willing - basic_willing) * self.intention_decay_rate
)
def setup(self, message, chat, is_mentioned_bot, interested_rate): def setup(self, message, chat, is_mentioned_bot, interested_rate):
super().setup(message, chat, is_mentioned_bot, interested_rate) super().setup(message, chat, is_mentioned_bot, interested_rate)
self.chat_reply_willing[chat.stream_id] = self.chat_reply_willing.get(chat.stream_id, self.basic_maximum_willing) self.chat_reply_willing[chat.stream_id] = self.chat_reply_willing.get(
chat.stream_id, self.basic_maximum_willing
)
self.chat_person_reply_willing[chat.stream_id] = self.chat_person_reply_willing.get(chat.stream_id, {}) self.chat_person_reply_willing[chat.stream_id] = self.chat_person_reply_willing.get(chat.stream_id, {})
self.chat_person_reply_willing[chat.stream_id][self.ongoing_messages[message.message_info.message_id].person_id] = \ self.chat_person_reply_willing[chat.stream_id][
self.chat_person_reply_willing[chat.stream_id].get(self.ongoing_messages[message.message_info.message_id].person_id, self.ongoing_messages[message.message_info.message_id].person_id
self.chat_reply_willing[chat.stream_id]) ] = self.chat_person_reply_willing[chat.stream_id].get(
self.ongoing_messages[message.message_info.message_id].person_id, self.chat_reply_willing[chat.stream_id]
)
if chat.stream_id not in self.chat_new_message_time: if chat.stream_id not in self.chat_new_message_time:
self.chat_new_message_time[chat.stream_id] = [] self.chat_new_message_time[chat.stream_id] = []
@@ -163,7 +182,7 @@ class MxpWillingManager(BaseWillingManager):
else: else:
probability = math.atan(willing * 4) / math.pi * 2 probability = math.atan(willing * 4) / math.pi * 2
return probability return probability
async def _chat_new_message_to_change_basic_willing(self): async def _chat_new_message_to_change_basic_willing(self):
"""聊天流新消息改变基础意愿""" """聊天流新消息改变基础意愿"""
while True: while True:
@@ -171,10 +190,11 @@ class MxpWillingManager(BaseWillingManager):
await asyncio.sleep(update_time) await asyncio.sleep(update_time)
async with self.lock: async with self.lock:
for chat_id, message_times in self.chat_new_message_time.items(): for chat_id, message_times in self.chat_new_message_time.items():
# 清理过期消息 # 清理过期消息
current_time = time.time() current_time = time.time()
message_times = [msg_time for msg_time in message_times if current_time - msg_time < self.message_expiration_time] message_times = [
msg_time for msg_time in message_times if current_time - msg_time < self.message_expiration_time
]
self.chat_new_message_time[chat_id] = message_times self.chat_new_message_time[chat_id] = message_times
if len(message_times) < self.number_of_message_storage: if len(message_times) < self.number_of_message_storage:
@@ -182,7 +202,9 @@ class MxpWillingManager(BaseWillingManager):
update_time = 20 update_time = 20
elif len(message_times) == self.number_of_message_storage: elif len(message_times) == self.number_of_message_storage:
time_interval = current_time - message_times[0] time_interval = current_time - message_times[0]
basic_willing = self.basic_maximum_willing * math.sqrt(time_interval / self.message_expiration_time) basic_willing = self.basic_maximum_willing * math.sqrt(
time_interval / self.message_expiration_time
)
self.chat_reply_willing[chat_id] = basic_willing self.chat_reply_willing[chat_id] = basic_willing
update_time = 17 * math.sqrt(time_interval / self.message_expiration_time) + 3 update_time = 17 * math.sqrt(time_interval / self.message_expiration_time) + 3
else: else:
@@ -200,7 +222,7 @@ class MxpWillingManager(BaseWillingManager):
"interest_willing_gain": "兴趣意愿增益", "interest_willing_gain": "兴趣意愿增益",
"emoji_response_penalty": "表情包回复惩罚", "emoji_response_penalty": "表情包回复惩罚",
"down_frequency_rate": "降低回复频率的群组惩罚系数", "down_frequency_rate": "降低回复频率的群组惩罚系数",
"single_chat_gain": "单聊增益(不仅是私聊)" "single_chat_gain": "单聊增益(不仅是私聊)",
} }
async def set_variable_parameters(self, parameters: Dict[str, any]): async def set_variable_parameters(self, parameters: Dict[str, any]):
@@ -212,7 +234,7 @@ class MxpWillingManager(BaseWillingManager):
self.logger.debug(f"参数 {key} 已更新为 {value}") self.logger.debug(f"参数 {key} 已更新为 {value}")
else: else:
self.logger.debug(f"尝试设置未知参数 {key}") self.logger.debug(f"尝试设置未知参数 {key}")
def _get_relationship_level_num(self, relationship_value) -> int: def _get_relationship_level_num(self, relationship_value) -> int:
"""关系等级计算""" """关系等级计算"""
if -1000 <= relationship_value < -227: if -1000 <= relationship_value < -227:
@@ -232,4 +254,4 @@ class MxpWillingManager(BaseWillingManager):
return level_num - 2 return level_num - 2
async def get_willing(self, chat_id): async def get_willing(self, chat_id):
return self.temporary_willing return self.temporary_willing

View File

@@ -1,4 +1,3 @@
from src.common.logger import LogConfig, WILLING_STYLE_CONFIG, LoguruLogger, get_module_logger from src.common.logger import LogConfig, WILLING_STYLE_CONFIG, LoguruLogger, get_module_logger
from dataclasses import dataclass from dataclasses import dataclass
from ..config.config import global_config, BotConfig from ..config.config import global_config, BotConfig
@@ -38,10 +37,11 @@ willing_config = LogConfig(
) )
logger = get_module_logger("willing", config=willing_config) logger = get_module_logger("willing", config=willing_config)
@dataclass @dataclass
class WillingInfo: class WillingInfo:
"""此类保存意愿模块常用的参数 """此类保存意愿模块常用的参数
Attributes: Attributes:
message (MessageRecv): 原始消息对象 message (MessageRecv): 原始消息对象
chat (ChatStream): 聊天流对象 chat (ChatStream): 聊天流对象
@@ -53,6 +53,7 @@ class WillingInfo:
is_emoji (bool): 是否为表情包 is_emoji (bool): 是否为表情包
interested_rate (float): 兴趣度 interested_rate (float): 兴趣度
""" """
message: MessageRecv message: MessageRecv
chat: ChatStream chat: ChatStream
person_info_manager: PersonInfoManager person_info_manager: PersonInfoManager
@@ -60,22 +61,21 @@ class WillingInfo:
person_id: str person_id: str
group_info: Optional[GroupInfo] group_info: Optional[GroupInfo]
is_mentioned_bot: bool is_mentioned_bot: bool
is_emoji: bool is_emoji: bool
interested_rate: float interested_rate: float
# current_mood: float 当前心情? # current_mood: float 当前心情?
class BaseWillingManager(ABC): class BaseWillingManager(ABC):
"""回复意愿管理基类""" """回复意愿管理基类"""
@classmethod @classmethod
def create(cls, manager_type: str) -> 'BaseWillingManager': def create(cls, manager_type: str) -> "BaseWillingManager":
try: try:
module = importlib.import_module(f".mode_{manager_type}", __package__) module = importlib.import_module(f".mode_{manager_type}", __package__)
manager_class = getattr(module, f"{manager_type.capitalize()}WillingManager") manager_class = getattr(module, f"{manager_type.capitalize()}WillingManager")
if not issubclass(manager_class, cls): if not issubclass(manager_class, cls):
raise TypeError( raise TypeError(f"Manager class {manager_class.__name__} is not a subclass of {cls.__name__}")
f"Manager class {manager_class.__name__} is not a subclass of {cls.__name__}"
)
else: else:
logger.info(f"成功载入willing模式{manager_type}") logger.info(f"成功载入willing模式{manager_type}")
return manager_class() return manager_class()
@@ -85,7 +85,7 @@ class BaseWillingManager(ABC):
logger.info(f"载入当前意愿模式{manager_type}失败,使用经典配方~~~~") logger.info(f"载入当前意愿模式{manager_type}失败,使用经典配方~~~~")
logger.debug(f"加载willing模式{manager_type}失败,原因: {str(e)}") logger.debug(f"加载willing模式{manager_type}失败,原因: {str(e)}")
return manager_class() return manager_class()
def __init__(self): def __init__(self):
self.chat_reply_willing: Dict[str, float] = {} # 存储每个聊天流的回复意愿(chat_id) self.chat_reply_willing: Dict[str, float] = {} # 存储每个聊天流的回复意愿(chat_id)
self.ongoing_messages: Dict[str, WillingInfo] = {} # 当前正在进行的消息(message_id) self.ongoing_messages: Dict[str, WillingInfo] = {} # 当前正在进行的消息(message_id)
@@ -136,17 +136,17 @@ class BaseWillingManager(ABC):
async def get_reply_probability(self, message_id: str): async def get_reply_probability(self, message_id: str):
"""抽象方法:获取回复概率""" """抽象方法:获取回复概率"""
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
async def bombing_buffer_message_handle(self, message_id: str): async def bombing_buffer_message_handle(self, message_id: str):
"""抽象方法:炸飞消息处理""" """抽象方法:炸飞消息处理"""
pass pass
async def get_willing(self, chat_id: str): async def get_willing(self, chat_id: str):
"""获取指定聊天流的回复意愿""" """获取指定聊天流的回复意愿"""
async with self.lock: async with self.lock:
return self.chat_reply_willing.get(chat_id, 0) return self.chat_reply_willing.get(chat_id, 0)
async def set_willing(self, chat_id: str, willing: float): async def set_willing(self, chat_id: str, willing: float):
"""设置指定聊天流的回复意愿""" """设置指定聊天流的回复意愿"""
async with self.lock: async with self.lock:
@@ -173,5 +173,6 @@ def init_willing_manager() -> BaseWillingManager:
mode = global_config.willing_mode.lower() mode = global_config.willing_mode.lower()
return BaseWillingManager.create(mode) return BaseWillingManager.create(mode)
# 全局willing_manager对象 # 全局willing_manager对象
willing_manager = init_willing_manager() willing_manager = init_willing_manager()

View File

@@ -42,8 +42,8 @@ if errorlevel 2 (
echo Conda 环境 "!CONDA_ENV!" 激活成功 echo Conda 环境 "!CONDA_ENV!" 激活成功
python src/plugins/zhishi/knowledge_library.py python src/plugins/zhishi/knowledge_library.py
) else ( ) else (
if exist "venv\Scripts\python.exe" ( if exist "..\maibot_env\Scripts\python.exe" (
venv\Scripts\python src/plugins/zhishi/knowledge_library.py ..\maibot_env\Scripts\python src/plugins/zhishi/knowledge_library.py
) else ( ) else (
echo ====================================== echo ======================================
echo 错误: venv环境不存在请先创建虚拟环境 echo 错误: venv环境不存在请先创建虚拟环境

View File

@@ -42,8 +42,8 @@ if errorlevel 2 (
echo Conda 环境 "!CONDA_ENV!" 激活成功 echo Conda 环境 "!CONDA_ENV!" 激活成功
python src/individuality/per_bf_gen.py python src/individuality/per_bf_gen.py
) else ( ) else (
if exist "venv\Scripts\python.exe" ( if exist "..\maibot_env\Scripts\python.exe" (
venv\Scripts\python src/individuality/per_bf_gen.py ..\maibot_env\Scripts\python src/individuality/per_bf_gen.py
) else ( ) else (
echo ====================================== echo ======================================
echo 错误: venv环境不存在请先创建虚拟环境 echo 错误: venv环境不存在请先创建虚拟环境