优化代码格式和异常处理
- 修复异常处理链,使用from语法保留原始异常 - 格式化代码以符合项目规范 - 优化导入模块的顺序 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -2,6 +2,7 @@ import sys
|
||||
import loguru
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class LogClassification(Enum):
|
||||
BASE = "base"
|
||||
MEMORY = "memory"
|
||||
@@ -9,14 +10,16 @@ class LogClassification(Enum):
|
||||
CHAT = "chat"
|
||||
PBUILDER = "promptbuilder"
|
||||
|
||||
|
||||
class LogModule:
|
||||
logger = loguru.logger.opt()
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def setup_logger(self, log_type: LogClassification):
|
||||
"""配置日志格式
|
||||
|
||||
|
||||
Args:
|
||||
log_type: 日志类型,可选值:BASE(基础日志)、MEMORY(记忆系统日志)、EMOJI(表情包系统日志)
|
||||
"""
|
||||
@@ -24,19 +27,33 @@ class LogModule:
|
||||
self.logger.remove()
|
||||
|
||||
# 基础日志格式
|
||||
base_format = "<green>{time:HH:mm:ss}</green> | <level>{level: <8}</level> | <cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>"
|
||||
|
||||
chat_format = "<green>{time:HH:mm:ss}</green> | <level>{level: <8}</level> | <cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>"
|
||||
|
||||
base_format = (
|
||||
"<green>{time:HH:mm:ss}</green> | <level>{level: <8}</level> | "
|
||||
" d<cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>"
|
||||
)
|
||||
|
||||
chat_format = (
|
||||
"<green>{time:HH:mm:ss}</green> | <level>{level: <8}</level> | "
|
||||
"<cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>"
|
||||
)
|
||||
|
||||
# 记忆系统日志格式
|
||||
memory_format = "<green>{time:HH:mm}</green> | <level>{level: <8}</level> | <light-magenta>海马体</light-magenta> | <level>{message}</level>"
|
||||
|
||||
memory_format = (
|
||||
"<green>{time:HH:mm}</green> | <level>{level: <8}</level> | "
|
||||
"<light-magenta>海马体</light-magenta> | <level>{message}</level>"
|
||||
)
|
||||
|
||||
# 表情包系统日志格式
|
||||
emoji_format = "<green>{time:HH:mm}</green> | <level>{level: <8}</level> | <yellow>表情包</yellow> | <cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>"
|
||||
|
||||
promptbuilder_format = "<green>{time:HH:mm}</green> | <level>{level: <8}</level> | <yellow>Prompt</yellow> | <cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>"
|
||||
|
||||
|
||||
emoji_format = (
|
||||
"<green>{time:HH:mm}</green> | <level>{level: <8}</level> | <yellow>表情包</yellow> | "
|
||||
"<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>"
|
||||
)
|
||||
|
||||
promptbuilder_format = (
|
||||
"<green>{time:HH:mm}</green> | <level>{level: <8}</level> | <yellow>Prompt</yellow> | "
|
||||
"<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>"
|
||||
)
|
||||
|
||||
# 根据日志类型选择日志格式和输出
|
||||
if log_type == LogClassification.CHAT:
|
||||
self.logger.add(
|
||||
@@ -51,38 +68,21 @@ class LogModule:
|
||||
# level="INFO"
|
||||
)
|
||||
elif log_type == LogClassification.MEMORY:
|
||||
|
||||
# 同时输出到控制台和文件
|
||||
self.logger.add(
|
||||
sys.stderr,
|
||||
format=memory_format,
|
||||
# level="INFO"
|
||||
)
|
||||
self.logger.add(
|
||||
"logs/memory.log",
|
||||
format=memory_format,
|
||||
level="INFO",
|
||||
rotation="1 day",
|
||||
retention="7 days"
|
||||
)
|
||||
self.logger.add("logs/memory.log", format=memory_format, level="INFO", rotation="1 day", retention="7 days")
|
||||
elif log_type == LogClassification.EMOJI:
|
||||
self.logger.add(
|
||||
sys.stderr,
|
||||
format=emoji_format,
|
||||
# level="INFO"
|
||||
)
|
||||
self.logger.add(
|
||||
"logs/emoji.log",
|
||||
format=emoji_format,
|
||||
level="INFO",
|
||||
rotation="1 day",
|
||||
retention="7 days"
|
||||
)
|
||||
self.logger.add("logs/emoji.log", format=emoji_format, level="INFO", rotation="1 day", retention="7 days")
|
||||
else: # BASE
|
||||
self.logger.add(
|
||||
sys.stderr,
|
||||
format=base_format,
|
||||
level="INFO"
|
||||
)
|
||||
|
||||
self.logger.add(sys.stderr, format=base_format, level="INFO")
|
||||
|
||||
return self.logger
|
||||
|
||||
@@ -9,17 +9,18 @@ from ...common.database import db
|
||||
|
||||
logger = get_module_logger("llm_statistics")
|
||||
|
||||
|
||||
class LLMStatistics:
|
||||
def __init__(self, output_file: str = "llm_statistics.txt"):
|
||||
"""初始化LLM统计类
|
||||
|
||||
|
||||
Args:
|
||||
output_file: 统计结果输出文件路径
|
||||
"""
|
||||
self.output_file = output_file
|
||||
self.running = False
|
||||
self.stats_thread = None
|
||||
|
||||
|
||||
def start(self):
|
||||
"""启动统计线程"""
|
||||
if not self.running:
|
||||
@@ -27,16 +28,16 @@ class LLMStatistics:
|
||||
self.stats_thread = threading.Thread(target=self._stats_loop)
|
||||
self.stats_thread.daemon = True
|
||||
self.stats_thread.start()
|
||||
|
||||
|
||||
def stop(self):
|
||||
"""停止统计线程"""
|
||||
self.running = False
|
||||
if self.stats_thread:
|
||||
self.stats_thread.join()
|
||||
|
||||
|
||||
def _collect_statistics_for_period(self, start_time: datetime) -> Dict[str, Any]:
|
||||
"""收集指定时间段的LLM请求统计数据
|
||||
|
||||
|
||||
Args:
|
||||
start_time: 统计开始时间
|
||||
"""
|
||||
@@ -51,28 +52,26 @@ class LLMStatistics:
|
||||
"costs_by_user": defaultdict(float),
|
||||
"costs_by_type": defaultdict(float),
|
||||
"costs_by_model": defaultdict(float),
|
||||
#新增token统计字段
|
||||
# 新增token统计字段
|
||||
"tokens_by_type": defaultdict(int),
|
||||
"tokens_by_user": defaultdict(int),
|
||||
"tokens_by_model": defaultdict(int),
|
||||
}
|
||||
|
||||
cursor = db.llm_usage.find({
|
||||
"timestamp": {"$gte": start_time}
|
||||
})
|
||||
|
||||
|
||||
cursor = db.llm_usage.find({"timestamp": {"$gte": start_time}})
|
||||
|
||||
total_requests = 0
|
||||
|
||||
|
||||
for doc in cursor:
|
||||
stats["total_requests"] += 1
|
||||
request_type = doc.get("request_type", "unknown")
|
||||
user_id = str(doc.get("user_id", "unknown"))
|
||||
model_name = doc.get("model_name", "unknown")
|
||||
|
||||
|
||||
stats["requests_by_type"][request_type] += 1
|
||||
stats["requests_by_user"][user_id] += 1
|
||||
stats["requests_by_model"][model_name] += 1
|
||||
|
||||
|
||||
prompt_tokens = doc.get("prompt_tokens", 0)
|
||||
completion_tokens = doc.get("completion_tokens", 0)
|
||||
total_tokens = prompt_tokens + completion_tokens # 根据数据库字段调整
|
||||
@@ -80,112 +79,107 @@ class LLMStatistics:
|
||||
stats["tokens_by_user"][user_id] += total_tokens
|
||||
stats["tokens_by_model"][model_name] += total_tokens
|
||||
stats["total_tokens"] += total_tokens
|
||||
|
||||
|
||||
cost = doc.get("cost", 0.0)
|
||||
stats["total_cost"] += cost
|
||||
stats["costs_by_user"][user_id] += cost
|
||||
stats["costs_by_type"][request_type] += cost
|
||||
stats["costs_by_model"][model_name] += cost
|
||||
|
||||
|
||||
total_requests += 1
|
||||
|
||||
|
||||
if total_requests > 0:
|
||||
stats["average_tokens"] = stats["total_tokens"] / total_requests
|
||||
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
def _collect_all_statistics(self) -> Dict[str, Dict[str, Any]]:
|
||||
"""收集所有时间范围的统计数据"""
|
||||
now = datetime.now()
|
||||
|
||||
|
||||
return {
|
||||
"all_time": self._collect_statistics_for_period(datetime.min),
|
||||
"last_7_days": self._collect_statistics_for_period(now - timedelta(days=7)),
|
||||
"last_24_hours": self._collect_statistics_for_period(now - timedelta(days=1)),
|
||||
"last_hour": self._collect_statistics_for_period(now - timedelta(hours=1))
|
||||
"last_hour": self._collect_statistics_for_period(now - timedelta(hours=1)),
|
||||
}
|
||||
|
||||
|
||||
def _format_stats_section(self, stats: Dict[str, Any], title: str) -> str:
|
||||
"""格式化统计部分的输出"""
|
||||
output = []
|
||||
|
||||
output.append("\n"+"-" * 84)
|
||||
output.append("\n" + "-" * 84)
|
||||
output.append(f"{title}")
|
||||
output.append("-" * 84)
|
||||
|
||||
|
||||
output.append(f"总请求数: {stats['total_requests']}")
|
||||
if stats['total_requests'] > 0:
|
||||
if stats["total_requests"] > 0:
|
||||
output.append(f"总Token数: {stats['total_tokens']}")
|
||||
output.append(f"总花费: {stats['total_cost']:.4f}¥\n")
|
||||
|
||||
|
||||
data_fmt = "{:<32} {:>10} {:>14} {:>13.4f} ¥"
|
||||
|
||||
|
||||
# 按模型统计
|
||||
output.append("按模型统计:")
|
||||
output.append(("模型名称 调用次数 Token总量 累计花费"))
|
||||
for model_name, count in sorted(stats["requests_by_model"].items()):
|
||||
tokens = stats["tokens_by_model"][model_name]
|
||||
cost = stats["costs_by_model"][model_name]
|
||||
output.append(data_fmt.format(
|
||||
model_name[:32] + ".." if len(model_name) > 32 else model_name,
|
||||
count,
|
||||
tokens,
|
||||
cost
|
||||
))
|
||||
output.append(
|
||||
data_fmt.format(model_name[:32] + ".." if len(model_name) > 32 else model_name, count, tokens, cost)
|
||||
)
|
||||
output.append("")
|
||||
|
||||
|
||||
# 按请求类型统计
|
||||
output.append("按请求类型统计:")
|
||||
output.append(("模型名称 调用次数 Token总量 累计花费"))
|
||||
for req_type, count in sorted(stats["requests_by_type"].items()):
|
||||
tokens = stats["tokens_by_type"][req_type]
|
||||
cost = stats["costs_by_type"][req_type]
|
||||
output.append(data_fmt.format(
|
||||
req_type[:22] + ".." if len(req_type) > 24 else req_type,
|
||||
count,
|
||||
tokens,
|
||||
cost
|
||||
))
|
||||
output.append(
|
||||
data_fmt.format(req_type[:22] + ".." if len(req_type) > 24 else req_type, count, tokens, cost)
|
||||
)
|
||||
output.append("")
|
||||
|
||||
|
||||
# 修正用户统计列宽
|
||||
output.append("按用户统计:")
|
||||
output.append(("模型名称 调用次数 Token总量 累计花费"))
|
||||
for user_id, count in sorted(stats["requests_by_user"].items()):
|
||||
tokens = stats["tokens_by_user"][user_id]
|
||||
cost = stats["costs_by_user"][user_id]
|
||||
output.append(data_fmt.format(
|
||||
user_id[:22], # 不再添加省略号,保持原始ID
|
||||
count,
|
||||
tokens,
|
||||
cost
|
||||
))
|
||||
output.append(
|
||||
data_fmt.format(
|
||||
user_id[:22], # 不再添加省略号,保持原始ID
|
||||
count,
|
||||
tokens,
|
||||
cost,
|
||||
)
|
||||
)
|
||||
|
||||
return "\n".join(output)
|
||||
|
||||
|
||||
def _save_statistics(self, all_stats: Dict[str, Dict[str, Any]]):
|
||||
"""将统计结果保存到文件"""
|
||||
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
|
||||
output = []
|
||||
output.append(f"LLM请求统计报告 (生成时间: {current_time})")
|
||||
|
||||
|
||||
# 添加各个时间段的统计
|
||||
sections = [
|
||||
("所有时间统计", "all_time"),
|
||||
("最近7天统计", "last_7_days"),
|
||||
("最近24小时统计", "last_24_hours"),
|
||||
("最近1小时统计", "last_hour")
|
||||
("最近1小时统计", "last_hour"),
|
||||
]
|
||||
|
||||
|
||||
for title, key in sections:
|
||||
output.append(self._format_stats_section(all_stats[key], title))
|
||||
|
||||
|
||||
# 写入文件
|
||||
with open(self.output_file, "w", encoding="utf-8") as f:
|
||||
f.write("\n".join(output))
|
||||
|
||||
|
||||
def _stats_loop(self):
|
||||
"""统计循环,每1分钟运行一次"""
|
||||
while self.running:
|
||||
@@ -194,7 +188,7 @@ class LLMStatistics:
|
||||
self._save_statistics(all_stats)
|
||||
except Exception:
|
||||
logger.exception("统计数据处理失败")
|
||||
|
||||
|
||||
# 等待1分钟
|
||||
for _ in range(60):
|
||||
if not self.running:
|
||||
|
||||
@@ -17,16 +17,12 @@ from src.common.logger import get_module_logger
|
||||
|
||||
logger = get_module_logger("typo_gen")
|
||||
|
||||
|
||||
class ChineseTypoGenerator:
|
||||
def __init__(self,
|
||||
error_rate=0.3,
|
||||
min_freq=5,
|
||||
tone_error_rate=0.2,
|
||||
word_replace_rate=0.3,
|
||||
max_freq_diff=200):
|
||||
def __init__(self, error_rate=0.3, min_freq=5, tone_error_rate=0.2, word_replace_rate=0.3, max_freq_diff=200):
|
||||
"""
|
||||
初始化错别字生成器
|
||||
|
||||
|
||||
参数:
|
||||
error_rate: 单字替换概率
|
||||
min_freq: 最小字频阈值
|
||||
@@ -39,46 +35,46 @@ class ChineseTypoGenerator:
|
||||
self.tone_error_rate = tone_error_rate
|
||||
self.word_replace_rate = word_replace_rate
|
||||
self.max_freq_diff = max_freq_diff
|
||||
|
||||
|
||||
# 加载数据
|
||||
# print("正在加载汉字数据库,请稍候...")
|
||||
# logger.info("正在加载汉字数据库,请稍候...")
|
||||
|
||||
|
||||
self.pinyin_dict = self._create_pinyin_dict()
|
||||
self.char_frequency = self._load_or_create_char_frequency()
|
||||
|
||||
|
||||
def _load_or_create_char_frequency(self):
|
||||
"""
|
||||
加载或创建汉字频率字典
|
||||
"""
|
||||
cache_file = Path("char_frequency.json")
|
||||
|
||||
|
||||
# 如果缓存文件存在,直接加载
|
||||
if cache_file.exists():
|
||||
with open(cache_file, 'r', encoding='utf-8') as f:
|
||||
with open(cache_file, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
# 使用内置的词频文件
|
||||
char_freq = defaultdict(int)
|
||||
dict_path = os.path.join(os.path.dirname(jieba.__file__), 'dict.txt')
|
||||
|
||||
dict_path = os.path.join(os.path.dirname(jieba.__file__), "dict.txt")
|
||||
|
||||
# 读取jieba的词典文件
|
||||
with open(dict_path, 'r', encoding='utf-8') as f:
|
||||
with open(dict_path, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
word, freq = line.strip().split()[:2]
|
||||
# 对词中的每个字进行频率累加
|
||||
for char in word:
|
||||
if self._is_chinese_char(char):
|
||||
char_freq[char] += int(freq)
|
||||
|
||||
|
||||
# 归一化频率值
|
||||
max_freq = max(char_freq.values())
|
||||
normalized_freq = {char: freq/max_freq * 1000 for char, freq in char_freq.items()}
|
||||
|
||||
normalized_freq = {char: freq / max_freq * 1000 for char, freq in char_freq.items()}
|
||||
|
||||
# 保存到缓存文件
|
||||
with open(cache_file, 'w', encoding='utf-8') as f:
|
||||
with open(cache_file, "w", encoding="utf-8") as f:
|
||||
json.dump(normalized_freq, f, ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
return normalized_freq
|
||||
|
||||
def _create_pinyin_dict(self):
|
||||
@@ -86,9 +82,9 @@ class ChineseTypoGenerator:
|
||||
创建拼音到汉字的映射字典
|
||||
"""
|
||||
# 常用汉字范围
|
||||
chars = [chr(i) for i in range(0x4e00, 0x9fff)]
|
||||
chars = [chr(i) for i in range(0x4E00, 0x9FFF)]
|
||||
pinyin_dict = defaultdict(list)
|
||||
|
||||
|
||||
# 为每个汉字建立拼音映射
|
||||
for char in chars:
|
||||
try:
|
||||
@@ -96,7 +92,7 @@ class ChineseTypoGenerator:
|
||||
pinyin_dict[py].append(char)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
|
||||
return pinyin_dict
|
||||
|
||||
def _is_chinese_char(self, char):
|
||||
@@ -104,8 +100,9 @@ class ChineseTypoGenerator:
|
||||
判断是否为汉字
|
||||
"""
|
||||
try:
|
||||
return '\u4e00' <= char <= '\u9fff'
|
||||
except:
|
||||
return "\u4e00" <= char <= "\u9fff"
|
||||
except Exception as e:
|
||||
logger.debug(e)
|
||||
return False
|
||||
|
||||
def _get_pinyin(self, sentence):
|
||||
@@ -114,7 +111,7 @@ class ChineseTypoGenerator:
|
||||
"""
|
||||
# 将句子拆分成单个字符
|
||||
characters = list(sentence)
|
||||
|
||||
|
||||
# 获取每个字符的拼音
|
||||
result = []
|
||||
for char in characters:
|
||||
@@ -124,7 +121,7 @@ class ChineseTypoGenerator:
|
||||
# 获取拼音(数字声调)
|
||||
py = pinyin(char, style=Style.TONE3)[0][0]
|
||||
result.append((char, py))
|
||||
|
||||
|
||||
return result
|
||||
|
||||
def _get_similar_tone_pinyin(self, py):
|
||||
@@ -134,19 +131,19 @@ class ChineseTypoGenerator:
|
||||
# 检查拼音是否为空或无效
|
||||
if not py or len(py) < 1:
|
||||
return py
|
||||
|
||||
|
||||
# 如果最后一个字符不是数字,说明可能是轻声或其他特殊情况
|
||||
if not py[-1].isdigit():
|
||||
# 为非数字结尾的拼音添加数字声调1
|
||||
return py + '1'
|
||||
|
||||
return py + "1"
|
||||
|
||||
base = py[:-1] # 去掉声调
|
||||
tone = int(py[-1]) # 获取声调
|
||||
|
||||
|
||||
# 处理轻声(通常用5表示)或无效声调
|
||||
if tone not in [1, 2, 3, 4]:
|
||||
return base + str(random.choice([1, 2, 3, 4]))
|
||||
|
||||
|
||||
# 正常处理声调
|
||||
possible_tones = [1, 2, 3, 4]
|
||||
possible_tones.remove(tone) # 移除原声调
|
||||
@@ -159,11 +156,11 @@ class ChineseTypoGenerator:
|
||||
"""
|
||||
if target_freq > orig_freq:
|
||||
return 1.0 # 如果替换字频率更高,保持原有概率
|
||||
|
||||
|
||||
freq_diff = orig_freq - target_freq
|
||||
if freq_diff > self.max_freq_diff:
|
||||
return 0.0 # 频率差太大,不替换
|
||||
|
||||
|
||||
# 使用指数衰减函数计算概率
|
||||
# 频率差为0时概率为1,频率差为max_freq_diff时概率接近0
|
||||
return math.exp(-3 * freq_diff / self.max_freq_diff)
|
||||
@@ -173,42 +170,44 @@ class ChineseTypoGenerator:
|
||||
获取与给定字频率相近的同音字,可能包含声调错误
|
||||
"""
|
||||
homophones = []
|
||||
|
||||
|
||||
# 有一定概率使用错误声调
|
||||
if random.random() < self.tone_error_rate:
|
||||
wrong_tone_py = self._get_similar_tone_pinyin(py)
|
||||
homophones.extend(self.pinyin_dict[wrong_tone_py])
|
||||
|
||||
|
||||
# 添加正确声调的同音字
|
||||
homophones.extend(self.pinyin_dict[py])
|
||||
|
||||
|
||||
if not homophones:
|
||||
return None
|
||||
|
||||
|
||||
# 获取原字的频率
|
||||
orig_freq = self.char_frequency.get(char, 0)
|
||||
|
||||
|
||||
# 计算所有同音字与原字的频率差,并过滤掉低频字
|
||||
freq_diff = [(h, self.char_frequency.get(h, 0))
|
||||
for h in homophones
|
||||
if h != char and self.char_frequency.get(h, 0) >= self.min_freq]
|
||||
|
||||
freq_diff = [
|
||||
(h, self.char_frequency.get(h, 0))
|
||||
for h in homophones
|
||||
if h != char and self.char_frequency.get(h, 0) >= self.min_freq
|
||||
]
|
||||
|
||||
if not freq_diff:
|
||||
return None
|
||||
|
||||
|
||||
# 计算每个候选字的替换概率
|
||||
candidates_with_prob = []
|
||||
for h, freq in freq_diff:
|
||||
prob = self._calculate_replacement_probability(orig_freq, freq)
|
||||
if prob > 0: # 只保留有效概率的候选字
|
||||
candidates_with_prob.append((h, prob))
|
||||
|
||||
|
||||
if not candidates_with_prob:
|
||||
return None
|
||||
|
||||
|
||||
# 根据概率排序
|
||||
candidates_with_prob.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
|
||||
# 返回概率最高的几个字
|
||||
return [char for char, _ in candidates_with_prob[:num_candidates]]
|
||||
|
||||
@@ -230,10 +229,10 @@ class ChineseTypoGenerator:
|
||||
"""
|
||||
if len(word) == 1:
|
||||
return []
|
||||
|
||||
|
||||
# 获取词的拼音
|
||||
word_pinyin = self._get_word_pinyin(word)
|
||||
|
||||
|
||||
# 遍历所有可能的同音字组合
|
||||
candidates = []
|
||||
for py in word_pinyin:
|
||||
@@ -241,30 +240,31 @@ class ChineseTypoGenerator:
|
||||
if not chars:
|
||||
return []
|
||||
candidates.append(chars)
|
||||
|
||||
|
||||
# 生成所有可能的组合
|
||||
import itertools
|
||||
|
||||
all_combinations = itertools.product(*candidates)
|
||||
|
||||
|
||||
# 获取jieba词典和词频信息
|
||||
dict_path = os.path.join(os.path.dirname(jieba.__file__), 'dict.txt')
|
||||
dict_path = os.path.join(os.path.dirname(jieba.__file__), "dict.txt")
|
||||
valid_words = {} # 改用字典存储词语及其频率
|
||||
with open(dict_path, 'r', encoding='utf-8') as f:
|
||||
with open(dict_path, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
parts = line.strip().split()
|
||||
if len(parts) >= 2:
|
||||
word_text = parts[0]
|
||||
word_freq = float(parts[1]) # 获取词频
|
||||
valid_words[word_text] = word_freq
|
||||
|
||||
|
||||
# 获取原词的词频作为参考
|
||||
original_word_freq = valid_words.get(word, 0)
|
||||
min_word_freq = original_word_freq * 0.1 # 设置最小词频为原词频的10%
|
||||
|
||||
|
||||
# 过滤和计算频率
|
||||
homophones = []
|
||||
for combo in all_combinations:
|
||||
new_word = ''.join(combo)
|
||||
new_word = "".join(combo)
|
||||
if new_word != word and new_word in valid_words:
|
||||
new_word_freq = valid_words[new_word]
|
||||
# 只保留词频达到阈值的词
|
||||
@@ -272,10 +272,10 @@ class ChineseTypoGenerator:
|
||||
# 计算词的平均字频(考虑字频和词频)
|
||||
char_avg_freq = sum(self.char_frequency.get(c, 0) for c in new_word) / len(new_word)
|
||||
# 综合评分:结合词频和字频
|
||||
combined_score = (new_word_freq * 0.7 + char_avg_freq * 0.3)
|
||||
combined_score = new_word_freq * 0.7 + char_avg_freq * 0.3
|
||||
if combined_score >= self.min_freq:
|
||||
homophones.append((new_word, combined_score))
|
||||
|
||||
|
||||
# 按综合分数排序并限制返回数量
|
||||
sorted_homophones = sorted(homophones, key=lambda x: x[1], reverse=True)
|
||||
return [word for word, _ in sorted_homophones[:5]] # 限制返回前5个结果
|
||||
@@ -283,10 +283,10 @@ class ChineseTypoGenerator:
|
||||
def create_typo_sentence(self, sentence):
|
||||
"""
|
||||
创建包含同音字错误的句子,支持词语级别和字级别的替换
|
||||
|
||||
|
||||
参数:
|
||||
sentence: 输入的中文句子
|
||||
|
||||
|
||||
返回:
|
||||
typo_sentence: 包含错别字的句子
|
||||
correction_suggestion: 随机选择的一个纠正建议,返回正确的字/词
|
||||
@@ -296,20 +296,20 @@ class ChineseTypoGenerator:
|
||||
word_typos = [] # 记录词语错误对(错词,正确词)
|
||||
char_typos = [] # 记录单字错误对(错字,正确字)
|
||||
current_pos = 0
|
||||
|
||||
|
||||
# 分词
|
||||
words = self._segment_sentence(sentence)
|
||||
|
||||
|
||||
for word in words:
|
||||
# 如果是标点符号或空格,直接添加
|
||||
if all(not self._is_chinese_char(c) for c in word):
|
||||
result.append(word)
|
||||
current_pos += len(word)
|
||||
continue
|
||||
|
||||
|
||||
# 获取词语的拼音
|
||||
word_pinyin = self._get_word_pinyin(word)
|
||||
|
||||
|
||||
# 尝试整词替换
|
||||
if len(word) > 1 and random.random() < self.word_replace_rate:
|
||||
word_homophones = self._get_word_homophones(word)
|
||||
@@ -318,17 +318,23 @@ class ChineseTypoGenerator:
|
||||
# 计算词的平均频率
|
||||
orig_freq = sum(self.char_frequency.get(c, 0) for c in word) / len(word)
|
||||
typo_freq = sum(self.char_frequency.get(c, 0) for c in typo_word) / len(typo_word)
|
||||
|
||||
|
||||
# 添加到结果中
|
||||
result.append(typo_word)
|
||||
typo_info.append((word, typo_word,
|
||||
' '.join(word_pinyin),
|
||||
' '.join(self._get_word_pinyin(typo_word)),
|
||||
orig_freq, typo_freq))
|
||||
typo_info.append(
|
||||
(
|
||||
word,
|
||||
typo_word,
|
||||
" ".join(word_pinyin),
|
||||
" ".join(self._get_word_pinyin(typo_word)),
|
||||
orig_freq,
|
||||
typo_freq,
|
||||
)
|
||||
)
|
||||
word_typos.append((typo_word, word)) # 记录(错词,正确词)对
|
||||
current_pos += len(typo_word)
|
||||
continue
|
||||
|
||||
|
||||
# 如果不进行整词替换,则进行单字替换
|
||||
if len(word) == 1:
|
||||
char = word
|
||||
@@ -352,11 +358,10 @@ class ChineseTypoGenerator:
|
||||
else:
|
||||
# 处理多字词的单字替换
|
||||
word_result = []
|
||||
word_start_pos = current_pos
|
||||
for i, (char, py) in enumerate(zip(word, word_pinyin)):
|
||||
for _, (char, py) in enumerate(zip(word, word_pinyin)):
|
||||
# 词中的字替换概率降低
|
||||
word_error_rate = self.error_rate * (0.7 ** (len(word) - 1))
|
||||
|
||||
|
||||
if random.random() < word_error_rate:
|
||||
similar_chars = self._get_similar_frequency_chars(char, py)
|
||||
if similar_chars:
|
||||
@@ -371,9 +376,9 @@ class ChineseTypoGenerator:
|
||||
char_typos.append((typo_char, char)) # 记录(错字,正确字)对
|
||||
continue
|
||||
word_result.append(char)
|
||||
result.append(''.join(word_result))
|
||||
result.append("".join(word_result))
|
||||
current_pos += len(word)
|
||||
|
||||
|
||||
# 优先从词语错误中选择,如果没有则从单字错误中选择
|
||||
correction_suggestion = None
|
||||
# 50%概率返回纠正建议
|
||||
@@ -384,41 +389,43 @@ class ChineseTypoGenerator:
|
||||
elif char_typos:
|
||||
wrong_char, correct_char = random.choice(char_typos)
|
||||
correction_suggestion = correct_char
|
||||
|
||||
return ''.join(result), correction_suggestion
|
||||
|
||||
return "".join(result), correction_suggestion
|
||||
|
||||
def format_typo_info(self, typo_info):
|
||||
"""
|
||||
格式化错别字信息
|
||||
|
||||
|
||||
参数:
|
||||
typo_info: 错别字信息列表
|
||||
|
||||
|
||||
返回:
|
||||
格式化后的错别字信息字符串
|
||||
"""
|
||||
if not typo_info:
|
||||
return "未生成错别字"
|
||||
|
||||
|
||||
result = []
|
||||
for orig, typo, orig_py, typo_py, orig_freq, typo_freq in typo_info:
|
||||
# 判断是否为词语替换
|
||||
is_word = ' ' in orig_py
|
||||
is_word = " " in orig_py
|
||||
if is_word:
|
||||
error_type = "整词替换"
|
||||
else:
|
||||
tone_error = orig_py[:-1] == typo_py[:-1] and orig_py[-1] != typo_py[-1]
|
||||
error_type = "声调错误" if tone_error else "同音字替换"
|
||||
|
||||
result.append(f"原文:{orig}({orig_py}) [频率:{orig_freq:.2f}] -> "
|
||||
f"替换:{typo}({typo_py}) [频率:{typo_freq:.2f}] [{error_type}]")
|
||||
|
||||
|
||||
result.append(
|
||||
f"原文:{orig}({orig_py}) [频率:{orig_freq:.2f}] -> "
|
||||
f"替换:{typo}({typo_py}) [频率:{typo_freq:.2f}] [{error_type}]"
|
||||
)
|
||||
|
||||
return "\n".join(result)
|
||||
|
||||
|
||||
def set_params(self, **kwargs):
|
||||
"""
|
||||
设置参数
|
||||
|
||||
|
||||
可设置参数:
|
||||
error_rate: 单字替换概率
|
||||
min_freq: 最小字频阈值
|
||||
@@ -433,35 +440,32 @@ class ChineseTypoGenerator:
|
||||
else:
|
||||
print(f"警告: 参数 {key} 不存在")
|
||||
|
||||
|
||||
def main():
|
||||
# 创建错别字生成器实例
|
||||
typo_generator = ChineseTypoGenerator(
|
||||
error_rate=0.03,
|
||||
min_freq=7,
|
||||
tone_error_rate=0.02,
|
||||
word_replace_rate=0.3
|
||||
)
|
||||
|
||||
typo_generator = ChineseTypoGenerator(error_rate=0.03, min_freq=7, tone_error_rate=0.02, word_replace_rate=0.3)
|
||||
|
||||
# 获取用户输入
|
||||
sentence = input("请输入中文句子:")
|
||||
|
||||
|
||||
# 创建包含错别字的句子
|
||||
start_time = time.time()
|
||||
typo_sentence, correction_suggestion = typo_generator.create_typo_sentence(sentence)
|
||||
|
||||
|
||||
# 打印结果
|
||||
print("\n原句:", sentence)
|
||||
print("错字版:", typo_sentence)
|
||||
|
||||
|
||||
# 打印纠正建议
|
||||
if correction_suggestion:
|
||||
print("\n随机纠正建议:")
|
||||
print(f"应该改为:{correction_suggestion}")
|
||||
|
||||
|
||||
# 计算并打印总耗时
|
||||
end_time = time.time()
|
||||
total_time = end_time - start_time
|
||||
print(f"\n总耗时:{total_time:.2f}秒")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user