feat: 知识库小重构
This commit is contained in:
0
src/chat/knowledge/utils/__init__.py
Normal file
0
src/chat/knowledge/utils/__init__.py
Normal file
47
src/chat/knowledge/utils/dyn_topk.py
Normal file
47
src/chat/knowledge/utils/dyn_topk.py
Normal file
@@ -0,0 +1,47 @@
|
||||
from typing import List, Any, Tuple
|
||||
|
||||
|
||||
def dyn_select_top_k(
|
||||
score: List[Tuple[Any, float]], jmp_factor: float, var_factor: float
|
||||
) -> List[Tuple[Any, float, float]]:
|
||||
"""动态TopK选择"""
|
||||
# 按照分数排序(降序)
|
||||
sorted_score = sorted(score, key=lambda x: x[1], reverse=True)
|
||||
|
||||
# 归一化
|
||||
max_score = sorted_score[0][1]
|
||||
min_score = sorted_score[-1][1]
|
||||
normalized_score = []
|
||||
for score_item in sorted_score:
|
||||
normalized_score.append(
|
||||
tuple(
|
||||
[
|
||||
score_item[0],
|
||||
score_item[1],
|
||||
(score_item[1] - min_score) / (max_score - min_score),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
# 寻找跳变点:score变化最大的位置
|
||||
jump_idx = 0
|
||||
for i in range(1, len(normalized_score)):
|
||||
if abs(normalized_score[i][2] - normalized_score[i - 1][2]) > abs(
|
||||
normalized_score[jump_idx][2] - normalized_score[jump_idx - 1][2]
|
||||
):
|
||||
jump_idx = i
|
||||
# 跳变阈值
|
||||
jump_threshold = normalized_score[jump_idx][2]
|
||||
|
||||
# 计算均值
|
||||
mean_score = sum([s[2] for s in normalized_score]) / len(normalized_score)
|
||||
# 计算方差
|
||||
var_score = sum([(s[2] - mean_score) ** 2 for s in normalized_score]) / len(normalized_score)
|
||||
|
||||
# 动态阈值
|
||||
threshold = jmp_factor * jump_threshold + (1 - jmp_factor) * (mean_score + var_factor * var_score)
|
||||
|
||||
# 重新过滤
|
||||
res = [s for s in normalized_score if s[2] > threshold]
|
||||
|
||||
return res
|
||||
8
src/chat/knowledge/utils/hash.py
Normal file
8
src/chat/knowledge/utils/hash.py
Normal file
@@ -0,0 +1,8 @@
|
||||
import hashlib
|
||||
|
||||
|
||||
def get_sha256(string: str) -> str:
|
||||
"""获取字符串的SHA256值"""
|
||||
sha256 = hashlib.sha256()
|
||||
sha256.update(string.encode("utf-8"))
|
||||
return sha256.hexdigest()
|
||||
98
src/chat/knowledge/utils/json_fix.py
Normal file
98
src/chat/knowledge/utils/json_fix.py
Normal file
@@ -0,0 +1,98 @@
|
||||
import json
|
||||
from json_repair import repair_json
|
||||
|
||||
|
||||
def _find_unclosed(json_str):
|
||||
"""
|
||||
Identifies the unclosed braces and brackets in the JSON string.
|
||||
|
||||
Args:
|
||||
json_str (str): The JSON string to analyze.
|
||||
|
||||
Returns:
|
||||
list: A list of unclosed elements in the order they were opened.
|
||||
"""
|
||||
unclosed = []
|
||||
inside_string = False
|
||||
escape_next = False
|
||||
|
||||
for char in json_str:
|
||||
if inside_string:
|
||||
if escape_next:
|
||||
escape_next = False
|
||||
elif char == "\\":
|
||||
escape_next = True
|
||||
elif char == '"':
|
||||
inside_string = False
|
||||
else:
|
||||
if char == '"':
|
||||
inside_string = True
|
||||
elif char in "{[":
|
||||
unclosed.append(char)
|
||||
elif char in "}]":
|
||||
if unclosed and ((char == "}" and unclosed[-1] == "{") or (char == "]" and unclosed[-1] == "[")):
|
||||
unclosed.pop()
|
||||
|
||||
return unclosed
|
||||
|
||||
|
||||
# The following code is used to fix a broken JSON string.
|
||||
# From HippoRAG2 (GitHub: OSU-NLP-Group/HippoRAG)
|
||||
def fix_broken_generated_json(json_str: str) -> str:
|
||||
"""
|
||||
Fixes a malformed JSON string by:
|
||||
- Removing the last comma and any trailing content.
|
||||
- Iterating over the JSON string once to determine and fix unclosed braces or brackets.
|
||||
- Ensuring braces and brackets inside string literals are not considered.
|
||||
|
||||
If the original json_str string can be successfully loaded by json.loads(), will directly return it without any modification.
|
||||
|
||||
Args:
|
||||
json_str (str): The malformed JSON string to be fixed.
|
||||
|
||||
Returns:
|
||||
str: The corrected JSON string.
|
||||
"""
|
||||
|
||||
try:
|
||||
# Try to load the JSON to see if it is valid
|
||||
json.loads(json_str)
|
||||
return json_str # Return as-is if valid
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Step 1: Remove trailing content after the last comma.
|
||||
last_comma_index = json_str.rfind(",")
|
||||
if last_comma_index != -1:
|
||||
json_str = json_str[:last_comma_index]
|
||||
|
||||
# Step 2: Identify unclosed braces and brackets.
|
||||
unclosed_elements = _find_unclosed(json_str)
|
||||
|
||||
# Step 3: Append the necessary closing elements in reverse order of opening.
|
||||
closing_map = {"{": "}", "[": "]"}
|
||||
for open_char in reversed(unclosed_elements):
|
||||
json_str += closing_map[open_char]
|
||||
|
||||
return json_str
|
||||
|
||||
|
||||
def new_fix_broken_generated_json(json_str: str) -> str:
|
||||
"""
|
||||
使用 json-repair 库修复格式错误的 JSON 字符串。
|
||||
|
||||
如果原始 json_str 字符串可以被 json.loads() 成功加载,则直接返回而不进行任何修改。
|
||||
|
||||
参数:
|
||||
json_str (str): 需要修复的格式错误的 JSON 字符串。
|
||||
|
||||
返回:
|
||||
str: 修复后的 JSON 字符串。
|
||||
"""
|
||||
try:
|
||||
# 尝试加载 JSON 以查看其是否有效
|
||||
json.loads(json_str)
|
||||
return json_str # 如果有效则按原样返回
|
||||
except json.JSONDecodeError:
|
||||
# 如果无效,则尝试修复它
|
||||
return repair_json(json_str)
|
||||
17
src/chat/knowledge/utils/visualize_graph.py
Normal file
17
src/chat/knowledge/utils/visualize_graph.py
Normal file
@@ -0,0 +1,17 @@
|
||||
import networkx as nx
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
|
||||
def draw_graph_and_show(graph):
|
||||
"""绘制图并显示,画布大小1280*1280"""
|
||||
fig = plt.figure(1, figsize=(12.8, 12.8), dpi=100)
|
||||
nx.draw_networkx(
|
||||
graph,
|
||||
node_size=100,
|
||||
width=0.5,
|
||||
with_labels=True,
|
||||
labels=nx.get_node_attributes(graph, "content"),
|
||||
font_family="Sarasa Mono SC",
|
||||
font_size=8,
|
||||
)
|
||||
fig.show()
|
||||
Reference in New Issue
Block a user