feat: 保留原来的修复逻辑
This commit is contained in:
@@ -6,7 +6,7 @@ from .global_logger import logger
|
|||||||
from . import prompt_template
|
from . import prompt_template
|
||||||
from .lpmmconfig import global_config, INVALID_ENTITY
|
from .lpmmconfig import global_config, INVALID_ENTITY
|
||||||
from .llm_client import LLMClient
|
from .llm_client import LLMClient
|
||||||
from .utils.json_fix import fix_broken_generated_json
|
from .utils.json_fix import new_fix_broken_generated_json
|
||||||
|
|
||||||
|
|
||||||
def _entity_extract(llm_client: LLMClient, paragraph: str) -> List[str]:
|
def _entity_extract(llm_client: LLMClient, paragraph: str) -> List[str]:
|
||||||
@@ -24,7 +24,7 @@ def _entity_extract(llm_client: LLMClient, paragraph: str) -> List[str]:
|
|||||||
if "]" in request_result:
|
if "]" in request_result:
|
||||||
request_result = request_result[: request_result.rindex("]") + 1]
|
request_result = request_result[: request_result.rindex("]") + 1]
|
||||||
|
|
||||||
entity_extract_result = json.loads(fix_broken_generated_json(request_result))
|
entity_extract_result = json.loads(new_fix_broken_generated_json(request_result))
|
||||||
|
|
||||||
entity_extract_result = [
|
entity_extract_result = [
|
||||||
entity
|
entity
|
||||||
@@ -53,7 +53,7 @@ def _rdf_triple_extract(llm_client: LLMClient, paragraph: str, entities: list) -
|
|||||||
if "]" in request_result:
|
if "]" in request_result:
|
||||||
request_result = request_result[: request_result.rindex("]") + 1]
|
request_result = request_result[: request_result.rindex("]") + 1]
|
||||||
|
|
||||||
entity_extract_result = json.loads(fix_broken_generated_json(request_result))
|
entity_extract_result = json.loads(new_fix_broken_generated_json(request_result))
|
||||||
|
|
||||||
for triple in entity_extract_result:
|
for triple in entity_extract_result:
|
||||||
if len(triple) != 3 or (triple[0] is None or triple[1] is None or triple[2] is None) or "" in triple:
|
if len(triple) != 3 or (triple[0] is None or triple[1] is None or triple[2] is None) or "" in triple:
|
||||||
|
|||||||
@@ -1,10 +1,82 @@
|
|||||||
import json
|
import json
|
||||||
from json_repair import repair_json
|
from json_repair import repair_json
|
||||||
|
|
||||||
# 以下代码用于修复损坏的 JSON 字符串。
|
def _find_unclosed(json_str):
|
||||||
|
"""
|
||||||
|
Identifies the unclosed braces and brackets in the JSON string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
json_str (str): The JSON string to analyze.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: A list of unclosed elements in the order they were opened.
|
||||||
|
"""
|
||||||
|
unclosed = []
|
||||||
|
inside_string = False
|
||||||
|
escape_next = False
|
||||||
|
|
||||||
|
for char in json_str:
|
||||||
|
if inside_string:
|
||||||
|
if escape_next:
|
||||||
|
escape_next = False
|
||||||
|
elif char == "\\":
|
||||||
|
escape_next = True
|
||||||
|
elif char == '"':
|
||||||
|
inside_string = False
|
||||||
|
else:
|
||||||
|
if char == '"':
|
||||||
|
inside_string = True
|
||||||
|
elif char in "{[":
|
||||||
|
unclosed.append(char)
|
||||||
|
elif char in "}]":
|
||||||
|
if unclosed and ((char == "}" and unclosed[-1] == "{") or (char == "]" and unclosed[-1] == "[")):
|
||||||
|
unclosed.pop()
|
||||||
|
|
||||||
|
return unclosed
|
||||||
|
|
||||||
|
|
||||||
|
# The following code is used to fix a broken JSON string.
|
||||||
|
# From HippoRAG2 (GitHub: OSU-NLP-Group/HippoRAG)
|
||||||
def fix_broken_generated_json(json_str: str) -> str:
|
def fix_broken_generated_json(json_str: str) -> str:
|
||||||
|
"""
|
||||||
|
Fixes a malformed JSON string by:
|
||||||
|
- Removing the last comma and any trailing content.
|
||||||
|
- Iterating over the JSON string once to determine and fix unclosed braces or brackets.
|
||||||
|
- Ensuring braces and brackets inside string literals are not considered.
|
||||||
|
|
||||||
|
If the original json_str string can be successfully loaded by json.loads(), will directly return it without any modification.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
json_str (str): The malformed JSON string to be fixed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The corrected JSON string.
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Try to load the JSON to see if it is valid
|
||||||
|
json.loads(json_str)
|
||||||
|
return json_str # Return as-is if valid
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Step 1: Remove trailing content after the last comma.
|
||||||
|
last_comma_index = json_str.rfind(",")
|
||||||
|
if last_comma_index != -1:
|
||||||
|
json_str = json_str[:last_comma_index]
|
||||||
|
|
||||||
|
# Step 2: Identify unclosed braces and brackets.
|
||||||
|
unclosed_elements = _find_unclosed(json_str)
|
||||||
|
|
||||||
|
# Step 3: Append the necessary closing elements in reverse order of opening.
|
||||||
|
closing_map = {"{": "}", "[": "]"}
|
||||||
|
for open_char in reversed(unclosed_elements):
|
||||||
|
json_str += closing_map[open_char]
|
||||||
|
|
||||||
|
return json_str
|
||||||
|
|
||||||
|
|
||||||
|
def new_fix_broken_generated_json(json_str: str) -> str:
|
||||||
"""
|
"""
|
||||||
使用 json-repair 库修复格式错误的 JSON 字符串。
|
使用 json-repair 库修复格式错误的 JSON 字符串。
|
||||||
|
|
||||||
@@ -22,4 +94,4 @@ def fix_broken_generated_json(json_str: str) -> str:
|
|||||||
return json_str # 如果有效则按原样返回
|
return json_str # 如果有效则按原样返回
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
# 如果无效,则尝试修复它
|
# 如果无效,则尝试修复它
|
||||||
return repair_json(json_str)
|
return repair_json(json_str)
|
||||||
Reference in New Issue
Block a user