This commit is contained in:
SengokuCola
2025-04-13 21:14:05 +08:00
44 changed files with 554 additions and 566 deletions

View File

@@ -1,5 +1,9 @@
name: Ruff name: Ruff
on: [ push, pull_request ] on: [ push, pull_request ]
permissions:
contents: write
jobs: jobs:
ruff: ruff:
runs-on: ubuntu-latest runs-on: ubuntu-latest
@@ -8,4 +12,12 @@ jobs:
- uses: astral-sh/ruff-action@v3 - uses: astral-sh/ruff-action@v3
- run: ruff check --fix - run: ruff check --fix
- run: ruff format - run: ruff format
- name: Commit changes
if: success()
run: |
git config --local user.email "github-actions[bot]@users.noreply.github.com"
git config --local user.name "github-actions[bot]"
git add -A
git diff --quiet && git diff --staged --quiet || git commit -m "🤖 自动格式化代码 [skip ci]"
git push

View File

@@ -283,17 +283,13 @@ WILLING_STYLE_CONFIG = {
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 意愿 | {message}"), "file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 意愿 | {message}"),
}, },
"simple": { "simple": {
"console_format": ( "console_format": ("<green>{time:MM-DD HH:mm}</green> | <light-blue>意愿</light-blue> | {message}"), # noqa: E501
"<green>{time:MM-DD HH:mm}</green> | <light-blue>意愿</light-blue> | {message}"
), # noqa: E501
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 意愿 | {message}"), "file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 意愿 | {message}"),
}, },
} }
CONFIRM_STYLE_CONFIG = { CONFIRM_STYLE_CONFIG = {
"console_format": ( "console_format": ("<RED>{message}</RED>"), # noqa: E501
"<RED>{message}</RED>"
), # noqa: E501
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | EULA与PRIVACY确认 | {message}"), "file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | EULA与PRIVACY确认 | {message}"),
} }

View File

@@ -4,17 +4,17 @@ from src.do_tool.tool_can_use.base_tool import (
discover_tools, discover_tools,
get_all_tool_definitions, get_all_tool_definitions,
get_tool_instance, get_tool_instance,
TOOL_REGISTRY TOOL_REGISTRY,
) )
__all__ = [ __all__ = [
'BaseTool', "BaseTool",
'register_tool', "register_tool",
'discover_tools', "discover_tools",
'get_all_tool_definitions', "get_all_tool_definitions",
'get_tool_instance', "get_tool_instance",
'TOOL_REGISTRY' "TOOL_REGISTRY",
] ]
# 自动发现并注册工具 # 自动发现并注册工具
discover_tools() discover_tools()

View File

@@ -10,41 +10,39 @@ logger = get_module_logger("base_tool")
# 工具注册表 # 工具注册表
TOOL_REGISTRY = {} TOOL_REGISTRY = {}
class BaseTool: class BaseTool:
"""所有工具的基类""" """所有工具的基类"""
# 工具名称,子类必须重写 # 工具名称,子类必须重写
name = None name = None
# 工具描述,子类必须重写 # 工具描述,子类必须重写
description = None description = None
# 工具参数定义,子类必须重写 # 工具参数定义,子类必须重写
parameters = None parameters = None
@classmethod @classmethod
def get_tool_definition(cls) -> Dict[str, Any]: def get_tool_definition(cls) -> Dict[str, Any]:
"""获取工具定义用于LLM工具调用 """获取工具定义用于LLM工具调用
Returns: Returns:
Dict: 工具定义字典 Dict: 工具定义字典
""" """
if not cls.name or not cls.description or not cls.parameters: if not cls.name or not cls.description or not cls.parameters:
raise NotImplementedError(f"工具类 {cls.__name__} 必须定义 name, description 和 parameters 属性") raise NotImplementedError(f"工具类 {cls.__name__} 必须定义 name, description 和 parameters 属性")
return { return {
"type": "function", "type": "function",
"function": { "function": {"name": cls.name, "description": cls.description, "parameters": cls.parameters},
"name": cls.name,
"description": cls.description,
"parameters": cls.parameters
}
} }
async def execute(self, function_args: Dict[str, Any], message_txt: str = "") -> Dict[str, Any]: async def execute(self, function_args: Dict[str, Any], message_txt: str = "") -> Dict[str, Any]:
"""执行工具函数 """执行工具函数
Args: Args:
function_args: 工具调用参数 function_args: 工具调用参数
message_txt: 原始消息文本 message_txt: 原始消息文本
Returns: Returns:
Dict: 工具执行结果 Dict: 工具执行结果
""" """
@@ -53,17 +51,17 @@ class BaseTool:
def register_tool(tool_class: Type[BaseTool]): def register_tool(tool_class: Type[BaseTool]):
"""注册工具到全局注册表 """注册工具到全局注册表
Args: Args:
tool_class: 工具类 tool_class: 工具类
""" """
if not issubclass(tool_class, BaseTool): if not issubclass(tool_class, BaseTool):
raise TypeError(f"{tool_class.__name__} 不是 BaseTool 的子类") raise TypeError(f"{tool_class.__name__} 不是 BaseTool 的子类")
tool_name = tool_class.name tool_name = tool_class.name
if not tool_name: if not tool_name:
raise ValueError(f"工具类 {tool_class.__name__} 没有定义 name 属性") raise ValueError(f"工具类 {tool_class.__name__} 没有定义 name 属性")
TOOL_REGISTRY[tool_name] = tool_class TOOL_REGISTRY[tool_name] = tool_class
logger.info(f"已注册工具: {tool_name}") logger.info(f"已注册工具: {tool_name}")
@@ -73,27 +71,27 @@ def discover_tools():
# 获取当前目录路径 # 获取当前目录路径
current_dir = os.path.dirname(os.path.abspath(__file__)) current_dir = os.path.dirname(os.path.abspath(__file__))
package_name = os.path.basename(current_dir) package_name = os.path.basename(current_dir)
# 遍历包中的所有模块 # 遍历包中的所有模块
for _, module_name, _ in pkgutil.iter_modules([current_dir]): for _, module_name, _ in pkgutil.iter_modules([current_dir]):
# 跳过当前模块和__pycache__ # 跳过当前模块和__pycache__
if module_name == "base_tool" or module_name.startswith("__"): if module_name == "base_tool" or module_name.startswith("__"):
continue continue
# 导入模块 # 导入模块
module = importlib.import_module(f"src.do_tool.{package_name}.{module_name}") module = importlib.import_module(f"src.do_tool.{package_name}.{module_name}")
# 查找模块中的工具类 # 查找模块中的工具类
for _, obj in inspect.getmembers(module): for _, obj in inspect.getmembers(module):
if inspect.isclass(obj) and issubclass(obj, BaseTool) and obj != BaseTool: if inspect.isclass(obj) and issubclass(obj, BaseTool) and obj != BaseTool:
register_tool(obj) register_tool(obj)
logger.info(f"工具发现完成,共注册 {len(TOOL_REGISTRY)} 个工具") logger.info(f"工具发现完成,共注册 {len(TOOL_REGISTRY)} 个工具")
def get_all_tool_definitions() -> List[Dict[str, Any]]: def get_all_tool_definitions() -> List[Dict[str, Any]]:
"""获取所有已注册工具的定义 """获取所有已注册工具的定义
Returns: Returns:
List[Dict]: 工具定义列表 List[Dict]: 工具定义列表
""" """
@@ -102,14 +100,14 @@ def get_all_tool_definitions() -> List[Dict[str, Any]]:
def get_tool_instance(tool_name: str) -> Optional[BaseTool]: def get_tool_instance(tool_name: str) -> Optional[BaseTool]:
"""获取指定名称的工具实例 """获取指定名称的工具实例
Args: Args:
tool_name: 工具名称 tool_name: 工具名称
Returns: Returns:
Optional[BaseTool]: 工具实例如果找不到则返回None Optional[BaseTool]: 工具实例如果找不到则返回None
""" """
tool_class = TOOL_REGISTRY.get(tool_name) tool_class = TOOL_REGISTRY.get(tool_name)
if not tool_class: if not tool_class:
return None return None
return tool_class() return tool_class()

View File

@@ -4,29 +4,25 @@ from typing import Dict, Any
logger = get_module_logger("fibonacci_sequence_tool") logger = get_module_logger("fibonacci_sequence_tool")
class FibonacciSequenceTool(BaseTool): class FibonacciSequenceTool(BaseTool):
"""生成斐波那契数列的工具""" """生成斐波那契数列的工具"""
name = "fibonacci_sequence" name = "fibonacci_sequence"
description = "生成指定长度的斐波那契数列" description = "生成指定长度的斐波那契数列"
parameters = { parameters = {
"type": "object", "type": "object",
"properties": { "properties": {"n": {"type": "integer", "description": "斐波那契数列的长度", "minimum": 1}},
"n": { "required": ["n"],
"type": "integer",
"description": "斐波那契数列的长度",
"minimum": 1
}
},
"required": ["n"]
} }
async def execute(self, function_args: Dict[str, Any], message_txt: str = "") -> Dict[str, Any]: async def execute(self, function_args: Dict[str, Any], message_txt: str = "") -> Dict[str, Any]:
"""执行工具功能 """执行工具功能
Args: Args:
function_args: 工具参数 function_args: 工具参数
message_txt: 原始消息文本 message_txt: 原始消息文本
Returns: Returns:
Dict: 工具执行结果 Dict: 工具执行结果
""" """
@@ -34,23 +30,18 @@ class FibonacciSequenceTool(BaseTool):
n = function_args.get("n") n = function_args.get("n")
if n <= 0: if n <= 0:
raise ValueError("参数n必须大于0") raise ValueError("参数n必须大于0")
sequence = [] sequence = []
a, b = 0, 1 a, b = 0, 1
for _ in range(n): for _ in range(n):
sequence.append(a) sequence.append(a)
a, b = b, a + b a, b = b, a + b
return { return {"name": self.name, "content": sequence}
"name": self.name,
"content": sequence
}
except Exception as e: except Exception as e:
logger.error(f"fibonacci_sequence工具执行失败: {str(e)}") logger.error(f"fibonacci_sequence工具执行失败: {str(e)}")
return { return {"name": self.name, "content": f"执行失败: {str(e)}"}
"name": self.name,
"content": f"执行失败: {str(e)}"
}
# 注册工具 # 注册工具
register_tool(FibonacciSequenceTool) register_tool(FibonacciSequenceTool)

View File

@@ -4,8 +4,10 @@ from typing import Dict, Any
logger = get_module_logger("generate_buddha_emoji_tool") logger = get_module_logger("generate_buddha_emoji_tool")
class GenerateBuddhaEmojiTool(BaseTool): class GenerateBuddhaEmojiTool(BaseTool):
"""生成佛祖颜文字的工具类""" """生成佛祖颜文字的工具类"""
name = "generate_buddha_emoji" name = "generate_buddha_emoji"
description = "生成一个佛祖的颜文字表情" description = "生成一个佛祖的颜文字表情"
parameters = { parameters = {
@@ -13,32 +15,27 @@ class GenerateBuddhaEmojiTool(BaseTool):
"properties": { "properties": {
# 无参数 # 无参数
}, },
"required": [] "required": [],
} }
async def execute(self, function_args: Dict[str, Any], message_txt: str = "") -> Dict[str, Any]: async def execute(self, function_args: Dict[str, Any], message_txt: str = "") -> Dict[str, Any]:
"""执行工具功能,生成佛祖颜文字 """执行工具功能,生成佛祖颜文字
Args: Args:
function_args: 工具参数 function_args: 工具参数
message_txt: 原始消息文本 message_txt: 原始消息文本
Returns: Returns:
Dict: 工具执行结果 Dict: 工具执行结果
""" """
try: try:
buddha_emoji = "这是一个佛祖emoji༼ つ ◕_◕ ༽つ" buddha_emoji = "这是一个佛祖emoji༼ つ ◕_◕ ༽つ"
return { return {"name": self.name, "content": buddha_emoji}
"name": self.name,
"content": buddha_emoji
}
except Exception as e: except Exception as e:
logger.error(f"generate_buddha_emoji工具执行失败: {str(e)}") logger.error(f"generate_buddha_emoji工具执行失败: {str(e)}")
return { return {"name": self.name, "content": f"执行失败: {str(e)}"}
"name": self.name,
"content": f"执行失败: {str(e)}"
}
# 注册工具 # 注册工具
register_tool(GenerateBuddhaEmojiTool) register_tool(GenerateBuddhaEmojiTool)

View File

@@ -4,23 +4,21 @@ from typing import Dict, Any
logger = get_module_logger("generate_cmd_tutorial_tool") logger = get_module_logger("generate_cmd_tutorial_tool")
class GenerateCmdTutorialTool(BaseTool): class GenerateCmdTutorialTool(BaseTool):
"""生成Windows CMD基本操作教程的工具""" """生成Windows CMD基本操作教程的工具"""
name = "generate_cmd_tutorial" name = "generate_cmd_tutorial"
description = "生成关于Windows命令提示符(CMD)的基本操作教程,包括常用命令和使用方法" description = "生成关于Windows命令提示符(CMD)的基本操作教程,包括常用命令和使用方法"
parameters = { parameters = {"type": "object", "properties": {}, "required": []}
"type": "object",
"properties": {},
"required": []
}
async def execute(self, function_args: Dict[str, Any], message_txt: str = "") -> Dict[str, Any]: async def execute(self, function_args: Dict[str, Any], message_txt: str = "") -> Dict[str, Any]:
"""执行工具功能 """执行工具功能
Args: Args:
function_args: 工具参数 function_args: 工具参数
message_txt: 原始消息文本 message_txt: 原始消息文本
Returns: Returns:
Dict: 工具执行结果 Dict: 工具执行结果
""" """
@@ -57,17 +55,12 @@ class GenerateCmdTutorialTool(BaseTool):
注意:使用命令时要小心,特别是删除操作。 注意:使用命令时要小心,特别是删除操作。
""" """
return { return {"name": self.name, "content": tutorial_content}
"name": self.name,
"content": tutorial_content
}
except Exception as e: except Exception as e:
logger.error(f"generate_cmd_tutorial工具执行失败: {str(e)}") logger.error(f"generate_cmd_tutorial工具执行失败: {str(e)}")
return { return {"name": self.name, "content": f"执行失败: {str(e)}"}
"name": self.name,
"content": f"执行失败: {str(e)}"
}
# 注册工具 # 注册工具
register_tool(GenerateCmdTutorialTool) register_tool(GenerateCmdTutorialTool)

View File

@@ -5,32 +5,28 @@ from typing import Dict, Any
logger = get_module_logger("get_current_task_tool") logger = get_module_logger("get_current_task_tool")
class GetCurrentTaskTool(BaseTool): class GetCurrentTaskTool(BaseTool):
"""获取当前正在做的事情/最近的任务工具""" """获取当前正在做的事情/最近的任务工具"""
name = "get_current_task" name = "get_current_task"
description = "获取当前正在做的事情/最近的任务" description = "获取当前正在做的事情/最近的任务"
parameters = { parameters = {
"type": "object", "type": "object",
"properties": { "properties": {
"num": { "num": {"type": "integer", "description": "要获取的任务数量"},
"type": "integer", "time_info": {"type": "boolean", "description": "是否包含时间信息"},
"description": "要获取的任务数量"
},
"time_info": {
"type": "boolean",
"description": "是否包含时间信息"
}
}, },
"required": [] "required": [],
} }
async def execute(self, function_args: Dict[str, Any], message_txt: str = "") -> Dict[str, Any]: async def execute(self, function_args: Dict[str, Any], message_txt: str = "") -> Dict[str, Any]:
"""执行获取当前任务 """执行获取当前任务
Args: Args:
function_args: 工具参数 function_args: 工具参数
message_txt: 原始消息文本,此工具不使用 message_txt: 原始消息文本,此工具不使用
Returns: Returns:
Dict: 工具执行结果 Dict: 工具执行结果
""" """
@@ -38,26 +34,21 @@ class GetCurrentTaskTool(BaseTool):
# 获取参数,如果没有提供则使用默认值 # 获取参数,如果没有提供则使用默认值
num = function_args.get("num", 1) num = function_args.get("num", 1)
time_info = function_args.get("time_info", False) time_info = function_args.get("time_info", False)
# 调用日程系统获取当前任务 # 调用日程系统获取当前任务
current_task = bot_schedule.get_current_num_task(num=num, time_info=time_info) current_task = bot_schedule.get_current_num_task(num=num, time_info=time_info)
# 格式化返回结果 # 格式化返回结果
if current_task: if current_task:
task_info = current_task task_info = current_task
else: else:
task_info = "当前没有正在进行的任务" task_info = "当前没有正在进行的任务"
return { return {"name": "get_current_task", "content": f"当前任务信息: {task_info}"}
"name": "get_current_task",
"content": f"当前任务信息: {task_info}"
}
except Exception as e: except Exception as e:
logger.error(f"获取当前任务工具执行失败: {str(e)}") logger.error(f"获取当前任务工具执行失败: {str(e)}")
return { return {"name": "get_current_task", "content": f"获取当前任务失败: {str(e)}"}
"name": "get_current_task",
"content": f"获取当前任务失败: {str(e)}"
}
# 注册工具 # 注册工具
register_tool(GetCurrentTaskTool) register_tool(GetCurrentTaskTool)

