Merge branch 'dev' of https://github.com/MaiM-with-u/MaiBot into dev
This commit is contained in:
12
.github/workflows/ruff.yml
vendored
12
.github/workflows/ruff.yml
vendored
@@ -1,5 +1,9 @@
|
|||||||
name: Ruff
|
name: Ruff
|
||||||
on: [ push, pull_request ]
|
on: [ push, pull_request ]
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: write
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
ruff:
|
ruff:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
@@ -8,4 +12,12 @@ jobs:
|
|||||||
- uses: astral-sh/ruff-action@v3
|
- uses: astral-sh/ruff-action@v3
|
||||||
- run: ruff check --fix
|
- run: ruff check --fix
|
||||||
- run: ruff format
|
- run: ruff format
|
||||||
|
- name: Commit changes
|
||||||
|
if: success()
|
||||||
|
run: |
|
||||||
|
git config --local user.email "github-actions[bot]@users.noreply.github.com"
|
||||||
|
git config --local user.name "github-actions[bot]"
|
||||||
|
git add -A
|
||||||
|
git diff --quiet && git diff --staged --quiet || git commit -m "🤖 自动格式化代码 [skip ci]"
|
||||||
|
git push
|
||||||
|
|
||||||
|
|||||||
@@ -283,17 +283,13 @@ WILLING_STYLE_CONFIG = {
|
|||||||
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 意愿 | {message}"),
|
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 意愿 | {message}"),
|
||||||
},
|
},
|
||||||
"simple": {
|
"simple": {
|
||||||
"console_format": (
|
"console_format": ("<green>{time:MM-DD HH:mm}</green> | <light-blue>意愿</light-blue> | {message}"), # noqa: E501
|
||||||
"<green>{time:MM-DD HH:mm}</green> | <light-blue>意愿</light-blue> | {message}"
|
|
||||||
), # noqa: E501
|
|
||||||
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 意愿 | {message}"),
|
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 意愿 | {message}"),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
CONFIRM_STYLE_CONFIG = {
|
CONFIRM_STYLE_CONFIG = {
|
||||||
"console_format": (
|
"console_format": ("<RED>{message}</RED>"), # noqa: E501
|
||||||
"<RED>{message}</RED>"
|
|
||||||
), # noqa: E501
|
|
||||||
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | EULA与PRIVACY确认 | {message}"),
|
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | EULA与PRIVACY确认 | {message}"),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -4,17 +4,17 @@ from src.do_tool.tool_can_use.base_tool import (
|
|||||||
discover_tools,
|
discover_tools,
|
||||||
get_all_tool_definitions,
|
get_all_tool_definitions,
|
||||||
get_tool_instance,
|
get_tool_instance,
|
||||||
TOOL_REGISTRY
|
TOOL_REGISTRY,
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'BaseTool',
|
"BaseTool",
|
||||||
'register_tool',
|
"register_tool",
|
||||||
'discover_tools',
|
"discover_tools",
|
||||||
'get_all_tool_definitions',
|
"get_all_tool_definitions",
|
||||||
'get_tool_instance',
|
"get_tool_instance",
|
||||||
'TOOL_REGISTRY'
|
"TOOL_REGISTRY",
|
||||||
]
|
]
|
||||||
|
|
||||||
# 自动发现并注册工具
|
# 自动发现并注册工具
|
||||||
discover_tools()
|
discover_tools()
|
||||||
|
|||||||
@@ -10,41 +10,39 @@ logger = get_module_logger("base_tool")
|
|||||||
# 工具注册表
|
# 工具注册表
|
||||||
TOOL_REGISTRY = {}
|
TOOL_REGISTRY = {}
|
||||||
|
|
||||||
|
|
||||||
class BaseTool:
|
class BaseTool:
|
||||||
"""所有工具的基类"""
|
"""所有工具的基类"""
|
||||||
|
|
||||||
# 工具名称,子类必须重写
|
# 工具名称,子类必须重写
|
||||||
name = None
|
name = None
|
||||||
# 工具描述,子类必须重写
|
# 工具描述,子类必须重写
|
||||||
description = None
|
description = None
|
||||||
# 工具参数定义,子类必须重写
|
# 工具参数定义,子类必须重写
|
||||||
parameters = None
|
parameters = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_tool_definition(cls) -> Dict[str, Any]:
|
def get_tool_definition(cls) -> Dict[str, Any]:
|
||||||
"""获取工具定义,用于LLM工具调用
|
"""获取工具定义,用于LLM工具调用
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict: 工具定义字典
|
Dict: 工具定义字典
|
||||||
"""
|
"""
|
||||||
if not cls.name or not cls.description or not cls.parameters:
|
if not cls.name or not cls.description or not cls.parameters:
|
||||||
raise NotImplementedError(f"工具类 {cls.__name__} 必须定义 name, description 和 parameters 属性")
|
raise NotImplementedError(f"工具类 {cls.__name__} 必须定义 name, description 和 parameters 属性")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"type": "function",
|
"type": "function",
|
||||||
"function": {
|
"function": {"name": cls.name, "description": cls.description, "parameters": cls.parameters},
|
||||||
"name": cls.name,
|
|
||||||
"description": cls.description,
|
|
||||||
"parameters": cls.parameters
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async def execute(self, function_args: Dict[str, Any], message_txt: str = "") -> Dict[str, Any]:
|
async def execute(self, function_args: Dict[str, Any], message_txt: str = "") -> Dict[str, Any]:
|
||||||
"""执行工具函数
|
"""执行工具函数
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
function_args: 工具调用参数
|
function_args: 工具调用参数
|
||||||
message_txt: 原始消息文本
|
message_txt: 原始消息文本
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict: 工具执行结果
|
Dict: 工具执行结果
|
||||||
"""
|
"""
|
||||||
@@ -53,17 +51,17 @@ class BaseTool:
|
|||||||
|
|
||||||
def register_tool(tool_class: Type[BaseTool]):
|
def register_tool(tool_class: Type[BaseTool]):
|
||||||
"""注册工具到全局注册表
|
"""注册工具到全局注册表
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tool_class: 工具类
|
tool_class: 工具类
|
||||||
"""
|
"""
|
||||||
if not issubclass(tool_class, BaseTool):
|
if not issubclass(tool_class, BaseTool):
|
||||||
raise TypeError(f"{tool_class.__name__} 不是 BaseTool 的子类")
|
raise TypeError(f"{tool_class.__name__} 不是 BaseTool 的子类")
|
||||||
|
|
||||||
tool_name = tool_class.name
|
tool_name = tool_class.name
|
||||||
if not tool_name:
|
if not tool_name:
|
||||||
raise ValueError(f"工具类 {tool_class.__name__} 没有定义 name 属性")
|
raise ValueError(f"工具类 {tool_class.__name__} 没有定义 name 属性")
|
||||||
|
|
||||||
TOOL_REGISTRY[tool_name] = tool_class
|
TOOL_REGISTRY[tool_name] = tool_class
|
||||||
logger.info(f"已注册工具: {tool_name}")
|
logger.info(f"已注册工具: {tool_name}")
|
||||||
|
|
||||||
@@ -73,27 +71,27 @@ def discover_tools():
|
|||||||
# 获取当前目录路径
|
# 获取当前目录路径
|
||||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
package_name = os.path.basename(current_dir)
|
package_name = os.path.basename(current_dir)
|
||||||
|
|
||||||
# 遍历包中的所有模块
|
# 遍历包中的所有模块
|
||||||
for _, module_name, _ in pkgutil.iter_modules([current_dir]):
|
for _, module_name, _ in pkgutil.iter_modules([current_dir]):
|
||||||
# 跳过当前模块和__pycache__
|
# 跳过当前模块和__pycache__
|
||||||
if module_name == "base_tool" or module_name.startswith("__"):
|
if module_name == "base_tool" or module_name.startswith("__"):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 导入模块
|
# 导入模块
|
||||||
module = importlib.import_module(f"src.do_tool.{package_name}.{module_name}")
|
module = importlib.import_module(f"src.do_tool.{package_name}.{module_name}")
|
||||||
|
|
||||||
# 查找模块中的工具类
|
# 查找模块中的工具类
|
||||||
for _, obj in inspect.getmembers(module):
|
for _, obj in inspect.getmembers(module):
|
||||||
if inspect.isclass(obj) and issubclass(obj, BaseTool) and obj != BaseTool:
|
if inspect.isclass(obj) and issubclass(obj, BaseTool) and obj != BaseTool:
|
||||||
register_tool(obj)
|
register_tool(obj)
|
||||||
|
|
||||||
logger.info(f"工具发现完成,共注册 {len(TOOL_REGISTRY)} 个工具")
|
logger.info(f"工具发现完成,共注册 {len(TOOL_REGISTRY)} 个工具")
|
||||||
|
|
||||||
|
|
||||||
def get_all_tool_definitions() -> List[Dict[str, Any]]:
|
def get_all_tool_definitions() -> List[Dict[str, Any]]:
|
||||||
"""获取所有已注册工具的定义
|
"""获取所有已注册工具的定义
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[Dict]: 工具定义列表
|
List[Dict]: 工具定义列表
|
||||||
"""
|
"""
|
||||||
@@ -102,14 +100,14 @@ def get_all_tool_definitions() -> List[Dict[str, Any]]:
|
|||||||
|
|
||||||
def get_tool_instance(tool_name: str) -> Optional[BaseTool]:
|
def get_tool_instance(tool_name: str) -> Optional[BaseTool]:
|
||||||
"""获取指定名称的工具实例
|
"""获取指定名称的工具实例
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tool_name: 工具名称
|
tool_name: 工具名称
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Optional[BaseTool]: 工具实例,如果找不到则返回None
|
Optional[BaseTool]: 工具实例,如果找不到则返回None
|
||||||
"""
|
"""
|
||||||
tool_class = TOOL_REGISTRY.get(tool_name)
|
tool_class = TOOL_REGISTRY.get(tool_name)
|
||||||
if not tool_class:
|
if not tool_class:
|
||||||
return None
|
return None
|
||||||
return tool_class()
|
return tool_class()
|
||||||
|
|||||||
@@ -4,29 +4,25 @@ from typing import Dict, Any
|
|||||||
|
|
||||||
logger = get_module_logger("fibonacci_sequence_tool")
|
logger = get_module_logger("fibonacci_sequence_tool")
|
||||||
|
|
||||||
|
|
||||||
class FibonacciSequenceTool(BaseTool):
|
class FibonacciSequenceTool(BaseTool):
|
||||||
"""生成斐波那契数列的工具"""
|
"""生成斐波那契数列的工具"""
|
||||||
|
|
||||||
name = "fibonacci_sequence"
|
name = "fibonacci_sequence"
|
||||||
description = "生成指定长度的斐波那契数列"
|
description = "生成指定长度的斐波那契数列"
|
||||||
parameters = {
|
parameters = {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {"n": {"type": "integer", "description": "斐波那契数列的长度", "minimum": 1}},
|
||||||
"n": {
|
"required": ["n"],
|
||||||
"type": "integer",
|
|
||||||
"description": "斐波那契数列的长度",
|
|
||||||
"minimum": 1
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"required": ["n"]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async def execute(self, function_args: Dict[str, Any], message_txt: str = "") -> Dict[str, Any]:
|
async def execute(self, function_args: Dict[str, Any], message_txt: str = "") -> Dict[str, Any]:
|
||||||
"""执行工具功能
|
"""执行工具功能
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
function_args: 工具参数
|
function_args: 工具参数
|
||||||
message_txt: 原始消息文本
|
message_txt: 原始消息文本
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict: 工具执行结果
|
Dict: 工具执行结果
|
||||||
"""
|
"""
|
||||||
@@ -34,23 +30,18 @@ class FibonacciSequenceTool(BaseTool):
|
|||||||
n = function_args.get("n")
|
n = function_args.get("n")
|
||||||
if n <= 0:
|
if n <= 0:
|
||||||
raise ValueError("参数n必须大于0")
|
raise ValueError("参数n必须大于0")
|
||||||
|
|
||||||
sequence = []
|
sequence = []
|
||||||
a, b = 0, 1
|
a, b = 0, 1
|
||||||
for _ in range(n):
|
for _ in range(n):
|
||||||
sequence.append(a)
|
sequence.append(a)
|
||||||
a, b = b, a + b
|
a, b = b, a + b
|
||||||
|
|
||||||
return {
|
return {"name": self.name, "content": sequence}
|
||||||
"name": self.name,
|
|
||||||
"content": sequence
|
|
||||||
}
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"fibonacci_sequence工具执行失败: {str(e)}")
|
logger.error(f"fibonacci_sequence工具执行失败: {str(e)}")
|
||||||
return {
|
return {"name": self.name, "content": f"执行失败: {str(e)}"}
|
||||||
"name": self.name,
|
|
||||||
"content": f"执行失败: {str(e)}"
|
|
||||||
}
|
|
||||||
|
|
||||||
# 注册工具
|
# 注册工具
|
||||||
register_tool(FibonacciSequenceTool)
|
register_tool(FibonacciSequenceTool)
|
||||||
|
|||||||
@@ -4,8 +4,10 @@ from typing import Dict, Any
|
|||||||
|
|
||||||
logger = get_module_logger("generate_buddha_emoji_tool")
|
logger = get_module_logger("generate_buddha_emoji_tool")
|
||||||
|
|
||||||
|
|
||||||
class GenerateBuddhaEmojiTool(BaseTool):
|
class GenerateBuddhaEmojiTool(BaseTool):
|
||||||
"""生成佛祖颜文字的工具类"""
|
"""生成佛祖颜文字的工具类"""
|
||||||
|
|
||||||
name = "generate_buddha_emoji"
|
name = "generate_buddha_emoji"
|
||||||
description = "生成一个佛祖的颜文字表情"
|
description = "生成一个佛祖的颜文字表情"
|
||||||
parameters = {
|
parameters = {
|
||||||
@@ -13,32 +15,27 @@ class GenerateBuddhaEmojiTool(BaseTool):
|
|||||||
"properties": {
|
"properties": {
|
||||||
# 无参数
|
# 无参数
|
||||||
},
|
},
|
||||||
"required": []
|
"required": [],
|
||||||
}
|
}
|
||||||
|
|
||||||
async def execute(self, function_args: Dict[str, Any], message_txt: str = "") -> Dict[str, Any]:
|
async def execute(self, function_args: Dict[str, Any], message_txt: str = "") -> Dict[str, Any]:
|
||||||
"""执行工具功能,生成佛祖颜文字
|
"""执行工具功能,生成佛祖颜文字
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
function_args: 工具参数
|
function_args: 工具参数
|
||||||
message_txt: 原始消息文本
|
message_txt: 原始消息文本
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict: 工具执行结果
|
Dict: 工具执行结果
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
buddha_emoji = "这是一个佛祖emoji:༼ つ ◕_◕ ༽つ"
|
buddha_emoji = "这是一个佛祖emoji:༼ つ ◕_◕ ༽つ"
|
||||||
|
|
||||||
return {
|
return {"name": self.name, "content": buddha_emoji}
|
||||||
"name": self.name,
|
|
||||||
"content": buddha_emoji
|
|
||||||
}
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"generate_buddha_emoji工具执行失败: {str(e)}")
|
logger.error(f"generate_buddha_emoji工具执行失败: {str(e)}")
|
||||||
return {
|
return {"name": self.name, "content": f"执行失败: {str(e)}"}
|
||||||
"name": self.name,
|
|
||||||
"content": f"执行失败: {str(e)}"
|
|
||||||
}
|
|
||||||
|
|
||||||
# 注册工具
|
# 注册工具
|
||||||
register_tool(GenerateBuddhaEmojiTool)
|
register_tool(GenerateBuddhaEmojiTool)
|
||||||
|
|||||||
@@ -4,23 +4,21 @@ from typing import Dict, Any
|
|||||||
|
|
||||||
logger = get_module_logger("generate_cmd_tutorial_tool")
|
logger = get_module_logger("generate_cmd_tutorial_tool")
|
||||||
|
|
||||||
|
|
||||||
class GenerateCmdTutorialTool(BaseTool):
|
class GenerateCmdTutorialTool(BaseTool):
|
||||||
"""生成Windows CMD基本操作教程的工具"""
|
"""生成Windows CMD基本操作教程的工具"""
|
||||||
|
|
||||||
name = "generate_cmd_tutorial"
|
name = "generate_cmd_tutorial"
|
||||||
description = "生成关于Windows命令提示符(CMD)的基本操作教程,包括常用命令和使用方法"
|
description = "生成关于Windows命令提示符(CMD)的基本操作教程,包括常用命令和使用方法"
|
||||||
parameters = {
|
parameters = {"type": "object", "properties": {}, "required": []}
|
||||||
"type": "object",
|
|
||||||
"properties": {},
|
|
||||||
"required": []
|
|
||||||
}
|
|
||||||
|
|
||||||
async def execute(self, function_args: Dict[str, Any], message_txt: str = "") -> Dict[str, Any]:
|
async def execute(self, function_args: Dict[str, Any], message_txt: str = "") -> Dict[str, Any]:
|
||||||
"""执行工具功能
|
"""执行工具功能
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
function_args: 工具参数
|
function_args: 工具参数
|
||||||
message_txt: 原始消息文本
|
message_txt: 原始消息文本
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict: 工具执行结果
|
Dict: 工具执行结果
|
||||||
"""
|
"""
|
||||||
@@ -57,17 +55,12 @@ class GenerateCmdTutorialTool(BaseTool):
|
|||||||
|
|
||||||
注意:使用命令时要小心,特别是删除操作。
|
注意:使用命令时要小心,特别是删除操作。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return {
|
return {"name": self.name, "content": tutorial_content}
|
||||||
"name": self.name,
|
|
||||||
"content": tutorial_content
|
|
||||||
}
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"generate_cmd_tutorial工具执行失败: {str(e)}")
|
logger.error(f"generate_cmd_tutorial工具执行失败: {str(e)}")
|
||||||
return {
|
return {"name": self.name, "content": f"执行失败: {str(e)}"}
|
||||||
"name": self.name,
|
|
||||||
"content": f"执行失败: {str(e)}"
|
|
||||||
}
|
|
||||||
|
|
||||||
# 注册工具
|
# 注册工具
|
||||||
register_tool(GenerateCmdTutorialTool)
|
register_tool(GenerateCmdTutorialTool)
|
||||||
|
|||||||
@@ -5,32 +5,28 @@ from typing import Dict, Any
|
|||||||
|
|
||||||
logger = get_module_logger("get_current_task_tool")
|
logger = get_module_logger("get_current_task_tool")
|
||||||
|
|
||||||
|
|
||||||
class GetCurrentTaskTool(BaseTool):
|
class GetCurrentTaskTool(BaseTool):
|
||||||
"""获取当前正在做的事情/最近的任务工具"""
|
"""获取当前正在做的事情/最近的任务工具"""
|
||||||
|
|
||||||
name = "get_current_task"
|
name = "get_current_task"
|
||||||
description = "获取当前正在做的事情/最近的任务"
|
description = "获取当前正在做的事情/最近的任务"
|
||||||
parameters = {
|
parameters = {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"num": {
|
"num": {"type": "integer", "description": "要获取的任务数量"},
|
||||||
"type": "integer",
|
"time_info": {"type": "boolean", "description": "是否包含时间信息"},
|
||||||
"description": "要获取的任务数量"
|
|
||||||
},
|
|
||||||
"time_info": {
|
|
||||||
"type": "boolean",
|
|
||||||
"description": "是否包含时间信息"
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
"required": []
|
"required": [],
|
||||||
}
|
}
|
||||||
|
|
||||||
async def execute(self, function_args: Dict[str, Any], message_txt: str = "") -> Dict[str, Any]:
|
async def execute(self, function_args: Dict[str, Any], message_txt: str = "") -> Dict[str, Any]:
|
||||||
"""执行获取当前任务
|
"""执行获取当前任务
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
function_args: 工具参数
|
function_args: 工具参数
|
||||||
message_txt: 原始消息文本,此工具不使用
|
message_txt: 原始消息文本,此工具不使用
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict: 工具执行结果
|
Dict: 工具执行结果
|
||||||
"""
|
"""
|
||||||
@@ -38,26 +34,21 @@ class GetCurrentTaskTool(BaseTool):
|
|||||||
# 获取参数,如果没有提供则使用默认值
|
# 获取参数,如果没有提供则使用默认值
|
||||||
num = function_args.get("num", 1)
|
num = function_args.get("num", 1)
|
||||||
time_info = function_args.get("time_info", False)
|
time_info = function_args.get("time_info", False)
|
||||||
|
|
||||||
# 调用日程系统获取当前任务
|
# 调用日程系统获取当前任务
|
||||||
current_task = bot_schedule.get_current_num_task(num=num, time_info=time_info)
|
current_task = bot_schedule.get_current_num_task(num=num, time_info=time_info)
|
||||||
|
|
||||||
# 格式化返回结果
|
# 格式化返回结果
|
||||||
if current_task:
|
if current_task:
|
||||||
task_info = current_task
|
task_info = current_task
|
||||||
else:
|
else:
|
||||||
task_info = "当前没有正在进行的任务"
|
task_info = "当前没有正在进行的任务"
|
||||||
|
|
||||||
return {
|
return {"name": "get_current_task", "content": f"当前任务信息: {task_info}"}
|
||||||
"name": "get_current_task",
|
|
||||||
"content": f"当前任务信息: {task_info}"
|
|
||||||
}
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"获取当前任务工具执行失败: {str(e)}")
|
logger.error(f"获取当前任务工具执行失败: {str(e)}")
|
||||||
return {
|
return {"name": "get_current_task", "content": f"获取当前任务失败: {str(e)}"}
|
||||||
"name": "get_current_task",
|
|
||||||
"content": f"获取当前任务失败: {str(e)}"
|
|
||||||
}
|
|
||||||
|
|
||||||
# 注册工具
|
# 注册工具
|
||||||
register_tool(GetCurrentTaskTool)
|
register_tool(GetCurrentTaskTool)
|
||||||
|
|||||||
@@ -6,39 +6,35 @@ from typing import Dict, Any, Union
|
|||||||
|
|
||||||
logger = get_module_logger("get_knowledge_tool")
|
logger = get_module_logger("get_knowledge_tool")
|
||||||
|
|
||||||
|
|
||||||
class SearchKnowledgeTool(BaseTool):
|
class SearchKnowledgeTool(BaseTool):
|
||||||
"""从知识库中搜索相关信息的工具"""
|
"""从知识库中搜索相关信息的工具"""
|
||||||
|
|
||||||
name = "search_knowledge"
|
name = "search_knowledge"
|
||||||
description = "从知识库中搜索相关信息"
|
description = "从知识库中搜索相关信息"
|
||||||
parameters = {
|
parameters = {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"query": {
|
"query": {"type": "string", "description": "搜索查询关键词"},
|
||||||
"type": "string",
|
"threshold": {"type": "number", "description": "相似度阈值,0.0到1.0之间"},
|
||||||
"description": "搜索查询关键词"
|
|
||||||
},
|
|
||||||
"threshold": {
|
|
||||||
"type": "number",
|
|
||||||
"description": "相似度阈值,0.0到1.0之间"
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
"required": ["query"]
|
"required": ["query"],
|
||||||
}
|
}
|
||||||
|
|
||||||
async def execute(self, function_args: Dict[str, Any], message_txt: str = "") -> Dict[str, Any]:
|
async def execute(self, function_args: Dict[str, Any], message_txt: str = "") -> Dict[str, Any]:
|
||||||
"""执行知识库搜索
|
"""执行知识库搜索
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
function_args: 工具参数
|
function_args: 工具参数
|
||||||
message_txt: 原始消息文本
|
message_txt: 原始消息文本
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict: 工具执行结果
|
Dict: 工具执行结果
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
query = function_args.get("query", message_txt)
|
query = function_args.get("query", message_txt)
|
||||||
threshold = function_args.get("threshold", 0.4)
|
threshold = function_args.get("threshold", 0.4)
|
||||||
|
|
||||||
# 调用知识库搜索
|
# 调用知识库搜索
|
||||||
embedding = await get_embedding(query, request_type="info_retrieval")
|
embedding = await get_embedding(query, request_type="info_retrieval")
|
||||||
if embedding:
|
if embedding:
|
||||||
@@ -47,38 +43,29 @@ class SearchKnowledgeTool(BaseTool):
|
|||||||
content = f"你知道这些知识: {knowledge_info}"
|
content = f"你知道这些知识: {knowledge_info}"
|
||||||
else:
|
else:
|
||||||
content = f"你不太了解有关{query}的知识"
|
content = f"你不太了解有关{query}的知识"
|
||||||
return {
|
return {"name": "search_knowledge", "content": content}
|
||||||
"name": "search_knowledge",
|
return {"name": "search_knowledge", "content": f"无法获取关于'{query}'的嵌入向量"}
|
||||||
"content": content
|
|
||||||
}
|
|
||||||
return {
|
|
||||||
"name": "search_knowledge",
|
|
||||||
"content": f"无法获取关于'{query}'的嵌入向量"
|
|
||||||
}
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"知识库搜索工具执行失败: {str(e)}")
|
logger.error(f"知识库搜索工具执行失败: {str(e)}")
|
||||||
return {
|
return {"name": "search_knowledge", "content": f"知识库搜索失败: {str(e)}"}
|
||||||
"name": "search_knowledge",
|
|
||||||
"content": f"知识库搜索失败: {str(e)}"
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_info_from_db(
|
def get_info_from_db(
|
||||||
self, query_embedding: list, limit: int = 1, threshold: float = 0.5, return_raw: bool = False
|
self, query_embedding: list, limit: int = 1, threshold: float = 0.5, return_raw: bool = False
|
||||||
) -> Union[str, list]:
|
) -> Union[str, list]:
|
||||||
"""从数据库中获取相关信息
|
"""从数据库中获取相关信息
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query_embedding: 查询的嵌入向量
|
query_embedding: 查询的嵌入向量
|
||||||
limit: 最大返回结果数
|
limit: 最大返回结果数
|
||||||
threshold: 相似度阈值
|
threshold: 相似度阈值
|
||||||
return_raw: 是否返回原始结果
|
return_raw: 是否返回原始结果
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Union[str, list]: 格式化的信息字符串或原始结果列表
|
Union[str, list]: 格式化的信息字符串或原始结果列表
|
||||||
"""
|
"""
|
||||||
if not query_embedding:
|
if not query_embedding:
|
||||||
return "" if not return_raw else []
|
return "" if not return_raw else []
|
||||||
|
|
||||||
# 使用余弦相似度计算
|
# 使用余弦相似度计算
|
||||||
pipeline = [
|
pipeline = [
|
||||||
{
|
{
|
||||||
@@ -143,5 +130,6 @@ class SearchKnowledgeTool(BaseTool):
|
|||||||
# 返回所有找到的内容,用换行分隔
|
# 返回所有找到的内容,用换行分隔
|
||||||
return "\n".join(str(result["content"]) for result in results)
|
return "\n".join(str(result["content"]) for result in results)
|
||||||
|
|
||||||
|
|
||||||
# 注册工具
|
# 注册工具
|
||||||
register_tool(SearchKnowledgeTool)
|
register_tool(SearchKnowledgeTool)
|
||||||
|
|||||||
@@ -5,68 +5,55 @@ from typing import Dict, Any
|
|||||||
|
|
||||||
logger = get_module_logger("get_memory_tool")
|
logger = get_module_logger("get_memory_tool")
|
||||||
|
|
||||||
|
|
||||||
class GetMemoryTool(BaseTool):
|
class GetMemoryTool(BaseTool):
|
||||||
"""从记忆系统中获取相关记忆的工具"""
|
"""从记忆系统中获取相关记忆的工具"""
|
||||||
|
|
||||||
name = "get_memory"
|
name = "get_memory"
|
||||||
description = "从记忆系统中获取相关记忆"
|
description = "从记忆系统中获取相关记忆"
|
||||||
parameters = {
|
parameters = {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"text": {
|
"text": {"type": "string", "description": "要查询的相关文本"},
|
||||||
"type": "string",
|
"max_memory_num": {"type": "integer", "description": "最大返回记忆数量"},
|
||||||
"description": "要查询的相关文本"
|
|
||||||
},
|
|
||||||
"max_memory_num": {
|
|
||||||
"type": "integer",
|
|
||||||
"description": "最大返回记忆数量"
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
"required": ["text"]
|
"required": ["text"],
|
||||||
}
|
}
|
||||||
|
|
||||||
async def execute(self, function_args: Dict[str, Any], message_txt: str = "") -> Dict[str, Any]:
|
async def execute(self, function_args: Dict[str, Any], message_txt: str = "") -> Dict[str, Any]:
|
||||||
"""执行记忆获取
|
"""执行记忆获取
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
function_args: 工具参数
|
function_args: 工具参数
|
||||||
message_txt: 原始消息文本
|
message_txt: 原始消息文本
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict: 工具执行结果
|
Dict: 工具执行结果
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
text = function_args.get("text", message_txt)
|
text = function_args.get("text", message_txt)
|
||||||
max_memory_num = function_args.get("max_memory_num", 2)
|
max_memory_num = function_args.get("max_memory_num", 2)
|
||||||
|
|
||||||
# 调用记忆系统
|
# 调用记忆系统
|
||||||
related_memory = await HippocampusManager.get_instance().get_memory_from_text(
|
related_memory = await HippocampusManager.get_instance().get_memory_from_text(
|
||||||
text=text,
|
text=text, max_memory_num=max_memory_num, max_memory_length=2, max_depth=3, fast_retrieval=False
|
||||||
max_memory_num=max_memory_num,
|
|
||||||
max_memory_length=2,
|
|
||||||
max_depth=3,
|
|
||||||
fast_retrieval=False
|
|
||||||
)
|
)
|
||||||
|
|
||||||
memory_info = ""
|
memory_info = ""
|
||||||
if related_memory:
|
if related_memory:
|
||||||
for memory in related_memory:
|
for memory in related_memory:
|
||||||
memory_info += memory[1] + "\n"
|
memory_info += memory[1] + "\n"
|
||||||
|
|
||||||
if memory_info:
|
if memory_info:
|
||||||
content = f"你记得这些事情: {memory_info}"
|
content = f"你记得这些事情: {memory_info}"
|
||||||
else:
|
else:
|
||||||
content = f"你不太记得有关{text}的记忆,你对此不太了解"
|
content = f"你不太记得有关{text}的记忆,你对此不太了解"
|
||||||
|
|
||||||
return {
|
return {"name": "get_memory", "content": content}
|
||||||
"name": "get_memory",
|
|
||||||
"content": content
|
|
||||||
}
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"记忆获取工具执行失败: {str(e)}")
|
logger.error(f"记忆获取工具执行失败: {str(e)}")
|
||||||
return {
|
return {"name": "get_memory", "content": f"记忆获取失败: {str(e)}"}
|
||||||
"name": "get_memory",
|
|
||||||
"content": f"记忆获取失败: {str(e)}"
|
|
||||||
}
|
|
||||||
|
|
||||||
# 注册工具
|
# 注册工具
|
||||||
register_tool(GetMemoryTool)
|
register_tool(GetMemoryTool)
|
||||||
|
|||||||
@@ -16,21 +16,19 @@ class ToolUser:
|
|||||||
model=global_config.llm_heartflow, temperature=0.2, max_tokens=1000, request_type="tool_use"
|
model=global_config.llm_heartflow, temperature=0.2, max_tokens=1000, request_type="tool_use"
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _build_tool_prompt(self, message_txt:str, sender_name:str, chat_stream:ChatStream):
|
async def _build_tool_prompt(self, message_txt: str, sender_name: str, chat_stream: ChatStream):
|
||||||
"""构建工具使用的提示词
|
"""构建工具使用的提示词
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
message_txt: 用户消息文本
|
message_txt: 用户消息文本
|
||||||
sender_name: 发送者名称
|
sender_name: 发送者名称
|
||||||
chat_stream: 聊天流对象
|
chat_stream: 聊天流对象
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: 构建好的提示词
|
str: 构建好的提示词
|
||||||
"""
|
"""
|
||||||
new_messages = list(
|
new_messages = list(
|
||||||
db.messages.find({"chat_id": chat_stream.stream_id, "time": {"$gt": time.time()}})
|
db.messages.find({"chat_id": chat_stream.stream_id, "time": {"$gt": time.time()}}).sort("time", 1).limit(15)
|
||||||
.sort("time", 1)
|
|
||||||
.limit(15)
|
|
||||||
)
|
)
|
||||||
new_messages_str = ""
|
new_messages_str = ""
|
||||||
for msg in new_messages:
|
for msg in new_messages:
|
||||||
@@ -43,38 +41,38 @@ class ToolUser:
|
|||||||
prompt += "你正在思考如何回复群里的消息。\n"
|
prompt += "你正在思考如何回复群里的消息。\n"
|
||||||
prompt += f"你注意到{sender_name}刚刚说:{message_txt}\n"
|
prompt += f"你注意到{sender_name}刚刚说:{message_txt}\n"
|
||||||
prompt += f"注意你就是{bot_name},{bot_name}指的就是你。"
|
prompt += f"注意你就是{bot_name},{bot_name}指的就是你。"
|
||||||
|
|
||||||
prompt += "你现在需要对群里的聊天内容进行回复,现在请你思考,你是否需要额外的信息,或者一些工具来帮你回复,不要使用危险功能(比如文件操作或者系统操作爬虫),比如回忆或者搜寻已有的知识,或者了解你现在正在做什么,请输出你需要的工具,或者你需要的额外信息。"
|
prompt += "你现在需要对群里的聊天内容进行回复,现在请你思考,你是否需要额外的信息,或者一些工具来帮你回复,不要使用危险功能(比如文件操作或者系统操作爬虫),比如回忆或者搜寻已有的知识,或者了解你现在正在做什么,请输出你需要的工具,或者你需要的额外信息。"
|
||||||
|
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
def _define_tools(self):
|
def _define_tools(self):
|
||||||
"""获取所有已注册工具的定义
|
"""获取所有已注册工具的定义
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
list: 工具定义列表
|
list: 工具定义列表
|
||||||
"""
|
"""
|
||||||
return get_all_tool_definitions()
|
return get_all_tool_definitions()
|
||||||
|
|
||||||
async def _execute_tool_call(self, tool_call, message_txt:str):
|
async def _execute_tool_call(self, tool_call, message_txt: str):
|
||||||
"""执行特定的工具调用
|
"""执行特定的工具调用
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tool_call: 工具调用对象
|
tool_call: 工具调用对象
|
||||||
message_txt: 原始消息文本
|
message_txt: 原始消息文本
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict: 工具调用结果
|
dict: 工具调用结果
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
function_name = tool_call["function"]["name"]
|
function_name = tool_call["function"]["name"]
|
||||||
function_args = json.loads(tool_call["function"]["arguments"])
|
function_args = json.loads(tool_call["function"]["arguments"])
|
||||||
|
|
||||||
# 获取对应工具实例
|
# 获取对应工具实例
|
||||||
tool_instance = get_tool_instance(function_name)
|
tool_instance = get_tool_instance(function_name)
|
||||||
if not tool_instance:
|
if not tool_instance:
|
||||||
logger.warning(f"未知工具名称: {function_name}")
|
logger.warning(f"未知工具名称: {function_name}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# 执行工具
|
# 执行工具
|
||||||
result = await tool_instance.execute(function_args, message_txt)
|
result = await tool_instance.execute(function_args, message_txt)
|
||||||
if result:
|
if result:
|
||||||
@@ -82,31 +80,30 @@ class ToolUser:
|
|||||||
"tool_call_id": tool_call["id"],
|
"tool_call_id": tool_call["id"],
|
||||||
"role": "tool",
|
"role": "tool",
|
||||||
"name": function_name,
|
"name": function_name,
|
||||||
"content": result["content"]
|
"content": result["content"],
|
||||||
}
|
}
|
||||||
return None
|
return None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"执行工具调用时发生错误: {str(e)}")
|
logger.error(f"执行工具调用时发生错误: {str(e)}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def use_tool(self, message_txt:str, sender_name:str, chat_stream:ChatStream):
|
async def use_tool(self, message_txt: str, sender_name: str, chat_stream: ChatStream):
|
||||||
"""使用工具辅助思考,判断是否需要额外信息
|
"""使用工具辅助思考,判断是否需要额外信息
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
message_txt: 用户消息文本
|
message_txt: 用户消息文本
|
||||||
sender_name: 发送者名称
|
sender_name: 发送者名称
|
||||||
chat_stream: 聊天流对象
|
chat_stream: 聊天流对象
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict: 工具使用结果
|
dict: 工具使用结果
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# 构建提示词
|
# 构建提示词
|
||||||
prompt = await self._build_tool_prompt(message_txt, sender_name, chat_stream)
|
prompt = await self._build_tool_prompt(message_txt, sender_name, chat_stream)
|
||||||
|
|
||||||
# 定义可用工具
|
# 定义可用工具
|
||||||
tools = self._define_tools()
|
tools = self._define_tools()
|
||||||
print(tools)
|
|
||||||
|
|
||||||
# 使用llm_model_tool发送带工具定义的请求
|
# 使用llm_model_tool发送带工具定义的请求
|
||||||
payload = {
|
payload = {
|
||||||
@@ -114,31 +111,29 @@ class ToolUser:
|
|||||||
"messages": [{"role": "user", "content": prompt}],
|
"messages": [{"role": "user", "content": prompt}],
|
||||||
"max_tokens": global_config.max_response_length,
|
"max_tokens": global_config.max_response_length,
|
||||||
"tools": tools,
|
"tools": tools,
|
||||||
"temperature": 0.2
|
"temperature": 0.2,
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.debug(f"发送工具调用请求,模型: {self.llm_model_tool.model_name}")
|
logger.debug(f"发送工具调用请求,模型: {self.llm_model_tool.model_name}")
|
||||||
# 发送请求获取模型是否需要调用工具
|
# 发送请求获取模型是否需要调用工具
|
||||||
response = await self.llm_model_tool._execute_request(
|
response = await self.llm_model_tool._execute_request(
|
||||||
endpoint="/chat/completions",
|
endpoint="/chat/completions", payload=payload, prompt=prompt
|
||||||
payload=payload,
|
|
||||||
prompt=prompt
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 根据返回值数量判断是否有工具调用
|
# 根据返回值数量判断是否有工具调用
|
||||||
if len(response) == 3:
|
if len(response) == 3:
|
||||||
content, reasoning_content, tool_calls = response
|
content, reasoning_content, tool_calls = response
|
||||||
logger.info(f"工具思考: {tool_calls}")
|
logger.info(f"工具思考: {tool_calls}")
|
||||||
|
|
||||||
# 检查响应中工具调用是否有效
|
# 检查响应中工具调用是否有效
|
||||||
if not tool_calls:
|
if not tool_calls:
|
||||||
logger.info("模型返回了空的tool_calls列表")
|
logger.info("模型返回了空的tool_calls列表")
|
||||||
return {"used_tools": False}
|
return {"used_tools": False}
|
||||||
|
|
||||||
logger.info(f"模型请求调用{len(tool_calls)}个工具")
|
logger.info(f"模型请求调用{len(tool_calls)}个工具")
|
||||||
tool_results = []
|
tool_results = []
|
||||||
collected_info = ""
|
collected_info = ""
|
||||||
|
|
||||||
# 执行所有工具调用
|
# 执行所有工具调用
|
||||||
for tool_call in tool_calls:
|
for tool_call in tool_calls:
|
||||||
result = await self._execute_tool_call(tool_call, message_txt)
|
result = await self._execute_tool_call(tool_call, message_txt)
|
||||||
@@ -146,7 +141,7 @@ class ToolUser:
|
|||||||
tool_results.append(result)
|
tool_results.append(result)
|
||||||
# 将工具结果添加到收集的信息中
|
# 将工具结果添加到收集的信息中
|
||||||
collected_info += f"\n{result['name']}返回结果: {result['content']}\n"
|
collected_info += f"\n{result['name']}返回结果: {result['content']}\n"
|
||||||
|
|
||||||
# 如果有工具结果,直接返回收集的信息
|
# 如果有工具结果,直接返回收集的信息
|
||||||
if collected_info:
|
if collected_info:
|
||||||
logger.info(f"工具调用收集到信息: {collected_info}")
|
logger.info(f"工具调用收集到信息: {collected_info}")
|
||||||
@@ -158,15 +153,15 @@ class ToolUser:
|
|||||||
# 没有工具调用
|
# 没有工具调用
|
||||||
content, reasoning_content = response
|
content, reasoning_content = response
|
||||||
logger.info("模型没有请求调用任何工具")
|
logger.info("模型没有请求调用任何工具")
|
||||||
|
|
||||||
# 如果没有工具调用或处理失败,直接返回原始思考
|
# 如果没有工具调用或处理失败,直接返回原始思考
|
||||||
return {
|
return {
|
||||||
"used_tools": False,
|
"used_tools": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"工具调用过程中出错: {str(e)}")
|
logger.error(f"工具调用过程中出错: {str(e)}")
|
||||||
return {
|
return {
|
||||||
"used_tools": False,
|
"used_tools": False,
|
||||||
"error": str(e),
|
"error": str(e),
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -43,12 +43,11 @@ def init_prompt():
|
|||||||
|
|
||||||
class CurrentState:
|
class CurrentState:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|
||||||
self.current_state_info = ""
|
self.current_state_info = ""
|
||||||
|
|
||||||
self.mood_manager = MoodManager()
|
self.mood_manager = MoodManager()
|
||||||
self.mood = self.mood_manager.get_prompt()
|
self.mood = self.mood_manager.get_prompt()
|
||||||
|
|
||||||
self.attendance_factor = 0
|
self.attendance_factor = 0
|
||||||
self.engagement_factor = 0
|
self.engagement_factor = 0
|
||||||
|
|
||||||
@@ -66,9 +65,6 @@ class Heartflow:
|
|||||||
)
|
)
|
||||||
|
|
||||||
self._subheartflows: Dict[Any, SubHeartflow] = {}
|
self._subheartflows: Dict[Any, SubHeartflow] = {}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def _cleanup_inactive_subheartflows(self):
|
async def _cleanup_inactive_subheartflows(self):
|
||||||
"""定期清理不活跃的子心流"""
|
"""定期清理不活跃的子心流"""
|
||||||
@@ -90,7 +86,7 @@ class Heartflow:
|
|||||||
logger.info(f"已清理不活跃的子心流: {subheartflow_id}")
|
logger.info(f"已清理不活跃的子心流: {subheartflow_id}")
|
||||||
|
|
||||||
await asyncio.sleep(30) # 每分钟检查一次
|
await asyncio.sleep(30) # 每分钟检查一次
|
||||||
|
|
||||||
async def _sub_heartflow_update(self):
|
async def _sub_heartflow_update(self):
|
||||||
while True:
|
while True:
|
||||||
# 检查是否存在子心流
|
# 检查是否存在子心流
|
||||||
@@ -103,13 +99,12 @@ class Heartflow:
|
|||||||
await asyncio.sleep(global_config.heart_flow_update_interval) # 5分钟思考一次
|
await asyncio.sleep(global_config.heart_flow_update_interval) # 5分钟思考一次
|
||||||
|
|
||||||
async def heartflow_start_working(self):
|
async def heartflow_start_working(self):
|
||||||
|
|
||||||
# 启动清理任务
|
# 启动清理任务
|
||||||
asyncio.create_task(self._cleanup_inactive_subheartflows())
|
asyncio.create_task(self._cleanup_inactive_subheartflows())
|
||||||
|
|
||||||
# 启动子心流更新任务
|
# 启动子心流更新任务
|
||||||
asyncio.create_task(self._sub_heartflow_update())
|
asyncio.create_task(self._sub_heartflow_update())
|
||||||
|
|
||||||
async def _update_current_state(self):
|
async def _update_current_state(self):
|
||||||
print("TODO")
|
print("TODO")
|
||||||
|
|
||||||
|
|||||||
@@ -150,7 +150,7 @@ class ChattingObservation(Observation):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"获取总结失败: {e}")
|
print(f"获取总结失败: {e}")
|
||||||
updated_observe_info = ""
|
updated_observe_info = ""
|
||||||
|
|
||||||
return updated_observe_info
|
return updated_observe_info
|
||||||
# print(f"prompt:{prompt}")
|
# print(f"prompt:{prompt}")
|
||||||
# print(f"self.observe_info:{self.observe_info}")
|
# print(f"self.observe_info:{self.observe_info}")
|
||||||
|
|||||||
@@ -5,9 +5,11 @@ from src.plugins.models.utils_model import LLM_request
|
|||||||
from src.plugins.config.config import global_config
|
from src.plugins.config.config import global_config
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
|
|
||||||
# from src.plugins.schedule.schedule_generator import bot_schedule
|
# from src.plugins.schedule.schedule_generator import bot_schedule
|
||||||
# from src.plugins.memory_system.Hippocampus import HippocampusManager
|
# from src.plugins.memory_system.Hippocampus import HippocampusManager
|
||||||
from src.common.logger import get_module_logger, LogConfig, SUB_HEARTFLOW_STYLE_CONFIG # noqa: E402
|
from src.common.logger import get_module_logger, LogConfig, SUB_HEARTFLOW_STYLE_CONFIG # noqa: E402
|
||||||
|
|
||||||
# from src.plugins.chat.utils import get_embedding
|
# from src.plugins.chat.utils import get_embedding
|
||||||
# from src.common.database import db
|
# from src.common.database import db
|
||||||
# from typing import Union
|
# from typing import Union
|
||||||
@@ -17,7 +19,7 @@ from src.plugins.chat.chat_stream import ChatStream
|
|||||||
from src.plugins.person_info.relationship_manager import relationship_manager
|
from src.plugins.person_info.relationship_manager import relationship_manager
|
||||||
from src.plugins.chat.utils import get_recent_group_speaker
|
from src.plugins.chat.utils import get_recent_group_speaker
|
||||||
from src.do_tool.tool_use import ToolUser
|
from src.do_tool.tool_use import ToolUser
|
||||||
from ..plugins.utils.prompt_builder import Prompt,global_prompt_manager
|
from ..plugins.utils.prompt_builder import Prompt, global_prompt_manager
|
||||||
|
|
||||||
subheartflow_config = LogConfig(
|
subheartflow_config = LogConfig(
|
||||||
# 使用海马体专用样式
|
# 使用海马体专用样式
|
||||||
@@ -26,6 +28,7 @@ subheartflow_config = LogConfig(
|
|||||||
)
|
)
|
||||||
logger = get_module_logger("subheartflow", config=subheartflow_config)
|
logger = get_module_logger("subheartflow", config=subheartflow_config)
|
||||||
|
|
||||||
|
|
||||||
def init_prompt():
|
def init_prompt():
|
||||||
prompt = ""
|
prompt = ""
|
||||||
# prompt += f"麦麦的总体想法是:{self.main_heartflow_info}\n\n"
|
# prompt += f"麦麦的总体想法是:{self.main_heartflow_info}\n\n"
|
||||||
@@ -41,7 +44,7 @@ def init_prompt():
|
|||||||
prompt += "思考时可以想想如何对群聊内容进行回复。回复的要求是:平淡一些,简短一些,说中文,尽量不要说你说过的话\n"
|
prompt += "思考时可以想想如何对群聊内容进行回复。回复的要求是:平淡一些,简短一些,说中文,尽量不要说你说过的话\n"
|
||||||
prompt += "请注意不要输出多余内容(包括前后缀,冒号和引号,括号, 表情,等),不要带有括号和动作描写"
|
prompt += "请注意不要输出多余内容(包括前后缀,冒号和引号,括号, 表情,等),不要带有括号和动作描写"
|
||||||
prompt += "记得结合上述的消息,生成内心想法,文字不要浮夸,注意你就是{bot_name},{bot_name}指的就是你。"
|
prompt += "记得结合上述的消息,生成内心想法,文字不要浮夸,注意你就是{bot_name},{bot_name}指的就是你。"
|
||||||
Prompt(prompt,"sub_heartflow_prompt_before")
|
Prompt(prompt, "sub_heartflow_prompt_before")
|
||||||
prompt = ""
|
prompt = ""
|
||||||
# prompt += f"你现在正在做的事情是:{schedule_info}\n"
|
# prompt += f"你现在正在做的事情是:{schedule_info}\n"
|
||||||
prompt += "{prompt_personality}\n"
|
prompt += "{prompt_personality}\n"
|
||||||
@@ -52,8 +55,7 @@ def init_prompt():
|
|||||||
prompt += "你现在{mood_info}"
|
prompt += "你现在{mood_info}"
|
||||||
prompt += "现在你接下去继续思考,产生新的想法,记得保留你刚刚的想法,不要分点输出,输出连贯的内心独白"
|
prompt += "现在你接下去继续思考,产生新的想法,记得保留你刚刚的想法,不要分点输出,输出连贯的内心独白"
|
||||||
prompt += "不要太长,但是记得结合上述的消息,要记得你的人设,关注聊天和新内容,关注你回复的内容,不要思考太多:"
|
prompt += "不要太长,但是记得结合上述的消息,要记得你的人设,关注聊天和新内容,关注你回复的内容,不要思考太多:"
|
||||||
Prompt(prompt,'sub_heartflow_prompt_after')
|
Prompt(prompt, "sub_heartflow_prompt_after")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class CurrentState:
|
class CurrentState:
|
||||||
@@ -78,7 +80,6 @@ class SubHeartflow:
|
|||||||
self.llm_model = LLM_request(
|
self.llm_model = LLM_request(
|
||||||
model=global_config.llm_sub_heartflow, temperature=0.2, max_tokens=600, request_type="sub_heart_flow"
|
model=global_config.llm_sub_heartflow, temperature=0.2, max_tokens=600, request_type="sub_heart_flow"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
self.main_heartflow_info = ""
|
self.main_heartflow_info = ""
|
||||||
|
|
||||||
@@ -93,9 +94,9 @@ class SubHeartflow:
|
|||||||
self.observations: list[Observation] = []
|
self.observations: list[Observation] = []
|
||||||
|
|
||||||
self.running_knowledges = []
|
self.running_knowledges = []
|
||||||
|
|
||||||
self.bot_name = global_config.BOT_NICKNAME
|
self.bot_name = global_config.BOT_NICKNAME
|
||||||
|
|
||||||
self.tool_user = ToolUser()
|
self.tool_user = ToolUser()
|
||||||
|
|
||||||
def add_observation(self, observation: Observation):
|
def add_observation(self, observation: Observation):
|
||||||
@@ -145,12 +146,12 @@ class SubHeartflow:
|
|||||||
): # 5分钟无回复/不在场,销毁
|
): # 5分钟无回复/不在场,销毁
|
||||||
logger.info(f"子心流 {self.subheartflow_id} 已经5分钟没有激活,正在销毁...")
|
logger.info(f"子心流 {self.subheartflow_id} 已经5分钟没有激活,正在销毁...")
|
||||||
break # 退出循环,销毁自己
|
break # 退出循环,销毁自己
|
||||||
|
|
||||||
async def do_observe(self):
|
async def do_observe(self):
|
||||||
observation = self.observations[0]
|
observation = self.observations[0]
|
||||||
await observation.observe()
|
await observation.observe()
|
||||||
|
|
||||||
|
async def do_thinking_before_reply(self, message_txt: str, sender_name: str, chat_stream: ChatStream):
|
||||||
async def do_thinking_before_reply(self, message_txt:str, sender_name:str, chat_stream:ChatStream):
|
|
||||||
current_thinking_info = self.current_mind
|
current_thinking_info = self.current_mind
|
||||||
mood_info = self.current_state.mood
|
mood_info = self.current_state.mood
|
||||||
# mood_info = "你很生气,很愤怒"
|
# mood_info = "你很生气,很愤怒"
|
||||||
@@ -160,12 +161,12 @@ class SubHeartflow:
|
|||||||
|
|
||||||
# 首先尝试使用工具获取更多信息
|
# 首先尝试使用工具获取更多信息
|
||||||
tool_result = await self.tool_user.use_tool(message_txt, sender_name, chat_stream)
|
tool_result = await self.tool_user.use_tool(message_txt, sender_name, chat_stream)
|
||||||
|
|
||||||
# 如果工具被使用且获得了结果,将收集到的信息合并到思考中
|
# 如果工具被使用且获得了结果,将收集到的信息合并到思考中
|
||||||
collected_info = ""
|
collected_info = ""
|
||||||
if tool_result.get("used_tools", False):
|
if tool_result.get("used_tools", False):
|
||||||
logger.info("使用工具收集了信息")
|
logger.info("使用工具收集了信息")
|
||||||
|
|
||||||
# 如果有收集到的信息,将其添加到当前思考中
|
# 如果有收集到的信息,将其添加到当前思考中
|
||||||
if "collected_info" in tool_result:
|
if "collected_info" in tool_result:
|
||||||
collected_info = tool_result["collected_info"]
|
collected_info = tool_result["collected_info"]
|
||||||
@@ -185,7 +186,7 @@ class SubHeartflow:
|
|||||||
identity_detail = individuality.identity.identity_detail
|
identity_detail = individuality.identity.identity_detail
|
||||||
random.shuffle(identity_detail)
|
random.shuffle(identity_detail)
|
||||||
prompt_personality += f",{identity_detail[0]}"
|
prompt_personality += f",{identity_detail[0]}"
|
||||||
|
|
||||||
# 关系
|
# 关系
|
||||||
who_chat_in_group = [
|
who_chat_in_group = [
|
||||||
(chat_stream.user_info.platform, chat_stream.user_info.user_id, chat_stream.user_info.user_nickname)
|
(chat_stream.user_info.platform, chat_stream.user_info.user_id, chat_stream.user_info.user_nickname)
|
||||||
@@ -204,9 +205,9 @@ class SubHeartflow:
|
|||||||
# f"{relation_prompt}关系等级越大,关系越好,请分析聊天记录,"
|
# f"{relation_prompt}关系等级越大,关系越好,请分析聊天记录,"
|
||||||
# f"根据你和说话者{sender_name}的关系和态度进行回复,明确你的立场和情感。"
|
# f"根据你和说话者{sender_name}的关系和态度进行回复,明确你的立场和情感。"
|
||||||
# )
|
# )
|
||||||
relation_prompt_all = (await global_prompt_manager.get_prompt_async('relationship_prompt')).format(
|
relation_prompt_all = (await global_prompt_manager.get_prompt_async("relationship_prompt")).format(
|
||||||
relation_prompt,sender_name
|
relation_prompt, sender_name
|
||||||
)
|
)
|
||||||
|
|
||||||
# prompt = ""
|
# prompt = ""
|
||||||
# # prompt += f"麦麦的总体想法是:{self.main_heartflow_info}\n\n"
|
# # prompt += f"麦麦的总体想法是:{self.main_heartflow_info}\n\n"
|
||||||
@@ -224,9 +225,16 @@ class SubHeartflow:
|
|||||||
# prompt += "请注意不要输出多余内容(包括前后缀,冒号和引号,括号, 表情,等),不要带有括号和动作描写"
|
# prompt += "请注意不要输出多余内容(包括前后缀,冒号和引号,括号, 表情,等),不要带有括号和动作描写"
|
||||||
# prompt += f"记得结合上述的消息,生成内心想法,文字不要浮夸,注意你就是{self.bot_name},{self.bot_name}指的就是你。"
|
# prompt += f"记得结合上述的消息,生成内心想法,文字不要浮夸,注意你就是{self.bot_name},{self.bot_name}指的就是你。"
|
||||||
|
|
||||||
prompt= (await global_prompt_manager.get_prompt_async("sub_heartflow_prompt_before")).format(
|
prompt = (await global_prompt_manager.get_prompt_async("sub_heartflow_prompt_before")).format(
|
||||||
collected_info,relation_prompt_all,prompt_personality,current_thinking_info,chat_observe_info,mood_info,sender_name,
|
collected_info,
|
||||||
message_txt,self.bot_name
|
relation_prompt_all,
|
||||||
|
prompt_personality,
|
||||||
|
current_thinking_info,
|
||||||
|
chat_observe_info,
|
||||||
|
mood_info,
|
||||||
|
sender_name,
|
||||||
|
message_txt,
|
||||||
|
self.bot_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -281,10 +289,10 @@ class SubHeartflow:
|
|||||||
# prompt += f"你现在{mood_info}"
|
# prompt += f"你现在{mood_info}"
|
||||||
# prompt += "现在你接下去继续思考,产生新的想法,记得保留你刚刚的想法,不要分点输出,输出连贯的内心独白"
|
# prompt += "现在你接下去继续思考,产生新的想法,记得保留你刚刚的想法,不要分点输出,输出连贯的内心独白"
|
||||||
# prompt += "不要太长,但是记得结合上述的消息,要记得你的人设,关注聊天和新内容,关注你回复的内容,不要思考太多:"
|
# prompt += "不要太长,但是记得结合上述的消息,要记得你的人设,关注聊天和新内容,关注你回复的内容,不要思考太多:"
|
||||||
prompt=(await global_prompt_manager.get_prompt_async('sub_heartflow_prompt_after')).format(
|
prompt = (await global_prompt_manager.get_prompt_async("sub_heartflow_prompt_after")).format(
|
||||||
prompt_personality,chat_observe_info,current_thinking_info,message_new_info,reply_info,mood_info
|
prompt_personality, chat_observe_info, current_thinking_info, message_new_info, reply_info, mood_info
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response, reasoning_content = await self.llm_model.generate_response_async(prompt)
|
response, reasoning_content = await self.llm_model.generate_response_async(prompt)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -343,5 +351,6 @@ class SubHeartflow:
|
|||||||
self.past_mind.append(self.current_mind)
|
self.past_mind.append(self.current_mind)
|
||||||
self.current_mind = response
|
self.current_mind = response
|
||||||
|
|
||||||
|
|
||||||
init_prompt()
|
init_prompt()
|
||||||
# subheartflow = SubHeartflow()
|
# subheartflow = SubHeartflow()
|
||||||
|
|||||||
@@ -53,13 +53,13 @@ class ActionPlanner:
|
|||||||
goal = goal_reason[0]
|
goal = goal_reason[0]
|
||||||
reasoning = goal_reason[1] if len(goal_reason) > 1 else "没有明确原因"
|
reasoning = goal_reason[1] if len(goal_reason) > 1 else "没有明确原因"
|
||||||
elif isinstance(goal_reason, dict):
|
elif isinstance(goal_reason, dict):
|
||||||
goal = goal_reason.get('goal')
|
goal = goal_reason.get("goal")
|
||||||
reasoning = goal_reason.get('reasoning', "没有明确原因")
|
reasoning = goal_reason.get("reasoning", "没有明确原因")
|
||||||
else:
|
else:
|
||||||
# 如果是其他类型,尝试转为字符串
|
# 如果是其他类型,尝试转为字符串
|
||||||
goal = str(goal_reason)
|
goal = str(goal_reason)
|
||||||
reasoning = "没有明确原因"
|
reasoning = "没有明确原因"
|
||||||
|
|
||||||
goal_str = f"目标:{goal},产生该对话目标的原因:{reasoning}\n"
|
goal_str = f"目标:{goal},产生该对话目标的原因:{reasoning}\n"
|
||||||
goals_str += goal_str
|
goals_str += goal_str
|
||||||
else:
|
else:
|
||||||
@@ -68,7 +68,11 @@ class ActionPlanner:
|
|||||||
goals_str = f"目标:{goal},产生该对话目标的原因:{reasoning}\n"
|
goals_str = f"目标:{goal},产生该对话目标的原因:{reasoning}\n"
|
||||||
|
|
||||||
# 获取聊天历史记录
|
# 获取聊天历史记录
|
||||||
chat_history_list = observation_info.chat_history[-20:] if len(observation_info.chat_history) >= 20 else observation_info.chat_history
|
chat_history_list = (
|
||||||
|
observation_info.chat_history[-20:]
|
||||||
|
if len(observation_info.chat_history) >= 20
|
||||||
|
else observation_info.chat_history
|
||||||
|
)
|
||||||
chat_history_text = ""
|
chat_history_text = ""
|
||||||
for msg in chat_history_list:
|
for msg in chat_history_list:
|
||||||
chat_history_text += f"{msg.get('detailed_plain_text', '')}\n"
|
chat_history_text += f"{msg.get('detailed_plain_text', '')}\n"
|
||||||
@@ -85,15 +89,21 @@ class ActionPlanner:
|
|||||||
personality_text = f"你的名字是{self.name},{self.personality_info}"
|
personality_text = f"你的名字是{self.name},{self.personality_info}"
|
||||||
|
|
||||||
# 构建action历史文本
|
# 构建action历史文本
|
||||||
action_history_list = conversation_info.done_action[-10:] if len(conversation_info.done_action) >= 10 else conversation_info.done_action
|
action_history_list = (
|
||||||
|
conversation_info.done_action[-10:]
|
||||||
|
if len(conversation_info.done_action) >= 10
|
||||||
|
else conversation_info.done_action
|
||||||
|
)
|
||||||
action_history_text = "你之前做的事情是:"
|
action_history_text = "你之前做的事情是:"
|
||||||
for action in action_history_list:
|
for action in action_history_list:
|
||||||
if isinstance(action, dict):
|
if isinstance(action, dict):
|
||||||
action_type = action.get('action')
|
action_type = action.get("action")
|
||||||
action_reason = action.get('reason')
|
action_reason = action.get("reason")
|
||||||
action_status = action.get('status')
|
action_status = action.get("status")
|
||||||
if action_status == "recall":
|
if action_status == "recall":
|
||||||
action_history_text += f"原本打算:{action_type},但是因为有新消息,你发现这个行动不合适,所以你没做\n"
|
action_history_text += (
|
||||||
|
f"原本打算:{action_type},但是因为有新消息,你发现这个行动不合适,所以你没做\n"
|
||||||
|
)
|
||||||
elif action_status == "done":
|
elif action_status == "done":
|
||||||
action_history_text += f"你之前做了:{action_type},原因:{action_reason}\n"
|
action_history_text += f"你之前做了:{action_type},原因:{action_reason}\n"
|
||||||
elif isinstance(action, tuple):
|
elif isinstance(action, tuple):
|
||||||
@@ -102,7 +112,9 @@ class ActionPlanner:
|
|||||||
action_reason = action[1] if len(action) > 1 else "未知原因"
|
action_reason = action[1] if len(action) > 1 else "未知原因"
|
||||||
action_status = action[2] if len(action) > 2 else "done"
|
action_status = action[2] if len(action) > 2 else "done"
|
||||||
if action_status == "recall":
|
if action_status == "recall":
|
||||||
action_history_text += f"原本打算:{action_type},但是因为有新消息,你发现这个行动不合适,所以你没做\n"
|
action_history_text += (
|
||||||
|
f"原本打算:{action_type},但是因为有新消息,你发现这个行动不合适,所以你没做\n"
|
||||||
|
)
|
||||||
elif action_status == "done":
|
elif action_status == "done":
|
||||||
action_history_text += f"你之前做了:{action_type},原因:{action_reason}\n"
|
action_history_text += f"你之前做了:{action_type},原因:{action_reason}\n"
|
||||||
|
|
||||||
@@ -147,7 +159,14 @@ end_conversation: 结束对话,长时间没回复或者当你觉得谈话暂
|
|||||||
reason = result["reason"]
|
reason = result["reason"]
|
||||||
|
|
||||||
# 验证action类型
|
# 验证action类型
|
||||||
if action not in ["direct_reply", "fetch_knowledge", "wait", "listening", "rethink_goal", "end_conversation"]:
|
if action not in [
|
||||||
|
"direct_reply",
|
||||||
|
"fetch_knowledge",
|
||||||
|
"wait",
|
||||||
|
"listening",
|
||||||
|
"rethink_goal",
|
||||||
|
"end_conversation",
|
||||||
|
]:
|
||||||
logger.warning(f"未知的行动类型: {action},默认使用listening")
|
logger.warning(f"未知的行动类型: {action},默认使用listening")
|
||||||
action = "listening"
|
action = "listening"
|
||||||
|
|
||||||
|
|||||||
@@ -1,12 +1,12 @@
|
|||||||
import time
|
import time
|
||||||
import asyncio
|
import asyncio
|
||||||
import traceback
|
import traceback
|
||||||
from typing import Optional, Dict, Any, List
|
from typing import Optional, Dict, Any, List
|
||||||
from src.common.logger import get_module_logger
|
from src.common.logger import get_module_logger
|
||||||
from ..message.message_base import UserInfo
|
from ..message.message_base import UserInfo
|
||||||
from ..config.config import global_config
|
from ..config.config import global_config
|
||||||
from .chat_states import NotificationManager, create_new_message_notification, create_cold_chat_notification
|
from .chat_states import NotificationManager, create_new_message_notification, create_cold_chat_notification
|
||||||
from .message_storage import MongoDBMessageStorage
|
from .message_storage import MongoDBMessageStorage
|
||||||
|
|
||||||
logger = get_module_logger("chat_observer")
|
logger = get_module_logger("chat_observer")
|
||||||
|
|
||||||
@@ -51,7 +51,6 @@ class ChatObserver:
|
|||||||
|
|
||||||
self.waiting_start_time: float = time.time() # 等待开始时间,初始化为当前时间
|
self.waiting_start_time: float = time.time() # 等待开始时间,初始化为当前时间
|
||||||
|
|
||||||
|
|
||||||
# 运行状态
|
# 运行状态
|
||||||
self._running: bool = False
|
self._running: bool = False
|
||||||
self._task: Optional[asyncio.Task] = None
|
self._task: Optional[asyncio.Task] = None
|
||||||
@@ -94,10 +93,11 @@ class ChatObserver:
|
|||||||
message: 消息数据
|
message: 消息数据
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
|
||||||
# 发送新消息通知
|
# 发送新消息通知
|
||||||
# logger.info(f"发送新ccchandleer消息通知: {message}")
|
# logger.info(f"发送新ccchandleer消息通知: {message}")
|
||||||
notification = create_new_message_notification(sender="chat_observer", target="observation_info", message=message)
|
notification = create_new_message_notification(
|
||||||
|
sender="chat_observer", target="observation_info", message=message
|
||||||
|
)
|
||||||
# logger.info(f"发送新消ddddd息通知: {notification}")
|
# logger.info(f"发送新消ddddd息通知: {notification}")
|
||||||
# print(self.notification_manager)
|
# print(self.notification_manager)
|
||||||
await self.notification_manager.send_notification(notification)
|
await self.notification_manager.send_notification(notification)
|
||||||
@@ -131,7 +131,6 @@ class ChatObserver:
|
|||||||
notification = create_cold_chat_notification(sender="chat_observer", target="pfc", is_cold=is_cold)
|
notification = create_cold_chat_notification(sender="chat_observer", target="pfc", is_cold=is_cold)
|
||||||
await self.notification_manager.send_notification(notification)
|
await self.notification_manager.send_notification(notification)
|
||||||
|
|
||||||
|
|
||||||
def new_message_after(self, time_point: float) -> bool:
|
def new_message_after(self, time_point: float) -> bool:
|
||||||
"""判断是否在指定时间点后有新消息
|
"""判断是否在指定时间点后有新消息
|
||||||
|
|
||||||
@@ -197,7 +196,7 @@ class ChatObserver:
|
|||||||
if new_messages:
|
if new_messages:
|
||||||
self.last_message_read = new_messages[-1]
|
self.last_message_read = new_messages[-1]
|
||||||
self.last_message_time = new_messages[-1]["time"]
|
self.last_message_time = new_messages[-1]["time"]
|
||||||
|
|
||||||
# print(f"获取数据库中找到的新消息: {new_messages}")
|
# print(f"获取数据库中找到的新消息: {new_messages}")
|
||||||
|
|
||||||
return new_messages
|
return new_messages
|
||||||
@@ -215,7 +214,7 @@ class ChatObserver:
|
|||||||
|
|
||||||
if new_messages:
|
if new_messages:
|
||||||
self.last_message_read = new_messages[-1]["message_id"]
|
self.last_message_read = new_messages[-1]["message_id"]
|
||||||
|
|
||||||
logger.debug(f"获取指定时间点111之前的消息: {new_messages}")
|
logger.debug(f"获取指定时间点111之前的消息: {new_messages}")
|
||||||
|
|
||||||
return new_messages
|
return new_messages
|
||||||
@@ -239,7 +238,7 @@ class ChatObserver:
|
|||||||
try:
|
try:
|
||||||
# print("等待事件")
|
# print("等待事件")
|
||||||
await asyncio.wait_for(self._update_event.wait(), timeout=1)
|
await asyncio.wait_for(self._update_event.wait(), timeout=1)
|
||||||
|
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
# print("超时")
|
# print("超时")
|
||||||
pass # 超时后也执行一次检查
|
pass # 超时后也执行一次检查
|
||||||
@@ -347,7 +346,6 @@ class ChatObserver:
|
|||||||
|
|
||||||
return time_info
|
return time_info
|
||||||
|
|
||||||
|
|
||||||
def get_cached_messages(self, limit: int = 50) -> List[Dict[str, Any]]:
|
def get_cached_messages(self, limit: int = 50) -> List[Dict[str, Any]]:
|
||||||
"""获取缓存的消息历史
|
"""获取缓存的消息历史
|
||||||
|
|
||||||
@@ -368,6 +366,6 @@ class ChatObserver:
|
|||||||
if not self.message_cache:
|
if not self.message_cache:
|
||||||
return None
|
return None
|
||||||
return self.message_cache[0]
|
return self.message_cache[0]
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return f"ChatObserver for {self.stream_id}"
|
return f"ChatObserver for {self.stream_id}"
|
||||||
|
|||||||
@@ -140,7 +140,6 @@ class NotificationManager:
|
|||||||
self._active_states.add(notification.type)
|
self._active_states.add(notification.type)
|
||||||
else:
|
else:
|
||||||
self._active_states.discard(notification.type)
|
self._active_states.discard(notification.type)
|
||||||
|
|
||||||
|
|
||||||
# 调用目标接收者的处理器
|
# 调用目标接收者的处理器
|
||||||
target = notification.target
|
target = notification.target
|
||||||
@@ -181,7 +180,7 @@ class NotificationManager:
|
|||||||
history = history[-limit:]
|
history = history[-limit:]
|
||||||
|
|
||||||
return history
|
return history
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
str = ""
|
str = ""
|
||||||
for target, handlers in self._handlers.items():
|
for target, handlers in self._handlers.items():
|
||||||
@@ -295,5 +294,3 @@ class ChatStateManager:
|
|||||||
|
|
||||||
current_time = datetime.now().timestamp()
|
current_time = datetime.now().timestamp()
|
||||||
return (current_time - self.state_info.last_message_time) <= threshold
|
return (current_time - self.state_info.last_message_time) <= threshold
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -65,7 +65,6 @@ class Conversation:
|
|||||||
self.observation_info.bind_to_chat_observer(self.chat_observer)
|
self.observation_info.bind_to_chat_observer(self.chat_observer)
|
||||||
# print(self.chat_observer.get_cached_messages(limit=)
|
# print(self.chat_observer.get_cached_messages(limit=)
|
||||||
|
|
||||||
|
|
||||||
self.conversation_info = ConversationInfo()
|
self.conversation_info = ConversationInfo()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"初始化对话实例:注册信息组件失败: {e}")
|
logger.error(f"初始化对话实例:注册信息组件失败: {e}")
|
||||||
@@ -96,7 +95,7 @@ class Conversation:
|
|||||||
|
|
||||||
# 执行行动
|
# 执行行动
|
||||||
await self._handle_action(action, reason, self.observation_info, self.conversation_info)
|
await self._handle_action(action, reason, self.observation_info, self.conversation_info)
|
||||||
|
|
||||||
for goal in self.conversation_info.goal_list:
|
for goal in self.conversation_info.goal_list:
|
||||||
# 检查goal是否为元组类型,如果是元组则使用索引访问,如果是字典则使用get方法
|
# 检查goal是否为元组类型,如果是元组则使用索引访问,如果是字典则使用get方法
|
||||||
if isinstance(goal, tuple):
|
if isinstance(goal, tuple):
|
||||||
@@ -151,7 +150,7 @@ class Conversation:
|
|||||||
|
|
||||||
if action == "direct_reply":
|
if action == "direct_reply":
|
||||||
self.waiter.wait_accumulated_time = 0
|
self.waiter.wait_accumulated_time = 0
|
||||||
|
|
||||||
self.state = ConversationState.GENERATING
|
self.state = ConversationState.GENERATING
|
||||||
self.generated_reply = await self.reply_generator.generate(observation_info, conversation_info)
|
self.generated_reply = await self.reply_generator.generate(observation_info, conversation_info)
|
||||||
print(f"生成回复: {self.generated_reply}")
|
print(f"生成回复: {self.generated_reply}")
|
||||||
@@ -174,7 +173,6 @@ class Conversation:
|
|||||||
|
|
||||||
await self._send_reply()
|
await self._send_reply()
|
||||||
|
|
||||||
|
|
||||||
conversation_info.done_action[-1].update(
|
conversation_info.done_action[-1].update(
|
||||||
{
|
{
|
||||||
"status": "done",
|
"status": "done",
|
||||||
@@ -184,7 +182,7 @@ class Conversation:
|
|||||||
|
|
||||||
elif action == "fetch_knowledge":
|
elif action == "fetch_knowledge":
|
||||||
self.waiter.wait_accumulated_time = 0
|
self.waiter.wait_accumulated_time = 0
|
||||||
|
|
||||||
self.state = ConversationState.FETCHING
|
self.state = ConversationState.FETCHING
|
||||||
knowledge = "TODO:知识"
|
knowledge = "TODO:知识"
|
||||||
topic = "TODO:关键词"
|
topic = "TODO:关键词"
|
||||||
@@ -199,7 +197,7 @@ class Conversation:
|
|||||||
|
|
||||||
elif action == "rethink_goal":
|
elif action == "rethink_goal":
|
||||||
self.waiter.wait_accumulated_time = 0
|
self.waiter.wait_accumulated_time = 0
|
||||||
|
|
||||||
self.state = ConversationState.RETHINKING
|
self.state = ConversationState.RETHINKING
|
||||||
await self.goal_analyzer.analyze_goal(conversation_info, observation_info)
|
await self.goal_analyzer.analyze_goal(conversation_info, observation_info)
|
||||||
|
|
||||||
@@ -208,7 +206,6 @@ class Conversation:
|
|||||||
logger.info("倾听对方发言...")
|
logger.info("倾听对方发言...")
|
||||||
await self.waiter.wait_listening(conversation_info)
|
await self.waiter.wait_listening(conversation_info)
|
||||||
|
|
||||||
|
|
||||||
elif action == "end_conversation":
|
elif action == "end_conversation":
|
||||||
self.should_continue = False
|
self.should_continue = False
|
||||||
logger.info("决定结束对话...")
|
logger.info("决定结束对话...")
|
||||||
@@ -239,9 +236,7 @@ class Conversation:
|
|||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await self.direct_sender.send_message(
|
await self.direct_sender.send_message(chat_stream=self.chat_stream, content=self.generated_reply)
|
||||||
chat_stream=self.chat_stream, content=self.generated_reply
|
|
||||||
)
|
|
||||||
self.chat_observer.trigger_update() # 触发立即更新
|
self.chat_observer.trigger_update() # 触发立即更新
|
||||||
if not await self.chat_observer.wait_for_update():
|
if not await self.chat_observer.wait_for_update():
|
||||||
logger.warning("等待消息更新超时")
|
logger.warning("等待消息更新超时")
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ from abc import ABC, abstractmethod
|
|||||||
from typing import List, Dict, Any
|
from typing import List, Dict, Any
|
||||||
from src.common.database import db
|
from src.common.database import db
|
||||||
|
|
||||||
|
|
||||||
class MessageStorage(ABC):
|
class MessageStorage(ABC):
|
||||||
"""消息存储接口"""
|
"""消息存储接口"""
|
||||||
|
|
||||||
|
|||||||
@@ -26,24 +26,24 @@ class ObservationInfoHandler(NotificationHandler):
|
|||||||
# 获取通知类型和数据
|
# 获取通知类型和数据
|
||||||
notification_type = notification.type
|
notification_type = notification.type
|
||||||
data = notification.data
|
data = notification.data
|
||||||
|
|
||||||
if notification_type == NotificationType.NEW_MESSAGE:
|
if notification_type == NotificationType.NEW_MESSAGE:
|
||||||
# 处理新消息通知
|
# 处理新消息通知
|
||||||
logger.debug(f"收到新消息通知data: {data}")
|
logger.debug(f"收到新消息通知data: {data}")
|
||||||
message_id = data.get("message_id")
|
message_id = data.get("message_id")
|
||||||
processed_plain_text = data.get("processed_plain_text")
|
processed_plain_text = data.get("processed_plain_text")
|
||||||
detailed_plain_text = data.get("detailed_plain_text")
|
detailed_plain_text = data.get("detailed_plain_text")
|
||||||
user_info = data.get("user_info")
|
user_info = data.get("user_info")
|
||||||
time_value = data.get("time")
|
time_value = data.get("time")
|
||||||
|
|
||||||
message = {
|
message = {
|
||||||
"message_id": message_id,
|
"message_id": message_id,
|
||||||
"processed_plain_text": processed_plain_text,
|
"processed_plain_text": processed_plain_text,
|
||||||
"detailed_plain_text": detailed_plain_text,
|
"detailed_plain_text": detailed_plain_text,
|
||||||
"user_info": user_info,
|
"user_info": user_info,
|
||||||
"time": time_value
|
"time": time_value,
|
||||||
}
|
}
|
||||||
|
|
||||||
self.observation_info.update_from_message(message)
|
self.observation_info.update_from_message(message)
|
||||||
|
|
||||||
elif notification_type == NotificationType.COLD_CHAT:
|
elif notification_type == NotificationType.COLD_CHAT:
|
||||||
@@ -161,7 +161,7 @@ class ObservationInfo:
|
|||||||
# logger.debug(f"更新信息from_message: {message}")
|
# logger.debug(f"更新信息from_message: {message}")
|
||||||
self.last_message_time = message["time"]
|
self.last_message_time = message["time"]
|
||||||
self.last_message_id = message["message_id"]
|
self.last_message_id = message["message_id"]
|
||||||
|
|
||||||
self.last_message_content = message.get("processed_plain_text", "")
|
self.last_message_content = message.get("processed_plain_text", "")
|
||||||
|
|
||||||
user_info = UserInfo.from_dict(message.get("user_info", {}))
|
user_info = UserInfo.from_dict(message.get("user_info", {}))
|
||||||
@@ -233,4 +233,3 @@ class ObservationInfo:
|
|||||||
self.unprocessed_messages.clear()
|
self.unprocessed_messages.clear()
|
||||||
self.chat_history_count = len(self.chat_history)
|
self.chat_history_count = len(self.chat_history)
|
||||||
self.new_messages_count = 0
|
self.new_messages_count = 0
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
# Programmable Friendly Conversationalist
|
# Programmable Friendly Conversationalist
|
||||||
# Prefrontal cortex
|
# Prefrontal cortex
|
||||||
import datetime
|
import datetime
|
||||||
|
|
||||||
# import asyncio
|
# import asyncio
|
||||||
from typing import List, Optional, Tuple, TYPE_CHECKING
|
from typing import List, Optional, Tuple, TYPE_CHECKING
|
||||||
from src.common.logger import get_module_logger
|
from src.common.logger import get_module_logger
|
||||||
@@ -63,13 +64,13 @@ class GoalAnalyzer:
|
|||||||
goal = goal_reason[0]
|
goal = goal_reason[0]
|
||||||
reasoning = goal_reason[1] if len(goal_reason) > 1 else "没有明确原因"
|
reasoning = goal_reason[1] if len(goal_reason) > 1 else "没有明确原因"
|
||||||
elif isinstance(goal_reason, dict):
|
elif isinstance(goal_reason, dict):
|
||||||
goal = goal_reason.get('goal')
|
goal = goal_reason.get("goal")
|
||||||
reasoning = goal_reason.get('reasoning', "没有明确原因")
|
reasoning = goal_reason.get("reasoning", "没有明确原因")
|
||||||
else:
|
else:
|
||||||
# 如果是其他类型,尝试转为字符串
|
# 如果是其他类型,尝试转为字符串
|
||||||
goal = str(goal_reason)
|
goal = str(goal_reason)
|
||||||
reasoning = "没有明确原因"
|
reasoning = "没有明确原因"
|
||||||
|
|
||||||
goal_str = f"目标:{goal},产生该对话目标的原因:{reasoning}\n"
|
goal_str = f"目标:{goal},产生该对话目标的原因:{reasoning}\n"
|
||||||
goals_str += goal_str
|
goals_str += goal_str
|
||||||
else:
|
else:
|
||||||
@@ -140,14 +141,12 @@ class GoalAnalyzer:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"分析对话目标时出错: {str(e)}")
|
logger.error(f"分析对话目标时出错: {str(e)}")
|
||||||
content = ""
|
content = ""
|
||||||
|
|
||||||
# 使用改进后的get_items_from_json函数处理JSON数组
|
# 使用改进后的get_items_from_json函数处理JSON数组
|
||||||
success, result = get_items_from_json(
|
success, result = get_items_from_json(
|
||||||
content, "goal", "reasoning",
|
content, "goal", "reasoning", required_types={"goal": str, "reasoning": str}, allow_array=True
|
||||||
required_types={"goal": str, "reasoning": str},
|
|
||||||
allow_array=True
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if success:
|
if success:
|
||||||
# 判断结果是单个字典还是字典列表
|
# 判断结果是单个字典还是字典列表
|
||||||
if isinstance(result, list):
|
if isinstance(result, list):
|
||||||
@@ -157,7 +156,7 @@ class GoalAnalyzer:
|
|||||||
goal = item.get("goal", "")
|
goal = item.get("goal", "")
|
||||||
reasoning = item.get("reasoning", "")
|
reasoning = item.get("reasoning", "")
|
||||||
conversation_info.goal_list.append((goal, reasoning))
|
conversation_info.goal_list.append((goal, reasoning))
|
||||||
|
|
||||||
# 返回第一个目标作为当前主要目标(如果有)
|
# 返回第一个目标作为当前主要目标(如果有)
|
||||||
if result:
|
if result:
|
||||||
first_goal = result[0]
|
first_goal = result[0]
|
||||||
@@ -168,7 +167,7 @@ class GoalAnalyzer:
|
|||||||
reasoning = result.get("reasoning", "")
|
reasoning = result.get("reasoning", "")
|
||||||
conversation_info.goal_list.append((goal, reasoning))
|
conversation_info.goal_list.append((goal, reasoning))
|
||||||
return (goal, "", reasoning)
|
return (goal, "", reasoning)
|
||||||
|
|
||||||
# 如果解析失败,返回默认值
|
# 如果解析失败,返回默认值
|
||||||
return ("", "", "")
|
return ("", "", "")
|
||||||
|
|
||||||
@@ -293,7 +292,6 @@ class GoalAnalyzer:
|
|||||||
return False, False, f"分析出错: {str(e)}"
|
return False, False, f"分析出错: {str(e)}"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class DirectMessageSender:
|
class DirectMessageSender:
|
||||||
"""直接发送消息到平台的发送器"""
|
"""直接发送消息到平台的发送器"""
|
||||||
|
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ def get_items_from_json(
|
|||||||
"""
|
"""
|
||||||
content = content.strip()
|
content = content.strip()
|
||||||
result = {}
|
result = {}
|
||||||
|
|
||||||
# 设置默认值
|
# 设置默认值
|
||||||
if default_values:
|
if default_values:
|
||||||
result.update(default_values)
|
result.update(default_values)
|
||||||
@@ -41,7 +41,7 @@ def get_items_from_json(
|
|||||||
if array_match:
|
if array_match:
|
||||||
array_content = array_match.group()
|
array_content = array_match.group()
|
||||||
json_array = json.loads(array_content)
|
json_array = json.loads(array_content)
|
||||||
|
|
||||||
# 确认是数组类型
|
# 确认是数组类型
|
||||||
if isinstance(json_array, list):
|
if isinstance(json_array, list):
|
||||||
# 验证数组中的每个项目是否包含所有必需字段
|
# 验证数组中的每个项目是否包含所有必需字段
|
||||||
@@ -49,7 +49,7 @@ def get_items_from_json(
|
|||||||
for item in json_array:
|
for item in json_array:
|
||||||
if not isinstance(item, dict):
|
if not isinstance(item, dict):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 检查是否有所有必需字段
|
# 检查是否有所有必需字段
|
||||||
if all(field in item for field in items):
|
if all(field in item for field in items):
|
||||||
# 验证字段类型
|
# 验证字段类型
|
||||||
@@ -59,22 +59,22 @@ def get_items_from_json(
|
|||||||
if field in item and not isinstance(item[field], expected_type):
|
if field in item and not isinstance(item[field], expected_type):
|
||||||
type_valid = False
|
type_valid = False
|
||||||
break
|
break
|
||||||
|
|
||||||
if not type_valid:
|
if not type_valid:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 验证字符串字段不为空
|
# 验证字符串字段不为空
|
||||||
string_valid = True
|
string_valid = True
|
||||||
for field in items:
|
for field in items:
|
||||||
if isinstance(item[field], str) and not item[field].strip():
|
if isinstance(item[field], str) and not item[field].strip():
|
||||||
string_valid = False
|
string_valid = False
|
||||||
break
|
break
|
||||||
|
|
||||||
if not string_valid:
|
if not string_valid:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
valid_items.append(item)
|
valid_items.append(item)
|
||||||
|
|
||||||
if valid_items:
|
if valid_items:
|
||||||
return True, valid_items
|
return True, valid_items
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
|
|||||||
@@ -49,22 +49,26 @@ class ReplyGenerator:
|
|||||||
goal = goal_reason[0]
|
goal = goal_reason[0]
|
||||||
reasoning = goal_reason[1] if len(goal_reason) > 1 else "没有明确原因"
|
reasoning = goal_reason[1] if len(goal_reason) > 1 else "没有明确原因"
|
||||||
elif isinstance(goal_reason, dict):
|
elif isinstance(goal_reason, dict):
|
||||||
goal = goal_reason.get('goal')
|
goal = goal_reason.get("goal")
|
||||||
reasoning = goal_reason.get('reasoning', "没有明确原因")
|
reasoning = goal_reason.get("reasoning", "没有明确原因")
|
||||||
else:
|
else:
|
||||||
# 如果是其他类型,尝试转为字符串
|
# 如果是其他类型,尝试转为字符串
|
||||||
goal = str(goal_reason)
|
goal = str(goal_reason)
|
||||||
reasoning = "没有明确原因"
|
reasoning = "没有明确原因"
|
||||||
|
|
||||||
goal_str = f"目标:{goal},产生该对话目标的原因:{reasoning}\n"
|
goal_str = f"目标:{goal},产生该对话目标的原因:{reasoning}\n"
|
||||||
goals_str += goal_str
|
goals_str += goal_str
|
||||||
else:
|
else:
|
||||||
goal = "目前没有明确对话目标"
|
goal = "目前没有明确对话目标"
|
||||||
reasoning = "目前没有明确对话目标,最好思考一个对话目标"
|
reasoning = "目前没有明确对话目标,最好思考一个对话目标"
|
||||||
goals_str = f"目标:{goal},产生该对话目标的原因:{reasoning}\n"
|
goals_str = f"目标:{goal},产生该对话目标的原因:{reasoning}\n"
|
||||||
|
|
||||||
# 获取聊天历史记录
|
# 获取聊天历史记录
|
||||||
chat_history_list = observation_info.chat_history[-20:] if len(observation_info.chat_history) >= 20 else observation_info.chat_history
|
chat_history_list = (
|
||||||
|
observation_info.chat_history[-20:]
|
||||||
|
if len(observation_info.chat_history) >= 20
|
||||||
|
else observation_info.chat_history
|
||||||
|
)
|
||||||
chat_history_text = ""
|
chat_history_text = ""
|
||||||
for msg in chat_history_list:
|
for msg in chat_history_list:
|
||||||
chat_history_text += f"{msg.get('detailed_plain_text', '')}\n"
|
chat_history_text += f"{msg.get('detailed_plain_text', '')}\n"
|
||||||
@@ -81,15 +85,21 @@ class ReplyGenerator:
|
|||||||
personality_text = f"你的名字是{self.name},{self.personality_info}"
|
personality_text = f"你的名字是{self.name},{self.personality_info}"
|
||||||
|
|
||||||
# 构建action历史文本
|
# 构建action历史文本
|
||||||
action_history_list = conversation_info.done_action[-10:] if len(conversation_info.done_action) >= 10 else conversation_info.done_action
|
action_history_list = (
|
||||||
|
conversation_info.done_action[-10:]
|
||||||
|
if len(conversation_info.done_action) >= 10
|
||||||
|
else conversation_info.done_action
|
||||||
|
)
|
||||||
action_history_text = "你之前做的事情是:"
|
action_history_text = "你之前做的事情是:"
|
||||||
for action in action_history_list:
|
for action in action_history_list:
|
||||||
if isinstance(action, dict):
|
if isinstance(action, dict):
|
||||||
action_type = action.get('action')
|
action_type = action.get("action")
|
||||||
action_reason = action.get('reason')
|
action_reason = action.get("reason")
|
||||||
action_status = action.get('status')
|
action_status = action.get("status")
|
||||||
if action_status == "recall":
|
if action_status == "recall":
|
||||||
action_history_text += f"原本打算:{action_type},但是因为有新消息,你发现这个行动不合适,所以你没做\n"
|
action_history_text += (
|
||||||
|
f"原本打算:{action_type},但是因为有新消息,你发现这个行动不合适,所以你没做\n"
|
||||||
|
)
|
||||||
elif action_status == "done":
|
elif action_status == "done":
|
||||||
action_history_text += f"你之前做了:{action_type},原因:{action_reason}\n"
|
action_history_text += f"你之前做了:{action_type},原因:{action_reason}\n"
|
||||||
elif isinstance(action, tuple):
|
elif isinstance(action, tuple):
|
||||||
@@ -98,7 +108,9 @@ class ReplyGenerator:
|
|||||||
action_reason = action[1] if len(action) > 1 else "未知原因"
|
action_reason = action[1] if len(action) > 1 else "未知原因"
|
||||||
action_status = action[2] if len(action) > 2 else "done"
|
action_status = action[2] if len(action) > 2 else "done"
|
||||||
if action_status == "recall":
|
if action_status == "recall":
|
||||||
action_history_text += f"原本打算:{action_type},但是因为有新消息,你发现这个行动不合适,所以你没做\n"
|
action_history_text += (
|
||||||
|
f"原本打算:{action_type},但是因为有新消息,你发现这个行动不合适,所以你没做\n"
|
||||||
|
)
|
||||||
elif action_status == "done":
|
elif action_status == "done":
|
||||||
action_history_text += f"你之前做了:{action_type},原因:{action_reason}\n"
|
action_history_text += f"你之前做了:{action_type},原因:{action_reason}\n"
|
||||||
|
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ class Waiter:
|
|||||||
self.chat_observer = ChatObserver.get_instance(stream_id)
|
self.chat_observer = ChatObserver.get_instance(stream_id)
|
||||||
self.personality_info = Individuality.get_instance().get_prompt(type="personality", x_person=2, level=2)
|
self.personality_info = Individuality.get_instance().get_prompt(type="personality", x_person=2, level=2)
|
||||||
self.name = global_config.BOT_NICKNAME
|
self.name = global_config.BOT_NICKNAME
|
||||||
|
|
||||||
self.wait_accumulated_time = 0
|
self.wait_accumulated_time = 0
|
||||||
|
|
||||||
async def wait(self, conversation_info: ConversationInfo) -> bool:
|
async def wait(self, conversation_info: ConversationInfo) -> bool:
|
||||||
@@ -38,20 +38,20 @@ class Waiter:
|
|||||||
# 检查是否超时
|
# 检查是否超时
|
||||||
if time.time() - wait_start_time > 300:
|
if time.time() - wait_start_time > 300:
|
||||||
self.wait_accumulated_time += 300
|
self.wait_accumulated_time += 300
|
||||||
|
|
||||||
logger.info("等待超过300秒,结束对话")
|
logger.info("等待超过300秒,结束对话")
|
||||||
wait_goal = {
|
wait_goal = {
|
||||||
"goal": f"你等待了{self.wait_accumulated_time/60}分钟,思考接下来要做什么",
|
"goal": f"你等待了{self.wait_accumulated_time / 60}分钟,思考接下来要做什么",
|
||||||
"reason": "对方很久没有回复你的消息了"
|
"reason": "对方很久没有回复你的消息了",
|
||||||
}
|
}
|
||||||
conversation_info.goal_list.append(wait_goal)
|
conversation_info.goal_list.append(wait_goal)
|
||||||
print(f"添加目标: {wait_goal}")
|
print(f"添加目标: {wait_goal}")
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
logger.info("等待中...")
|
logger.info("等待中...")
|
||||||
|
|
||||||
async def wait_listening(self, conversation_info: ConversationInfo) -> bool:
|
async def wait_listening(self, conversation_info: ConversationInfo) -> bool:
|
||||||
"""等待倾听
|
"""等待倾听
|
||||||
|
|
||||||
@@ -73,14 +73,13 @@ class Waiter:
|
|||||||
self.wait_accumulated_time += 300
|
self.wait_accumulated_time += 300
|
||||||
logger.info("等待超过300秒,结束对话")
|
logger.info("等待超过300秒,结束对话")
|
||||||
wait_goal = {
|
wait_goal = {
|
||||||
"goal": f"你等待了{self.wait_accumulated_time/60}分钟,思考接下来要做什么",
|
"goal": f"你等待了{self.wait_accumulated_time / 60}分钟,思考接下来要做什么",
|
||||||
"reason": "对方话说一半消失了,很久没有回复"
|
"reason": "对方话说一半消失了,很久没有回复",
|
||||||
}
|
}
|
||||||
conversation_info.goal_list.append(wait_goal)
|
conversation_info.goal_list.append(wait_goal)
|
||||||
print(f"添加目标: {wait_goal}")
|
print(f"添加目标: {wait_goal}")
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
logger.info("等待中...")
|
logger.info("等待中...")
|
||||||
|
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from ..chat_module.only_process.only_message_process import MessageProcessor
|
|||||||
from src.common.logger import get_module_logger, CHAT_STYLE_CONFIG, LogConfig
|
from src.common.logger import get_module_logger, CHAT_STYLE_CONFIG, LogConfig
|
||||||
from ..chat_module.think_flow_chat.think_flow_chat import ThinkFlowChat
|
from ..chat_module.think_flow_chat.think_flow_chat import ThinkFlowChat
|
||||||
from ..chat_module.reasoning_chat.reasoning_chat import ReasoningChat
|
from ..chat_module.reasoning_chat.reasoning_chat import ReasoningChat
|
||||||
from ..utils.prompt_builder import Prompt,global_prompt_manager
|
from ..utils.prompt_builder import Prompt, global_prompt_manager
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
# 定义日志配置
|
# 定义日志配置
|
||||||
@@ -89,17 +89,17 @@ class ChatBot:
|
|||||||
if userinfo.user_id in global_config.ban_user_id:
|
if userinfo.user_id in global_config.ban_user_id:
|
||||||
logger.debug(f"用户{userinfo.user_id}被禁止回复")
|
logger.debug(f"用户{userinfo.user_id}被禁止回复")
|
||||||
return
|
return
|
||||||
|
|
||||||
if message.message_info.template_info and not message.message_info.template_info.template_default:
|
if message.message_info.template_info and not message.message_info.template_info.template_default:
|
||||||
template_group_name=message.message_info.template_info.template_name
|
template_group_name = message.message_info.template_info.template_name
|
||||||
template_items=message.message_info.template_info.template_items
|
template_items = message.message_info.template_info.template_items
|
||||||
async with global_prompt_manager.async_message_scope(template_group_name):
|
async with global_prompt_manager.async_message_scope(template_group_name):
|
||||||
if isinstance(template_items,dict):
|
if isinstance(template_items, dict):
|
||||||
for k in template_items.keys():
|
for k in template_items.keys():
|
||||||
await Prompt.create_async(template_items[k],k)
|
await Prompt.create_async(template_items[k], k)
|
||||||
print(f"注册{template_items[k]},{k}")
|
print(f"注册{template_items[k]},{k}")
|
||||||
else:
|
else:
|
||||||
template_group_name=None
|
template_group_name = None
|
||||||
|
|
||||||
async def preprocess():
|
async def preprocess():
|
||||||
if global_config.enable_pfc_chatting:
|
if global_config.enable_pfc_chatting:
|
||||||
|
|||||||
@@ -87,7 +87,6 @@ async def get_embedding(text, request_type="embedding"):
|
|||||||
return embedding
|
return embedding
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def get_recent_group_messages(chat_id: str, limit: int = 12) -> list:
|
async def get_recent_group_messages(chat_id: str, limit: int = 12) -> list:
|
||||||
"""从数据库获取群组最近的消息记录
|
"""从数据库获取群组最近的消息记录
|
||||||
|
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ class ResponseGenerator:
|
|||||||
self.current_model_type = "r1" # 默认使用 R1
|
self.current_model_type = "r1" # 默认使用 R1
|
||||||
self.current_model_name = "unknown model"
|
self.current_model_name = "unknown model"
|
||||||
|
|
||||||
async def generate_response(self, message: MessageThinking,thinking_id:str) -> Optional[Union[str, List[str]]]:
|
async def generate_response(self, message: MessageThinking, thinking_id: str) -> Optional[Union[str, List[str]]]:
|
||||||
"""根据当前模型类型选择对应的生成函数"""
|
"""根据当前模型类型选择对应的生成函数"""
|
||||||
# 从global_config中获取模型概率值并选择模型
|
# 从global_config中获取模型概率值并选择模型
|
||||||
if random.random() < global_config.MODEL_R1_PROBABILITY:
|
if random.random() < global_config.MODEL_R1_PROBABILITY:
|
||||||
@@ -52,7 +52,7 @@ class ResponseGenerator:
|
|||||||
f"{self.current_model_type}思考:{message.processed_plain_text[:30] + '...' if len(message.processed_plain_text) > 30 else message.processed_plain_text}"
|
f"{self.current_model_type}思考:{message.processed_plain_text[:30] + '...' if len(message.processed_plain_text) > 30 else message.processed_plain_text}"
|
||||||
) # noqa: E501
|
) # noqa: E501
|
||||||
|
|
||||||
model_response = await self._generate_response_with_model(message, current_model,thinking_id)
|
model_response = await self._generate_response_with_model(message, current_model, thinking_id)
|
||||||
|
|
||||||
# print(f"raw_content: {model_response}")
|
# print(f"raw_content: {model_response}")
|
||||||
|
|
||||||
@@ -65,11 +65,11 @@ class ResponseGenerator:
|
|||||||
logger.info(f"{self.current_model_type}思考,失败")
|
logger.info(f"{self.current_model_type}思考,失败")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def _generate_response_with_model(self, message: MessageThinking, model: LLM_request,thinking_id:str):
|
async def _generate_response_with_model(self, message: MessageThinking, model: LLM_request, thinking_id: str):
|
||||||
sender_name = ""
|
sender_name = ""
|
||||||
|
|
||||||
info_catcher = info_catcher_manager.get_info_catcher(thinking_id)
|
info_catcher = info_catcher_manager.get_info_catcher(thinking_id)
|
||||||
|
|
||||||
if message.chat_stream.user_info.user_cardname and message.chat_stream.user_info.user_nickname:
|
if message.chat_stream.user_info.user_cardname and message.chat_stream.user_info.user_nickname:
|
||||||
sender_name = (
|
sender_name = (
|
||||||
f"[({message.chat_stream.user_info.user_id}){message.chat_stream.user_info.user_nickname}]"
|
f"[({message.chat_stream.user_info.user_id}){message.chat_stream.user_info.user_nickname}]"
|
||||||
@@ -94,14 +94,11 @@ class ResponseGenerator:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
content, reasoning_content, self.current_model_name = await model.generate_response(prompt)
|
content, reasoning_content, self.current_model_name = await model.generate_response(prompt)
|
||||||
|
|
||||||
info_catcher.catch_after_llm_generated(
|
info_catcher.catch_after_llm_generated(
|
||||||
prompt=prompt,
|
prompt=prompt, response=content, reasoning_content=reasoning_content, model_name=self.current_model_name
|
||||||
response=content,
|
)
|
||||||
reasoning_content=reasoning_content,
|
|
||||||
model_name=self.current_model_name)
|
|
||||||
|
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("生成回复时出错")
|
logger.exception("生成回复时出错")
|
||||||
return None
|
return None
|
||||||
@@ -118,7 +115,6 @@ class ResponseGenerator:
|
|||||||
|
|
||||||
return content
|
return content
|
||||||
|
|
||||||
|
|
||||||
# def _save_to_db(
|
# def _save_to_db(
|
||||||
# self,
|
# self,
|
||||||
# message: MessageRecv,
|
# message: MessageRecv,
|
||||||
|
|||||||
@@ -144,12 +144,10 @@ class PromptBuilder:
|
|||||||
for pattern in rule.get("regex", []):
|
for pattern in rule.get("regex", []):
|
||||||
result = pattern.search(message_txt)
|
result = pattern.search(message_txt)
|
||||||
if result:
|
if result:
|
||||||
reaction = rule.get('reaction', '')
|
reaction = rule.get("reaction", "")
|
||||||
for name, content in result.groupdict().items():
|
for name, content in result.groupdict().items():
|
||||||
reaction = reaction.replace(f'[{name}]', content)
|
reaction = reaction.replace(f"[{name}]", content)
|
||||||
logger.info(
|
logger.info(f"匹配到以下正则表达式:{pattern},触发反应:{reaction}")
|
||||||
f"匹配到以下正则表达式:{pattern},触发反应:{reaction}"
|
|
||||||
)
|
|
||||||
keywords_reaction_prompt += reaction + ","
|
keywords_reaction_prompt += reaction + ","
|
||||||
break
|
break
|
||||||
|
|
||||||
@@ -197,7 +195,7 @@ class PromptBuilder:
|
|||||||
prompt = await global_prompt_manager.format_prompt(
|
prompt = await global_prompt_manager.format_prompt(
|
||||||
"reasoning_prompt_main",
|
"reasoning_prompt_main",
|
||||||
relation_prompt_all=await global_prompt_manager.get_prompt_async("relationship_prompt"),
|
relation_prompt_all=await global_prompt_manager.get_prompt_async("relationship_prompt"),
|
||||||
replation_prompt=relation_prompt,
|
relation_prompt=relation_prompt,
|
||||||
sender_name=sender_name,
|
sender_name=sender_name,
|
||||||
memory_prompt=memory_prompt,
|
memory_prompt=memory_prompt,
|
||||||
prompt_info=prompt_info,
|
prompt_info=prompt_info,
|
||||||
|
|||||||
@@ -59,11 +59,7 @@ class ThinkFlowChat:
|
|||||||
|
|
||||||
return thinking_id
|
return thinking_id
|
||||||
|
|
||||||
async def _send_response_messages(self,
|
async def _send_response_messages(self, message, chat, response_set: List[str], thinking_id) -> MessageSending:
|
||||||
message,
|
|
||||||
chat,
|
|
||||||
response_set:List[str],
|
|
||||||
thinking_id) -> MessageSending:
|
|
||||||
"""发送回复消息"""
|
"""发送回复消息"""
|
||||||
container = message_manager.get_container(chat.stream_id)
|
container = message_manager.get_container(chat.stream_id)
|
||||||
thinking_message = None
|
thinking_message = None
|
||||||
@@ -260,8 +256,6 @@ class ThinkFlowChat:
|
|||||||
if random() < reply_probability:
|
if random() < reply_probability:
|
||||||
try:
|
try:
|
||||||
do_reply = True
|
do_reply = True
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# 回复前处理
|
# 回复前处理
|
||||||
await willing_manager.before_generate_reply_handle(message.message_info.message_id)
|
await willing_manager.before_generate_reply_handle(message.message_info.message_id)
|
||||||
@@ -274,9 +268,9 @@ class ThinkFlowChat:
|
|||||||
timing_results["创建思考消息"] = timer2 - timer1
|
timing_results["创建思考消息"] = timer2 - timer1
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"心流创建思考消息失败: {e}")
|
logger.error(f"心流创建思考消息失败: {e}")
|
||||||
|
|
||||||
logger.debug(f"创建捕捉器,thinking_id:{thinking_id}")
|
logger.debug(f"创建捕捉器,thinking_id:{thinking_id}")
|
||||||
|
|
||||||
info_catcher = info_catcher_manager.get_info_catcher(thinking_id)
|
info_catcher = info_catcher_manager.get_info_catcher(thinking_id)
|
||||||
info_catcher.catch_decide_to_response(message)
|
info_catcher.catch_decide_to_response(message)
|
||||||
|
|
||||||
@@ -288,32 +282,32 @@ class ThinkFlowChat:
|
|||||||
timing_results["观察"] = timer2 - timer1
|
timing_results["观察"] = timer2 - timer1
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"心流观察失败: {e}")
|
logger.error(f"心流观察失败: {e}")
|
||||||
|
|
||||||
info_catcher.catch_after_observe(timing_results["观察"])
|
info_catcher.catch_after_observe(timing_results["观察"])
|
||||||
|
|
||||||
# 思考前脑内状态
|
# 思考前脑内状态
|
||||||
try:
|
try:
|
||||||
timer1 = time.time()
|
timer1 = time.time()
|
||||||
current_mind,past_mind = await heartflow.get_subheartflow(chat.stream_id).do_thinking_before_reply(
|
current_mind, past_mind = await heartflow.get_subheartflow(chat.stream_id).do_thinking_before_reply(
|
||||||
message_txt = message.processed_plain_text,
|
message_txt=message.processed_plain_text,
|
||||||
sender_name = message.message_info.user_info.user_nickname,
|
sender_name=message.message_info.user_info.user_nickname,
|
||||||
chat_stream = chat
|
chat_stream=chat,
|
||||||
)
|
)
|
||||||
timer2 = time.time()
|
timer2 = time.time()
|
||||||
timing_results["思考前脑内状态"] = timer2 - timer1
|
timing_results["思考前脑内状态"] = timer2 - timer1
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"心流思考前脑内状态失败: {e}")
|
logger.error(f"心流思考前脑内状态失败: {e}")
|
||||||
|
|
||||||
info_catcher.catch_afer_shf_step(timing_results["思考前脑内状态"],past_mind,current_mind)
|
info_catcher.catch_afer_shf_step(timing_results["思考前脑内状态"], past_mind, current_mind)
|
||||||
|
|
||||||
# 生成回复
|
# 生成回复
|
||||||
timer1 = time.time()
|
timer1 = time.time()
|
||||||
response_set = await self.gpt.generate_response(message,thinking_id)
|
response_set = await self.gpt.generate_response(message, thinking_id)
|
||||||
timer2 = time.time()
|
timer2 = time.time()
|
||||||
timing_results["生成回复"] = timer2 - timer1
|
timing_results["生成回复"] = timer2 - timer1
|
||||||
|
|
||||||
info_catcher.catch_after_generate_response(timing_results["生成回复"])
|
info_catcher.catch_after_generate_response(timing_results["生成回复"])
|
||||||
|
|
||||||
if not response_set:
|
if not response_set:
|
||||||
logger.info("回复生成失败,返回为空")
|
logger.info("回复生成失败,返回为空")
|
||||||
return
|
return
|
||||||
@@ -326,11 +320,9 @@ class ThinkFlowChat:
|
|||||||
timing_results["发送消息"] = timer2 - timer1
|
timing_results["发送消息"] = timer2 - timer1
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"心流发送消息失败: {e}")
|
logger.error(f"心流发送消息失败: {e}")
|
||||||
|
|
||||||
|
info_catcher.catch_after_response(timing_results["发送消息"], response_set, first_bot_msg)
|
||||||
info_catcher.catch_after_response(timing_results["发送消息"],response_set,first_bot_msg)
|
|
||||||
|
|
||||||
|
|
||||||
info_catcher.done_catch()
|
info_catcher.done_catch()
|
||||||
|
|
||||||
# 处理表情包
|
# 处理表情包
|
||||||
|
|||||||
@@ -35,44 +35,51 @@ class ResponseGenerator:
|
|||||||
self.current_model_type = "r1" # 默认使用 R1
|
self.current_model_type = "r1" # 默认使用 R1
|
||||||
self.current_model_name = "unknown model"
|
self.current_model_name = "unknown model"
|
||||||
|
|
||||||
async def generate_response(self, message: MessageRecv,thinking_id:str) -> Optional[List[str]]:
|
async def generate_response(self, message: MessageRecv, thinking_id: str) -> Optional[List[str]]:
|
||||||
"""根据当前模型类型选择对应的生成函数"""
|
"""根据当前模型类型选择对应的生成函数"""
|
||||||
|
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"思考:{message.processed_plain_text[:30] + '...' if len(message.processed_plain_text) > 30 else message.processed_plain_text}"
|
f"思考:{message.processed_plain_text[:30] + '...' if len(message.processed_plain_text) > 30 else message.processed_plain_text}"
|
||||||
)
|
)
|
||||||
|
|
||||||
arousal_multiplier = MoodManager.get_instance().get_arousal_multiplier()
|
arousal_multiplier = MoodManager.get_instance().get_arousal_multiplier()
|
||||||
|
|
||||||
time1 = time.time()
|
time1 = time.time()
|
||||||
|
|
||||||
checked = False
|
checked = False
|
||||||
if random.random() > 0:
|
if random.random() > 0:
|
||||||
checked = False
|
checked = False
|
||||||
current_model = self.model_normal
|
current_model = self.model_normal
|
||||||
current_model.temperature = 0.3 * arousal_multiplier #激活度越高,温度越高
|
current_model.temperature = 0.3 * arousal_multiplier # 激活度越高,温度越高
|
||||||
model_response = await self._generate_response_with_model(message, current_model,thinking_id,mode="normal")
|
model_response = await self._generate_response_with_model(
|
||||||
|
message, current_model, thinking_id, mode="normal"
|
||||||
|
)
|
||||||
|
|
||||||
model_checked_response = model_response
|
model_checked_response = model_response
|
||||||
else:
|
else:
|
||||||
checked = True
|
checked = True
|
||||||
current_model = self.model_normal
|
current_model = self.model_normal
|
||||||
current_model.temperature = 0.3 * arousal_multiplier #激活度越高,温度越高
|
current_model.temperature = 0.3 * arousal_multiplier # 激活度越高,温度越高
|
||||||
print(f"生成{message.processed_plain_text}回复温度是:{current_model.temperature}")
|
print(f"生成{message.processed_plain_text}回复温度是:{current_model.temperature}")
|
||||||
model_response = await self._generate_response_with_model(message, current_model,thinking_id,mode="simple")
|
model_response = await self._generate_response_with_model(
|
||||||
|
message, current_model, thinking_id, mode="simple"
|
||||||
|
)
|
||||||
|
|
||||||
current_model.temperature = 0.3
|
current_model.temperature = 0.3
|
||||||
model_checked_response = await self._check_response_with_model(message, model_response, current_model,thinking_id)
|
model_checked_response = await self._check_response_with_model(
|
||||||
|
message, model_response, current_model, thinking_id
|
||||||
|
)
|
||||||
|
|
||||||
time2 = time.time()
|
time2 = time.time()
|
||||||
|
|
||||||
if model_response:
|
if model_response:
|
||||||
if checked:
|
if checked:
|
||||||
logger.info(f"{global_config.BOT_NICKNAME}的回复是:{model_response},思忖后,回复是:{model_checked_response},生成回复时间: {time2 - time1}秒")
|
logger.info(
|
||||||
|
f"{global_config.BOT_NICKNAME}的回复是:{model_response},思忖后,回复是:{model_checked_response},生成回复时间: {time2 - time1}秒"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
logger.info(f"{global_config.BOT_NICKNAME}的回复是:{model_response},生成回复时间: {time2 - time1}秒")
|
logger.info(f"{global_config.BOT_NICKNAME}的回复是:{model_response},生成回复时间: {time2 - time1}秒")
|
||||||
|
|
||||||
model_processed_response = await self._process_response(model_checked_response)
|
model_processed_response = await self._process_response(model_checked_response)
|
||||||
|
|
||||||
return model_processed_response
|
return model_processed_response
|
||||||
@@ -80,11 +87,13 @@ class ResponseGenerator:
|
|||||||
logger.info(f"{self.current_model_type}思考,失败")
|
logger.info(f"{self.current_model_type}思考,失败")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def _generate_response_with_model(self, message: MessageRecv, model: LLM_request,thinking_id:str,mode:str = "normal") -> str:
|
async def _generate_response_with_model(
|
||||||
|
self, message: MessageRecv, model: LLM_request, thinking_id: str, mode: str = "normal"
|
||||||
|
) -> str:
|
||||||
sender_name = ""
|
sender_name = ""
|
||||||
|
|
||||||
info_catcher = info_catcher_manager.get_info_catcher(thinking_id)
|
info_catcher = info_catcher_manager.get_info_catcher(thinking_id)
|
||||||
|
|
||||||
if message.chat_stream.user_info.user_cardname and message.chat_stream.user_info.user_nickname:
|
if message.chat_stream.user_info.user_cardname and message.chat_stream.user_info.user_nickname:
|
||||||
sender_name = (
|
sender_name = (
|
||||||
f"[({message.chat_stream.user_info.user_id}){message.chat_stream.user_info.user_nickname}]"
|
f"[({message.chat_stream.user_info.user_id}){message.chat_stream.user_info.user_nickname}]"
|
||||||
@@ -116,25 +125,22 @@ class ResponseGenerator:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
content, reasoning_content, self.current_model_name = await model.generate_response(prompt)
|
content, reasoning_content, self.current_model_name = await model.generate_response(prompt)
|
||||||
|
|
||||||
|
|
||||||
info_catcher.catch_after_llm_generated(
|
info_catcher.catch_after_llm_generated(
|
||||||
prompt=prompt,
|
prompt=prompt, response=content, reasoning_content=reasoning_content, model_name=self.current_model_name
|
||||||
response=content,
|
)
|
||||||
reasoning_content=reasoning_content,
|
|
||||||
model_name=self.current_model_name)
|
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("生成回复时出错")
|
logger.exception("生成回复时出错")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
return content
|
return content
|
||||||
|
|
||||||
async def _check_response_with_model(self, message: MessageRecv, content:str, model: LLM_request,thinking_id:str) -> str:
|
async def _check_response_with_model(
|
||||||
|
self, message: MessageRecv, content: str, model: LLM_request, thinking_id: str
|
||||||
|
) -> str:
|
||||||
_info_catcher = info_catcher_manager.get_info_catcher(thinking_id)
|
_info_catcher = info_catcher_manager.get_info_catcher(thinking_id)
|
||||||
|
|
||||||
sender_name = ""
|
sender_name = ""
|
||||||
if message.chat_stream.user_info.user_cardname and message.chat_stream.user_info.user_nickname:
|
if message.chat_stream.user_info.user_cardname and message.chat_stream.user_info.user_nickname:
|
||||||
sender_name = (
|
sender_name = (
|
||||||
@@ -145,8 +151,7 @@ class ResponseGenerator:
|
|||||||
sender_name = f"({message.chat_stream.user_info.user_id}){message.chat_stream.user_info.user_nickname}"
|
sender_name = f"({message.chat_stream.user_info.user_id}){message.chat_stream.user_info.user_nickname}"
|
||||||
else:
|
else:
|
||||||
sender_name = f"用户({message.chat_stream.user_info.user_id})"
|
sender_name = f"用户({message.chat_stream.user_info.user_id})"
|
||||||
|
|
||||||
|
|
||||||
# 构建prompt
|
# 构建prompt
|
||||||
timer1 = time.time()
|
timer1 = time.time()
|
||||||
prompt = await prompt_builder._build_prompt_check_response(
|
prompt = await prompt_builder._build_prompt_check_response(
|
||||||
@@ -154,7 +159,7 @@ class ResponseGenerator:
|
|||||||
message_txt=message.processed_plain_text,
|
message_txt=message.processed_plain_text,
|
||||||
sender_name=sender_name,
|
sender_name=sender_name,
|
||||||
stream_id=message.chat_stream.stream_id,
|
stream_id=message.chat_stream.stream_id,
|
||||||
content=content
|
content=content,
|
||||||
)
|
)
|
||||||
timer2 = time.time()
|
timer2 = time.time()
|
||||||
logger.info(f"构建check_prompt: {prompt}")
|
logger.info(f"构建check_prompt: {prompt}")
|
||||||
@@ -162,19 +167,17 @@ class ResponseGenerator:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
checked_content, reasoning_content, self.current_model_name = await model.generate_response(prompt)
|
checked_content, reasoning_content, self.current_model_name = await model.generate_response(prompt)
|
||||||
|
|
||||||
|
|
||||||
# info_catcher.catch_after_llm_generated(
|
# info_catcher.catch_after_llm_generated(
|
||||||
# prompt=prompt,
|
# prompt=prompt,
|
||||||
# response=content,
|
# response=content,
|
||||||
# reasoning_content=reasoning_content,
|
# reasoning_content=reasoning_content,
|
||||||
# model_name=self.current_model_name)
|
# model_name=self.current_model_name)
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("检查回复时出错")
|
logger.exception("检查回复时出错")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
return checked_content
|
return checked_content
|
||||||
|
|
||||||
async def _get_emotion_tags(self, content: str, processed_plain_text: str):
|
async def _get_emotion_tags(self, content: str, processed_plain_text: str):
|
||||||
|
|||||||
@@ -110,12 +110,10 @@ class PromptBuilder:
|
|||||||
for pattern in rule.get("regex", []):
|
for pattern in rule.get("regex", []):
|
||||||
result = pattern.search(message_txt)
|
result = pattern.search(message_txt)
|
||||||
if result:
|
if result:
|
||||||
reaction = rule.get('reaction', '')
|
reaction = rule.get("reaction", "")
|
||||||
for name, content in result.groupdict().items():
|
for name, content in result.groupdict().items():
|
||||||
reaction = reaction.replace(f'[{name}]', content)
|
reaction = reaction.replace(f"[{name}]", content)
|
||||||
logger.info(
|
logger.info(f"匹配到以下正则表达式:{pattern},触发反应:{reaction}")
|
||||||
f"匹配到以下正则表达式:{pattern},触发反应:{reaction}"
|
|
||||||
)
|
|
||||||
keywords_reaction_prompt += reaction + ","
|
keywords_reaction_prompt += reaction + ","
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|||||||
@@ -225,6 +225,7 @@ class Memory_graph:
|
|||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
# 海马体
|
# 海马体
|
||||||
class Hippocampus:
|
class Hippocampus:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@@ -653,7 +654,6 @@ class Hippocampus:
|
|||||||
return activation_ratio
|
return activation_ratio
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# 负责海马体与其他部分的交互
|
# 负责海马体与其他部分的交互
|
||||||
class EntorhinalCortex:
|
class EntorhinalCortex:
|
||||||
def __init__(self, hippocampus: Hippocampus):
|
def __init__(self, hippocampus: Hippocampus):
|
||||||
|
|||||||
@@ -27,7 +27,6 @@ async def test_memory_system():
|
|||||||
# 测试记忆检索
|
# 测试记忆检索
|
||||||
test_text = "千石可乐在群里聊天"
|
test_text = "千石可乐在群里聊天"
|
||||||
|
|
||||||
|
|
||||||
# test_text = '''千石可乐:分不清AI的陪伴和人类的陪伴,是这样吗?'''
|
# test_text = '''千石可乐:分不清AI的陪伴和人类的陪伴,是这样吗?'''
|
||||||
print(f"开始测试记忆检索,测试文本: {test_text}\n")
|
print(f"开始测试记忆检索,测试文本: {test_text}\n")
|
||||||
memories = await hippocampus_manager.get_memory_from_text(
|
memories = await hippocampus_manager.get_memory_from_text(
|
||||||
|
|||||||
@@ -574,7 +574,7 @@ class LLM_request:
|
|||||||
reasoning_content = message.get("reasoning_content", "")
|
reasoning_content = message.get("reasoning_content", "")
|
||||||
if not reasoning_content:
|
if not reasoning_content:
|
||||||
reasoning_content = reasoning
|
reasoning_content = reasoning
|
||||||
|
|
||||||
# 提取工具调用信息
|
# 提取工具调用信息
|
||||||
tool_calls = message.get("tool_calls", None)
|
tool_calls = message.get("tool_calls", None)
|
||||||
|
|
||||||
@@ -592,7 +592,7 @@ class LLM_request:
|
|||||||
request_type=request_type if request_type is not None else self.request_type,
|
request_type=request_type if request_type is not None else self.request_type,
|
||||||
endpoint=endpoint,
|
endpoint=endpoint,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 只有当tool_calls存在且不为空时才返回
|
# 只有当tool_calls存在且不为空时才返回
|
||||||
if tool_calls:
|
if tool_calls:
|
||||||
return content, reasoning_content, tool_calls
|
return content, reasoning_content, tool_calls
|
||||||
@@ -657,9 +657,7 @@ class LLM_request:
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
}
|
}
|
||||||
|
|
||||||
response = await self._execute_request(
|
response = await self._execute_request(endpoint="/chat/completions", payload=data, prompt=prompt)
|
||||||
endpoint="/chat/completions", payload=data, prompt=prompt
|
|
||||||
)
|
|
||||||
# 原样返回响应,不做处理
|
# 原样返回响应,不做处理
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|||||||
@@ -238,14 +238,14 @@ class MoodManager:
|
|||||||
base_prompt += "情绪比较平静。"
|
base_prompt += "情绪比较平静。"
|
||||||
|
|
||||||
return base_prompt
|
return base_prompt
|
||||||
|
|
||||||
def get_arousal_multiplier(self) -> float:
|
def get_arousal_multiplier(self) -> float:
|
||||||
"""根据当前情绪状态返回唤醒度乘数"""
|
"""根据当前情绪状态返回唤醒度乘数"""
|
||||||
if self.current_mood.arousal > 0.4:
|
if self.current_mood.arousal > 0.4:
|
||||||
multiplier = 1 + min(0.15,(self.current_mood.arousal - 0.4)/3)
|
multiplier = 1 + min(0.15, (self.current_mood.arousal - 0.4) / 3)
|
||||||
return multiplier
|
return multiplier
|
||||||
elif self.current_mood.arousal < -0.4:
|
elif self.current_mood.arousal < -0.4:
|
||||||
multiplier = 1 - min(0.15,((0 - self.current_mood.arousal) - 0.4)/3)
|
multiplier = 1 - min(0.15, ((0 - self.current_mood.arousal) - 0.4) / 3)
|
||||||
return multiplier
|
return multiplier
|
||||||
return 1.0
|
return 1.0
|
||||||
|
|
||||||
|
|||||||
@@ -1,28 +1,29 @@
|
|||||||
from src.plugins.config.config import global_config
|
from src.plugins.config.config import global_config
|
||||||
from src.plugins.chat.message import MessageRecv,MessageSending,Message
|
from src.plugins.chat.message import MessageRecv, MessageSending, Message
|
||||||
from src.common.database import db
|
from src.common.database import db
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
|
||||||
class InfoCatcher:
|
class InfoCatcher:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.chat_history = [] # 聊天历史,长度为三倍使用的上下文
|
self.chat_history = [] # 聊天历史,长度为三倍使用的上下文
|
||||||
self.context_length = global_config.MAX_CONTEXT_SIZE
|
self.context_length = global_config.MAX_CONTEXT_SIZE
|
||||||
self.chat_history_in_thinking = [] # 思考期间的聊天内容
|
self.chat_history_in_thinking = [] # 思考期间的聊天内容
|
||||||
self.chat_history_after_response = [] # 回复后的聊天内容,长度为一倍上下文
|
self.chat_history_after_response = [] # 回复后的聊天内容,长度为一倍上下文
|
||||||
|
|
||||||
self.chat_id = ""
|
self.chat_id = ""
|
||||||
self.response_mode = global_config.response_mode
|
self.response_mode = global_config.response_mode
|
||||||
self.trigger_response_text = ""
|
self.trigger_response_text = ""
|
||||||
self.response_text = ""
|
self.response_text = ""
|
||||||
|
|
||||||
self.trigger_response_time = 0
|
self.trigger_response_time = 0
|
||||||
self.trigger_response_message = None
|
self.trigger_response_message = None
|
||||||
|
|
||||||
self.response_time = 0
|
self.response_time = 0
|
||||||
self.response_messages = []
|
self.response_messages = []
|
||||||
|
|
||||||
# 使用字典来存储 heartflow 模式的数据
|
# 使用字典来存储 heartflow 模式的数据
|
||||||
self.heartflow_data = {
|
self.heartflow_data = {
|
||||||
"heart_flow_prompt": "",
|
"heart_flow_prompt": "",
|
||||||
@@ -32,17 +33,12 @@ class InfoCatcher:
|
|||||||
"sub_heartflow_model": "",
|
"sub_heartflow_model": "",
|
||||||
"prompt": "",
|
"prompt": "",
|
||||||
"response": "",
|
"response": "",
|
||||||
"model": ""
|
"model": "",
|
||||||
}
|
}
|
||||||
|
|
||||||
# 使用字典来存储 reasoning 模式的数据
|
# 使用字典来存储 reasoning 模式的数据
|
||||||
self.reasoning_data = {
|
self.reasoning_data = {"thinking_log": "", "prompt": "", "response": "", "model": ""}
|
||||||
"thinking_log": "",
|
|
||||||
"prompt": "",
|
|
||||||
"response": "",
|
|
||||||
"model": ""
|
|
||||||
}
|
|
||||||
|
|
||||||
# 耗时
|
# 耗时
|
||||||
self.timing_results = {
|
self.timing_results = {
|
||||||
"interested_rate_time": 0,
|
"interested_rate_time": 0,
|
||||||
@@ -50,24 +46,24 @@ class InfoCatcher:
|
|||||||
"sub_heartflow_step_time": 0,
|
"sub_heartflow_step_time": 0,
|
||||||
"make_response_time": 0,
|
"make_response_time": 0,
|
||||||
}
|
}
|
||||||
|
|
||||||
def catch_decide_to_response(self,message:MessageRecv):
|
def catch_decide_to_response(self, message: MessageRecv):
|
||||||
# 搜集决定回复时的信息
|
# 搜集决定回复时的信息
|
||||||
self.trigger_response_message = message
|
self.trigger_response_message = message
|
||||||
self.trigger_response_text = message.detailed_plain_text
|
self.trigger_response_text = message.detailed_plain_text
|
||||||
|
|
||||||
self.trigger_response_time = time.time()
|
self.trigger_response_time = time.time()
|
||||||
|
|
||||||
self.chat_id = message.chat_stream.stream_id
|
self.chat_id = message.chat_stream.stream_id
|
||||||
|
|
||||||
self.chat_history = self.get_message_from_db_before_msg(message)
|
self.chat_history = self.get_message_from_db_before_msg(message)
|
||||||
|
|
||||||
def catch_after_observe(self,obs_duration:float):#这里可以有更多信息
|
def catch_after_observe(self, obs_duration: float): # 这里可以有更多信息
|
||||||
self.timing_results["sub_heartflow_observe_time"] = obs_duration
|
self.timing_results["sub_heartflow_observe_time"] = obs_duration
|
||||||
|
|
||||||
# def catch_shf
|
# def catch_shf
|
||||||
|
|
||||||
def catch_afer_shf_step(self,step_duration:float,past_mind:str,current_mind:str):
|
def catch_afer_shf_step(self, step_duration: float, past_mind: str, current_mind: str):
|
||||||
self.timing_results["sub_heartflow_step_time"] = step_duration
|
self.timing_results["sub_heartflow_step_time"] = step_duration
|
||||||
if len(past_mind) > 1:
|
if len(past_mind) > 1:
|
||||||
self.heartflow_data["sub_heartflow_before"] = past_mind[-1]
|
self.heartflow_data["sub_heartflow_before"] = past_mind[-1]
|
||||||
@@ -75,11 +71,8 @@ class InfoCatcher:
|
|||||||
else:
|
else:
|
||||||
self.heartflow_data["sub_heartflow_before"] = past_mind[-1]
|
self.heartflow_data["sub_heartflow_before"] = past_mind[-1]
|
||||||
self.heartflow_data["sub_heartflow_now"] = current_mind
|
self.heartflow_data["sub_heartflow_now"] = current_mind
|
||||||
|
|
||||||
def catch_after_llm_generated(self,prompt:str,
|
def catch_after_llm_generated(self, prompt: str, response: str, reasoning_content: str = "", model_name: str = ""):
|
||||||
response:str,
|
|
||||||
reasoning_content:str = "",
|
|
||||||
model_name:str = ""):
|
|
||||||
if self.response_mode == "heart_flow":
|
if self.response_mode == "heart_flow":
|
||||||
self.heartflow_data["prompt"] = prompt
|
self.heartflow_data["prompt"] = prompt
|
||||||
self.heartflow_data["response"] = response
|
self.heartflow_data["response"] = response
|
||||||
@@ -89,41 +82,38 @@ class InfoCatcher:
|
|||||||
self.reasoning_data["prompt"] = prompt
|
self.reasoning_data["prompt"] = prompt
|
||||||
self.reasoning_data["response"] = response
|
self.reasoning_data["response"] = response
|
||||||
self.reasoning_data["model"] = model_name
|
self.reasoning_data["model"] = model_name
|
||||||
|
|
||||||
self.response_text = response
|
self.response_text = response
|
||||||
|
|
||||||
def catch_after_generate_response(self,response_duration:float):
|
def catch_after_generate_response(self, response_duration: float):
|
||||||
self.timing_results["make_response_time"] = response_duration
|
self.timing_results["make_response_time"] = response_duration
|
||||||
|
|
||||||
|
def catch_after_response(
|
||||||
|
self, response_duration: float, response_message: List[str], first_bot_msg: MessageSending
|
||||||
def catch_after_response(self,response_duration:float,
|
):
|
||||||
response_message:List[str],
|
|
||||||
first_bot_msg:MessageSending):
|
|
||||||
self.timing_results["make_response_time"] = response_duration
|
self.timing_results["make_response_time"] = response_duration
|
||||||
self.response_time = time.time()
|
self.response_time = time.time()
|
||||||
for msg in response_message:
|
for msg in response_message:
|
||||||
self.response_messages.append(msg)
|
self.response_messages.append(msg)
|
||||||
|
|
||||||
self.chat_history_in_thinking = self.get_message_from_db_between_msgs(self.trigger_response_message,first_bot_msg)
|
self.chat_history_in_thinking = self.get_message_from_db_between_msgs(
|
||||||
|
self.trigger_response_message, first_bot_msg
|
||||||
|
)
|
||||||
|
|
||||||
def get_message_from_db_between_msgs(self, message_start: Message, message_end: Message):
|
def get_message_from_db_between_msgs(self, message_start: Message, message_end: Message):
|
||||||
try:
|
try:
|
||||||
# 从数据库中获取消息的时间戳
|
# 从数据库中获取消息的时间戳
|
||||||
time_start = message_start.message_info.time
|
time_start = message_start.message_info.time
|
||||||
time_end = message_end.message_info.time
|
time_end = message_end.message_info.time
|
||||||
chat_id = message_start.chat_stream.stream_id
|
chat_id = message_start.chat_stream.stream_id
|
||||||
|
|
||||||
print(f"查询参数: time_start={time_start}, time_end={time_end}, chat_id={chat_id}")
|
print(f"查询参数: time_start={time_start}, time_end={time_end}, chat_id={chat_id}")
|
||||||
|
|
||||||
# 查询数据库,获取 chat_id 相同且时间在 start 和 end 之间的数据
|
# 查询数据库,获取 chat_id 相同且时间在 start 和 end 之间的数据
|
||||||
messages_between = db.messages.find(
|
messages_between = db.messages.find(
|
||||||
{
|
{"chat_id": chat_id, "time": {"$gt": time_start, "$lt": time_end}}
|
||||||
"chat_id": chat_id,
|
|
||||||
"time": {"$gt": time_start, "$lt": time_end}
|
|
||||||
}
|
|
||||||
).sort("time", -1)
|
).sort("time", -1)
|
||||||
|
|
||||||
result = list(messages_between)
|
result = list(messages_between)
|
||||||
print(f"查询结果数量: {len(result)}")
|
print(f"查询结果数量: {len(result)}")
|
||||||
if result:
|
if result:
|
||||||
@@ -133,21 +123,23 @@ class InfoCatcher:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"获取消息时出错: {str(e)}")
|
print(f"获取消息时出错: {str(e)}")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def get_message_from_db_before_msg(self, message: MessageRecv):
|
def get_message_from_db_before_msg(self, message: MessageRecv):
|
||||||
# 从数据库中获取消息
|
# 从数据库中获取消息
|
||||||
message_id = message.message_info.message_id
|
message_id = message.message_info.message_id
|
||||||
chat_id = message.chat_stream.stream_id
|
chat_id = message.chat_stream.stream_id
|
||||||
|
|
||||||
# 查询数据库,获取 chat_id 相同且 message_id 小于当前消息的 30 条数据
|
# 查询数据库,获取 chat_id 相同且 message_id 小于当前消息的 30 条数据
|
||||||
messages_before = db.messages.find(
|
messages_before = (
|
||||||
{"chat_id": chat_id, "message_id": {"$lt": message_id}}
|
db.messages.find({"chat_id": chat_id, "message_id": {"$lt": message_id}})
|
||||||
).sort("time", -1).limit(self.context_length*3) #获取更多历史信息
|
.sort("time", -1)
|
||||||
|
.limit(self.context_length * 3)
|
||||||
|
) # 获取更多历史信息
|
||||||
|
|
||||||
return list(messages_before)
|
return list(messages_before)
|
||||||
|
|
||||||
def message_list_to_dict(self, message_list):
|
def message_list_to_dict(self, message_list):
|
||||||
#存储简化的聊天记录
|
# 存储简化的聊天记录
|
||||||
result = []
|
result = []
|
||||||
for message in message_list:
|
for message in message_list:
|
||||||
if not isinstance(message, dict):
|
if not isinstance(message, dict):
|
||||||
@@ -160,7 +152,7 @@ class InfoCatcher:
|
|||||||
"processed_plain_text": message["processed_plain_text"],
|
"processed_plain_text": message["processed_plain_text"],
|
||||||
}
|
}
|
||||||
result.append(lite_message)
|
result.append(lite_message)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def message_to_dict(self, message):
|
def message_to_dict(self, message):
|
||||||
@@ -176,12 +168,12 @@ class InfoCatcher:
|
|||||||
"processed_plain_text": message.processed_plain_text,
|
"processed_plain_text": message.processed_plain_text,
|
||||||
# "detailed_plain_text": message.detailed_plain_text
|
# "detailed_plain_text": message.detailed_plain_text
|
||||||
}
|
}
|
||||||
|
|
||||||
def done_catch(self):
|
def done_catch(self):
|
||||||
"""将收集到的信息存储到数据库的 thinking_log 集合中"""
|
"""将收集到的信息存储到数据库的 thinking_log 集合中"""
|
||||||
try:
|
try:
|
||||||
# 将消息对象转换为可序列化的字典
|
# 将消息对象转换为可序列化的字典
|
||||||
|
|
||||||
thinking_log_data = {
|
thinking_log_data = {
|
||||||
"chat_id": self.chat_id,
|
"chat_id": self.chat_id,
|
||||||
"response_mode": self.response_mode,
|
"response_mode": self.response_mode,
|
||||||
@@ -198,7 +190,7 @@ class InfoCatcher:
|
|||||||
"timing_results": self.timing_results,
|
"timing_results": self.timing_results,
|
||||||
"chat_history": self.message_list_to_dict(self.chat_history),
|
"chat_history": self.message_list_to_dict(self.chat_history),
|
||||||
"chat_history_in_thinking": self.message_list_to_dict(self.chat_history_in_thinking),
|
"chat_history_in_thinking": self.message_list_to_dict(self.chat_history_in_thinking),
|
||||||
"chat_history_after_response": self.message_list_to_dict(self.chat_history_after_response)
|
"chat_history_after_response": self.message_list_to_dict(self.chat_history_after_response),
|
||||||
}
|
}
|
||||||
|
|
||||||
# 根据不同的响应模式添加相应的数据
|
# 根据不同的响应模式添加相应的数据
|
||||||
@@ -209,20 +201,22 @@ class InfoCatcher:
|
|||||||
|
|
||||||
# 将数据插入到 thinking_log 集合中
|
# 将数据插入到 thinking_log 集合中
|
||||||
db.thinking_log.insert_one(thinking_log_data)
|
db.thinking_log.insert_one(thinking_log_data)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"存储思考日志时出错: {str(e)}")
|
print(f"存储思考日志时出错: {str(e)}")
|
||||||
print(traceback.format_exc())
|
print(traceback.format_exc())
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
class InfoCatcherManager:
|
class InfoCatcherManager:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.info_catchers = {}
|
self.info_catchers = {}
|
||||||
|
|
||||||
def get_info_catcher(self,thinking_id:str) -> InfoCatcher:
|
def get_info_catcher(self, thinking_id: str) -> InfoCatcher:
|
||||||
if thinking_id not in self.info_catchers:
|
if thinking_id not in self.info_catchers:
|
||||||
self.info_catchers[thinking_id] = InfoCatcher()
|
self.info_catchers[thinking_id] = InfoCatcher()
|
||||||
return self.info_catchers[thinking_id]
|
return self.info_catchers[thinking_id]
|
||||||
|
|
||||||
info_catcher_manager = InfoCatcherManager()
|
|
||||||
|
info_catcher_manager = InfoCatcherManager()
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ class ScheduleGenerator:
|
|||||||
# 使用离线LLM模型
|
# 使用离线LLM模型
|
||||||
self.llm_scheduler_all = LLM_request(
|
self.llm_scheduler_all = LLM_request(
|
||||||
model=global_config.llm_reasoning,
|
model=global_config.llm_reasoning,
|
||||||
temperature=global_config.SCHEDULE_TEMPERATURE+0.3,
|
temperature=global_config.SCHEDULE_TEMPERATURE + 0.3,
|
||||||
max_tokens=7000,
|
max_tokens=7000,
|
||||||
request_type="schedule",
|
request_type="schedule",
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ from src.common.logger import get_module_logger
|
|||||||
|
|
||||||
logger = get_module_logger("message_storage")
|
logger = get_module_logger("message_storage")
|
||||||
|
|
||||||
|
|
||||||
class MessageStorage:
|
class MessageStorage:
|
||||||
async def store_message(self, message: Union[MessageSending, MessageRecv], chat_stream: ChatStream) -> None:
|
async def store_message(self, message: Union[MessageSending, MessageRecv], chat_stream: ChatStream) -> None:
|
||||||
"""存储消息到数据库"""
|
"""存储消息到数据库"""
|
||||||
|
|||||||
@@ -1,9 +1,12 @@
|
|||||||
# import re
|
|
||||||
import ast
|
|
||||||
from typing import Dict, Any, Optional, List, Union
|
from typing import Dict, Any, Optional, List, Union
|
||||||
|
import re
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from src.common.logger import get_module_logger
|
||||||
|
# import traceback
|
||||||
|
|
||||||
|
logger = get_module_logger("prompt_build")
|
||||||
|
|
||||||
|
|
||||||
class PromptContext:
|
class PromptContext:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@@ -95,15 +98,13 @@ class Prompt(str):
|
|||||||
# 如果传入的是元组,转换为列表
|
# 如果传入的是元组,转换为列表
|
||||||
if isinstance(args, tuple):
|
if isinstance(args, tuple):
|
||||||
args = list(args)
|
args = list(args)
|
||||||
|
should_register = kwargs.pop("_should_register", True)
|
||||||
# 解析模板
|
# 解析模板
|
||||||
tree = ast.parse(f"f'''{fstr}'''", mode="eval")
|
template_args = []
|
||||||
template_args = set()
|
result = re.findall(r"\{(.*?)\}", fstr)
|
||||||
for node in ast.walk(tree):
|
for expr in result:
|
||||||
if isinstance(node, ast.FormattedValue):
|
if expr and expr not in template_args:
|
||||||
expr = ast.get_source_segment(fstr, node.value)
|
template_args.append(expr)
|
||||||
if expr:
|
|
||||||
template_args.add(expr)
|
|
||||||
|
|
||||||
# 如果提供了初始参数,立即格式化
|
# 如果提供了初始参数,立即格式化
|
||||||
if kwargs or args:
|
if kwargs or args:
|
||||||
@@ -119,17 +120,20 @@ class Prompt(str):
|
|||||||
obj._kwargs = kwargs
|
obj._kwargs = kwargs
|
||||||
|
|
||||||
# 修改自动注册逻辑
|
# 修改自动注册逻辑
|
||||||
if global_prompt_manager._context._current_context:
|
if should_register:
|
||||||
# 如果存在当前上下文,则注册到上下文中
|
if global_prompt_manager._context._current_context:
|
||||||
# asyncio.create_task(global_prompt_manager._context.register_async(obj))
|
# 如果存在当前上下文,则注册到上下文中
|
||||||
pass
|
# asyncio.create_task(global_prompt_manager._context.register_async(obj))
|
||||||
else:
|
pass
|
||||||
# 否则注册到全局管理器
|
else:
|
||||||
global_prompt_manager.register(obj)
|
# 否则注册到全局管理器
|
||||||
|
global_prompt_manager.register(obj)
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def create_async(cls, fstr: str, name: Optional[str] = None, args: Union[List[Any], tuple[Any, ...]] = None, **kwargs):
|
async def create_async(
|
||||||
|
cls, fstr: str, name: Optional[str] = None, args: Union[List[Any], tuple[Any, ...]] = None, **kwargs
|
||||||
|
):
|
||||||
"""异步创建Prompt实例"""
|
"""异步创建Prompt实例"""
|
||||||
prompt = cls(fstr, name, args, **kwargs)
|
prompt = cls(fstr, name, args, **kwargs)
|
||||||
if global_prompt_manager._context._current_context:
|
if global_prompt_manager._context._current_context:
|
||||||
@@ -138,25 +142,29 @@ class Prompt(str):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _format_template(cls, template: str, args: List[Any] = None, kwargs: Dict[str, Any] = None) -> str:
|
def _format_template(cls, template: str, args: List[Any] = None, kwargs: Dict[str, Any] = None) -> str:
|
||||||
fmt_str = f"f'''{template}'''"
|
|
||||||
tree = ast.parse(fmt_str, mode="eval")
|
|
||||||
template_args = []
|
template_args = []
|
||||||
for node in ast.walk(tree):
|
result = re.findall(r"\{(.*?)\}", template)
|
||||||
if isinstance(node, ast.FormattedValue):
|
for expr in result:
|
||||||
expr = ast.get_source_segment(fmt_str, node.value)
|
if expr and expr not in template_args:
|
||||||
if expr and expr not in template_args:
|
template_args.append(expr)
|
||||||
template_args.append(expr)
|
|
||||||
formatted_args = {}
|
formatted_args = {}
|
||||||
formatted_kwargs = {}
|
formatted_kwargs = {}
|
||||||
|
|
||||||
# 处理位置参数
|
# 处理位置参数
|
||||||
if args:
|
if args:
|
||||||
|
# print(len(template_args), len(args), template_args, args)
|
||||||
for i in range(len(args)):
|
for i in range(len(args)):
|
||||||
arg = args[i]
|
if i < len(template_args):
|
||||||
if isinstance(arg, Prompt):
|
arg = args[i]
|
||||||
formatted_args[template_args[i]] = arg.format(**kwargs)
|
if isinstance(arg, Prompt):
|
||||||
|
formatted_args[template_args[i]] = arg.format(**kwargs)
|
||||||
|
else:
|
||||||
|
formatted_args[template_args[i]] = arg
|
||||||
else:
|
else:
|
||||||
formatted_args[template_args[i]] = arg
|
logger.error(
|
||||||
|
f"构建提示词模板失败,解析到的参数列表{template_args},长度为{len(template_args)},输入的参数列表为{args},提示词模板为{template}"
|
||||||
|
)
|
||||||
|
raise ValueError("格式化模板失败")
|
||||||
|
|
||||||
# 处理关键字参数
|
# 处理关键字参数
|
||||||
if kwargs:
|
if kwargs:
|
||||||
@@ -177,15 +185,21 @@ class Prompt(str):
|
|||||||
template = template.format(**formatted_kwargs)
|
template = template.format(**formatted_kwargs)
|
||||||
return template
|
return template
|
||||||
except (IndexError, KeyError) as e:
|
except (IndexError, KeyError) as e:
|
||||||
raise ValueError(f"格式化模板失败: {template}, args={formatted_args}, kwargs={formatted_kwargs}") from e
|
raise ValueError(
|
||||||
|
f"格式化模板失败: {template}, args={formatted_args}, kwargs={formatted_kwargs} {str(e)}"
|
||||||
|
) from e
|
||||||
|
|
||||||
def format(self, *args, **kwargs) -> "Prompt":
|
def format(self, *args, **kwargs) -> "str":
|
||||||
"""支持位置参数和关键字参数的格式化,使用"""
|
"""支持位置参数和关键字参数的格式化,使用"""
|
||||||
ret = type(self)(
|
ret = type(self)(
|
||||||
self.template, self.name, args=list(args) if args else self._args, **kwargs if kwargs else self._kwargs
|
self.template,
|
||||||
|
self.name,
|
||||||
|
args=list(args) if args else self._args,
|
||||||
|
_should_register=False,
|
||||||
|
**kwargs if kwargs else self._kwargs,
|
||||||
)
|
)
|
||||||
# print(f"prompt build result: {ret} name: {ret.name} ")
|
# print(f"prompt build result: {ret} name: {ret.name} ")
|
||||||
return ret
|
return str(ret)
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
if self._kwargs or self._args:
|
if self._kwargs or self._args:
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
from .willing_manager import BaseWillingManager
|
from .willing_manager import BaseWillingManager
|
||||||
|
|
||||||
|
|
||||||
class ClassicalWillingManager(BaseWillingManager):
|
class ClassicalWillingManager(BaseWillingManager):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -41,17 +42,22 @@ class ClassicalWillingManager(BaseWillingManager):
|
|||||||
|
|
||||||
self.chat_reply_willing[chat_id] = min(current_willing, 3.0)
|
self.chat_reply_willing[chat_id] = min(current_willing, 3.0)
|
||||||
|
|
||||||
reply_probability = min(max((current_willing - 0.5), 0.01) * self.global_config.response_willing_amplifier * 2, 1)
|
reply_probability = min(
|
||||||
|
max((current_willing - 0.5), 0.01) * self.global_config.response_willing_amplifier * 2, 1
|
||||||
|
)
|
||||||
|
|
||||||
# 检查群组权限(如果是群聊)
|
# 检查群组权限(如果是群聊)
|
||||||
if willing_info.group_info and willing_info.group_info.group_id in self.global_config.talk_frequency_down_groups:
|
if (
|
||||||
|
willing_info.group_info
|
||||||
|
and willing_info.group_info.group_id in self.global_config.talk_frequency_down_groups
|
||||||
|
):
|
||||||
reply_probability = reply_probability / self.global_config.down_frequency_rate
|
reply_probability = reply_probability / self.global_config.down_frequency_rate
|
||||||
|
|
||||||
if is_emoji_not_reply:
|
if is_emoji_not_reply:
|
||||||
reply_probability = 0
|
reply_probability = 0
|
||||||
|
|
||||||
return reply_probability
|
return reply_probability
|
||||||
|
|
||||||
async def before_generate_reply_handle(self, message_id):
|
async def before_generate_reply_handle(self, message_id):
|
||||||
chat_id = self.ongoing_messages[message_id].chat_id
|
chat_id = self.ongoing_messages[message_id].chat_id
|
||||||
current_willing = self.chat_reply_willing.get(chat_id, 0)
|
current_willing = self.chat_reply_willing.get(chat_id, 0)
|
||||||
@@ -71,8 +77,6 @@ class ClassicalWillingManager(BaseWillingManager):
|
|||||||
|
|
||||||
async def get_variable_parameters(self):
|
async def get_variable_parameters(self):
|
||||||
return await super().get_variable_parameters()
|
return await super().get_variable_parameters()
|
||||||
|
|
||||||
async def set_variable_parameters(self, parameters):
|
async def set_variable_parameters(self, parameters):
|
||||||
return await super().set_variable_parameters(parameters)
|
return await super().set_variable_parameters(parameters)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -4,4 +4,3 @@ from .willing_manager import BaseWillingManager
|
|||||||
class CustomWillingManager(BaseWillingManager):
|
class CustomWillingManager(BaseWillingManager):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
|||||||
@@ -20,7 +20,6 @@ class DynamicWillingManager(BaseWillingManager):
|
|||||||
self._decay_task = None
|
self._decay_task = None
|
||||||
self._mode_switch_task = None
|
self._mode_switch_task = None
|
||||||
|
|
||||||
|
|
||||||
async def async_task_starter(self):
|
async def async_task_starter(self):
|
||||||
if self._decay_task is None:
|
if self._decay_task is None:
|
||||||
self._decay_task = asyncio.create_task(self._decay_reply_willing())
|
self._decay_task = asyncio.create_task(self._decay_reply_willing())
|
||||||
@@ -84,7 +83,9 @@ class DynamicWillingManager(BaseWillingManager):
|
|||||||
self.chat_high_willing_mode[chat_id] = True
|
self.chat_high_willing_mode[chat_id] = True
|
||||||
self.chat_reply_willing[chat_id] = 1.0 # 设置为较高回复意愿
|
self.chat_reply_willing[chat_id] = 1.0 # 设置为较高回复意愿
|
||||||
self.chat_high_willing_duration[chat_id] = random.randint(180, 240) # 3-4分钟
|
self.chat_high_willing_duration[chat_id] = random.randint(180, 240) # 3-4分钟
|
||||||
self.logger.debug(f"聊天流 {chat_id} 切换到高回复意愿期,持续 {self.chat_high_willing_duration[chat_id]} 秒")
|
self.logger.debug(
|
||||||
|
f"聊天流 {chat_id} 切换到高回复意愿期,持续 {self.chat_high_willing_duration[chat_id]} 秒"
|
||||||
|
)
|
||||||
|
|
||||||
self.chat_last_mode_change[chat_id] = time.time()
|
self.chat_last_mode_change[chat_id] = time.time()
|
||||||
self.chat_msg_count[chat_id] = 0 # 重置消息计数
|
self.chat_msg_count[chat_id] = 0 # 重置消息计数
|
||||||
@@ -148,7 +149,9 @@ class DynamicWillingManager(BaseWillingManager):
|
|||||||
|
|
||||||
# 根据话题兴趣度适当调整
|
# 根据话题兴趣度适当调整
|
||||||
if willing_info.interested_rate > 0.5:
|
if willing_info.interested_rate > 0.5:
|
||||||
current_willing += (willing_info.interested_rate - 0.5) * 0.5 * self.global_config.response_interested_rate_amplifier
|
current_willing += (
|
||||||
|
(willing_info.interested_rate - 0.5) * 0.5 * self.global_config.response_interested_rate_amplifier
|
||||||
|
)
|
||||||
|
|
||||||
# 根据当前模式计算回复概率
|
# 根据当前模式计算回复概率
|
||||||
base_probability = 0.0
|
base_probability = 0.0
|
||||||
@@ -228,12 +231,12 @@ class DynamicWillingManager(BaseWillingManager):
|
|||||||
|
|
||||||
async def bombing_buffer_message_handle(self, message_id):
|
async def bombing_buffer_message_handle(self, message_id):
|
||||||
return await super().bombing_buffer_message_handle(message_id)
|
return await super().bombing_buffer_message_handle(message_id)
|
||||||
|
|
||||||
async def after_generate_reply_handle(self, message_id):
|
async def after_generate_reply_handle(self, message_id):
|
||||||
return await super().after_generate_reply_handle(message_id)
|
return await super().after_generate_reply_handle(message_id)
|
||||||
|
|
||||||
async def get_variable_parameters(self):
|
async def get_variable_parameters(self):
|
||||||
return await super().get_variable_parameters()
|
return await super().get_variable_parameters()
|
||||||
|
|
||||||
async def set_variable_parameters(self, parameters):
|
async def set_variable_parameters(self, parameters):
|
||||||
return await super().set_variable_parameters(parameters)
|
return await super().set_variable_parameters(parameters)
|
||||||
|
|||||||
@@ -17,19 +17,22 @@ Mxp 模式:梦溪畔独家赞助
|
|||||||
中策是发issue
|
中策是发issue
|
||||||
下下策是询问一个菜鸟(@梦溪畔)
|
下下策是询问一个菜鸟(@梦溪畔)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from .willing_manager import BaseWillingManager
|
from .willing_manager import BaseWillingManager
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
import math
|
import math
|
||||||
|
|
||||||
|
|
||||||
class MxpWillingManager(BaseWillingManager):
|
class MxpWillingManager(BaseWillingManager):
|
||||||
"""Mxp意愿管理器"""
|
"""Mxp意愿管理器"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.chat_person_reply_willing: Dict[str, Dict[str, float]] = {} # chat_id: {person_id: 意愿值}
|
self.chat_person_reply_willing: Dict[str, Dict[str, float]] = {} # chat_id: {person_id: 意愿值}
|
||||||
self.chat_new_message_time: Dict[str, list[float]] = {} # 聊天流ID: 消息时间
|
self.chat_new_message_time: Dict[str, list[float]] = {} # 聊天流ID: 消息时间
|
||||||
self.last_response_person: Dict[str, tuple[str, int]] = {} # 上次回复的用户信息
|
self.last_response_person: Dict[str, tuple[str, int]] = {} # 上次回复的用户信息
|
||||||
self.temporary_willing: float = 0 # 临时意愿值
|
self.temporary_willing: float = 0 # 临时意愿值
|
||||||
|
|
||||||
# 可变参数
|
# 可变参数
|
||||||
@@ -39,8 +42,8 @@ class MxpWillingManager(BaseWillingManager):
|
|||||||
self.basic_maximum_willing = 0.5 # 基础最大意愿值
|
self.basic_maximum_willing = 0.5 # 基础最大意愿值
|
||||||
self.mention_willing_gain = 0.6 # 提及意愿增益
|
self.mention_willing_gain = 0.6 # 提及意愿增益
|
||||||
self.interest_willing_gain = 0.3 # 兴趣意愿增益
|
self.interest_willing_gain = 0.3 # 兴趣意愿增益
|
||||||
self.emoji_response_penalty = self.global_config.emoji_response_penalty # 表情包回复惩罚
|
self.emoji_response_penalty = self.global_config.emoji_response_penalty # 表情包回复惩罚
|
||||||
self.down_frequency_rate = self.global_config.down_frequency_rate # 降低回复频率的群组惩罚系数
|
self.down_frequency_rate = self.global_config.down_frequency_rate # 降低回复频率的群组惩罚系数
|
||||||
self.single_chat_gain = 0.12 # 单聊增益
|
self.single_chat_gain = 0.12 # 单聊增益
|
||||||
|
|
||||||
async def async_task_starter(self) -> None:
|
async def async_task_starter(self) -> None:
|
||||||
@@ -73,9 +76,13 @@ class MxpWillingManager(BaseWillingManager):
|
|||||||
w_info = self.ongoing_messages[message_id]
|
w_info = self.ongoing_messages[message_id]
|
||||||
if w_info.is_mentioned_bot:
|
if w_info.is_mentioned_bot:
|
||||||
self.chat_person_reply_willing[w_info.chat_id][w_info.person_id] += 0.2
|
self.chat_person_reply_willing[w_info.chat_id][w_info.person_id] += 0.2
|
||||||
if w_info.chat_id in self.last_response_person and self.last_response_person[w_info.chat_id][0] == w_info.person_id:
|
if (
|
||||||
self.chat_person_reply_willing[w_info.chat_id][w_info.person_id] +=\
|
w_info.chat_id in self.last_response_person
|
||||||
self.single_chat_gain * (2 * self.last_response_person[w_info.chat_id][1] + 1)
|
and self.last_response_person[w_info.chat_id][0] == w_info.person_id
|
||||||
|
):
|
||||||
|
self.chat_person_reply_willing[w_info.chat_id][w_info.person_id] += self.single_chat_gain * (
|
||||||
|
2 * self.last_response_person[w_info.chat_id][1] + 1
|
||||||
|
)
|
||||||
now_chat_new_person = self.last_response_person.get(w_info.chat_id, ["", 0])
|
now_chat_new_person = self.last_response_person.get(w_info.chat_id, ["", 0])
|
||||||
if now_chat_new_person[0] != w_info.person_id:
|
if now_chat_new_person[0] != w_info.person_id:
|
||||||
self.last_response_person[w_info.chat_id] = [w_info.person_id, 0]
|
self.last_response_person[w_info.chat_id] = [w_info.person_id, 0]
|
||||||
@@ -98,7 +105,10 @@ class MxpWillingManager(BaseWillingManager):
|
|||||||
rel_level = self._get_relationship_level_num(rel_value)
|
rel_level = self._get_relationship_level_num(rel_value)
|
||||||
current_willing += rel_level * 0.1
|
current_willing += rel_level * 0.1
|
||||||
|
|
||||||
if w_info.chat_id in self.last_response_person and self.last_response_person[w_info.chat_id][0] == w_info.person_id:
|
if (
|
||||||
|
w_info.chat_id in self.last_response_person
|
||||||
|
and self.last_response_person[w_info.chat_id][0] == w_info.person_id
|
||||||
|
):
|
||||||
current_willing += self.single_chat_gain * (2 * self.last_response_person[w_info.chat_id][1] + 1)
|
current_willing += self.single_chat_gain * (2 * self.last_response_person[w_info.chat_id][1] + 1)
|
||||||
|
|
||||||
chat_ongoing_messages = [msg for msg in self.ongoing_messages.values() if msg.chat_id == w_info.chat_id]
|
chat_ongoing_messages = [msg for msg in self.ongoing_messages.values() if msg.chat_id == w_info.chat_id]
|
||||||
@@ -141,16 +151,22 @@ class MxpWillingManager(BaseWillingManager):
|
|||||||
self.logger.debug(f"聊天流{chat_id}不存在,错误")
|
self.logger.debug(f"聊天流{chat_id}不存在,错误")
|
||||||
continue
|
continue
|
||||||
basic_willing = self.chat_reply_willing[chat_id]
|
basic_willing = self.chat_reply_willing[chat_id]
|
||||||
person_willing[person_id] = basic_willing + (willing - basic_willing) * self.intention_decay_rate
|
person_willing[person_id] = (
|
||||||
|
basic_willing + (willing - basic_willing) * self.intention_decay_rate
|
||||||
|
)
|
||||||
|
|
||||||
def setup(self, message, chat, is_mentioned_bot, interested_rate):
|
def setup(self, message, chat, is_mentioned_bot, interested_rate):
|
||||||
super().setup(message, chat, is_mentioned_bot, interested_rate)
|
super().setup(message, chat, is_mentioned_bot, interested_rate)
|
||||||
|
|
||||||
self.chat_reply_willing[chat.stream_id] = self.chat_reply_willing.get(chat.stream_id, self.basic_maximum_willing)
|
self.chat_reply_willing[chat.stream_id] = self.chat_reply_willing.get(
|
||||||
|
chat.stream_id, self.basic_maximum_willing
|
||||||
|
)
|
||||||
self.chat_person_reply_willing[chat.stream_id] = self.chat_person_reply_willing.get(chat.stream_id, {})
|
self.chat_person_reply_willing[chat.stream_id] = self.chat_person_reply_willing.get(chat.stream_id, {})
|
||||||
self.chat_person_reply_willing[chat.stream_id][self.ongoing_messages[message.message_info.message_id].person_id] = \
|
self.chat_person_reply_willing[chat.stream_id][
|
||||||
self.chat_person_reply_willing[chat.stream_id].get(self.ongoing_messages[message.message_info.message_id].person_id,
|
self.ongoing_messages[message.message_info.message_id].person_id
|
||||||
self.chat_reply_willing[chat.stream_id])
|
] = self.chat_person_reply_willing[chat.stream_id].get(
|
||||||
|
self.ongoing_messages[message.message_info.message_id].person_id, self.chat_reply_willing[chat.stream_id]
|
||||||
|
)
|
||||||
|
|
||||||
if chat.stream_id not in self.chat_new_message_time:
|
if chat.stream_id not in self.chat_new_message_time:
|
||||||
self.chat_new_message_time[chat.stream_id] = []
|
self.chat_new_message_time[chat.stream_id] = []
|
||||||
@@ -166,7 +182,7 @@ class MxpWillingManager(BaseWillingManager):
|
|||||||
else:
|
else:
|
||||||
probability = math.atan(willing * 4) / math.pi * 2
|
probability = math.atan(willing * 4) / math.pi * 2
|
||||||
return probability
|
return probability
|
||||||
|
|
||||||
async def _chat_new_message_to_change_basic_willing(self):
|
async def _chat_new_message_to_change_basic_willing(self):
|
||||||
"""聊天流新消息改变基础意愿"""
|
"""聊天流新消息改变基础意愿"""
|
||||||
while True:
|
while True:
|
||||||
@@ -174,10 +190,11 @@ class MxpWillingManager(BaseWillingManager):
|
|||||||
await asyncio.sleep(update_time)
|
await asyncio.sleep(update_time)
|
||||||
async with self.lock:
|
async with self.lock:
|
||||||
for chat_id, message_times in self.chat_new_message_time.items():
|
for chat_id, message_times in self.chat_new_message_time.items():
|
||||||
|
|
||||||
# 清理过期消息
|
# 清理过期消息
|
||||||
current_time = time.time()
|
current_time = time.time()
|
||||||
message_times = [msg_time for msg_time in message_times if current_time - msg_time < self.message_expiration_time]
|
message_times = [
|
||||||
|
msg_time for msg_time in message_times if current_time - msg_time < self.message_expiration_time
|
||||||
|
]
|
||||||
self.chat_new_message_time[chat_id] = message_times
|
self.chat_new_message_time[chat_id] = message_times
|
||||||
|
|
||||||
if len(message_times) < self.number_of_message_storage:
|
if len(message_times) < self.number_of_message_storage:
|
||||||
@@ -185,7 +202,9 @@ class MxpWillingManager(BaseWillingManager):
|
|||||||
update_time = 20
|
update_time = 20
|
||||||
elif len(message_times) == self.number_of_message_storage:
|
elif len(message_times) == self.number_of_message_storage:
|
||||||
time_interval = current_time - message_times[0]
|
time_interval = current_time - message_times[0]
|
||||||
basic_willing = self.basic_maximum_willing * math.sqrt(time_interval / self.message_expiration_time)
|
basic_willing = self.basic_maximum_willing * math.sqrt(
|
||||||
|
time_interval / self.message_expiration_time
|
||||||
|
)
|
||||||
self.chat_reply_willing[chat_id] = basic_willing
|
self.chat_reply_willing[chat_id] = basic_willing
|
||||||
update_time = 17 * math.sqrt(time_interval / self.message_expiration_time) + 3
|
update_time = 17 * math.sqrt(time_interval / self.message_expiration_time) + 3
|
||||||
else:
|
else:
|
||||||
@@ -203,7 +222,7 @@ class MxpWillingManager(BaseWillingManager):
|
|||||||
"interest_willing_gain": "兴趣意愿增益",
|
"interest_willing_gain": "兴趣意愿增益",
|
||||||
"emoji_response_penalty": "表情包回复惩罚",
|
"emoji_response_penalty": "表情包回复惩罚",
|
||||||
"down_frequency_rate": "降低回复频率的群组惩罚系数",
|
"down_frequency_rate": "降低回复频率的群组惩罚系数",
|
||||||
"single_chat_gain": "单聊增益(不仅是私聊)"
|
"single_chat_gain": "单聊增益(不仅是私聊)",
|
||||||
}
|
}
|
||||||
|
|
||||||
async def set_variable_parameters(self, parameters: Dict[str, any]):
|
async def set_variable_parameters(self, parameters: Dict[str, any]):
|
||||||
@@ -215,7 +234,7 @@ class MxpWillingManager(BaseWillingManager):
|
|||||||
self.logger.debug(f"参数 {key} 已更新为 {value}")
|
self.logger.debug(f"参数 {key} 已更新为 {value}")
|
||||||
else:
|
else:
|
||||||
self.logger.debug(f"尝试设置未知参数 {key}")
|
self.logger.debug(f"尝试设置未知参数 {key}")
|
||||||
|
|
||||||
def _get_relationship_level_num(self, relationship_value) -> int:
|
def _get_relationship_level_num(self, relationship_value) -> int:
|
||||||
"""关系等级计算"""
|
"""关系等级计算"""
|
||||||
if -1000 <= relationship_value < -227:
|
if -1000 <= relationship_value < -227:
|
||||||
@@ -235,4 +254,4 @@ class MxpWillingManager(BaseWillingManager):
|
|||||||
return level_num - 2
|
return level_num - 2
|
||||||
|
|
||||||
async def get_willing(self, chat_id):
|
async def get_willing(self, chat_id):
|
||||||
return self.temporary_willing
|
return self.temporary_willing
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
|
|
||||||
from src.common.logger import LogConfig, WILLING_STYLE_CONFIG, LoguruLogger, get_module_logger
|
from src.common.logger import LogConfig, WILLING_STYLE_CONFIG, LoguruLogger, get_module_logger
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from ..config.config import global_config, BotConfig
|
from ..config.config import global_config, BotConfig
|
||||||
@@ -38,10 +37,11 @@ willing_config = LogConfig(
|
|||||||
)
|
)
|
||||||
logger = get_module_logger("willing", config=willing_config)
|
logger = get_module_logger("willing", config=willing_config)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class WillingInfo:
|
class WillingInfo:
|
||||||
"""此类保存意愿模块常用的参数
|
"""此类保存意愿模块常用的参数
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
message (MessageRecv): 原始消息对象
|
message (MessageRecv): 原始消息对象
|
||||||
chat (ChatStream): 聊天流对象
|
chat (ChatStream): 聊天流对象
|
||||||
@@ -53,6 +53,7 @@ class WillingInfo:
|
|||||||
is_emoji (bool): 是否为表情包
|
is_emoji (bool): 是否为表情包
|
||||||
interested_rate (float): 兴趣度
|
interested_rate (float): 兴趣度
|
||||||
"""
|
"""
|
||||||
|
|
||||||
message: MessageRecv
|
message: MessageRecv
|
||||||
chat: ChatStream
|
chat: ChatStream
|
||||||
person_info_manager: PersonInfoManager
|
person_info_manager: PersonInfoManager
|
||||||
@@ -60,22 +61,21 @@ class WillingInfo:
|
|||||||
person_id: str
|
person_id: str
|
||||||
group_info: Optional[GroupInfo]
|
group_info: Optional[GroupInfo]
|
||||||
is_mentioned_bot: bool
|
is_mentioned_bot: bool
|
||||||
is_emoji: bool
|
is_emoji: bool
|
||||||
interested_rate: float
|
interested_rate: float
|
||||||
# current_mood: float 当前心情?
|
# current_mood: float 当前心情?
|
||||||
|
|
||||||
|
|
||||||
class BaseWillingManager(ABC):
|
class BaseWillingManager(ABC):
|
||||||
"""回复意愿管理基类"""
|
"""回复意愿管理基类"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create(cls, manager_type: str) -> 'BaseWillingManager':
|
def create(cls, manager_type: str) -> "BaseWillingManager":
|
||||||
try:
|
try:
|
||||||
module = importlib.import_module(f".mode_{manager_type}", __package__)
|
module = importlib.import_module(f".mode_{manager_type}", __package__)
|
||||||
manager_class = getattr(module, f"{manager_type.capitalize()}WillingManager")
|
manager_class = getattr(module, f"{manager_type.capitalize()}WillingManager")
|
||||||
if not issubclass(manager_class, cls):
|
if not issubclass(manager_class, cls):
|
||||||
raise TypeError(
|
raise TypeError(f"Manager class {manager_class.__name__} is not a subclass of {cls.__name__}")
|
||||||
f"Manager class {manager_class.__name__} is not a subclass of {cls.__name__}"
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
logger.info(f"成功载入willing模式:{manager_type}")
|
logger.info(f"成功载入willing模式:{manager_type}")
|
||||||
return manager_class()
|
return manager_class()
|
||||||
@@ -85,7 +85,7 @@ class BaseWillingManager(ABC):
|
|||||||
logger.info(f"载入当前意愿模式{manager_type}失败,使用经典配方~~~~")
|
logger.info(f"载入当前意愿模式{manager_type}失败,使用经典配方~~~~")
|
||||||
logger.debug(f"加载willing模式{manager_type}失败,原因: {str(e)}。")
|
logger.debug(f"加载willing模式{manager_type}失败,原因: {str(e)}。")
|
||||||
return manager_class()
|
return manager_class()
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.chat_reply_willing: Dict[str, float] = {} # 存储每个聊天流的回复意愿(chat_id)
|
self.chat_reply_willing: Dict[str, float] = {} # 存储每个聊天流的回复意愿(chat_id)
|
||||||
self.ongoing_messages: Dict[str, WillingInfo] = {} # 当前正在进行的消息(message_id)
|
self.ongoing_messages: Dict[str, WillingInfo] = {} # 当前正在进行的消息(message_id)
|
||||||
@@ -136,17 +136,17 @@ class BaseWillingManager(ABC):
|
|||||||
async def get_reply_probability(self, message_id: str):
|
async def get_reply_probability(self, message_id: str):
|
||||||
"""抽象方法:获取回复概率"""
|
"""抽象方法:获取回复概率"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def bombing_buffer_message_handle(self, message_id: str):
|
async def bombing_buffer_message_handle(self, message_id: str):
|
||||||
"""抽象方法:炸飞消息处理"""
|
"""抽象方法:炸飞消息处理"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def get_willing(self, chat_id: str):
|
async def get_willing(self, chat_id: str):
|
||||||
"""获取指定聊天流的回复意愿"""
|
"""获取指定聊天流的回复意愿"""
|
||||||
async with self.lock:
|
async with self.lock:
|
||||||
return self.chat_reply_willing.get(chat_id, 0)
|
return self.chat_reply_willing.get(chat_id, 0)
|
||||||
|
|
||||||
async def set_willing(self, chat_id: str, willing: float):
|
async def set_willing(self, chat_id: str, willing: float):
|
||||||
"""设置指定聊天流的回复意愿"""
|
"""设置指定聊天流的回复意愿"""
|
||||||
async with self.lock:
|
async with self.lock:
|
||||||
@@ -173,5 +173,6 @@ def init_willing_manager() -> BaseWillingManager:
|
|||||||
mode = global_config.willing_mode.lower()
|
mode = global_config.willing_mode.lower()
|
||||||
return BaseWillingManager.create(mode)
|
return BaseWillingManager.create(mode)
|
||||||
|
|
||||||
|
|
||||||
# 全局willing_manager对象
|
# 全局willing_manager对象
|
||||||
willing_manager = init_willing_manager()
|
willing_manager = init_willing_manager()
|
||||||
|
|||||||
Reference in New Issue
Block a user