Merge branch 'dev' of https://github.com/MaiM-with-u/MaiBot into dev
This commit is contained in:
@@ -1,7 +1,58 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "MaiMaiBot"
|
name = "MaiBot"
|
||||||
version = "0.1.0"
|
version = "0.8.1"
|
||||||
description = "MaiMaiBot"
|
description = "MaiCore 是一个基于大语言模型的可交互智能体"
|
||||||
|
requires-python = ">=3.10"
|
||||||
|
dependencies = [
|
||||||
|
"aiohttp>=3.12.14",
|
||||||
|
"apscheduler>=3.11.0",
|
||||||
|
"colorama>=0.4.6",
|
||||||
|
"cryptography>=45.0.5",
|
||||||
|
"customtkinter>=5.2.2",
|
||||||
|
"dotenv>=0.9.9",
|
||||||
|
"faiss-cpu>=1.11.0",
|
||||||
|
"fastapi>=0.116.0",
|
||||||
|
"jieba>=0.42.1",
|
||||||
|
"json-repair>=0.47.6",
|
||||||
|
"jsonlines>=4.0.0",
|
||||||
|
"maim-message>=0.3.8",
|
||||||
|
"matplotlib>=3.10.3",
|
||||||
|
"networkx>=3.4.2",
|
||||||
|
"numpy>=2.2.6",
|
||||||
|
"openai>=1.95.0",
|
||||||
|
"packaging>=25.0",
|
||||||
|
"pandas>=2.3.1",
|
||||||
|
"peewee>=3.18.2",
|
||||||
|
"pillow>=11.3.0",
|
||||||
|
"psutil>=7.0.0",
|
||||||
|
"pyarrow>=20.0.0",
|
||||||
|
"pydantic>=2.11.7",
|
||||||
|
"pymongo>=4.13.2",
|
||||||
|
"pypinyin>=0.54.0",
|
||||||
|
"python-dateutil>=2.9.0.post0",
|
||||||
|
"python-dotenv>=1.1.1",
|
||||||
|
"python-igraph>=0.11.9",
|
||||||
|
"quick-algo>=0.1.3",
|
||||||
|
"reportportal-client>=5.6.5",
|
||||||
|
"requests>=2.32.4",
|
||||||
|
"rich>=14.0.0",
|
||||||
|
"ruff>=0.12.2",
|
||||||
|
"scikit-learn>=1.7.0",
|
||||||
|
"scipy>=1.15.3",
|
||||||
|
"seaborn>=0.13.2",
|
||||||
|
"setuptools>=80.9.0",
|
||||||
|
"strawberry-graphql[fastapi]>=0.275.5",
|
||||||
|
"structlog>=25.4.0",
|
||||||
|
"toml>=0.10.2",
|
||||||
|
"tomli>=2.2.1",
|
||||||
|
"tomli-w>=1.2.0",
|
||||||
|
"tomlkit>=0.13.3",
|
||||||
|
"tqdm>=4.67.1",
|
||||||
|
"urllib3>=2.5.0",
|
||||||
|
"uvicorn>=0.35.0",
|
||||||
|
"websockets>=15.0.1",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
[tool.ruff]
|
[tool.ruff]
|
||||||
|
|
||||||
|
|||||||
271
requirements.lock
Normal file
271
requirements.lock
Normal file
@@ -0,0 +1,271 @@
|
|||||||
|
# This file was autogenerated by uv via the following command:
|
||||||
|
# uv pip compile requirements.txt -o requirements.lock
|
||||||
|
aenum==3.1.16
|
||||||
|
# via reportportal-client
|
||||||
|
aiohappyeyeballs==2.6.1
|
||||||
|
# via aiohttp
|
||||||
|
aiohttp==3.12.14
|
||||||
|
# via
|
||||||
|
# -r requirements.txt
|
||||||
|
# maim-message
|
||||||
|
# reportportal-client
|
||||||
|
aiosignal==1.4.0
|
||||||
|
# via aiohttp
|
||||||
|
annotated-types==0.7.0
|
||||||
|
# via pydantic
|
||||||
|
anyio==4.9.0
|
||||||
|
# via
|
||||||
|
# httpx
|
||||||
|
# openai
|
||||||
|
# starlette
|
||||||
|
apscheduler==3.11.0
|
||||||
|
# via -r requirements.txt
|
||||||
|
attrs==25.3.0
|
||||||
|
# via
|
||||||
|
# aiohttp
|
||||||
|
# jsonlines
|
||||||
|
certifi==2025.7.9
|
||||||
|
# via
|
||||||
|
# httpcore
|
||||||
|
# httpx
|
||||||
|
# reportportal-client
|
||||||
|
# requests
|
||||||
|
cffi==1.17.1
|
||||||
|
# via cryptography
|
||||||
|
charset-normalizer==3.4.2
|
||||||
|
# via requests
|
||||||
|
click==8.2.1
|
||||||
|
# via uvicorn
|
||||||
|
colorama==0.4.6
|
||||||
|
# via
|
||||||
|
# -r requirements.txt
|
||||||
|
# click
|
||||||
|
# tqdm
|
||||||
|
contourpy==1.3.2
|
||||||
|
# via matplotlib
|
||||||
|
cryptography==45.0.5
|
||||||
|
# via
|
||||||
|
# -r requirements.txt
|
||||||
|
# maim-message
|
||||||
|
customtkinter==5.2.2
|
||||||
|
# via -r requirements.txt
|
||||||
|
cycler==0.12.1
|
||||||
|
# via matplotlib
|
||||||
|
darkdetect==0.8.0
|
||||||
|
# via customtkinter
|
||||||
|
distro==1.9.0
|
||||||
|
# via openai
|
||||||
|
dnspython==2.7.0
|
||||||
|
# via pymongo
|
||||||
|
dotenv==0.9.9
|
||||||
|
# via -r requirements.txt
|
||||||
|
faiss-cpu==1.11.0
|
||||||
|
# via -r requirements.txt
|
||||||
|
fastapi==0.116.0
|
||||||
|
# via
|
||||||
|
# -r requirements.txt
|
||||||
|
# maim-message
|
||||||
|
# strawberry-graphql
|
||||||
|
fonttools==4.58.5
|
||||||
|
# via matplotlib
|
||||||
|
frozenlist==1.7.0
|
||||||
|
# via
|
||||||
|
# aiohttp
|
||||||
|
# aiosignal
|
||||||
|
graphql-core==3.2.6
|
||||||
|
# via strawberry-graphql
|
||||||
|
h11==0.16.0
|
||||||
|
# via
|
||||||
|
# httpcore
|
||||||
|
# uvicorn
|
||||||
|
httpcore==1.0.9
|
||||||
|
# via httpx
|
||||||
|
httpx==0.28.1
|
||||||
|
# via openai
|
||||||
|
idna==3.10
|
||||||
|
# via
|
||||||
|
# anyio
|
||||||
|
# httpx
|
||||||
|
# requests
|
||||||
|
# yarl
|
||||||
|
igraph==0.11.9
|
||||||
|
# via python-igraph
|
||||||
|
jieba==0.42.1
|
||||||
|
# via -r requirements.txt
|
||||||
|
jiter==0.10.0
|
||||||
|
# via openai
|
||||||
|
joblib==1.5.1
|
||||||
|
# via scikit-learn
|
||||||
|
json-repair==0.47.6
|
||||||
|
# via -r requirements.txt
|
||||||
|
jsonlines==4.0.0
|
||||||
|
# via -r requirements.txt
|
||||||
|
kiwisolver==1.4.8
|
||||||
|
# via matplotlib
|
||||||
|
maim-message==0.3.8
|
||||||
|
# via -r requirements.txt
|
||||||
|
markdown-it-py==3.0.0
|
||||||
|
# via rich
|
||||||
|
matplotlib==3.10.3
|
||||||
|
# via
|
||||||
|
# -r requirements.txt
|
||||||
|
# seaborn
|
||||||
|
mdurl==0.1.2
|
||||||
|
# via markdown-it-py
|
||||||
|
multidict==6.6.3
|
||||||
|
# via
|
||||||
|
# aiohttp
|
||||||
|
# yarl
|
||||||
|
networkx==3.5
|
||||||
|
# via -r requirements.txt
|
||||||
|
numpy==2.3.1
|
||||||
|
# via
|
||||||
|
# -r requirements.txt
|
||||||
|
# contourpy
|
||||||
|
# faiss-cpu
|
||||||
|
# matplotlib
|
||||||
|
# pandas
|
||||||
|
# scikit-learn
|
||||||
|
# scipy
|
||||||
|
# seaborn
|
||||||
|
openai==1.95.0
|
||||||
|
# via -r requirements.txt
|
||||||
|
packaging==25.0
|
||||||
|
# via
|
||||||
|
# -r requirements.txt
|
||||||
|
# customtkinter
|
||||||
|
# faiss-cpu
|
||||||
|
# matplotlib
|
||||||
|
# strawberry-graphql
|
||||||
|
pandas==2.3.1
|
||||||
|
# via
|
||||||
|
# -r requirements.txt
|
||||||
|
# seaborn
|
||||||
|
peewee==3.18.2
|
||||||
|
# via -r requirements.txt
|
||||||
|
pillow==11.3.0
|
||||||
|
# via
|
||||||
|
# -r requirements.txt
|
||||||
|
# matplotlib
|
||||||
|
propcache==0.3.2
|
||||||
|
# via
|
||||||
|
# aiohttp
|
||||||
|
# yarl
|
||||||
|
psutil==7.0.0
|
||||||
|
# via -r requirements.txt
|
||||||
|
pyarrow==20.0.0
|
||||||
|
# via -r requirements.txt
|
||||||
|
pycparser==2.22
|
||||||
|
# via cffi
|
||||||
|
pydantic==2.11.7
|
||||||
|
# via
|
||||||
|
# -r requirements.txt
|
||||||
|
# fastapi
|
||||||
|
# maim-message
|
||||||
|
# openai
|
||||||
|
pydantic-core==2.33.2
|
||||||
|
# via pydantic
|
||||||
|
pygments==2.19.2
|
||||||
|
# via rich
|
||||||
|
pymongo==4.13.2
|
||||||
|
# via -r requirements.txt
|
||||||
|
pyparsing==3.2.3
|
||||||
|
# via matplotlib
|
||||||
|
pypinyin==0.54.0
|
||||||
|
# via -r requirements.txt
|
||||||
|
python-dateutil==2.9.0.post0
|
||||||
|
# via
|
||||||
|
# -r requirements.txt
|
||||||
|
# matplotlib
|
||||||
|
# pandas
|
||||||
|
# strawberry-graphql
|
||||||
|
python-dotenv==1.1.1
|
||||||
|
# via
|
||||||
|
# -r requirements.txt
|
||||||
|
# dotenv
|
||||||
|
python-igraph==0.11.9
|
||||||
|
# via -r requirements.txt
|
||||||
|
python-multipart==0.0.20
|
||||||
|
# via strawberry-graphql
|
||||||
|
pytz==2025.2
|
||||||
|
# via pandas
|
||||||
|
quick-algo==0.1.3
|
||||||
|
# via -r requirements.txt
|
||||||
|
reportportal-client==5.6.5
|
||||||
|
# via -r requirements.txt
|
||||||
|
requests==2.32.4
|
||||||
|
# via
|
||||||
|
# -r requirements.txt
|
||||||
|
# reportportal-client
|
||||||
|
rich==14.0.0
|
||||||
|
# via -r requirements.txt
|
||||||
|
ruff==0.12.2
|
||||||
|
# via -r requirements.txt
|
||||||
|
scikit-learn==1.7.0
|
||||||
|
# via -r requirements.txt
|
||||||
|
scipy==1.16.0
|
||||||
|
# via
|
||||||
|
# -r requirements.txt
|
||||||
|
# scikit-learn
|
||||||
|
seaborn==0.13.2
|
||||||
|
# via -r requirements.txt
|
||||||
|
setuptools==80.9.0
|
||||||
|
# via -r requirements.txt
|
||||||
|
six==1.17.0
|
||||||
|
# via python-dateutil
|
||||||
|
sniffio==1.3.1
|
||||||
|
# via
|
||||||
|
# anyio
|
||||||
|
# openai
|
||||||
|
starlette==0.46.2
|
||||||
|
# via fastapi
|
||||||
|
strawberry-graphql==0.275.5
|
||||||
|
# via -r requirements.txt
|
||||||
|
structlog==25.4.0
|
||||||
|
# via -r requirements.txt
|
||||||
|
texttable==1.7.0
|
||||||
|
# via igraph
|
||||||
|
threadpoolctl==3.6.0
|
||||||
|
# via scikit-learn
|
||||||
|
toml==0.10.2
|
||||||
|
# via -r requirements.txt
|
||||||
|
tomli==2.2.1
|
||||||
|
# via -r requirements.txt
|
||||||
|
tomli-w==1.2.0
|
||||||
|
# via -r requirements.txt
|
||||||
|
tomlkit==0.13.3
|
||||||
|
# via -r requirements.txt
|
||||||
|
tqdm==4.67.1
|
||||||
|
# via
|
||||||
|
# -r requirements.txt
|
||||||
|
# openai
|
||||||
|
typing-extensions==4.14.1
|
||||||
|
# via
|
||||||
|
# fastapi
|
||||||
|
# openai
|
||||||
|
# pydantic
|
||||||
|
# pydantic-core
|
||||||
|
# strawberry-graphql
|
||||||
|
# typing-inspection
|
||||||
|
typing-inspection==0.4.1
|
||||||
|
# via pydantic
|
||||||
|
tzdata==2025.2
|
||||||
|
# via
|
||||||
|
# pandas
|
||||||
|
# tzlocal
|
||||||
|
tzlocal==5.3.1
|
||||||
|
# via apscheduler
|
||||||
|
urllib3==2.5.0
|
||||||
|
# via
|
||||||
|
# -r requirements.txt
|
||||||
|
# requests
|
||||||
|
uvicorn==0.35.0
|
||||||
|
# via
|
||||||
|
# -r requirements.txt
|
||||||
|
# maim-message
|
||||||
|
websockets==15.0.1
|
||||||
|
# via
|
||||||
|
# -r requirements.txt
|
||||||
|
# maim-message
|
||||||
|
yarl==1.20.1
|
||||||
|
# via aiohttp
|
||||||
@@ -58,7 +58,9 @@ def hash_deduplicate(
|
|||||||
# 保存去重后的三元组
|
# 保存去重后的三元组
|
||||||
new_triple_list_data = {}
|
new_triple_list_data = {}
|
||||||
|
|
||||||
for _, (raw_paragraph, triple_list) in enumerate(zip(raw_paragraphs.values(), triple_list_data.values())):
|
for _, (raw_paragraph, triple_list) in enumerate(
|
||||||
|
zip(raw_paragraphs.values(), triple_list_data.values(), strict=False)
|
||||||
|
):
|
||||||
# 段落hash
|
# 段落hash
|
||||||
paragraph_hash = get_sha256(raw_paragraph)
|
paragraph_hash = get_sha256(raw_paragraph)
|
||||||
if f"{PG_NAMESPACE}-{paragraph_hash}" in stored_pg_hashes and paragraph_hash in stored_paragraph_hashes:
|
if f"{PG_NAMESPACE}-{paragraph_hash}" in stored_pg_hashes and paragraph_hash in stored_paragraph_hashes:
|
||||||
|
|||||||
@@ -174,7 +174,7 @@ def main(): # sourcery skip: comprehension-to-generator, extract-method
|
|||||||
with ThreadPoolExecutor(max_workers=workers) as executor:
|
with ThreadPoolExecutor(max_workers=workers) as executor:
|
||||||
future_to_hash = {
|
future_to_hash = {
|
||||||
executor.submit(process_single_text, pg_hash, raw_data, llm_client_list): pg_hash
|
executor.submit(process_single_text, pg_hash, raw_data, llm_client_list): pg_hash
|
||||||
for pg_hash, raw_data in zip(all_sha256_list, all_raw_datas)
|
for pg_hash, raw_data in zip(all_sha256_list, all_raw_datas, strict=False)
|
||||||
}
|
}
|
||||||
|
|
||||||
with Progress(
|
with Progress(
|
||||||
|
|||||||
@@ -354,7 +354,7 @@ class VirtualLogDisplay:
|
|||||||
|
|
||||||
# 为每个部分应用正确的标签
|
# 为每个部分应用正确的标签
|
||||||
current_len = 0
|
current_len = 0
|
||||||
for part, tag_name in zip(parts, tags):
|
for part, tag_name in zip(parts, tags, strict=False):
|
||||||
start_index = f"{start_pos}+{current_len}c"
|
start_index = f"{start_pos}+{current_len}c"
|
||||||
end_index = f"{start_pos}+{current_len + len(part)}c"
|
end_index = f"{start_pos}+{current_len + len(part)}c"
|
||||||
self.text_widget.tag_add(tag_name, start_index, end_index)
|
self.text_widget.tag_add(tag_name, start_index, end_index)
|
||||||
|
|||||||
@@ -119,7 +119,7 @@ class ExpressionLearner:
|
|||||||
min_len = min(len(s1), len(s2))
|
min_len = min(len(s1), len(s2))
|
||||||
if min_len < 5:
|
if min_len < 5:
|
||||||
return False
|
return False
|
||||||
same = sum(1 for a, b in zip(s1, s2) if a == b)
|
same = sum(1 for a, b in zip(s1, s2, strict=False) if a == b)
|
||||||
return same / min_len > 0.8
|
return same / min_len > 0.8
|
||||||
|
|
||||||
async def learn_and_store_expression(self) -> List[Tuple[str, str, str]]:
|
async def learn_and_store_expression(self) -> List[Tuple[str, str, str]]:
|
||||||
|
|||||||
@@ -120,7 +120,6 @@ class HeartFCMessageReceiver:
|
|||||||
chat_mood = mood_manager.get_mood_by_chat_id(subheartflow.chat_id)
|
chat_mood = mood_manager.get_mood_by_chat_id(subheartflow.chat_id)
|
||||||
asyncio.create_task(chat_mood.update_mood_by_message(message, interested_rate))
|
asyncio.create_task(chat_mood.update_mood_by_message(message, interested_rate))
|
||||||
|
|
||||||
|
|
||||||
# 7. 日志记录
|
# 7. 日志记录
|
||||||
mes_name = chat.group_info.group_name if chat.group_info else "私聊"
|
mes_name = chat.group_info.group_name if chat.group_info else "私聊"
|
||||||
# current_time = time.strftime("%H:%M:%S", time.localtime(message.message_info.time))
|
# current_time = time.strftime("%H:%M:%S", time.localtime(message.message_info.time))
|
||||||
|
|||||||
@@ -59,7 +59,7 @@ EMBEDDING_SIM_THRESHOLD = 0.99
|
|||||||
|
|
||||||
def cosine_similarity(a, b):
|
def cosine_similarity(a, b):
|
||||||
# 计算余弦相似度
|
# 计算余弦相似度
|
||||||
dot = sum(x * y for x, y in zip(a, b))
|
dot = sum(x * y for x, y in zip(a, b, strict=False))
|
||||||
norm_a = math.sqrt(sum(x * x for x in a))
|
norm_a = math.sqrt(sum(x * x for x in a))
|
||||||
norm_b = math.sqrt(sum(x * x for x in b))
|
norm_b = math.sqrt(sum(x * x for x in b))
|
||||||
if norm_a == 0 or norm_b == 0:
|
if norm_a == 0 or norm_b == 0:
|
||||||
@@ -285,7 +285,7 @@ class EmbeddingStore:
|
|||||||
distances = list(distances.flatten())
|
distances = list(distances.flatten())
|
||||||
result = [
|
result = [
|
||||||
(self.idx2hash[str(int(idx))], float(sim))
|
(self.idx2hash[str(int(idx))], float(sim))
|
||||||
for (idx, sim) in zip(indices, distances)
|
for (idx, sim) in zip(indices, distances, strict=False)
|
||||||
if idx in range(len(self.idx2hash))
|
if idx in range(len(self.idx2hash))
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@@ -819,7 +819,7 @@ class EntorhinalCortex:
|
|||||||
timestamps = sample_scheduler.get_timestamp_array()
|
timestamps = sample_scheduler.get_timestamp_array()
|
||||||
# 使用 translate_timestamp_to_human_readable 并指定 mode="normal"
|
# 使用 translate_timestamp_to_human_readable 并指定 mode="normal"
|
||||||
readable_timestamps = [translate_timestamp_to_human_readable(ts, mode="normal") for ts in timestamps]
|
readable_timestamps = [translate_timestamp_to_human_readable(ts, mode="normal") for ts in timestamps]
|
||||||
for _, readable_timestamp in zip(timestamps, readable_timestamps):
|
for _, readable_timestamp in zip(timestamps, readable_timestamps, strict=False):
|
||||||
logger.debug(f"回忆往事: {readable_timestamp}")
|
logger.debug(f"回忆往事: {readable_timestamp}")
|
||||||
chat_samples = []
|
chat_samples = []
|
||||||
for timestamp in timestamps:
|
for timestamp in timestamps:
|
||||||
|
|||||||
@@ -299,7 +299,7 @@ class ActionModifier:
|
|||||||
task_results = await asyncio.gather(*tasks, return_exceptions=True)
|
task_results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
|
||||||
# 处理结果并更新缓存
|
# 处理结果并更新缓存
|
||||||
for _, (action_name, result) in enumerate(zip(task_names, task_results)):
|
for _, (action_name, result) in enumerate(zip(task_names, task_results, strict=False)):
|
||||||
if isinstance(result, Exception):
|
if isinstance(result, Exception):
|
||||||
logger.error(f"{self.log_prefix}LLM判定action {action_name} 时出错: {result}")
|
logger.error(f"{self.log_prefix}LLM判定action {action_name} 时出错: {result}")
|
||||||
results[action_name] = False
|
results[action_name] = False
|
||||||
|
|||||||
@@ -845,7 +845,7 @@ def weighted_sample_no_replacement(items, weights, k) -> list:
|
|||||||
2. 不会重复选中同一个元素
|
2. 不会重复选中同一个元素
|
||||||
"""
|
"""
|
||||||
selected = []
|
selected = []
|
||||||
pool = list(zip(items, weights))
|
pool = list(zip(items, weights, strict=False))
|
||||||
for _ in range(min(k, len(pool))):
|
for _ in range(min(k, len(pool))):
|
||||||
total = sum(w for _, w in pool)
|
total = sum(w for _, w in pool)
|
||||||
r = random.uniform(0, total)
|
r = random.uniform(0, total)
|
||||||
|
|||||||
@@ -363,7 +363,7 @@ class ChineseTypoGenerator:
|
|||||||
else:
|
else:
|
||||||
# 处理多字词的单字替换
|
# 处理多字词的单字替换
|
||||||
word_result = []
|
word_result = []
|
||||||
for _, (char, py) in enumerate(zip(word, word_pinyin)):
|
for _, (char, py) in enumerate(zip(word, word_pinyin, strict=False)):
|
||||||
# 词中的字替换概率降低
|
# 词中的字替换概率降低
|
||||||
word_error_rate = self.error_rate * (0.7 ** (len(word) - 1))
|
word_error_rate = self.error_rate * (0.7 ** (len(word) - 1))
|
||||||
|
|
||||||
|
|||||||
@@ -94,7 +94,7 @@ class ConfigBase:
|
|||||||
raise TypeError(
|
raise TypeError(
|
||||||
f"Expected {len(field_type_args)} items for {field_type.__name__}, got {len(value)}"
|
f"Expected {len(field_type_args)} items for {field_type.__name__}, got {len(value)}"
|
||||||
)
|
)
|
||||||
return tuple(cls._convert_field(item, arg) for item, arg in zip(value, field_type_args))
|
return tuple(cls._convert_field(item, arg) for item, arg in zip(value, field_type_args, strict=False))
|
||||||
|
|
||||||
if field_origin_type is dict:
|
if field_origin_type is dict:
|
||||||
# 检查提供的value是否为dict
|
# 检查提供的value是否为dict
|
||||||
|
|||||||
@@ -247,7 +247,7 @@ def weighted_sample_no_replacement(items, weights, k) -> list:
|
|||||||
2. 不会重复选中同一个元素
|
2. 不会重复选中同一个元素
|
||||||
"""
|
"""
|
||||||
selected = []
|
selected = []
|
||||||
pool = list(zip(items, weights))
|
pool = list(zip(items, weights, strict=False))
|
||||||
for _ in range(min(k, len(pool))):
|
for _ in range(min(k, len(pool))):
|
||||||
total = sum(w for _, w in pool)
|
total = sum(w for _, w in pool)
|
||||||
r = random.uniform(0, total)
|
r = random.uniform(0, total)
|
||||||
|
|||||||
@@ -142,7 +142,6 @@ class NoReplyAction(BaseAction):
|
|||||||
)
|
)
|
||||||
return False, f"不回复动作执行失败: {e}"
|
return False, f"不回复动作执行失败: {e}"
|
||||||
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def reset_consecutive_count(cls):
|
def reset_consecutive_count(cls):
|
||||||
"""重置连续计数器"""
|
"""重置连续计数器"""
|
||||||
|
|||||||
@@ -54,7 +54,7 @@ class SearchKnowledgeTool(BaseTool):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def _cosine_similarity(vec1: List[float], vec2: List[float]) -> float:
|
def _cosine_similarity(vec1: List[float], vec2: List[float]) -> float:
|
||||||
"""计算两个向量之间的余弦相似度"""
|
"""计算两个向量之间的余弦相似度"""
|
||||||
dot_product = sum(p * q for p, q in zip(vec1, vec2))
|
dot_product = sum(p * q for p, q in zip(vec1, vec2, strict=False))
|
||||||
magnitude1 = math.sqrt(sum(p * p for p in vec1))
|
magnitude1 = math.sqrt(sum(p * p for p in vec1))
|
||||||
magnitude2 = math.sqrt(sum(q * q for q in vec2))
|
magnitude2 = math.sqrt(sum(q * q for q in vec2))
|
||||||
if magnitude1 == 0 or magnitude2 == 0:
|
if magnitude1 == 0 or magnitude2 == 0:
|
||||||
|
|||||||
Reference in New Issue
Block a user