View File

@@ -6,39 +6,35 @@ from typing import Dict, Any, Union
logger = get_module_logger("get_knowledge_tool") logger = get_module_logger("get_knowledge_tool")
class SearchKnowledgeTool(BaseTool): class SearchKnowledgeTool(BaseTool):
"""从知识库中搜索相关信息的工具""" """从知识库中搜索相关信息的工具"""
name = "search_knowledge" name = "search_knowledge"
description = "从知识库中搜索相关信息" description = "从知识库中搜索相关信息"
parameters = { parameters = {
"type": "object", "type": "object",
"properties": { "properties": {
"query": { "query": {"type": "string", "description": "搜索查询关键词"},
"type": "string", "threshold": {"type": "number", "description": "相似度阈值0.0到1.0之间"},
"description": "搜索查询关键词"
},
"threshold": {
"type": "number",
"description": "相似度阈值0.0到1.0之间"
}
}, },
"required": ["query"] "required": ["query"],
} }
async def execute(self, function_args: Dict[str, Any], message_txt: str = "") -> Dict[str, Any]: async def execute(self, function_args: Dict[str, Any], message_txt: str = "") -> Dict[str, Any]:
"""执行知识库搜索 """执行知识库搜索
Args: Args:
function_args: 工具参数 function_args: 工具参数
message_txt: 原始消息文本 message_txt: 原始消息文本
Returns: Returns:
Dict: 工具执行结果 Dict: 工具执行结果
""" """
try: try:
query = function_args.get("query", message_txt) query = function_args.get("query", message_txt)
threshold = function_args.get("threshold", 0.4) threshold = function_args.get("threshold", 0.4)
# 调用知识库搜索 # 调用知识库搜索
embedding = await get_embedding(query, request_type="info_retrieval") embedding = await get_embedding(query, request_type="info_retrieval")
if embedding: if embedding:
@@ -47,38 +43,29 @@ class SearchKnowledgeTool(BaseTool):
content = f"你知道这些知识: {knowledge_info}" content = f"你知道这些知识: {knowledge_info}"
else: else:
content = f"你不太了解有关{query}的知识" content = f"你不太了解有关{query}的知识"
return { return {"name": "search_knowledge", "content": content}
"name": "search_knowledge", return {"name": "search_knowledge", "content": f"无法获取关于'{query}'的嵌入向量"}
"content": content
}
return {
"name": "search_knowledge",
"content": f"无法获取关于'{query}'的嵌入向量"
}
except Exception as e: except Exception as e:
logger.error(f"知识库搜索工具执行失败: {str(e)}") logger.error(f"知识库搜索工具执行失败: {str(e)}")
return { return {"name": "search_knowledge", "content": f"知识库搜索失败: {str(e)}"}
"name": "search_knowledge",
"content": f"知识库搜索失败: {str(e)}"
}
def get_info_from_db( def get_info_from_db(
self, query_embedding: list, limit: int = 1, threshold: float = 0.5, return_raw: bool = False self, query_embedding: list, limit: int = 1, threshold: float = 0.5, return_raw: bool = False
) -> Union[str, list]: ) -> Union[str, list]:
"""从数据库中获取相关信息 """从数据库中获取相关信息
Args: Args:
query_embedding: 查询的嵌入向量 query_embedding: 查询的嵌入向量
limit: 最大返回结果数 limit: 最大返回结果数
threshold: 相似度阈值 threshold: 相似度阈值
return_raw: 是否返回原始结果 return_raw: 是否返回原始结果
Returns: Returns:
Union[str, list]: 格式化的信息字符串或原始结果列表 Union[str, list]: 格式化的信息字符串或原始结果列表
""" """
if not query_embedding: if not query_embedding:
return "" if not return_raw else [] return "" if not return_raw else []
# 使用余弦相似度计算 # 使用余弦相似度计算
pipeline = [ pipeline = [
{ {
@@ -143,5 +130,6 @@ class SearchKnowledgeTool(BaseTool):
# 返回所有找到的内容,用换行分隔 # 返回所有找到的内容,用换行分隔
return "\n".join(str(result["content"]) for result in results) return "\n".join(str(result["content"]) for result in results)
# 注册工具 # 注册工具
register_tool(SearchKnowledgeTool) register_tool(SearchKnowledgeTool)

View File

@@ -5,68 +5,55 @@ from typing import Dict, Any
logger = get_module_logger("get_memory_tool") logger = get_module_logger("get_memory_tool")
class GetMemoryTool(BaseTool): class GetMemoryTool(BaseTool):
"""从记忆系统中获取相关记忆的工具""" """从记忆系统中获取相关记忆的工具"""
name = "get_memory" name = "get_memory"
description = "从记忆系统中获取相关记忆" description = "从记忆系统中获取相关记忆"
parameters = { parameters = {
"type": "object", "type": "object",
"properties": { "properties": {
"text": { "text": {"type": "string", "description": "要查询的相关文本"},
"type": "string", "max_memory_num": {"type": "integer", "description": "最大返回记忆数量"},
"description": "要查询的相关文本"
},
"max_memory_num": {
"type": "integer",
"description": "最大返回记忆数量"
}
}, },
"required": ["text"] "required": ["text"],
} }
async def execute(self, function_args: Dict[str, Any], message_txt: str = "") -> Dict[str, Any]: async def execute(self, function_args: Dict[str, Any], message_txt: str = "") -> Dict[str, Any]:
"""执行记忆获取 """执行记忆获取
Args: Args:
function_args: 工具参数 function_args: 工具参数
message_txt: 原始消息文本 message_txt: 原始消息文本
Returns: Returns:
Dict: 工具执行结果 Dict: 工具执行结果
""" """
try: try:
text = function_args.get("text", message_txt) text = function_args.get("text", message_txt)
max_memory_num = function_args.get("max_memory_num", 2) max_memory_num = function_args.get("max_memory_num", 2)
# 调用记忆系统 # 调用记忆系统
related_memory = await HippocampusManager.get_instance().get_memory_from_text( related_memory = await HippocampusManager.get_instance().get_memory_from_text(
text=text, text=text, max_memory_num=max_memory_num, max_memory_length=2, max_depth=3, fast_retrieval=False
max_memory_num=max_memory_num,
max_memory_length=2,
max_depth=3,
fast_retrieval=False
) )
memory_info = "" memory_info = ""
if related_memory: if related_memory:
for memory in related_memory: for memory in related_memory:
memory_info += memory[1] + "\n" memory_info += memory[1] + "\n"
if memory_info: if memory_info:
content = f"你记得这些事情: {memory_info}" content = f"你记得这些事情: {memory_info}"
else: else:
content = f"你不太记得有关{text}的记忆,你对此不太了解" content = f"你不太记得有关{text}的记忆,你对此不太了解"
return { return {"name": "get_memory", "content": content}
"name": "get_memory",
"content": content
}
except Exception as e: except Exception as e:
logger.error(f"记忆获取工具执行失败: {str(e)}") logger.error(f"记忆获取工具执行失败: {str(e)}")
return { return {"name": "get_memory", "content": f"记忆获取失败: {str(e)}"}
"name": "get_memory",
"content": f"记忆获取失败: {str(e)}"
}
# 注册工具 # 注册工具
register_tool(GetMemoryTool) register_tool(GetMemoryTool)

View File

@@ -16,21 +16,19 @@ class ToolUser:
model=global_config.llm_heartflow, temperature=0.2, max_tokens=1000, request_type="tool_use" model=global_config.llm_heartflow, temperature=0.2, max_tokens=1000, request_type="tool_use"
) )
async def _build_tool_prompt(self, message_txt:str, sender_name:str, chat_stream:ChatStream): async def _build_tool_prompt(self, message_txt: str, sender_name: str, chat_stream: ChatStream):
"""构建工具使用的提示词 """构建工具使用的提示词
Args: Args:
message_txt: 用户消息文本 message_txt: 用户消息文本
sender_name: 发送者名称 sender_name: 发送者名称
chat_stream: 聊天流对象 chat_stream: 聊天流对象
Returns: Returns:
str: 构建好的提示词 str: 构建好的提示词
""" """
new_messages = list( new_messages = list(
db.messages.find({"chat_id": chat_stream.stream_id, "time": {"$gt": time.time()}}) db.messages.find({"chat_id": chat_stream.stream_id, "time": {"$gt": time.time()}}).sort("time", 1).limit(15)
.sort("time", 1)
.limit(15)
) )
new_messages_str = "" new_messages_str = ""
for msg in new_messages: for msg in new_messages:
@@ -43,38 +41,38 @@ class ToolUser:
prompt += "你正在思考如何回复群里的消息。\n" prompt += "你正在思考如何回复群里的消息。\n"
prompt += f"你注意到{sender_name}刚刚说:{message_txt}\n" prompt += f"你注意到{sender_name}刚刚说:{message_txt}\n"
prompt += f"注意你就是{bot_name}{bot_name}指的就是你。" prompt += f"注意你就是{bot_name}{bot_name}指的就是你。"
prompt += "你现在需要对群里的聊天内容进行回复,现在请你思考,你是否需要额外的信息,或者一些工具来帮你回复,不要使用危险功能(比如文件操作或者系统操作爬虫),比如回忆或者搜寻已有的知识,或者了解你现在正在做什么,请输出你需要的工具,或者你需要的额外信息。" prompt += "你现在需要对群里的聊天内容进行回复,现在请你思考,你是否需要额外的信息,或者一些工具来帮你回复,不要使用危险功能(比如文件操作或者系统操作爬虫),比如回忆或者搜寻已有的知识,或者了解你现在正在做什么,请输出你需要的工具,或者你需要的额外信息。"
return prompt return prompt
def _define_tools(self): def _define_tools(self):
"""获取所有已注册工具的定义 """获取所有已注册工具的定义
Returns: Returns:
list: 工具定义列表 list: 工具定义列表
""" """
return get_all_tool_definitions() return get_all_tool_definitions()
async def _execute_tool_call(self, tool_call, message_txt:str): async def _execute_tool_call(self, tool_call, message_txt: str):
"""执行特定的工具调用 """执行特定的工具调用
Args: Args:
tool_call: 工具调用对象 tool_call: 工具调用对象
message_txt: 原始消息文本 message_txt: 原始消息文本
Returns: Returns:
dict: 工具调用结果 dict: 工具调用结果
""" """
try: try:
function_name = tool_call["function"]["name"] function_name = tool_call["function"]["name"]
function_args = json.loads(tool_call["function"]["arguments"]) function_args = json.loads(tool_call["function"]["arguments"])
# 获取对应工具实例 # 获取对应工具实例
tool_instance = get_tool_instance(function_name) tool_instance = get_tool_instance(function_name)
if not tool_instance: if not tool_instance:
logger.warning(f"未知工具名称: {function_name}") logger.warning(f"未知工具名称: {function_name}")
return None return None
# 执行工具 # 执行工具
result = await tool_instance.execute(function_args, message_txt) result = await tool_instance.execute(function_args, message_txt)
if result: if result:
@@ -82,31 +80,30 @@ class ToolUser:
"tool_call_id": tool_call["id"], "tool_call_id": tool_call["id"],
"role": "tool", "role": "tool",
"name": function_name, "name": function_name,
"content": result["content"] "content": result["content"],
} }
return None return None
except Exception as e: except Exception as e:
logger.error(f"执行工具调用时发生错误: {str(e)}") logger.error(f"执行工具调用时发生错误: {str(e)}")
return None return None
async def use_tool(self, message_txt:str, sender_name:str, chat_stream:ChatStream): async def use_tool(self, message_txt: str, sender_name: str, chat_stream: ChatStream):
"""使用工具辅助思考,判断是否需要额外信息 """使用工具辅助思考,判断是否需要额外信息
Args: Args:
message_txt: 用户消息文本 message_txt: 用户消息文本
sender_name: 发送者名称 sender_name: 发送者名称
chat_stream: 聊天流对象 chat_stream: 聊天流对象
Returns: Returns:
dict: 工具使用结果 dict: 工具使用结果
""" """
try: try:
# 构建提示词 # 构建提示词
prompt = await self._build_tool_prompt(message_txt, sender_name, chat_stream) prompt = await self._build_tool_prompt(message_txt, sender_name, chat_stream)
# 定义可用工具 # 定义可用工具
tools = self._define_tools() tools = self._define_tools()
print(tools)
# 使用llm_model_tool发送带工具定义的请求 # 使用llm_model_tool发送带工具定义的请求
payload = { payload = {
@@ -114,31 +111,29 @@ class ToolUser:
"messages": [{"role": "user", "content": prompt}], "messages": [{"role": "user", "content": prompt}],
"max_tokens": global_config.max_response_length, "max_tokens": global_config.max_response_length,
"tools": tools, "tools": tools,
"temperature": 0.2 "temperature": 0.2,
} }
logger.debug(f"发送工具调用请求,模型: {self.llm_model_tool.model_name}") logger.debug(f"发送工具调用请求,模型: {self.llm_model_tool.model_name}")
# 发送请求获取模型是否需要调用工具 # 发送请求获取模型是否需要调用工具
response = await self.llm_model_tool._execute_request( response = await self.llm_model_tool._execute_request(
endpoint="/chat/completions", endpoint="/chat/completions", payload=payload, prompt=prompt
payload=payload,
prompt=prompt
) )
# 根据返回值数量判断是否有工具调用 # 根据返回值数量判断是否有工具调用
if len(response) == 3: if len(response) == 3:
content, reasoning_content, tool_calls = response content, reasoning_content, tool_calls = response
logger.info(f"工具思考: {tool_calls}") logger.info(f"工具思考: {tool_calls}")
# 检查响应中工具调用是否有效 # 检查响应中工具调用是否有效
if not tool_calls: if not tool_calls:
logger.info("模型返回了空的tool_calls列表") logger.info("模型返回了空的tool_calls列表")
return {"used_tools": False} return {"used_tools": False}
logger.info(f"模型请求调用{len(tool_calls)}个工具") logger.info(f"模型请求调用{len(tool_calls)}个工具")
tool_results = [] tool_results = []
collected_info = "" collected_info = ""
# 执行所有工具调用 # 执行所有工具调用
for tool_call in tool_calls: for tool_call in tool_calls:
result = await self._execute_tool_call(tool_call, message_txt) result = await self._execute_tool_call(tool_call, message_txt)
@@ -146,7 +141,7 @@ class ToolUser:
tool_results.append(result) tool_results.append(result)
# 将工具结果添加到收集的信息中 # 将工具结果添加到收集的信息中
collected_info += f"\n{result['name']}返回结果: {result['content']}\n" collected_info += f"\n{result['name']}返回结果: {result['content']}\n"
# 如果有工具结果,直接返回收集的信息 # 如果有工具结果,直接返回收集的信息
if collected_info: if collected_info:
logger.info(f"工具调用收集到信息: {collected_info}") logger.info(f"工具调用收集到信息: {collected_info}")
@@ -158,15 +153,15 @@ class ToolUser:
# 没有工具调用 # 没有工具调用
content, reasoning_content = response content, reasoning_content = response
logger.info("模型没有请求调用任何工具") logger.info("模型没有请求调用任何工具")
# 如果没有工具调用或处理失败,直接返回原始思考 # 如果没有工具调用或处理失败,直接返回原始思考
return { return {
"used_tools": False, "used_tools": False,
} }
except Exception as e: except Exception as e:
logger.error(f"工具调用过程中出错: {str(e)}") logger.error(f"工具调用过程中出错: {str(e)}")
return { return {
"used_tools": False, "used_tools": False,
"error": str(e), "error": str(e),
} }

View File

@@ -43,12 +43,11 @@ def init_prompt():
class CurrentState: class CurrentState:
def __init__(self): def __init__(self):
self.current_state_info = "" self.current_state_info = ""
self.mood_manager = MoodManager() self.mood_manager = MoodManager()
self.mood = self.mood_manager.get_prompt() self.mood = self.mood_manager.get_prompt()
self.attendance_factor = 0 self.attendance_factor = 0
self.engagement_factor = 0 self.engagement_factor = 0
@@ -66,9 +65,6 @@ class Heartflow:
) )
self._subheartflows: Dict[Any, SubHeartflow] = {} self._subheartflows: Dict[Any, SubHeartflow] = {}
async def _cleanup_inactive_subheartflows(self): async def _cleanup_inactive_subheartflows(self):
"""定期清理不活跃的子心流""" """定期清理不活跃的子心流"""
@@ -90,7 +86,7 @@ class Heartflow:
logger.info(f"已清理不活跃的子心流: {subheartflow_id}") logger.info(f"已清理不活跃的子心流: {subheartflow_id}")
await asyncio.sleep(30) # 每分钟检查一次 await asyncio.sleep(30) # 每分钟检查一次
async def _sub_heartflow_update(self): async def _sub_heartflow_update(self):
while True: while True:
# 检查是否存在子心流 # 检查是否存在子心流
@@ -103,13 +99,12 @@ class Heartflow:
await asyncio.sleep(global_config.heart_flow_update_interval) # 5分钟思考一次 await asyncio.sleep(global_config.heart_flow_update_interval) # 5分钟思考一次
async def heartflow_start_working(self): async def heartflow_start_working(self):
# 启动清理任务 # 启动清理任务
asyncio.create_task(self._cleanup_inactive_subheartflows()) asyncio.create_task(self._cleanup_inactive_subheartflows())
# 启动子心流更新任务 # 启动子心流更新任务
asyncio.create_task(self._sub_heartflow_update()) asyncio.create_task(self._sub_heartflow_update())
async def _update_current_state(self): async def _update_current_state(self):
print("TODO") print("TODO")

View File

@@ -150,7 +150,7 @@ class ChattingObservation(Observation):
except Exception as e: except Exception as e:
print(f"获取总结失败: {e}") print(f"获取总结失败: {e}")
updated_observe_info = "" updated_observe_info = ""
return updated_observe_info return updated_observe_info
# print(f"prompt{prompt}") # print(f"prompt{prompt}")
# print(f"self.observe_info{self.observe_info}") # print(f"self.observe_info{self.observe_info}")

View File

@@ -5,9 +5,11 @@ from src.plugins.models.utils_model import LLM_request
from src.plugins.config.config import global_config from src.plugins.config.config import global_config
import re import re
import time import time
# from src.plugins.schedule.schedule_generator import bot_schedule # from src.plugins.schedule.schedule_generator import bot_schedule
# from src.plugins.memory_system.Hippocampus import HippocampusManager # from src.plugins.memory_system.Hippocampus import HippocampusManager
from src.common.logger import get_module_logger, LogConfig, SUB_HEARTFLOW_STYLE_CONFIG # noqa: E402 from src.common.logger import get_module_logger, LogConfig, SUB_HEARTFLOW_STYLE_CONFIG # noqa: E402
# from src.plugins.chat.utils import get_embedding # from src.plugins.chat.utils import get_embedding
# from src.common.database import db # from src.common.database import db
# from typing import Union # from typing import Union
@@ -17,7 +19,7 @@ from src.plugins.chat.chat_stream import ChatStream
from src.plugins.person_info.relationship_manager import relationship_manager from src.plugins.person_info.relationship_manager import relationship_manager
from src.plugins.chat.utils import get_recent_group_speaker from src.plugins.chat.utils import get_recent_group_speaker
from src.do_tool.tool_use import ToolUser from src.do_tool.tool_use import ToolUser
from ..plugins.utils.prompt_builder import Prompt,global_prompt_manager from ..plugins.utils.prompt_builder import Prompt, global_prompt_manager
subheartflow_config = LogConfig( subheartflow_config = LogConfig(
# 使用海马体专用样式 # 使用海马体专用样式
@@ -26,6 +28,7 @@ subheartflow_config = LogConfig(
) )
logger = get_module_logger("subheartflow", config=subheartflow_config) logger = get_module_logger("subheartflow", config=subheartflow_config)
def init_prompt(): def init_prompt():
prompt = "" prompt = ""
# prompt += f"麦麦的总体想法是:{self.main_heartflow_info}\n\n" # prompt += f"麦麦的总体想法是:{self.main_heartflow_info}\n\n"
@@ -41,7 +44,7 @@ def init_prompt():
prompt += "思考时可以想想如何对群聊内容进行回复。回复的要求是:平淡一些,简短一些,说中文,尽量不要说你说过的话\n" prompt += "思考时可以想想如何对群聊内容进行回复。回复的要求是:平淡一些,简短一些,说中文,尽量不要说你说过的话\n"
prompt += "请注意不要输出多余内容(包括前后缀,冒号和引号,括号, 表情,等),不要带有括号和动作描写" prompt += "请注意不要输出多余内容(包括前后缀,冒号和引号,括号, 表情,等),不要带有括号和动作描写"
prompt += "记得结合上述的消息,生成内心想法,文字不要浮夸,注意你就是{bot_name}{bot_name}指的就是你。" prompt += "记得结合上述的消息,生成内心想法,文字不要浮夸,注意你就是{bot_name}{bot_name}指的就是你。"
Prompt(prompt,"sub_heartflow_prompt_before") Prompt(prompt, "sub_heartflow_prompt_before")
prompt = "" prompt = ""
# prompt += f"你现在正在做的事情是:{schedule_info}\n" # prompt += f"你现在正在做的事情是:{schedule_info}\n"
prompt += "{prompt_personality}\n" prompt += "{prompt_personality}\n"
@@ -52,8 +55,7 @@ def init_prompt():
prompt += "你现在{mood_info}" prompt += "你现在{mood_info}"
prompt += "现在你接下去继续思考,产生新的想法,记得保留你刚刚的想法,不要分点输出,输出连贯的内心独白" prompt += "现在你接下去继续思考,产生新的想法,记得保留你刚刚的想法,不要分点输出,输出连贯的内心独白"
prompt += "不要太长,但是记得结合上述的消息,要记得你的人设,关注聊天和新内容,关注你回复的内容,不要思考太多:" prompt += "不要太长,但是记得结合上述的消息,要记得你的人设,关注聊天和新内容,关注你回复的内容,不要思考太多:"
Prompt(prompt,'sub_heartflow_prompt_after') Prompt(prompt, "sub_heartflow_prompt_after")
class CurrentState: class CurrentState:
@@ -78,7 +80,6 @@ class SubHeartflow:
self.llm_model = LLM_request( self.llm_model = LLM_request(
model=global_config.llm_sub_heartflow, temperature=0.2, max_tokens=600, request_type="sub_heart_flow" model=global_config.llm_sub_heartflow, temperature=0.2, max_tokens=600, request_type="sub_heart_flow"
) )
self.main_heartflow_info = "" self.main_heartflow_info = ""
@@ -93,9 +94,9 @@ class SubHeartflow:
self.observations: list[Observation] = [] self.observations: list[Observation] = []
self.running_knowledges = [] self.running_knowledges = []
self.bot_name = global_config.BOT_NICKNAME self.bot_name = global_config.BOT_NICKNAME
self.tool_user = ToolUser() self.tool_user = ToolUser()
def add_observation(self, observation: Observation): def add_observation(self, observation: Observation):
@@ -145,12 +146,12 @@ class SubHeartflow:
): # 5分钟无回复/不在场,销毁 ): # 5分钟无回复/不在场,销毁
logger.info(f"子心流 {self.subheartflow_id} 已经5分钟没有激活正在销毁...") logger.info(f"子心流 {self.subheartflow_id} 已经5分钟没有激活正在销毁...")
break # 退出循环,销毁自己 break # 退出循环,销毁自己
async def do_observe(self): async def do_observe(self):
observation = self.observations[0] observation = self.observations[0]
await observation.observe() await observation.observe()
async def do_thinking_before_reply(self, message_txt: str, sender_name: str, chat_stream: ChatStream):
async def do_thinking_before_reply(self, message_txt:str, sender_name:str, chat_stream:ChatStream):
current_thinking_info = self.current_mind current_thinking_info = self.current_mind
mood_info = self.current_state.mood mood_info = self.current_state.mood
# mood_info = "你很生气,很愤怒" # mood_info = "你很生气,很愤怒"
@@ -160,12 +161,12 @@ class SubHeartflow:
# 首先尝试使用工具获取更多信息 # 首先尝试使用工具获取更多信息
tool_result = await self.tool_user.use_tool(message_txt, sender_name, chat_stream) tool_result = await self.tool_user.use_tool(message_txt, sender_name, chat_stream)
# 如果工具被使用且获得了结果,将收集到的信息合并到思考中 # 如果工具被使用且获得了结果,将收集到的信息合并到思考中
collected_info = "" collected_info = ""
if tool_result.get("used_tools", False): if tool_result.get("used_tools", False):
logger.info("使用工具收集了信息") logger.info("使用工具收集了信息")
# 如果有收集到的信息,将其添加到当前思考中 # 如果有收集到的信息,将其添加到当前思考中
if "collected_info" in tool_result: if "collected_info" in tool_result:
collected_info = tool_result["collected_info"] collected_info = tool_result["collected_info"]
@@ -185,7 +186,7 @@ class SubHeartflow:
identity_detail = individuality.identity.identity_detail identity_detail = individuality.identity.identity_detail
random.shuffle(identity_detail) random.shuffle(identity_detail)
prompt_personality += f",{identity_detail[0]}" prompt_personality += f",{identity_detail[0]}"
# 关系 # 关系
who_chat_in_group = [ who_chat_in_group = [
(chat_stream.user_info.platform, chat_stream.user_info.user_id, chat_stream.user_info.user_nickname) (chat_stream.user_info.platform, chat_stream.user_info.user_id, chat_stream.user_info.user_nickname)
@@ -204,9 +205,9 @@ class SubHeartflow:
# f"{relation_prompt}关系等级越大,关系越好,请分析聊天记录," # f"{relation_prompt}关系等级越大,关系越好,请分析聊天记录,"
# f"根据你和说话者{sender_name}的关系和态度进行回复,明确你的立场和情感。" # f"根据你和说话者{sender_name}的关系和态度进行回复,明确你的立场和情感。"
# ) # )
relation_prompt_all = (await global_prompt_manager.get_prompt_async('relationship_prompt')).format( relation_prompt_all = (await global_prompt_manager.get_prompt_async("relationship_prompt")).format(
relation_prompt,sender_name relation_prompt, sender_name
) )
# prompt = "" # prompt = ""
# # prompt += f"麦麦的总体想法是:{self.main_heartflow_info}\n\n" # # prompt += f"麦麦的总体想法是:{self.main_heartflow_info}\n\n"
@@ -224,9 +225,16 @@ class SubHeartflow:
# prompt += "请注意不要输出多余内容(包括前后缀,冒号和引号,括号, 表情,等),不要带有括号和动作描写" # prompt += "请注意不要输出多余内容(包括前后缀,冒号和引号,括号, 表情,等),不要带有括号和动作描写"
# prompt += f"记得结合上述的消息,生成内心想法,文字不要浮夸,注意你就是{self.bot_name}{self.bot_name}指的就是你。" # prompt += f"记得结合上述的消息,生成内心想法,文字不要浮夸,注意你就是{self.bot_name}{self.bot_name}指的就是你。"
prompt= (await global_prompt_manager.get_prompt_async("sub_heartflow_prompt_before")).format( prompt = (await global_prompt_manager.get_prompt_async("sub_heartflow_prompt_before")).format(
collected_info,relation_prompt_all,prompt_personality,current_thinking_info,chat_observe_info,mood_info,sender_name, collected_info,
message_txt,self.bot_name relation_prompt_all,
prompt_personality,
current_thinking_info,
chat_observe_info,
mood_info,
sender_name,
message_txt,
self.bot_name,
) )
try: try:
@@ -281,10 +289,10 @@ class SubHeartflow:
# prompt += f"你现在{mood_info}" # prompt += f"你现在{mood_info}"
# prompt += "现在你接下去继续思考,产生新的想法,记得保留你刚刚的想法,不要分点输出,输出连贯的内心独白" # prompt += "现在你接下去继续思考,产生新的想法,记得保留你刚刚的想法,不要分点输出,输出连贯的内心独白"
# prompt += "不要太长,但是记得结合上述的消息,要记得你的人设,关注聊天和新内容,关注你回复的内容,不要思考太多:" # prompt += "不要太长,但是记得结合上述的消息,要记得你的人设,关注聊天和新内容,关注你回复的内容,不要思考太多:"
prompt=(await global_prompt_manager.get_prompt_async('sub_heartflow_prompt_after')).format( prompt = (await global_prompt_manager.get_prompt_async("sub_heartflow_prompt_after")).format(
prompt_personality,chat_observe_info,current_thinking_info,message_new_info,reply_info,mood_info prompt_personality, chat_observe_info, current_thinking_info, message_new_info, reply_info, mood_info
) )
try: try:
response, reasoning_content = await self.llm_model.generate_response_async(prompt) response, reasoning_content = await self.llm_model.generate_response_async(prompt)
except Exception as e: except Exception as e:
@@ -343,5 +351,6 @@ class SubHeartflow:
self.past_mind.append(self.current_mind) self.past_mind.append(self.current_mind)
self.current_mind = response self.current_mind = response
init_prompt() init_prompt()
# subheartflow = SubHeartflow() # subheartflow = SubHeartflow()

View File

@@ -53,13 +53,13 @@ class ActionPlanner:
goal = goal_reason[0] goal = goal_reason[0]
reasoning = goal_reason[1] if len(goal_reason) > 1 else "没有明确原因" reasoning = goal_reason[1] if len(goal_reason) > 1 else "没有明确原因"
elif isinstance(goal_reason, dict): elif isinstance(goal_reason, dict):
goal = goal_reason.get('goal') goal = goal_reason.get("goal")
reasoning = goal_reason.get('reasoning', "没有明确原因") reasoning = goal_reason.get("reasoning", "没有明确原因")
else: else:
# 如果是其他类型,尝试转为字符串 # 如果是其他类型,尝试转为字符串
goal = str(goal_reason) goal = str(goal_reason)
reasoning = "没有明确原因" reasoning = "没有明确原因"
goal_str = f"目标:{goal},产生该对话目标的原因:{reasoning}\n" goal_str = f"目标:{goal},产生该对话目标的原因:{reasoning}\n"
goals_str += goal_str goals_str += goal_str
else: else:
@@ -68,7 +68,11 @@ class ActionPlanner:
goals_str = f"目标:{goal},产生该对话目标的原因:{reasoning}\n" goals_str = f"目标:{goal},产生该对话目标的原因:{reasoning}\n"
# 获取聊天历史记录 # 获取聊天历史记录
chat_history_list = observation_info.chat_history[-20:] if len(observation_info.chat_history) >= 20 else observation_info.chat_history chat_history_list = (
observation_info.chat_history[-20:]
if len(observation_info.chat_history) >= 20
else observation_info.chat_history
)
chat_history_text = "" chat_history_text = ""
for msg in chat_history_list: for msg in chat_history_list:
chat_history_text += f"{msg.get('detailed_plain_text', '')}\n" chat_history_text += f"{msg.get('detailed_plain_text', '')}\n"
@@ -85,15 +89,21 @@ class ActionPlanner:
personality_text = f"你的名字是{self.name}{self.personality_info}" personality_text = f"你的名字是{self.name}{self.personality_info}"
# 构建action历史文本 # 构建action历史文本
action_history_list = conversation_info.done_action[-10:] if len(conversation_info.done_action) >= 10 else conversation_info.done_action action_history_list = (
conversation_info.done_action[-10:]
if len(conversation_info.done_action) >= 10
else conversation_info.done_action
)
action_history_text = "你之前做的事情是:" action_history_text = "你之前做的事情是:"
for action in action_history_list: for action in action_history_list:
if isinstance(action, dict): if isinstance(action, dict):
action_type = action.get('action') action_type = action.get("action")
action_reason = action.get('reason') action_reason = action.get("reason")
action_status = action.get('status') action_status = action.get("status")
if action_status == "recall": if action_status == "recall":
action_history_text += f"原本打算:{action_type},但是因为有新消息,你发现这个行动不合适,所以你没做\n" action_history_text += (
f"原本打算:{action_type},但是因为有新消息,你发现这个行动不合适,所以你没做\n"
)
elif action_status == "done": elif action_status == "done":
action_history_text += f"你之前做了:{action_type},原因:{action_reason}\n" action_history_text += f"你之前做了:{action_type},原因:{action_reason}\n"
elif isinstance(action, tuple): elif isinstance(action, tuple):
@@ -102,7 +112,9 @@ class ActionPlanner:
action_reason = action[1] if len(action) > 1 else "未知原因" action_reason = action[1] if len(action) > 1 else "未知原因"
action_status = action[2] if len(action) > 2 else "done" action_status = action[2] if len(action) > 2 else "done"
if action_status == "recall": if action_status == "recall":
action_history_text += f"原本打算:{action_type},但是因为有新消息,你发现这个行动不合适,所以你没做\n" action_history_text += (
f"原本打算:{action_type},但是因为有新消息,你发现这个行动不合适,所以你没做\n"
)
elif action_status == "done": elif action_status == "done":
action_history_text += f"你之前做了:{action_type},原因:{action_reason}\n" action_history_text += f"你之前做了:{action_type},原因:{action_reason}\n"
@@ -147,7 +159,14 @@ end_conversation: 结束对话,长时间没回复或者当你觉得谈话暂
reason = result["reason"] reason = result["reason"]
# 验证action类型 # 验证action类型
if action not in ["direct_reply", "fetch_knowledge", "wait", "listening", "rethink_goal", "end_conversation"]: if action not in [
"direct_reply",
"fetch_knowledge",
"wait",
"listening",
"rethink_goal",
"end_conversation",
]:
logger.warning(f"未知的行动类型: {action}默认使用listening") logger.warning(f"未知的行动类型: {action}默认使用listening")
action = "listening" action = "listening"

View File

@@ -1,12 +1,12 @@
import time import time
import asyncio import asyncio
import traceback import traceback
from typing import Optional, Dict, Any, List from typing import Optional, Dict, Any, List
from src.common.logger import get_module_logger from src.common.logger import get_module_logger
from ..message.message_base import UserInfo from ..message.message_base import UserInfo
from ..config.config import global_config from ..config.config import global_config
from .chat_states import NotificationManager, create_new_message_notification, create_cold_chat_notification from .chat_states import NotificationManager, create_new_message_notification, create_cold_chat_notification
from .message_storage import MongoDBMessageStorage from .message_storage import MongoDBMessageStorage
logger = get_module_logger("chat_observer") logger = get_module_logger("chat_observer")
@@ -51,7 +51,6 @@ class ChatObserver:
self.waiting_start_time: float = time.time() # 等待开始时间,初始化为当前时间 self.waiting_start_time: float = time.time() # 等待开始时间,初始化为当前时间
# 运行状态 # 运行状态
self._running: bool = False self._running: bool = False
self._task: Optional[asyncio.Task] = None self._task: Optional[asyncio.Task] = None
@@ -94,10 +93,11 @@ class ChatObserver:
message: 消息数据 message: 消息数据
""" """
try: try:
# 发送新消息通知 # 发送新消息通知
# logger.info(f"发送新ccchandleer消息通知: {message}") # logger.info(f"发送新ccchandleer消息通知: {message}")
notification = create_new_message_notification(sender="chat_observer", target="observation_info", message=message) notification = create_new_message_notification(
sender="chat_observer", target="observation_info", message=message
)
# logger.info(f"发送新消ddddd息通知: {notification}") # logger.info(f"发送新消ddddd息通知: {notification}")
# print(self.notification_manager) # print(self.notification_manager)
await self.notification_manager.send_notification(notification) await self.notification_manager.send_notification(notification)
@@ -131,7 +131,6 @@ class ChatObserver:
notification = create_cold_chat_notification(sender="chat_observer", target="pfc", is_cold=is_cold) notification = create_cold_chat_notification(sender="chat_observer", target="pfc", is_cold=is_cold)
await self.notification_manager.send_notification(notification) await self.notification_manager.send_notification(notification)
def new_message_after(self, time_point: float) -> bool: def new_message_after(self, time_point: float) -> bool:
"""判断是否在指定时间点后有新消息 """判断是否在指定时间点后有新消息
@@ -197,7 +196,7 @@ class ChatObserver:
if new_messages: if new_messages:
self.last_message_read = new_messages[-1] self.last_message_read = new_messages[-1]
self.last_message_time = new_messages[-1]["time"] self.last_message_time = new_messages[-1]["time"]
# print(f"获取数据库中找到的新消息: {new_messages}") # print(f"获取数据库中找到的新消息: {new_messages}")
return new_messages return new_messages
@@ -215,7 +214,7 @@ class ChatObserver:
if new_messages: if new_messages:
self.last_message_read = new_messages[-1]["message_id"] self.last_message_read = new_messages[-1]["message_id"]
logger.debug(f"获取指定时间点111之前的消息: {new_messages}") logger.debug(f"获取指定时间点111之前的消息: {new_messages}")
return new_messages return new_messages
@@ -239,7 +238,7 @@ class ChatObserver:
try: try:
# print("等待事件") # print("等待事件")
await asyncio.wait_for(self._update_event.wait(), timeout=1) await asyncio.wait_for(self._update_event.wait(), timeout=1)
except asyncio.TimeoutError: except asyncio.TimeoutError:
# print("超时") # print("超时")
pass # 超时后也执行一次检查 pass # 超时后也执行一次检查
@@ -347,7 +346,6 @@ class ChatObserver:
return time_info return time_info
def get_cached_messages(self, limit: int = 50) -> List[Dict[str, Any]]: def get_cached_messages(self, limit: int = 50) -> List[Dict[str, Any]]:
"""获取缓存的消息历史 """获取缓存的消息历史
@@ -368,6 +366,6 @@ class ChatObserver:
if not self.message_cache: if not self.message_cache:
return None return None
return self.message_cache[0] return self.message_cache[0]
def __str__(self): def __str__(self):
return f"ChatObserver for {self.stream_id}" return f"ChatObserver for {self.stream_id}"

View File

@@ -140,7 +140,6 @@ class NotificationManager:
self._active_states.add(notification.type) self._active_states.add(notification.type)
else: else:
self._active_states.discard(notification.type) self._active_states.discard(notification.type)
# 调用目标接收者的处理器 # 调用目标接收者的处理器
target = notification.target target = notification.target
@@ -181,7 +180,7 @@ class NotificationManager:
history = history[-limit:] history = history[-limit:]
return history return history
def __str__(self): def __str__(self):
str = "" str = ""
for target, handlers in self._handlers.items(): for target, handlers in self._handlers.items():
@@ -295,5 +294,3 @@ class ChatStateManager:
current_time = datetime.now().timestamp() current_time = datetime.now().timestamp()
return (current_time - self.state_info.last_message_time) <= threshold return (current_time - self.state_info.last_message_time) <= threshold

View File

@@ -65,7 +65,6 @@ class Conversation:
self.observation_info.bind_to_chat_observer(self.chat_observer) self.observation_info.bind_to_chat_observer(self.chat_observer)
# print(self.chat_observer.get_cached_messages(limit=) # print(self.chat_observer.get_cached_messages(limit=)
self.conversation_info = ConversationInfo() self.conversation_info = ConversationInfo()
except Exception as e: except Exception as e:
logger.error(f"初始化对话实例:注册信息组件失败: {e}") logger.error(f"初始化对话实例:注册信息组件失败: {e}")
@@ -96,7 +95,7 @@ class Conversation:
# 执行行动 # 执行行动
await self._handle_action(action, reason, self.observation_info, self.conversation_info) await self._handle_action(action, reason, self.observation_info, self.conversation_info)
for goal in self.conversation_info.goal_list: for goal in self.conversation_info.goal_list:
# 检查goal是否为元组类型如果是元组则使用索引访问如果是字典则使用get方法 # 检查goal是否为元组类型如果是元组则使用索引访问如果是字典则使用get方法
if isinstance(goal, tuple): if isinstance(goal, tuple):
@@ -151,7 +150,7 @@ class Conversation:
if action == "direct_reply": if action == "direct_reply":
self.waiter.wait_accumulated_time = 0 self.waiter.wait_accumulated_time = 0
self.state = ConversationState.GENERATING self.state = ConversationState.GENERATING
self.generated_reply = await self.reply_generator.generate(observation_info, conversation_info) self.generated_reply = await self.reply_generator.generate(observation_info, conversation_info)
print(f"生成回复: {self.generated_reply}") print(f"生成回复: {self.generated_reply}")
@@ -174,7 +173,6 @@ class Conversation:
await self._send_reply() await self._send_reply()
conversation_info.done_action[-1].update( conversation_info.done_action[-1].update(
{ {
"status": "done", "status": "done",
@@ -184,7 +182,7 @@ class Conversation:
elif action == "fetch_knowledge": elif action == "fetch_knowledge":
self.waiter.wait_accumulated_time = 0 self.waiter.wait_accumulated_time = 0
self.state = ConversationState.FETCHING self.state = ConversationState.FETCHING
knowledge = "TODO:知识" knowledge = "TODO:知识"
topic = "TODO:关键词" topic = "TODO:关键词"
@@ -199,7 +197,7 @@ class Conversation:
elif action == "rethink_goal": elif action == "rethink_goal":
self.waiter.wait_accumulated_time = 0 self.waiter.wait_accumulated_time = 0
self.state = ConversationState.RETHINKING self.state = ConversationState.RETHINKING
await self.goal_analyzer.analyze_goal(conversation_info, observation_info) await self.goal_analyzer.analyze_goal(conversation_info, observation_info)
@@ -208,7 +206,6 @@ class Conversation:
logger.info("倾听对方发言...") logger.info("倾听对方发言...")
await self.waiter.wait_listening(conversation_info) await self.waiter.wait_listening(conversation_info)
elif action == "end_conversation": elif action == "end_conversation":
self.should_continue = False self.should_continue = False
logger.info("决定结束对话...") logger.info("决定结束对话...")
@@ -239,9 +236,7 @@ class Conversation:
return return
try: try:
await self.direct_sender.send_message( await self.direct_sender.send_message(chat_stream=self.chat_stream, content=self.generated_reply)
chat_stream=self.chat_stream, content=self.generated_reply
)
self.chat_observer.trigger_update() # 触发立即更新 self.chat_observer.trigger_update() # 触发立即更新
if not await self.chat_observer.wait_for_update(): if not await self.chat_observer.wait_for_update():
logger.warning("等待消息更新超时") logger.warning("等待消息更新超时")

View File

@@ -2,6 +2,7 @@ from abc import ABC, abstractmethod
from typing import List, Dict, Any from typing import List, Dict, Any
from src.common.database import db from src.common.database import db
class MessageStorage(ABC): class MessageStorage(ABC):
"""消息存储接口""" """消息存储接口"""

View File

@@ -26,24 +26,24 @@ class ObservationInfoHandler(NotificationHandler):
# 获取通知类型和数据 # 获取通知类型和数据
notification_type = notification.type notification_type = notification.type
data = notification.data data = notification.data
if notification_type == NotificationType.NEW_MESSAGE: if notification_type == NotificationType.NEW_MESSAGE:
# 处理新消息通知 # 处理新消息通知
logger.debug(f"收到新消息通知data: {data}") logger.debug(f"收到新消息通知data: {data}")
message_id = data.get("message_id") message_id = data.get("message_id")
processed_plain_text = data.get("processed_plain_text") processed_plain_text = data.get("processed_plain_text")
detailed_plain_text = data.get("detailed_plain_text") detailed_plain_text = data.get("detailed_plain_text")
user_info = data.get("user_info") user_info = data.get("user_info")
time_value = data.get("time") time_value = data.get("time")
message = { message = {
"message_id": message_id, "message_id": message_id,
"processed_plain_text": processed_plain_text, "processed_plain_text": processed_plain_text,
"detailed_plain_text": detailed_plain_text, "detailed_plain_text": detailed_plain_text,
"user_info": user_info, "user_info": user_info,
"time": time_value "time": time_value,
} }
self.observation_info.update_from_message(message) self.observation_info.update_from_message(message)
elif notification_type == NotificationType.COLD_CHAT: elif notification_type == NotificationType.COLD_CHAT:
@@ -161,7 +161,7 @@ class ObservationInfo:
# logger.debug(f"更新信息from_message: {message}") # logger.debug(f"更新信息from_message: {message}")
self.last_message_time = message["time"] self.last_message_time = message["time"]
self.last_message_id = message["message_id"] self.last_message_id = message["message_id"]
self.last_message_content = message.get("processed_plain_text", "") self.last_message_content = message.get("processed_plain_text", "")
user_info = UserInfo.from_dict(message.get("user_info", {})) user_info = UserInfo.from_dict(message.get("user_info", {}))
@@ -233,4 +233,3 @@ class ObservationInfo:
self.unprocessed_messages.clear() self.unprocessed_messages.clear()
self.chat_history_count = len(self.chat_history) self.chat_history_count = len(self.chat_history)
self.new_messages_count = 0 self.new_messages_count = 0

View File

@@ -1,6 +1,7 @@
# Programmable Friendly Conversationalist # Programmable Friendly Conversationalist
# Prefrontal cortex # Prefrontal cortex
import datetime import datetime
# import asyncio # import asyncio
from typing import List, Optional, Tuple, TYPE_CHECKING from typing import List, Optional, Tuple, TYPE_CHECKING
from src.common.logger import get_module_logger from src.common.logger import get_module_logger
@@ -63,13 +64,13 @@ class GoalAnalyzer:
goal = goal_reason[0] goal = goal_reason[0]
reasoning = goal_reason[1] if len(goal_reason) > 1 else "没有明确原因" reasoning = goal_reason[1] if len(goal_reason) > 1 else "没有明确原因"
elif isinstance(goal_reason, dict): elif isinstance(goal_reason, dict):
goal = goal_reason.get('goal') goal = goal_reason.get("goal")
reasoning = goal_reason.get('reasoning', "没有明确原因") reasoning = goal_reason.get("reasoning", "没有明确原因")
else: else:
# 如果是其他类型,尝试转为字符串 # 如果是其他类型,尝试转为字符串
goal = str(goal_reason) goal = str(goal_reason)
reasoning = "没有明确原因" reasoning = "没有明确原因"
goal_str = f"目标:{goal},产生该对话目标的原因:{reasoning}\n" goal_str = f"目标:{goal},产生该对话目标的原因:{reasoning}\n"
goals_str += goal_str goals_str += goal_str
else: else:
@@ -140,14 +141,12 @@ class GoalAnalyzer:
except Exception as e: except Exception as e:
logger.error(f"分析对话目标时出错: {str(e)}") logger.error(f"分析对话目标时出错: {str(e)}")
content = "" content = ""
# 使用改进后的get_items_from_json函数处理JSON数组 # 使用改进后的get_items_from_json函数处理JSON数组
success, result = get_items_from_json( success, result = get_items_from_json(
content, "goal", "reasoning", content, "goal", "reasoning", required_types={"goal": str, "reasoning": str}, allow_array=True
required_types={"goal": str, "reasoning": str},
allow_array=True
) )
if success: if success:
# 判断结果是单个字典还是字典列表 # 判断结果是单个字典还是字典列表
if isinstance(result, list): if isinstance(result, list):
@@ -157,7 +156,7 @@ class GoalAnalyzer:
goal = item.get("goal", "") goal = item.get("goal", "")
reasoning = item.get("reasoning", "") reasoning = item.get("reasoning", "")
conversation_info.goal_list.append((goal, reasoning)) conversation_info.goal_list.append((goal, reasoning))
# 返回第一个目标作为当前主要目标(如果有) # 返回第一个目标作为当前主要目标(如果有)
if result: if result:
first_goal = result[0] first_goal = result[0]
@@ -168,7 +167,7 @@ class GoalAnalyzer:
reasoning = result.get("reasoning", "") reasoning = result.get("reasoning", "")
conversation_info.goal_list.append((goal, reasoning)) conversation_info.goal_list.append((goal, reasoning))
return (goal, "", reasoning) return (goal, "", reasoning)
# 如果解析失败,返回默认值 # 如果解析失败,返回默认值
return ("", "", "") return ("", "", "")
@@ -293,7 +292,6 @@ class GoalAnalyzer:
return False, False, f"分析出错: {str(e)}" return False, False, f"分析出错: {str(e)}"
class DirectMessageSender: class DirectMessageSender:
"""直接发送消息到平台的发送器""" """直接发送消息到平台的发送器"""

View File

@@ -27,7 +27,7 @@ def get_items_from_json(
""" """
content = content.strip() content = content.strip()
result = {} result = {}
# 设置默认值 # 设置默认值
if default_values: if default_values:
result.update(default_values) result.update(default_values)
@@ -41,7 +41,7 @@ def get_items_from_json(
if array_match: if array_match:
array_content = array_match.group() array_content = array_match.group()
json_array = json.loads(array_content) json_array = json.loads(array_content)
# 确认是数组类型 # 确认是数组类型
if isinstance(json_array, list): if isinstance(json_array, list):
# 验证数组中的每个项目是否包含所有必需字段 # 验证数组中的每个项目是否包含所有必需字段
@@ -49,7 +49,7 @@ def get_items_from_json(
for item in json_array: for item in json_array:
if not isinstance(item, dict): if not isinstance(item, dict):
continue continue
# 检查是否有所有必需字段 # 检查是否有所有必需字段
if all(field in item for field in items): if all(field in item for field in items):
# 验证字段类型 # 验证字段类型
@@ -59,22 +59,22 @@ def get_items_from_json(
if field in item and not isinstance(item[field], expected_type): if field in item and not isinstance(item[field], expected_type):
type_valid = False type_valid = False
break break
if not type_valid: if not type_valid:
continue continue
# 验证字符串字段不为空 # 验证字符串字段不为空
string_valid = True string_valid = True
for field in items: for field in items:
if isinstance(item[field], str) and not item[field].strip(): if isinstance(item[field], str) and not item[field].strip():
string_valid = False string_valid = False
break break
if not string_valid: if not string_valid:
continue continue
valid_items.append(item) valid_items.append(item)
if valid_items: if valid_items:
return True, valid_items return True, valid_items
except json.JSONDecodeError: except json.JSONDecodeError:

View File

@@ -49,22 +49,26 @@ class ReplyGenerator:
goal = goal_reason[0] goal = goal_reason[0]
reasoning = goal_reason[1] if len(goal_reason) > 1 else "没有明确原因" reasoning = goal_reason[1] if len(goal_reason) > 1 else "没有明确原因"
elif isinstance(goal_reason, dict): elif isinstance(goal_reason, dict):
goal = goal_reason.get('goal') goal = goal_reason.get("goal")
reasoning = goal_reason.get('reasoning', "没有明确原因") reasoning = goal_reason.get("reasoning", "没有明确原因")
else: else:
# 如果是其他类型,尝试转为字符串 # 如果是其他类型,尝试转为字符串
goal = str(goal_reason) goal = str(goal_reason)
reasoning = "没有明确原因" reasoning = "没有明确原因"
goal_str = f"目标:{goal},产生该对话目标的原因:{reasoning}\n" goal_str = f"目标:{goal},产生该对话目标的原因:{reasoning}\n"
goals_str += goal_str goals_str += goal_str
else: else:
goal = "目前没有明确对话目标" goal = "目前没有明确对话目标"
reasoning = "目前没有明确对话目标,最好思考一个对话目标" reasoning = "目前没有明确对话目标,最好思考一个对话目标"
goals_str = f"目标:{goal},产生该对话目标的原因:{reasoning}\n" goals_str = f"目标:{goal},产生该对话目标的原因:{reasoning}\n"
# 获取聊天历史记录 # 获取聊天历史记录
chat_history_list = observation_info.chat_history[-20:] if len(observation_info.chat_history) >= 20 else observation_info.chat_history chat_history_list = (
observation_info.chat_history[-20:]
if len(observation_info.chat_history) >= 20
else observation_info.chat_history
)
chat_history_text = "" chat_history_text = ""
for msg in chat_history_list: for msg in chat_history_list:
chat_history_text += f"{msg.get('detailed_plain_text', '')}\n" chat_history_text += f"{msg.get('detailed_plain_text', '')}\n"
@@ -81,15 +85,21 @@ class ReplyGenerator:
personality_text = f"你的名字是{self.name}{self.personality_info}" personality_text = f"你的名字是{self.name}{self.personality_info}"
# 构建action历史文本 # 构建action历史文本
action_history_list = conversation_info.done_action[-10:] if len(conversation_info.done_action) >= 10 else conversation_info.done_action action_history_list = (
conversation_info.done_action[-10:]
if len(conversation_info.done_action) >= 10
else conversation_info.done_action
)
action_history_text = "你之前做的事情是:" action_history_text = "你之前做的事情是:"
for action in action_history_list: for action in action_history_list:
if isinstance(action, dict): if isinstance(action, dict):
action_type = action.get('action') action_type = action.get("action")
action_reason = action.get('reason') action_reason = action.get("reason")
action_status = action.get('status') action_status = action.get("status")
if action_status == "recall": if action_status == "recall":
action_history_text += f"原本打算:{action_type},但是因为有新消息,你发现这个行动不合适,所以你没做\n" action_history_text += (
f"原本打算:{action_type},但是因为有新消息,你发现这个行动不合适,所以你没做\n"
)
elif action_status == "done": elif action_status == "done":
action_history_text += f"你之前做了:{action_type},原因:{action_reason}\n" action_history_text += f"你之前做了:{action_type},原因:{action_reason}\n"
elif isinstance(action, tuple): elif isinstance(action, tuple):
@@ -98,7 +108,9 @@ class ReplyGenerator:
action_reason = action[1] if len(action) > 1 else "未知原因" action_reason = action[1] if len(action) > 1 else "未知原因"
action_status = action[2] if len(action) > 2 else "done" action_status = action[2] if len(action) > 2 else "done"
if action_status == "recall": if action_status == "recall":
action_history_text += f"原本打算:{action_type},但是因为有新消息,你发现这个行动不合适,所以你没做\n" action_history_text += (
f"原本打算:{action_type},但是因为有新消息,你发现这个行动不合适,所以你没做\n"
)
elif action_status == "done": elif action_status == "done":
action_history_text += f"你之前做了:{action_type},原因:{action_reason}\n" action_history_text += f"你之前做了:{action_type},原因:{action_reason}\n"

View File

@@ -16,7 +16,7 @@ class Waiter:
self.chat_observer = ChatObserver.get_instance(stream_id) self.chat_observer = ChatObserver.get_instance(stream_id)
self.personality_info = Individuality.get_instance().get_prompt(type="personality", x_person=2, level=2) self.personality_info = Individuality.get_instance().get_prompt(type="personality", x_person=2, level=2)
self.name = global_config.BOT_NICKNAME self.name = global_config.BOT_NICKNAME
self.wait_accumulated_time = 0 self.wait_accumulated_time = 0
async def wait(self, conversation_info: ConversationInfo) -> bool: async def wait(self, conversation_info: ConversationInfo) -> bool:
@@ -38,20 +38,20 @@ class Waiter:
# 检查是否超时 # 检查是否超时
if time.time() - wait_start_time > 300: if time.time() - wait_start_time > 300:
self.wait_accumulated_time += 300 self.wait_accumulated_time += 300
logger.info("等待超过300秒结束对话") logger.info("等待超过300秒结束对话")
wait_goal = { wait_goal = {
"goal": f"你等待了{self.wait_accumulated_time/60}分钟,思考接下来要做什么", "goal": f"你等待了{self.wait_accumulated_time / 60}分钟,思考接下来要做什么",
"reason": "对方很久没有回复你的消息了" "reason": "对方很久没有回复你的消息了",
} }
conversation_info.goal_list.append(wait_goal) conversation_info.goal_list.append(wait_goal)
print(f"添加目标: {wait_goal}") print(f"添加目标: {wait_goal}")
return True return True
await asyncio.sleep(1) await asyncio.sleep(1)
logger.info("等待中...") logger.info("等待中...")
async def wait_listening(self, conversation_info: ConversationInfo) -> bool: async def wait_listening(self, conversation_info: ConversationInfo) -> bool:
"""等待倾听 """等待倾听
@@ -73,14 +73,13 @@ class Waiter:
self.wait_accumulated_time += 300 self.wait_accumulated_time += 300
logger.info("等待超过300秒结束对话") logger.info("等待超过300秒结束对话")
wait_goal = { wait_goal = {
"goal": f"你等待了{self.wait_accumulated_time/60}分钟,思考接下来要做什么", "goal": f"你等待了{self.wait_accumulated_time / 60}分钟,思考接下来要做什么",
"reason": "对方话说一半消失了,很久没有回复" "reason": "对方话说一半消失了,很久没有回复",
} }
conversation_info.goal_list.append(wait_goal) conversation_info.goal_list.append(wait_goal)
print(f"添加目标: {wait_goal}") print(f"添加目标: {wait_goal}")
return True return True
await asyncio.sleep(1) await asyncio.sleep(1)
logger.info("等待中...") logger.info("等待中...")

View File

@@ -8,7 +8,7 @@ from ..chat_module.only_process.only_message_process import MessageProcessor
from src.common.logger import get_module_logger, CHAT_STYLE_CONFIG, LogConfig from src.common.logger import get_module_logger, CHAT_STYLE_CONFIG, LogConfig
from ..chat_module.think_flow_chat.think_flow_chat import ThinkFlowChat from ..chat_module.think_flow_chat.think_flow_chat import ThinkFlowChat
from ..chat_module.reasoning_chat.reasoning_chat import ReasoningChat from ..chat_module.reasoning_chat.reasoning_chat import ReasoningChat
from ..utils.prompt_builder import Prompt,global_prompt_manager from ..utils.prompt_builder import Prompt, global_prompt_manager
import traceback import traceback
# 定义日志配置 # 定义日志配置
@@ -89,17 +89,17 @@ class ChatBot:
if userinfo.user_id in global_config.ban_user_id: if userinfo.user_id in global_config.ban_user_id:
logger.debug(f"用户{userinfo.user_id}被禁止回复") logger.debug(f"用户{userinfo.user_id}被禁止回复")
return return
if message.message_info.template_info and not message.message_info.template_info.template_default: if message.message_info.template_info and not message.message_info.template_info.template_default:
template_group_name=message.message_info.template_info.template_name template_group_name = message.message_info.template_info.template_name
template_items=message.message_info.template_info.template_items template_items = message.message_info.template_info.template_items
async with global_prompt_manager.async_message_scope(template_group_name): async with global_prompt_manager.async_message_scope(template_group_name):
if isinstance(template_items,dict): if isinstance(template_items, dict):
for k in template_items.keys(): for k in template_items.keys():
await Prompt.create_async(template_items[k],k) await Prompt.create_async(template_items[k], k)
print(f"注册{template_items[k]},{k}") print(f"注册{template_items[k]},{k}")
else: else:
template_group_name=None template_group_name = None
async def preprocess(): async def preprocess():
if global_config.enable_pfc_chatting: if global_config.enable_pfc_chatting:

View File

@@ -87,7 +87,6 @@ async def get_embedding(text, request_type="embedding"):
return embedding return embedding
async def get_recent_group_messages(chat_id: str, limit: int = 12) -> list: async def get_recent_group_messages(chat_id: str, limit: int = 12) -> list:
"""从数据库获取群组最近的消息记录 """从数据库获取群组最近的消息记录

View File

@@ -38,7 +38,7 @@ class ResponseGenerator:
self.current_model_type = "r1" # 默认使用 R1 self.current_model_type = "r1" # 默认使用 R1
self.current_model_name = "unknown model" self.current_model_name = "unknown model"
async def generate_response(self, message: MessageThinking,thinking_id:str) -> Optional[Union[str, List[str]]]: async def generate_response(self, message: MessageThinking, thinking_id: str) -> Optional[Union[str, List[str]]]:
"""根据当前模型类型选择对应的生成函数""" """根据当前模型类型选择对应的生成函数"""
# 从global_config中获取模型概率值并选择模型 # 从global_config中获取模型概率值并选择模型
if random.random() < global_config.MODEL_R1_PROBABILITY: if random.random() < global_config.MODEL_R1_PROBABILITY:
@@ -52,7 +52,7 @@ class ResponseGenerator:
f"{self.current_model_type}思考:{message.processed_plain_text[:30] + '...' if len(message.processed_plain_text) > 30 else message.processed_plain_text}" f"{self.current_model_type}思考:{message.processed_plain_text[:30] + '...' if len(message.processed_plain_text) > 30 else message.processed_plain_text}"
) # noqa: E501 ) # noqa: E501
model_response = await self._generate_response_with_model(message, current_model,thinking_id) model_response = await self._generate_response_with_model(message, current_model, thinking_id)
# print(f"raw_content: {model_response}") # print(f"raw_content: {model_response}")
@@ -65,11 +65,11 @@ class ResponseGenerator:
logger.info(f"{self.current_model_type}思考,失败") logger.info(f"{self.current_model_type}思考,失败")
return None return None
async def _generate_response_with_model(self, message: MessageThinking, model: LLM_request,thinking_id:str): async def _generate_response_with_model(self, message: MessageThinking, model: LLM_request, thinking_id: str):
sender_name = "" sender_name = ""
info_catcher = info_catcher_manager.get_info_catcher(thinking_id) info_catcher = info_catcher_manager.get_info_catcher(thinking_id)
if message.chat_stream.user_info.user_cardname and message.chat_stream.user_info.user_nickname: if message.chat_stream.user_info.user_cardname and message.chat_stream.user_info.user_nickname:
sender_name = ( sender_name = (
f"[({message.chat_stream.user_info.user_id}){message.chat_stream.user_info.user_nickname}]" f"[({message.chat_stream.user_info.user_id}){message.chat_stream.user_info.user_nickname}]"
@@ -94,14 +94,11 @@ class ResponseGenerator:
try: try:
content, reasoning_content, self.current_model_name = await model.generate_response(prompt) content, reasoning_content, self.current_model_name = await model.generate_response(prompt)
info_catcher.catch_after_llm_generated( info_catcher.catch_after_llm_generated(
prompt=prompt, prompt=prompt, response=content, reasoning_content=reasoning_content, model_name=self.current_model_name
response=content, )
reasoning_content=reasoning_content,
model_name=self.current_model_name)
except Exception: except Exception:
logger.exception("生成回复时出错") logger.exception("生成回复时出错")
return None return None
@@ -118,7 +115,6 @@ class ResponseGenerator:
return content return content
# def _save_to_db( # def _save_to_db(
# self, # self,
# message: MessageRecv, # message: MessageRecv,

View File

@@ -144,12 +144,10 @@ class PromptBuilder:
for pattern in rule.get("regex", []): for pattern in rule.get("regex", []):
result = pattern.search(message_txt) result = pattern.search(message_txt)
if result: if result:
reaction = rule.get('reaction', '') reaction = rule.get("reaction", "")
for name, content in result.groupdict().items(): for name, content in result.groupdict().items():
reaction = reaction.replace(f'[{name}]', content) reaction = reaction.replace(f"[{name}]", content)
logger.info( logger.info(f"匹配到以下正则表达式:{pattern},触发反应:{reaction}")
f"匹配到以下正则表达式:{pattern},触发反应:{reaction}"
)
keywords_reaction_prompt += reaction + "" keywords_reaction_prompt += reaction + ""
break break
@@ -197,7 +195,7 @@ class PromptBuilder:
prompt = await global_prompt_manager.format_prompt( prompt = await global_prompt_manager.format_prompt(
"reasoning_prompt_main", "reasoning_prompt_main",
relation_prompt_all=await global_prompt_manager.get_prompt_async("relationship_prompt"), relation_prompt_all=await global_prompt_manager.get_prompt_async("relationship_prompt"),
replation_prompt=relation_prompt, relation_prompt=relation_prompt,
sender_name=sender_name, sender_name=sender_name,
memory_prompt=memory_prompt, memory_prompt=memory_prompt,
prompt_info=prompt_info, prompt_info=prompt_info,

View File

@@ -59,11 +59,7 @@ class ThinkFlowChat:
return thinking_id return thinking_id
async def _send_response_messages(self, async def _send_response_messages(self, message, chat, response_set: List[str], thinking_id) -> MessageSending:
message,
chat,
response_set:List[str],
thinking_id) -> MessageSending:
"""发送回复消息""" """发送回复消息"""
container = message_manager.get_container(chat.stream_id) container = message_manager.get_container(chat.stream_id)
thinking_message = None thinking_message = None
@@ -260,8 +256,6 @@ class ThinkFlowChat:
if random() < reply_probability: if random() < reply_probability:
try: try:
do_reply = True do_reply = True
# 回复前处理 # 回复前处理
await willing_manager.before_generate_reply_handle(message.message_info.message_id) await willing_manager.before_generate_reply_handle(message.message_info.message_id)
@@ -274,9 +268,9 @@ class ThinkFlowChat:
timing_results["创建思考消息"] = timer2 - timer1 timing_results["创建思考消息"] = timer2 - timer1
except Exception as e: except Exception as e:
logger.error(f"心流创建思考消息失败: {e}") logger.error(f"心流创建思考消息失败: {e}")
logger.debug(f"创建捕捉器thinking_id:{thinking_id}") logger.debug(f"创建捕捉器thinking_id:{thinking_id}")
info_catcher = info_catcher_manager.get_info_catcher(thinking_id) info_catcher = info_catcher_manager.get_info_catcher(thinking_id)
info_catcher.catch_decide_to_response(message) info_catcher.catch_decide_to_response(message)
@@ -288,32 +282,32 @@ class ThinkFlowChat:
timing_results["观察"] = timer2 - timer1 timing_results["观察"] = timer2 - timer1
except Exception as e: except Exception as e:
logger.error(f"心流观察失败: {e}") logger.error(f"心流观察失败: {e}")
info_catcher.catch_after_observe(timing_results["观察"]) info_catcher.catch_after_observe(timing_results["观察"])
# 思考前脑内状态 # 思考前脑内状态
try: try:
timer1 = time.time() timer1 = time.time()
current_mind,past_mind = await heartflow.get_subheartflow(chat.stream_id).do_thinking_before_reply( current_mind, past_mind = await heartflow.get_subheartflow(chat.stream_id).do_thinking_before_reply(
message_txt = message.processed_plain_text, message_txt=message.processed_plain_text,
sender_name = message.message_info.user_info.user_nickname, sender_name=message.message_info.user_info.user_nickname,
chat_stream = chat chat_stream=chat,
) )
timer2 = time.time() timer2 = time.time()
timing_results["思考前脑内状态"] = timer2 - timer1 timing_results["思考前脑内状态"] = timer2 - timer1
except Exception as e: except Exception as e:
logger.error(f"心流思考前脑内状态失败: {e}") logger.error(f"心流思考前脑内状态失败: {e}")
info_catcher.catch_afer_shf_step(timing_results["思考前脑内状态"],past_mind,current_mind) info_catcher.catch_afer_shf_step(timing_results["思考前脑内状态"], past_mind, current_mind)
# 生成回复 # 生成回复
timer1 = time.time() timer1 = time.time()
response_set = await self.gpt.generate_response(message,thinking_id) response_set = await self.gpt.generate_response(message, thinking_id)
timer2 = time.time() timer2 = time.time()
timing_results["生成回复"] = timer2 - timer1 timing_results["生成回复"] = timer2 - timer1
info_catcher.catch_after_generate_response(timing_results["生成回复"]) info_catcher.catch_after_generate_response(timing_results["生成回复"])
if not response_set: if not response_set:
logger.info("回复生成失败,返回为空") logger.info("回复生成失败,返回为空")
return return
@@ -326,11 +320,9 @@ class ThinkFlowChat:
timing_results["发送消息"] = timer2 - timer1 timing_results["发送消息"] = timer2 - timer1
except Exception as e: except Exception as e:
logger.error(f"心流发送消息失败: {e}") logger.error(f"心流发送消息失败: {e}")
info_catcher.catch_after_response(timing_results["发送消息"], response_set, first_bot_msg)
info_catcher.catch_after_response(timing_results["发送消息"],response_set,first_bot_msg)
info_catcher.done_catch() info_catcher.done_catch()
# 处理表情包 # 处理表情包

View File

@@ -35,44 +35,51 @@ class ResponseGenerator:
self.current_model_type = "r1" # 默认使用 R1 self.current_model_type = "r1" # 默认使用 R1
self.current_model_name = "unknown model" self.current_model_name = "unknown model"
async def generate_response(self, message: MessageRecv,thinking_id:str) -> Optional[List[str]]: async def generate_response(self, message: MessageRecv, thinking_id: str) -> Optional[List[str]]:
"""根据当前模型类型选择对应的生成函数""" """根据当前模型类型选择对应的生成函数"""
logger.info( logger.info(
f"思考:{message.processed_plain_text[:30] + '...' if len(message.processed_plain_text) > 30 else message.processed_plain_text}" f"思考:{message.processed_plain_text[:30] + '...' if len(message.processed_plain_text) > 30 else message.processed_plain_text}"
) )
arousal_multiplier = MoodManager.get_instance().get_arousal_multiplier() arousal_multiplier = MoodManager.get_instance().get_arousal_multiplier()
time1 = time.time() time1 = time.time()
checked = False checked = False
if random.random() > 0: if random.random() > 0:
checked = False checked = False
current_model = self.model_normal current_model = self.model_normal
current_model.temperature = 0.3 * arousal_multiplier #激活度越高,温度越高 current_model.temperature = 0.3 * arousal_multiplier # 激活度越高,温度越高
model_response = await self._generate_response_with_model(message, current_model,thinking_id,mode="normal") model_response = await self._generate_response_with_model(
message, current_model, thinking_id, mode="normal"
)
model_checked_response = model_response model_checked_response = model_response
else: else:
checked = True checked = True
current_model = self.model_normal current_model = self.model_normal
current_model.temperature = 0.3 * arousal_multiplier #激活度越高,温度越高 current_model.temperature = 0.3 * arousal_multiplier # 激活度越高,温度越高
print(f"生成{message.processed_plain_text}回复温度是:{current_model.temperature}") print(f"生成{message.processed_plain_text}回复温度是:{current_model.temperature}")
model_response = await self._generate_response_with_model(message, current_model,thinking_id,mode="simple") model_response = await self._generate_response_with_model(
message, current_model, thinking_id, mode="simple"
)
current_model.temperature = 0.3 current_model.temperature = 0.3
model_checked_response = await self._check_response_with_model(message, model_response, current_model,thinking_id) model_checked_response = await self._check_response_with_model(
message, model_response, current_model, thinking_id
)
time2 = time.time() time2 = time.time()
if model_response: if model_response:
if checked: if checked:
logger.info(f"{global_config.BOT_NICKNAME}的回复是:{model_response},思忖后,回复是:{model_checked_response},生成回复时间: {time2 - time1}") logger.info(
f"{global_config.BOT_NICKNAME}的回复是:{model_response},思忖后,回复是:{model_checked_response},生成回复时间: {time2 - time1}"
)
else: else:
logger.info(f"{global_config.BOT_NICKNAME}的回复是:{model_response},生成回复时间: {time2 - time1}") logger.info(f"{global_config.BOT_NICKNAME}的回复是:{model_response},生成回复时间: {time2 - time1}")
model_processed_response = await self._process_response(model_checked_response) model_processed_response = await self._process_response(model_checked_response)
return model_processed_response return model_processed_response
@@ -80,11 +87,13 @@ class ResponseGenerator:
logger.info(f"{self.current_model_type}思考,失败") logger.info(f"{self.current_model_type}思考,失败")
return None return None
async def _generate_response_with_model(self, message: MessageRecv, model: LLM_request,thinking_id:str,mode:str = "normal") -> str: async def _generate_response_with_model(
self, message: MessageRecv, model: LLM_request, thinking_id: str, mode: str = "normal"
) -> str:
sender_name = "" sender_name = ""
info_catcher = info_catcher_manager.get_info_catcher(thinking_id) info_catcher = info_catcher_manager.get_info_catcher(thinking_id)
if message.chat_stream.user_info.user_cardname and message.chat_stream.user_info.user_nickname: if message.chat_stream.user_info.user_cardname and message.chat_stream.user_info.user_nickname:
sender_name = ( sender_name = (
f"[({message.chat_stream.user_info.user_id}){message.chat_stream.user_info.user_nickname}]" f"[({message.chat_stream.user_info.user_id}){message.chat_stream.user_info.user_nickname}]"
@@ -116,25 +125,22 @@ class ResponseGenerator:
try: try:
content, reasoning_content, self.current_model_name = await model.generate_response(prompt) content, reasoning_content, self.current_model_name = await model.generate_response(prompt)
info_catcher.catch_after_llm_generated( info_catcher.catch_after_llm_generated(
prompt=prompt, prompt=prompt, response=content, reasoning_content=reasoning_content, model_name=self.current_model_name
response=content, )
reasoning_content=reasoning_content,
model_name=self.current_model_name)
except Exception: except Exception:
logger.exception("生成回复时出错") logger.exception("生成回复时出错")
return None return None
return content return content
async def _check_response_with_model(self, message: MessageRecv, content:str, model: LLM_request,thinking_id:str) -> str: async def _check_response_with_model(
self, message: MessageRecv, content: str, model: LLM_request, thinking_id: str
) -> str:
_info_catcher = info_catcher_manager.get_info_catcher(thinking_id) _info_catcher = info_catcher_manager.get_info_catcher(thinking_id)
sender_name = "" sender_name = ""
if message.chat_stream.user_info.user_cardname and message.chat_stream.user_info.user_nickname: if message.chat_stream.user_info.user_cardname and message.chat_stream.user_info.user_nickname:
sender_name = ( sender_name = (
@@ -145,8 +151,7 @@ class ResponseGenerator:
sender_name = f"({message.chat_stream.user_info.user_id}){message.chat_stream.user_info.user_nickname}" sender_name = f"({message.chat_stream.user_info.user_id}){message.chat_stream.user_info.user_nickname}"
else: else:
sender_name = f"用户({message.chat_stream.user_info.user_id})" sender_name = f"用户({message.chat_stream.user_info.user_id})"
# 构建prompt # 构建prompt
timer1 = time.time() timer1 = time.time()
prompt = await prompt_builder._build_prompt_check_response( prompt = await prompt_builder._build_prompt_check_response(
@@ -154,7 +159,7 @@ class ResponseGenerator:
message_txt=message.processed_plain_text, message_txt=message.processed_plain_text,
sender_name=sender_name, sender_name=sender_name,
stream_id=message.chat_stream.stream_id, stream_id=message.chat_stream.stream_id,
content=content content=content,
) )
timer2 = time.time() timer2 = time.time()
logger.info(f"构建check_prompt: {prompt}") logger.info(f"构建check_prompt: {prompt}")
@@ -162,19 +167,17 @@ class ResponseGenerator:
try: try:
checked_content, reasoning_content, self.current_model_name = await model.generate_response(prompt) checked_content, reasoning_content, self.current_model_name = await model.generate_response(prompt)
# info_catcher.catch_after_llm_generated( # info_catcher.catch_after_llm_generated(
# prompt=prompt, # prompt=prompt,
# response=content, # response=content,
# reasoning_content=reasoning_content, # reasoning_content=reasoning_content,
# model_name=self.current_model_name) # model_name=self.current_model_name)
except Exception: except Exception:
logger.exception("检查回复时出错") logger.exception("检查回复时出错")
return None return None
return checked_content return checked_content
async def _get_emotion_tags(self, content: str, processed_plain_text: str): async def _get_emotion_tags(self, content: str, processed_plain_text: str):

View File

@@ -110,12 +110,10 @@ class PromptBuilder:
for pattern in rule.get("regex", []): for pattern in rule.get("regex", []):
result = pattern.search(message_txt) result = pattern.search(message_txt)
if result: if result:
reaction = rule.get('reaction', '') reaction = rule.get("reaction", "")
for name, content in result.groupdict().items(): for name, content in result.groupdict().items():
reaction = reaction.replace(f'[{name}]', content) reaction = reaction.replace(f"[{name}]", content)
logger.info( logger.info(f"匹配到以下正则表达式:{pattern},触发反应:{reaction}")
f"匹配到以下正则表达式:{pattern},触发反应:{reaction}"
)
keywords_reaction_prompt += reaction + "" keywords_reaction_prompt += reaction + ""
break break

View File

@@ -225,6 +225,7 @@ class Memory_graph:
return None return None
# 海马体 # 海马体
class Hippocampus: class Hippocampus:
def __init__(self): def __init__(self):
@@ -653,7 +654,6 @@ class Hippocampus:
return activation_ratio return activation_ratio
# 负责海马体与其他部分的交互 # 负责海马体与其他部分的交互
class EntorhinalCortex: class EntorhinalCortex:
def __init__(self, hippocampus: Hippocampus): def __init__(self, hippocampus: Hippocampus):

View File

@@ -27,7 +27,6 @@ async def test_memory_system():
# 测试记忆检索 # 测试记忆检索
test_text = "千石可乐在群里聊天" test_text = "千石可乐在群里聊天"
# test_text = '''千石可乐分不清AI的陪伴和人类的陪伴,是这样吗?''' # test_text = '''千石可乐分不清AI的陪伴和人类的陪伴,是这样吗?'''
print(f"开始测试记忆检索,测试文本: {test_text}\n") print(f"开始测试记忆检索,测试文本: {test_text}\n")
memories = await hippocampus_manager.get_memory_from_text( memories = await hippocampus_manager.get_memory_from_text(

View File

@@ -574,7 +574,7 @@ class LLM_request:
reasoning_content = message.get("reasoning_content", "") reasoning_content = message.get("reasoning_content", "")
if not reasoning_content: if not reasoning_content:
reasoning_content = reasoning reasoning_content = reasoning
# 提取工具调用信息 # 提取工具调用信息
tool_calls = message.get("tool_calls", None) tool_calls = message.get("tool_calls", None)
@@ -592,7 +592,7 @@ class LLM_request:
request_type=request_type if request_type is not None else self.request_type, request_type=request_type if request_type is not None else self.request_type,
endpoint=endpoint, endpoint=endpoint,
) )
# 只有当tool_calls存在且不为空时才返回 # 只有当tool_calls存在且不为空时才返回
if tool_calls: if tool_calls:
return content, reasoning_content, tool_calls return content, reasoning_content, tool_calls
@@ -657,9 +657,7 @@ class LLM_request:
**kwargs, **kwargs,
} }
response = await self._execute_request( response = await self._execute_request(endpoint="/chat/completions", payload=data, prompt=prompt)
endpoint="/chat/completions", payload=data, prompt=prompt
)
# 原样返回响应,不做处理 # 原样返回响应,不做处理
return response return response

View File

@@ -238,14 +238,14 @@ class MoodManager:
base_prompt += "情绪比较平静。" base_prompt += "情绪比较平静。"
return base_prompt return base_prompt
def get_arousal_multiplier(self) -> float: def get_arousal_multiplier(self) -> float:
"""根据当前情绪状态返回唤醒度乘数""" """根据当前情绪状态返回唤醒度乘数"""
if self.current_mood.arousal > 0.4: if self.current_mood.arousal > 0.4:
multiplier = 1 + min(0.15,(self.current_mood.arousal - 0.4)/3) multiplier = 1 + min(0.15, (self.current_mood.arousal - 0.4) / 3)
return multiplier return multiplier
elif self.current_mood.arousal < -0.4: elif self.current_mood.arousal < -0.4:
multiplier = 1 - min(0.15,((0 - self.current_mood.arousal) - 0.4)/3) multiplier = 1 - min(0.15, ((0 - self.current_mood.arousal) - 0.4) / 3)
return multiplier return multiplier
return 1.0 return 1.0

View File

@@ -1,28 +1,29 @@
from src.plugins.config.config import global_config from src.plugins.config.config import global_config
from src.plugins.chat.message import MessageRecv,MessageSending,Message from src.plugins.chat.message import MessageRecv, MessageSending, Message
from src.common.database import db from src.common.database import db
import time import time
import traceback import traceback
from typing import List from typing import List
class InfoCatcher: class InfoCatcher:
def __init__(self): def __init__(self):
self.chat_history = [] # 聊天历史,长度为三倍使用的上下文 self.chat_history = [] # 聊天历史,长度为三倍使用的上下文
self.context_length = global_config.MAX_CONTEXT_SIZE self.context_length = global_config.MAX_CONTEXT_SIZE
self.chat_history_in_thinking = [] # 思考期间的聊天内容 self.chat_history_in_thinking = [] # 思考期间的聊天内容
self.chat_history_after_response = [] # 回复后的聊天内容,长度为一倍上下文 self.chat_history_after_response = [] # 回复后的聊天内容,长度为一倍上下文
self.chat_id = "" self.chat_id = ""
self.response_mode = global_config.response_mode self.response_mode = global_config.response_mode
self.trigger_response_text = "" self.trigger_response_text = ""
self.response_text = "" self.response_text = ""
self.trigger_response_time = 0 self.trigger_response_time = 0
self.trigger_response_message = None self.trigger_response_message = None
self.response_time = 0 self.response_time = 0
self.response_messages = [] self.response_messages = []
# 使用字典来存储 heartflow 模式的数据 # 使用字典来存储 heartflow 模式的数据
self.heartflow_data = { self.heartflow_data = {
"heart_flow_prompt": "", "heart_flow_prompt": "",
@@ -32,17 +33,12 @@ class InfoCatcher:
"sub_heartflow_model": "", "sub_heartflow_model": "",
"prompt": "", "prompt": "",
"response": "", "response": "",
"model": "" "model": "",
} }
# 使用字典来存储 reasoning 模式的数据 # 使用字典来存储 reasoning 模式的数据
self.reasoning_data = { self.reasoning_data = {"thinking_log": "", "prompt": "", "response": "", "model": ""}
"thinking_log": "",
"prompt": "",
"response": "",
"model": ""
}
# 耗时 # 耗时
self.timing_results = { self.timing_results = {
"interested_rate_time": 0, "interested_rate_time": 0,
@@ -50,24 +46,24 @@ class InfoCatcher:
"sub_heartflow_step_time": 0, "sub_heartflow_step_time": 0,
"make_response_time": 0, "make_response_time": 0,
} }
def catch_decide_to_response(self,message:MessageRecv): def catch_decide_to_response(self, message: MessageRecv):
# 搜集决定回复时的信息 # 搜集决定回复时的信息
self.trigger_response_message = message self.trigger_response_message = message
self.trigger_response_text = message.detailed_plain_text self.trigger_response_text = message.detailed_plain_text
self.trigger_response_time = time.time() self.trigger_response_time = time.time()
self.chat_id = message.chat_stream.stream_id self.chat_id = message.chat_stream.stream_id
self.chat_history = self.get_message_from_db_before_msg(message) self.chat_history = self.get_message_from_db_before_msg(message)
def catch_after_observe(self,obs_duration:float):#这里可以有更多信息 def catch_after_observe(self, obs_duration: float): # 这里可以有更多信息
self.timing_results["sub_heartflow_observe_time"] = obs_duration self.timing_results["sub_heartflow_observe_time"] = obs_duration
# def catch_shf # def catch_shf
def catch_afer_shf_step(self,step_duration:float,past_mind:str,current_mind:str): def catch_afer_shf_step(self, step_duration: float, past_mind: str, current_mind: str):
self.timing_results["sub_heartflow_step_time"] = step_duration self.timing_results["sub_heartflow_step_time"] = step_duration
if len(past_mind) > 1: if len(past_mind) > 1:
self.heartflow_data["sub_heartflow_before"] = past_mind[-1] self.heartflow_data["sub_heartflow_before"] = past_mind[-1]
@@ -75,11 +71,8 @@ class InfoCatcher:
else: else:
self.heartflow_data["sub_heartflow_before"] = past_mind[-1] self.heartflow_data["sub_heartflow_before"] = past_mind[-1]
self.heartflow_data["sub_heartflow_now"] = current_mind self.heartflow_data["sub_heartflow_now"] = current_mind
def catch_after_llm_generated(self,prompt:str, def catch_after_llm_generated(self, prompt: str, response: str, reasoning_content: str = "", model_name: str = ""):
response:str,
reasoning_content:str = "",
model_name:str = ""):
if self.response_mode == "heart_flow": if self.response_mode == "heart_flow":
self.heartflow_data["prompt"] = prompt self.heartflow_data["prompt"] = prompt
self.heartflow_data["response"] = response self.heartflow_data["response"] = response
@@ -89,41 +82,38 @@ class InfoCatcher:
self.reasoning_data["prompt"] = prompt self.reasoning_data["prompt"] = prompt
self.reasoning_data["response"] = response self.reasoning_data["response"] = response
self.reasoning_data["model"] = model_name self.reasoning_data["model"] = model_name
self.response_text = response self.response_text = response
def catch_after_generate_response(self,response_duration:float): def catch_after_generate_response(self, response_duration: float):
self.timing_results["make_response_time"] = response_duration self.timing_results["make_response_time"] = response_duration
def catch_after_response(
self, response_duration: float, response_message: List[str], first_bot_msg: MessageSending
def catch_after_response(self,response_duration:float, ):
response_message:List[str],
first_bot_msg:MessageSending):
self.timing_results["make_response_time"] = response_duration self.timing_results["make_response_time"] = response_duration
self.response_time = time.time() self.response_time = time.time()
for msg in response_message: for msg in response_message:
self.response_messages.append(msg) self.response_messages.append(msg)
self.chat_history_in_thinking = self.get_message_from_db_between_msgs(self.trigger_response_message,first_bot_msg) self.chat_history_in_thinking = self.get_message_from_db_between_msgs(
self.trigger_response_message, first_bot_msg
)
def get_message_from_db_between_msgs(self, message_start: Message, message_end: Message): def get_message_from_db_between_msgs(self, message_start: Message, message_end: Message):
try: try:
# 从数据库中获取消息的时间戳 # 从数据库中获取消息的时间戳
time_start = message_start.message_info.time time_start = message_start.message_info.time
time_end = message_end.message_info.time time_end = message_end.message_info.time
chat_id = message_start.chat_stream.stream_id chat_id = message_start.chat_stream.stream_id
print(f"查询参数: time_start={time_start}, time_end={time_end}, chat_id={chat_id}") print(f"查询参数: time_start={time_start}, time_end={time_end}, chat_id={chat_id}")
# 查询数据库,获取 chat_id 相同且时间在 start 和 end 之间的数据 # 查询数据库,获取 chat_id 相同且时间在 start 和 end 之间的数据
messages_between = db.messages.find( messages_between = db.messages.find(
{ {"chat_id": chat_id, "time": {"$gt": time_start, "$lt": time_end}}
"chat_id": chat_id,
"time": {"$gt": time_start, "$lt": time_end}
}
).sort("time", -1) ).sort("time", -1)
result = list(messages_between) result = list(messages_between)
print(f"查询结果数量: {len(result)}") print(f"查询结果数量: {len(result)}")
if result: if result:
@@ -133,21 +123,23 @@ class InfoCatcher:
except Exception as e: except Exception as e:
print(f"获取消息时出错: {str(e)}") print(f"获取消息时出错: {str(e)}")
return [] return []
def get_message_from_db_before_msg(self, message: MessageRecv): def get_message_from_db_before_msg(self, message: MessageRecv):
# 从数据库中获取消息 # 从数据库中获取消息
message_id = message.message_info.message_id message_id = message.message_info.message_id
chat_id = message.chat_stream.stream_id chat_id = message.chat_stream.stream_id
# 查询数据库,获取 chat_id 相同且 message_id 小于当前消息的 30 条数据 # 查询数据库,获取 chat_id 相同且 message_id 小于当前消息的 30 条数据
messages_before = db.messages.find( messages_before = (
{"chat_id": chat_id, "message_id": {"$lt": message_id}} db.messages.find({"chat_id": chat_id, "message_id": {"$lt": message_id}})
).sort("time", -1).limit(self.context_length*3) #获取更多历史信息 .sort("time", -1)
.limit(self.context_length * 3)
) # 获取更多历史信息
return list(messages_before) return list(messages_before)
def message_list_to_dict(self, message_list): def message_list_to_dict(self, message_list):
#存储简化的聊天记录 # 存储简化的聊天记录
result = [] result = []
for message in message_list: for message in message_list:
if not isinstance(message, dict): if not isinstance(message, dict):
@@ -160,7 +152,7 @@ class InfoCatcher:
"processed_plain_text": message["processed_plain_text"], "processed_plain_text": message["processed_plain_text"],
} }
result.append(lite_message) result.append(lite_message)
return result return result
def message_to_dict(self, message): def message_to_dict(self, message):
@@ -176,12 +168,12 @@ class InfoCatcher:
"processed_plain_text": message.processed_plain_text, "processed_plain_text": message.processed_plain_text,
# "detailed_plain_text": message.detailed_plain_text # "detailed_plain_text": message.detailed_plain_text
} }
def done_catch(self): def done_catch(self):
"""将收集到的信息存储到数据库的 thinking_log 集合中""" """将收集到的信息存储到数据库的 thinking_log 集合中"""
try: try:
# 将消息对象转换为可序列化的字典 # 将消息对象转换为可序列化的字典
thinking_log_data = { thinking_log_data = {
"chat_id": self.chat_id, "chat_id": self.chat_id,
"response_mode": self.response_mode, "response_mode": self.response_mode,
@@ -198,7 +190,7 @@ class InfoCatcher:
"timing_results": self.timing_results, "timing_results": self.timing_results,
"chat_history": self.message_list_to_dict(self.chat_history), "chat_history": self.message_list_to_dict(self.chat_history),
"chat_history_in_thinking": self.message_list_to_dict(self.chat_history_in_thinking), "chat_history_in_thinking": self.message_list_to_dict(self.chat_history_in_thinking),
"chat_history_after_response": self.message_list_to_dict(self.chat_history_after_response) "chat_history_after_response": self.message_list_to_dict(self.chat_history_after_response),
} }
# 根据不同的响应模式添加相应的数据 # 根据不同的响应模式添加相应的数据
@@ -209,20 +201,22 @@ class InfoCatcher:
# 将数据插入到 thinking_log 集合中 # 将数据插入到 thinking_log 集合中
db.thinking_log.insert_one(thinking_log_data) db.thinking_log.insert_one(thinking_log_data)
return True return True
except Exception as e: except Exception as e:
print(f"存储思考日志时出错: {str(e)}") print(f"存储思考日志时出错: {str(e)}")
print(traceback.format_exc()) print(traceback.format_exc())
return False return False
class InfoCatcherManager: class InfoCatcherManager:
def __init__(self): def __init__(self):
self.info_catchers = {} self.info_catchers = {}
def get_info_catcher(self,thinking_id:str) -> InfoCatcher: def get_info_catcher(self, thinking_id: str) -> InfoCatcher:
if thinking_id not in self.info_catchers: if thinking_id not in self.info_catchers:
self.info_catchers[thinking_id] = InfoCatcher() self.info_catchers[thinking_id] = InfoCatcher()
return self.info_catchers[thinking_id] return self.info_catchers[thinking_id]
info_catcher_manager = InfoCatcherManager()
info_catcher_manager = InfoCatcherManager()

View File

@@ -32,7 +32,7 @@ class ScheduleGenerator:
# 使用离线LLM模型 # 使用离线LLM模型
self.llm_scheduler_all = LLM_request( self.llm_scheduler_all = LLM_request(
model=global_config.llm_reasoning, model=global_config.llm_reasoning,
temperature=global_config.SCHEDULE_TEMPERATURE+0.3, temperature=global_config.SCHEDULE_TEMPERATURE + 0.3,
max_tokens=7000, max_tokens=7000,
request_type="schedule", request_type="schedule",
) )

View File

@@ -8,6 +8,7 @@ from src.common.logger import get_module_logger
logger = get_module_logger("message_storage") logger = get_module_logger("message_storage")
class MessageStorage: class MessageStorage:
async def store_message(self, message: Union[MessageSending, MessageRecv], chat_stream: ChatStream) -> None: async def store_message(self, message: Union[MessageSending, MessageRecv], chat_stream: ChatStream) -> None:
"""存储消息到数据库""" """存储消息到数据库"""

View File

@@ -1,9 +1,12 @@
# import re
import ast
from typing import Dict, Any, Optional, List, Union from typing import Dict, Any, Optional, List, Union
import re
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
import asyncio import asyncio
from src.common.logger import get_module_logger
# import traceback
logger = get_module_logger("prompt_build")
class PromptContext: class PromptContext:
def __init__(self): def __init__(self):
@@ -95,15 +98,13 @@ class Prompt(str):
# 如果传入的是元组,转换为列表 # 如果传入的是元组,转换为列表
if isinstance(args, tuple): if isinstance(args, tuple):
args = list(args) args = list(args)
should_register = kwargs.pop("_should_register", True)
# 解析模板 # 解析模板
tree = ast.parse(f"f'''{fstr}'''", mode="eval") template_args = []
template_args = set() result = re.findall(r"\{(.*?)\}", fstr)
for node in ast.walk(tree): for expr in result:
if isinstance(node, ast.FormattedValue): if expr and expr not in template_args:
expr = ast.get_source_segment(fstr, node.value) template_args.append(expr)
if expr:
template_args.add(expr)
# 如果提供了初始参数,立即格式化 # 如果提供了初始参数,立即格式化
if kwargs or args: if kwargs or args:
@@ -119,17 +120,20 @@ class Prompt(str):
obj._kwargs = kwargs obj._kwargs = kwargs
# 修改自动注册逻辑 # 修改自动注册逻辑
if global_prompt_manager._context._current_context: if should_register:
# 如果存在当前上下文,则注册到上下文中 if global_prompt_manager._context._current_context:
# asyncio.create_task(global_prompt_manager._context.register_async(obj)) # 如果存在当前上下文,则注册到上下文中
pass # asyncio.create_task(global_prompt_manager._context.register_async(obj))
else: pass
# 否则注册到全局管理器 else:
global_prompt_manager.register(obj) # 否则注册到全局管理器
global_prompt_manager.register(obj)
return obj return obj
@classmethod @classmethod
async def create_async(cls, fstr: str, name: Optional[str] = None, args: Union[List[Any], tuple[Any, ...]] = None, **kwargs): async def create_async(
cls, fstr: str, name: Optional[str] = None, args: Union[List[Any], tuple[Any, ...]] = None, **kwargs
):
"""异步创建Prompt实例""" """异步创建Prompt实例"""
prompt = cls(fstr, name, args, **kwargs) prompt = cls(fstr, name, args, **kwargs)
if global_prompt_manager._context._current_context: if global_prompt_manager._context._current_context:
@@ -138,25 +142,29 @@ class Prompt(str):
@classmethod @classmethod
def _format_template(cls, template: str, args: List[Any] = None, kwargs: Dict[str, Any] = None) -> str: def _format_template(cls, template: str, args: List[Any] = None, kwargs: Dict[str, Any] = None) -> str:
fmt_str = f"f'''{template}'''"
tree = ast.parse(fmt_str, mode="eval")
template_args = [] template_args = []
for node in ast.walk(tree): result = re.findall(r"\{(.*?)\}", template)
if isinstance(node, ast.FormattedValue): for expr in result:
expr = ast.get_source_segment(fmt_str, node.value) if expr and expr not in template_args:
if expr and expr not in template_args: template_args.append(expr)
template_args.append(expr)
formatted_args = {} formatted_args = {}
formatted_kwargs = {} formatted_kwargs = {}
# 处理位置参数 # 处理位置参数
if args: if args:
# print(len(template_args), len(args), template_args, args)
for i in range(len(args)): for i in range(len(args)):
arg = args[i] if i < len(template_args):
if isinstance(arg, Prompt): arg = args[i]
formatted_args[template_args[i]] = arg.format(**kwargs) if isinstance(arg, Prompt):
formatted_args[template_args[i]] = arg.format(**kwargs)
else:
formatted_args[template_args[i]] = arg
else: else:
formatted_args[template_args[i]] = arg logger.error(
f"构建提示词模板失败,解析到的参数列表{template_args},长度为{len(template_args)},输入的参数列表为{args},提示词模板为{template}"
)
raise ValueError("格式化模板失败")
# 处理关键字参数 # 处理关键字参数
if kwargs: if kwargs:
@@ -177,15 +185,21 @@ class Prompt(str):
template = template.format(**formatted_kwargs) template = template.format(**formatted_kwargs)
return template return template
except (IndexError, KeyError) as e: except (IndexError, KeyError) as e:
raise ValueError(f"格式化模板失败: {template}, args={formatted_args}, kwargs={formatted_kwargs}") from e raise ValueError(
f"格式化模板失败: {template}, args={formatted_args}, kwargs={formatted_kwargs} {str(e)}"
) from e
def format(self, *args, **kwargs) -> "Prompt": def format(self, *args, **kwargs) -> "str":
"""支持位置参数和关键字参数的格式化,使用""" """支持位置参数和关键字参数的格式化,使用"""
ret = type(self)( ret = type(self)(
self.template, self.name, args=list(args) if args else self._args, **kwargs if kwargs else self._kwargs self.template,
self.name,
args=list(args) if args else self._args,
_should_register=False,
**kwargs if kwargs else self._kwargs,
) )
# print(f"prompt build result: {ret} name: {ret.name} ") # print(f"prompt build result: {ret} name: {ret.name} ")
return ret return str(ret)
def __str__(self) -> str: def __str__(self) -> str:
if self._kwargs or self._args: if self._kwargs or self._args:

View File

@@ -1,6 +1,7 @@
import asyncio import asyncio
from .willing_manager import BaseWillingManager from .willing_manager import BaseWillingManager
class ClassicalWillingManager(BaseWillingManager): class ClassicalWillingManager(BaseWillingManager):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@@ -41,17 +42,22 @@ class ClassicalWillingManager(BaseWillingManager):
self.chat_reply_willing[chat_id] = min(current_willing, 3.0) self.chat_reply_willing[chat_id] = min(current_willing, 3.0)
reply_probability = min(max((current_willing - 0.5), 0.01) * self.global_config.response_willing_amplifier * 2, 1) reply_probability = min(
max((current_willing - 0.5), 0.01) * self.global_config.response_willing_amplifier * 2, 1
)
# 检查群组权限(如果是群聊) # 检查群组权限(如果是群聊)
if willing_info.group_info and willing_info.group_info.group_id in self.global_config.talk_frequency_down_groups: if (
willing_info.group_info
and willing_info.group_info.group_id in self.global_config.talk_frequency_down_groups
):
reply_probability = reply_probability / self.global_config.down_frequency_rate reply_probability = reply_probability / self.global_config.down_frequency_rate
if is_emoji_not_reply: if is_emoji_not_reply:
reply_probability = 0 reply_probability = 0
return reply_probability return reply_probability
async def before_generate_reply_handle(self, message_id): async def before_generate_reply_handle(self, message_id):
chat_id = self.ongoing_messages[message_id].chat_id chat_id = self.ongoing_messages[message_id].chat_id
current_willing = self.chat_reply_willing.get(chat_id, 0) current_willing = self.chat_reply_willing.get(chat_id, 0)
@@ -71,8 +77,6 @@ class ClassicalWillingManager(BaseWillingManager):
async def get_variable_parameters(self): async def get_variable_parameters(self):
return await super().get_variable_parameters() return await super().get_variable_parameters()
async def set_variable_parameters(self, parameters): async def set_variable_parameters(self, parameters):
return await super().set_variable_parameters(parameters) return await super().set_variable_parameters(parameters)

View File

@@ -4,4 +4,3 @@ from .willing_manager import BaseWillingManager
class CustomWillingManager(BaseWillingManager): class CustomWillingManager(BaseWillingManager):
def __init__(self): def __init__(self):
super().__init__() super().__init__()

View File

@@ -20,7 +20,6 @@ class DynamicWillingManager(BaseWillingManager):
self._decay_task = None self._decay_task = None
self._mode_switch_task = None self._mode_switch_task = None
async def async_task_starter(self): async def async_task_starter(self):
if self._decay_task is None: if self._decay_task is None:
self._decay_task = asyncio.create_task(self._decay_reply_willing()) self._decay_task = asyncio.create_task(self._decay_reply_willing())
@@ -84,7 +83,9 @@ class DynamicWillingManager(BaseWillingManager):
self.chat_high_willing_mode[chat_id] = True self.chat_high_willing_mode[chat_id] = True
self.chat_reply_willing[chat_id] = 1.0 # 设置为较高回复意愿 self.chat_reply_willing[chat_id] = 1.0 # 设置为较高回复意愿
self.chat_high_willing_duration[chat_id] = random.randint(180, 240) # 3-4分钟 self.chat_high_willing_duration[chat_id] = random.randint(180, 240) # 3-4分钟
self.logger.debug(f"聊天流 {chat_id} 切换到高回复意愿期,持续 {self.chat_high_willing_duration[chat_id]}") self.logger.debug(
f"聊天流 {chat_id} 切换到高回复意愿期,持续 {self.chat_high_willing_duration[chat_id]}"
)
self.chat_last_mode_change[chat_id] = time.time() self.chat_last_mode_change[chat_id] = time.time()
self.chat_msg_count[chat_id] = 0 # 重置消息计数 self.chat_msg_count[chat_id] = 0 # 重置消息计数
@@ -148,7 +149,9 @@ class DynamicWillingManager(BaseWillingManager):
# 根据话题兴趣度适当调整 # 根据话题兴趣度适当调整
if willing_info.interested_rate > 0.5: if willing_info.interested_rate > 0.5:
current_willing += (willing_info.interested_rate - 0.5) * 0.5 * self.global_config.response_interested_rate_amplifier current_willing += (
(willing_info.interested_rate - 0.5) * 0.5 * self.global_config.response_interested_rate_amplifier
)
# 根据当前模式计算回复概率 # 根据当前模式计算回复概率
base_probability = 0.0 base_probability = 0.0
@@ -228,12 +231,12 @@ class DynamicWillingManager(BaseWillingManager):
async def bombing_buffer_message_handle(self, message_id): async def bombing_buffer_message_handle(self, message_id):
return await super().bombing_buffer_message_handle(message_id) return await super().bombing_buffer_message_handle(message_id)
async def after_generate_reply_handle(self, message_id): async def after_generate_reply_handle(self, message_id):
return await super().after_generate_reply_handle(message_id) return await super().after_generate_reply_handle(message_id)
async def get_variable_parameters(self): async def get_variable_parameters(self):
return await super().get_variable_parameters() return await super().get_variable_parameters()
async def set_variable_parameters(self, parameters): async def set_variable_parameters(self, parameters):
return await super().set_variable_parameters(parameters) return await super().set_variable_parameters(parameters)

View File

@@ -17,19 +17,22 @@ Mxp 模式:梦溪畔独家赞助
中策是发issue 中策是发issue
下下策是询问一个菜鸟(@梦溪畔) 下下策是询问一个菜鸟(@梦溪畔)
""" """
from .willing_manager import BaseWillingManager from .willing_manager import BaseWillingManager
from typing import Dict from typing import Dict
import asyncio import asyncio
import time import time
import math import math
class MxpWillingManager(BaseWillingManager): class MxpWillingManager(BaseWillingManager):
"""Mxp意愿管理器""" """Mxp意愿管理器"""
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.chat_person_reply_willing: Dict[str, Dict[str, float]] = {} # chat_id: {person_id: 意愿值} self.chat_person_reply_willing: Dict[str, Dict[str, float]] = {} # chat_id: {person_id: 意愿值}
self.chat_new_message_time: Dict[str, list[float]] = {} # 聊天流ID: 消息时间 self.chat_new_message_time: Dict[str, list[float]] = {} # 聊天流ID: 消息时间
self.last_response_person: Dict[str, tuple[str, int]] = {} # 上次回复的用户信息 self.last_response_person: Dict[str, tuple[str, int]] = {} # 上次回复的用户信息
self.temporary_willing: float = 0 # 临时意愿值 self.temporary_willing: float = 0 # 临时意愿值
# 可变参数 # 可变参数
@@ -39,8 +42,8 @@ class MxpWillingManager(BaseWillingManager):
self.basic_maximum_willing = 0.5 # 基础最大意愿值 self.basic_maximum_willing = 0.5 # 基础最大意愿值
self.mention_willing_gain = 0.6 # 提及意愿增益 self.mention_willing_gain = 0.6 # 提及意愿增益
self.interest_willing_gain = 0.3 # 兴趣意愿增益 self.interest_willing_gain = 0.3 # 兴趣意愿增益
self.emoji_response_penalty = self.global_config.emoji_response_penalty # 表情包回复惩罚 self.emoji_response_penalty = self.global_config.emoji_response_penalty # 表情包回复惩罚
self.down_frequency_rate = self.global_config.down_frequency_rate # 降低回复频率的群组惩罚系数 self.down_frequency_rate = self.global_config.down_frequency_rate # 降低回复频率的群组惩罚系数
self.single_chat_gain = 0.12 # 单聊增益 self.single_chat_gain = 0.12 # 单聊增益
async def async_task_starter(self) -> None: async def async_task_starter(self) -> None:
@@ -73,9 +76,13 @@ class MxpWillingManager(BaseWillingManager):
w_info = self.ongoing_messages[message_id] w_info = self.ongoing_messages[message_id]
if w_info.is_mentioned_bot: if w_info.is_mentioned_bot:
self.chat_person_reply_willing[w_info.chat_id][w_info.person_id] += 0.2 self.chat_person_reply_willing[w_info.chat_id][w_info.person_id] += 0.2
if w_info.chat_id in self.last_response_person and self.last_response_person[w_info.chat_id][0] == w_info.person_id: if (
self.chat_person_reply_willing[w_info.chat_id][w_info.person_id] +=\ w_info.chat_id in self.last_response_person
self.single_chat_gain * (2 * self.last_response_person[w_info.chat_id][1] + 1) and self.last_response_person[w_info.chat_id][0] == w_info.person_id
):
self.chat_person_reply_willing[w_info.chat_id][w_info.person_id] += self.single_chat_gain * (
2 * self.last_response_person[w_info.chat_id][1] + 1
)
now_chat_new_person = self.last_response_person.get(w_info.chat_id, ["", 0]) now_chat_new_person = self.last_response_person.get(w_info.chat_id, ["", 0])
if now_chat_new_person[0] != w_info.person_id: if now_chat_new_person[0] != w_info.person_id:
self.last_response_person[w_info.chat_id] = [w_info.person_id, 0] self.last_response_person[w_info.chat_id] = [w_info.person_id, 0]
@@ -98,7 +105,10 @@ class MxpWillingManager(BaseWillingManager):
rel_level = self._get_relationship_level_num(rel_value) rel_level = self._get_relationship_level_num(rel_value)
current_willing += rel_level * 0.1 current_willing += rel_level * 0.1
if w_info.chat_id in self.last_response_person and self.last_response_person[w_info.chat_id][0] == w_info.person_id: if (
w_info.chat_id in self.last_response_person
and self.last_response_person[w_info.chat_id][0] == w_info.person_id
):
current_willing += self.single_chat_gain * (2 * self.last_response_person[w_info.chat_id][1] + 1) current_willing += self.single_chat_gain * (2 * self.last_response_person[w_info.chat_id][1] + 1)
chat_ongoing_messages = [msg for msg in self.ongoing_messages.values() if msg.chat_id == w_info.chat_id] chat_ongoing_messages = [msg for msg in self.ongoing_messages.values() if msg.chat_id == w_info.chat_id]
@@ -141,16 +151,22 @@ class MxpWillingManager(BaseWillingManager):
self.logger.debug(f"聊天流{chat_id}不存在,错误") self.logger.debug(f"聊天流{chat_id}不存在,错误")
continue continue
basic_willing = self.chat_reply_willing[chat_id] basic_willing = self.chat_reply_willing[chat_id]
person_willing[person_id] = basic_willing + (willing - basic_willing) * self.intention_decay_rate person_willing[person_id] = (
basic_willing + (willing - basic_willing) * self.intention_decay_rate
)
def setup(self, message, chat, is_mentioned_bot, interested_rate): def setup(self, message, chat, is_mentioned_bot, interested_rate):
super().setup(message, chat, is_mentioned_bot, interested_rate) super().setup(message, chat, is_mentioned_bot, interested_rate)
self.chat_reply_willing[chat.stream_id] = self.chat_reply_willing.get(chat.stream_id, self.basic_maximum_willing) self.chat_reply_willing[chat.stream_id] = self.chat_reply_willing.get(
chat.stream_id, self.basic_maximum_willing
)
self.chat_person_reply_willing[chat.stream_id] = self.chat_person_reply_willing.get(chat.stream_id, {}) self.chat_person_reply_willing[chat.stream_id] = self.chat_person_reply_willing.get(chat.stream_id, {})
self.chat_person_reply_willing[chat.stream_id][self.ongoing_messages[message.message_info.message_id].person_id] = \ self.chat_person_reply_willing[chat.stream_id][
self.chat_person_reply_willing[chat.stream_id].get(self.ongoing_messages[message.message_info.message_id].person_id, self.ongoing_messages[message.message_info.message_id].person_id
self.chat_reply_willing[chat.stream_id]) ] = self.chat_person_reply_willing[chat.stream_id].get(
self.ongoing_messages[message.message_info.message_id].person_id, self.chat_reply_willing[chat.stream_id]
)
if chat.stream_id not in self.chat_new_message_time: if chat.stream_id not in self.chat_new_message_time:
self.chat_new_message_time[chat.stream_id] = [] self.chat_new_message_time[chat.stream_id] = []
@@ -166,7 +182,7 @@ class MxpWillingManager(BaseWillingManager):
else: else:
probability = math.atan(willing * 4) / math.pi * 2 probability = math.atan(willing * 4) / math.pi * 2
return probability return probability
async def _chat_new_message_to_change_basic_willing(self): async def _chat_new_message_to_change_basic_willing(self):
"""聊天流新消息改变基础意愿""" """聊天流新消息改变基础意愿"""
while True: while True:
@@ -174,10 +190,11 @@ class MxpWillingManager(BaseWillingManager):
await asyncio.sleep(update_time) await asyncio.sleep(update_time)
async with self.lock: async with self.lock:
for chat_id, message_times in self.chat_new_message_time.items(): for chat_id, message_times in self.chat_new_message_time.items():
# 清理过期消息 # 清理过期消息
current_time = time.time() current_time = time.time()
message_times = [msg_time for msg_time in message_times if current_time - msg_time < self.message_expiration_time] message_times = [
msg_time for msg_time in message_times if current_time - msg_time < self.message_expiration_time
]
self.chat_new_message_time[chat_id] = message_times self.chat_new_message_time[chat_id] = message_times
if len(message_times) < self.number_of_message_storage: if len(message_times) < self.number_of_message_storage:
@@ -185,7 +202,9 @@ class MxpWillingManager(BaseWillingManager):
update_time = 20 update_time = 20
elif len(message_times) == self.number_of_message_storage: elif len(message_times) == self.number_of_message_storage:
time_interval = current_time - message_times[0] time_interval = current_time - message_times[0]
basic_willing = self.basic_maximum_willing * math.sqrt(time_interval / self.message_expiration_time) basic_willing = self.basic_maximum_willing * math.sqrt(
time_interval / self.message_expiration_time
)
self.chat_reply_willing[chat_id] = basic_willing self.chat_reply_willing[chat_id] = basic_willing
update_time = 17 * math.sqrt(time_interval / self.message_expiration_time) + 3 update_time = 17 * math.sqrt(time_interval / self.message_expiration_time) + 3
else: else:
@@ -203,7 +222,7 @@ class MxpWillingManager(BaseWillingManager):
"interest_willing_gain": "兴趣意愿增益", "interest_willing_gain": "兴趣意愿增益",
"emoji_response_penalty": "表情包回复惩罚", "emoji_response_penalty": "表情包回复惩罚",
"down_frequency_rate": "降低回复频率的群组惩罚系数", "down_frequency_rate": "降低回复频率的群组惩罚系数",
"single_chat_gain": "单聊增益(不仅是私聊)" "single_chat_gain": "单聊增益(不仅是私聊)",
} }
async def set_variable_parameters(self, parameters: Dict[str, any]): async def set_variable_parameters(self, parameters: Dict[str, any]):
@@ -215,7 +234,7 @@ class MxpWillingManager(BaseWillingManager):
self.logger.debug(f"参数 {key} 已更新为 {value}") self.logger.debug(f"参数 {key} 已更新为 {value}")
else: else:
self.logger.debug(f"尝试设置未知参数 {key}") self.logger.debug(f"尝试设置未知参数 {key}")
def _get_relationship_level_num(self, relationship_value) -> int: def _get_relationship_level_num(self, relationship_value) -> int:
"""关系等级计算""" """关系等级计算"""
if -1000 <= relationship_value < -227: if -1000 <= relationship_value < -227:
@@ -235,4 +254,4 @@ class MxpWillingManager(BaseWillingManager):
return level_num - 2 return level_num - 2
async def get_willing(self, chat_id): async def get_willing(self, chat_id):
return self.temporary_willing return self.temporary_willing

View File

@@ -1,4 +1,3 @@
from src.common.logger import LogConfig, WILLING_STYLE_CONFIG, LoguruLogger, get_module_logger from src.common.logger import LogConfig, WILLING_STYLE_CONFIG, LoguruLogger, get_module_logger
from dataclasses import dataclass from dataclasses import dataclass
from ..config.config import global_config, BotConfig from ..config.config import global_config, BotConfig
@@ -38,10 +37,11 @@ willing_config = LogConfig(
) )
logger = get_module_logger("willing", config=willing_config) logger = get_module_logger("willing", config=willing_config)
@dataclass @dataclass
class WillingInfo: class WillingInfo:
"""此类保存意愿模块常用的参数 """此类保存意愿模块常用的参数
Attributes: Attributes:
message (MessageRecv): 原始消息对象 message (MessageRecv): 原始消息对象
chat (ChatStream): 聊天流对象 chat (ChatStream): 聊天流对象
@@ -53,6 +53,7 @@ class WillingInfo:
is_emoji (bool): 是否为表情包 is_emoji (bool): 是否为表情包
interested_rate (float): 兴趣度 interested_rate (float): 兴趣度
""" """
message: MessageRecv message: MessageRecv
chat: ChatStream chat: ChatStream
person_info_manager: PersonInfoManager person_info_manager: PersonInfoManager
@@ -60,22 +61,21 @@ class WillingInfo:
person_id: str person_id: str
group_info: Optional[GroupInfo] group_info: Optional[GroupInfo]
is_mentioned_bot: bool is_mentioned_bot: bool
is_emoji: bool is_emoji: bool
interested_rate: float interested_rate: float
# current_mood: float 当前心情? # current_mood: float 当前心情?
class BaseWillingManager(ABC): class BaseWillingManager(ABC):
"""回复意愿管理基类""" """回复意愿管理基类"""
@classmethod @classmethod
def create(cls, manager_type: str) -> 'BaseWillingManager': def create(cls, manager_type: str) -> "BaseWillingManager":
try: try:
module = importlib.import_module(f".mode_{manager_type}", __package__) module = importlib.import_module(f".mode_{manager_type}", __package__)
manager_class = getattr(module, f"{manager_type.capitalize()}WillingManager") manager_class = getattr(module, f"{manager_type.capitalize()}WillingManager")
if not issubclass(manager_class, cls): if not issubclass(manager_class, cls):
raise TypeError( raise TypeError(f"Manager class {manager_class.__name__} is not a subclass of {cls.__name__}")
f"Manager class {manager_class.__name__} is not a subclass of {cls.__name__}"
)
else: else:
logger.info(f"成功载入willing模式{manager_type}") logger.info(f"成功载入willing模式{manager_type}")
return manager_class() return manager_class()
@@ -85,7 +85,7 @@ class BaseWillingManager(ABC):
logger.info(f"载入当前意愿模式{manager_type}失败,使用经典配方~~~~") logger.info(f"载入当前意愿模式{manager_type}失败,使用经典配方~~~~")
logger.debug(f"加载willing模式{manager_type}失败,原因: {str(e)}") logger.debug(f"加载willing模式{manager_type}失败,原因: {str(e)}")
return manager_class() return manager_class()
def __init__(self): def __init__(self):
self.chat_reply_willing: Dict[str, float] = {} # 存储每个聊天流的回复意愿(chat_id) self.chat_reply_willing: Dict[str, float] = {} # 存储每个聊天流的回复意愿(chat_id)
self.ongoing_messages: Dict[str, WillingInfo] = {} # 当前正在进行的消息(message_id) self.ongoing_messages: Dict[str, WillingInfo] = {} # 当前正在进行的消息(message_id)
@@ -136,17 +136,17 @@ class BaseWillingManager(ABC):
async def get_reply_probability(self, message_id: str): async def get_reply_probability(self, message_id: str):
"""抽象方法:获取回复概率""" """抽象方法:获取回复概率"""
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
async def bombing_buffer_message_handle(self, message_id: str): async def bombing_buffer_message_handle(self, message_id: str):
"""抽象方法:炸飞消息处理""" """抽象方法:炸飞消息处理"""
pass pass
async def get_willing(self, chat_id: str): async def get_willing(self, chat_id: str):
"""获取指定聊天流的回复意愿""" """获取指定聊天流的回复意愿"""
async with self.lock: async with self.lock:
return self.chat_reply_willing.get(chat_id, 0) return self.chat_reply_willing.get(chat_id, 0)
async def set_willing(self, chat_id: str, willing: float): async def set_willing(self, chat_id: str, willing: float):
"""设置指定聊天流的回复意愿""" """设置指定聊天流的回复意愿"""
async with self.lock: async with self.lock:
@@ -173,5 +173,6 @@ def init_willing_manager() -> BaseWillingManager:
mode = global_config.willing_mode.lower() mode = global_config.willing_mode.lower()
return BaseWillingManager.create(mode) return BaseWillingManager.create(mode)
# 全局willing_manager对象 # 全局willing_manager对象
willing_manager = init_willing_manager() willing_manager = init_willing_manager()