Merge remote-tracking branch 'upstream/debug' into feature

This commit is contained in:
tcmofashi
2025-03-04 08:18:22 +08:00
34 changed files with 13732 additions and 413 deletions

View File

@@ -17,12 +17,12 @@ driver = get_driver()
config = driver.config
Database.initialize(
host= config.mongodb_host,
port= int(config.mongodb_port),
db_name= config.database_name,
username= config.mongodb_username,
password= config.mongodb_password,
auth_source= config.mongodb_auth_source
host= config.MONGODB_HOST,
port= int(config.MONGODB_PORT),
db_name= config.DATABASE_NAME,
username= config.MONGODB_USERNAME,
password= config.MONGODB_PASSWORD,
auth_source= config.MONGODB_AUTH_SOURCE
)
print("\033[1;32m[初始化数据库完成]\033[0m")

View File

@@ -97,8 +97,13 @@ class ChatBot:
current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(message.time))
topic = topic_identifier.identify_topic_jieba(message.processed_plain_text)
print(f"\033[1;32m[主题识别]\033[0m 主题: {topic}")
topic1 = topic_identifier.identify_topic_jieba(message.processed_plain_text)
topic2 = await topic_identifier.identify_topic_llm(message.processed_plain_text)
topic3 = topic_identifier.identify_topic_snownlp(message.processed_plain_text)
print(f"\033[1;32m[主题识别]\033[0m 使用jieba主题: {topic1}")
print(f"\033[1;32m[主题识别]\033[0m 使用llm主题: {topic2}")
print(f"\033[1;32m[主题识别]\033[0m 使用snownlp主题: {topic3}")
topic = topic3
all_num = 0
interested_num = 0
@@ -166,7 +171,6 @@ class ChatBot:
group_id=event.group_id,
user_id=global_config.BOT_QQ,
message_id=think_id,
message_based_id=event.message_id,
raw_message=msg,
plain_text=msg,
processed_plain_text=msg,

View File

@@ -116,6 +116,9 @@ class BotConfig:
if "vlm" in model_config:
config.vlm = model_config["vlm"]
if "embedding" in model_config:
config.embedding = model_config["embedding"]
# 消息配置
if "message" in toml_dict:
@@ -138,7 +141,7 @@ class BotConfig:
if "others" in toml_dict:
others_config = toml_dict["others"]
config.enable_advance_output = others_config.get("enable_advance_output", config.enable_advance_output)
config.enable_kuuki_read = others_config.get("enable_kuuki_read", config.enable_kuuki_read)
logger.success(f"成功加载配置文件: {config_path}")
@@ -152,31 +155,13 @@ bot_config_path = os.path.join(bot_config_floder_path, "bot_config_dev.toml")
if not os.path.exists(bot_config_path):
# 如果开发环境配置文件不存在,则使用默认配置文件
bot_config_path = os.path.join(bot_config_floder_path, "bot_config.toml")
logger.info("使用默认配置文件")
logger.info("使用bot配置文件")
else:
logger.info("已找到开发环境配置文件")
logger.info("已找到开发bot配置文件")
global_config = BotConfig.load_config(config_path=bot_config_path)
@dataclass
class LLMConfig:
"""机器人配置类"""
# 基础配置
SILICONFLOW_API_KEY: str = None
SILICONFLOW_BASE_URL: str = None
DEEP_SEEK_API_KEY: str = None
DEEP_SEEK_BASE_URL: str = None
llm_config = LLMConfig()
config = get_driver().config
llm_config.SILICONFLOW_API_KEY = config.siliconflow_key
llm_config.SILICONFLOW_BASE_URL = config.siliconflow_base_url
llm_config.DEEP_SEEK_API_KEY = config.deep_seek_key
llm_config.DEEP_SEEK_BASE_URL = config.deep_seek_base_url
if not global_config.enable_advance_output:
# logger.remove()
pass

View File

@@ -8,7 +8,7 @@ from ...common.database import Database
from PIL import Image
from .config import global_config
import urllib3
from .utils_user import get_user_nickname,get_user_cardname
from .utils_user import get_user_nickname,get_user_cardname,get_groupname
from .utils_cq import parse_cq_code
from .cq_code import cq_code_tool,CQCode
@@ -21,50 +21,47 @@ urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
#它还定义了两个辅助属性keywords用于提取消息的关键词is_plain_text用于判断消息是否为纯文本。
@dataclass
class Message:
"""消息数据类"""
message_id: int = None
time: float = None
group_id: int = None
group_name: str = None # 群名称
user_id: int = None
user_nickname: str = None # 用户昵称
user_cardname: str=None # 用户群昵称
group_name: str = None # 群名称
message_id: int = None
raw_message: str = None
plain_text: str = None
message_based_id: int = None
reply_message: Dict = None # 存储回复消息
raw_message: str = None # 原始消息包含未解析的cq码
plain_text: str = None # 纯文本
message_segments: List[Dict] = None # 存储解析后的消息片段
processed_plain_text: str = None # 用于存储处理后的plain_text
detailed_plain_text: str = None # 用于存储详细可读文本
time: float = None
reply_message: Dict = None # 存储 回复的 源消息
is_emoji: bool = False # 是否是表情包
has_emoji: bool = False # 是否包含表情包
translate_cq: bool = True # 是否翻译cq码
reply_benefits: float = 0.0
type: str = 'received' # 消息类型可以是received或者send
def __post_init__(self):
if self.time is None:
self.time = int(time.time())
if not self.group_name:
self.group_name = get_groupname(self.group_id)
if not self.user_nickname:
self.user_nickname = get_user_nickname(self.user_id)
if not self.user_cardname:
self.user_cardname=get_user_cardname(self.user_id)
if not self.group_name:
self.group_name = self.get_groupname(self.group_id)
if not self.processed_plain_text:
if self.raw_message:
self.message_segments = self.parse_message_segments(str(self.raw_message))
@@ -244,6 +241,38 @@ class MessageSet:
return len(self.messages)
@dataclass
class Message_Sending(Message):
"""发送消息数据类继承自Message类"""
priority: int = 0 # 发送优先级,数字越大优先级越高
wait_until: float = None # 等待发送的时间戳
continue_thinking: bool = False # 是否继续思考
def __post_init__(self):
super().__post_init__()
if self.wait_until is None:
self.wait_until = self.time
@property
def can_send(self) -> bool:
"""检查是否可以发送消息"""
return time.time() >= self.wait_until
def set_wait_time(self, seconds: float) -> None:
"""设置等待发送时间"""
self.wait_until = time.time() + seconds
def set_priority(self, priority: int) -> None:
"""设置发送优先级"""
self.priority = priority
def __lt__(self, other):
"""重写小于比较,用于优先级排序"""
if not isinstance(other, Message_Sending):
return NotImplemented
return (self.priority, -self.wait_until) < (other.priority, -other.wait_until)

View File

@@ -201,7 +201,7 @@ class MessageSendControl:
print(f"- 群组: {group_id} - 内容: {message.processed_plain_text}")
cost_time = round(time.time(), 2) - message.time
if cost_time > 40:
message.processed_plain_text = cq_code_tool.create_reply_cq(message.message_based_id) + message.processed_plain_text
message.processed_plain_text = cq_code_tool.create_reply_cq(message.message_id) + message.processed_plain_text
cur_time = time.time()
await self._current_bot.send_group_msg(
group_id=group_id,

View File

View File

@@ -127,15 +127,15 @@ class MessageStream:
# 从数据库中查询最近的消息
recent_messages = list(db.db.messages.find(
{"group_id": self.group_id},
{
"time": 1,
"user_id": 1,
"user_nickname": 1,
# "user_cardname": 1,
"message_id": 1,
"raw_message": 1,
"processed_text": 1
}
# {
# "time": 1,
# "user_id": 1,
# "user_nickname": 1,
# # "user_cardname": 1,
# "message_id": 1,
# "raw_message": 1,
# "processed_text": 1
# }
).sort("time", -1).limit(count))
if not recent_messages:
@@ -145,17 +145,21 @@ class MessageStream:
from .message import Message
messages = []
for msg_data in recent_messages:
msg = Message(
time=msg_data["time"],
user_id=msg_data["user_id"],
user_nickname=msg_data.get("user_nickname", ""),
user_cardname=msg_data.get("user_cardname", ""),
message_id=msg_data["message_id"],
raw_message=msg_data["raw_message"],
processed_plain_text=msg_data.get("processed_text", ""),
group_id=self.group_id
)
messages.append(msg)
try:
msg = Message(
time=msg_data["time"],
user_id=msg_data["user_id"],
user_nickname=msg_data.get("user_nickname", ""),
user_cardname=msg_data.get("user_cardname", ""),
message_id=msg_data["message_id"],
raw_message=msg_data["raw_message"],
processed_plain_text=msg_data.get("processed_text", ""),
group_id=self.group_id
)
messages.append(msg)
except KeyError:
print("[WARNING] 数据库中存在无效的消息")
continue
return list(reversed(messages)) # 返回按时间正序的消息

View File

@@ -118,7 +118,7 @@ class PromptBuilder:
prompt_info = self.get_prompt_info(message_txt,threshold=0.5)
if prompt_info:
prompt_info = f'''\n----------------------------------------------------\n你有以下这些[知识]\n{prompt_info}\n请你记住上面的[知识],之后可能会用到\n----------------------------------------------------\n'''
promt_info_prompt = '你有一些[知识],在上面可以参考。'
# promt_info_prompt = '你有一些[知识],在上面可以参考。'
end_time = time.time()
print(f"\033[1;32m[知识检索]\033[0m 耗时: {(end_time - start_time):.3f}")

View File

@@ -0,0 +1,14 @@
#Broca's Area
# 功能:语言产生、语法处理和言语运动控制。
# 损伤后果:布洛卡失语症(表达困难,但理解保留)。
import time
class Thinking_Idea:
def __init__(self, message_id: str):
self.messages = [] # 消息列表集合
self.current_thoughts = [] # 当前思考内容列表
self.time = time.time() # 创建时间
self.id = str(int(time.time() * 1000)) # 使用时间戳生成唯一标识ID

View File

@@ -4,6 +4,8 @@ from .message import Message
import jieba
from nonebot import get_driver
from .config import global_config
from snownlp import SnowNLP
from ..models.utils_model import LLM_request
driver = get_driver()
config = driver.config
@@ -11,12 +13,10 @@ config = driver.config
class TopicIdentifier:
def __init__(self):
self.client = OpenAI(
api_key=config.siliconflow_key, base_url=config.siliconflow_base_url
)
def identify_topic_llm(self, text: str) -> Optional[str]:
"""识别消息主题"""
self.llm_client = LLM_request(model=global_config.llm_normal)
async def identify_topic_llm(self, text: str) -> Optional[List[str]]:
"""识别消息主题,返回主题列表"""
prompt = f"""判断这条消息的主题,如果没有明显主题请回复"无主题",要求:\
1. 主题通常2-4个字必须简短要求精准概括不要太具体。\
@@ -24,36 +24,20 @@ class TopicIdentifier:
3. 这里是
消息内容:{text}"""
response = self.client.chat.completions.create(
model=global_config.SILICONFLOW_MODEL_V3,
messages=[{"role": "user", "content": prompt}],
temperature=0.8,
max_tokens=10,
)
if not response or not response.choices:
print(f"\033[1;31m[错误]\033[0m OpenAI API 返回为空")
# 使用 LLM_request 类进行请求
topic, _ = await self.llm_client.generate_response(prompt)
if not topic:
print(f"\033[1;31m[错误]\033[0m LLM API 返回为空")
return None
# 从 OpenAI API 响应中获取第一个选项的消息内容,并去除首尾空白字符
topic = (
response.choices[0].message.content.strip()
if response.choices[0].message.content
else None
)
if topic == "无主题":
return None
else:
# print(f"[主题分析结果]{text[:20]}... : {topic}")
split_topic = self.parse_topic(topic)
return split_topic
def parse_topic(self, topic: str) -> List[str]:
"""解析主题,返回主题列表"""
# 直接在这里处理主题解析
if not topic or topic == "无主题":
return []
return [t.strip() for t in topic.split(",") if t.strip()]
return None
# 解析主题字符串为列表
topic_list = [t.strip() for t in topic.split(",") if t.strip()]
return topic_list if topic_list else None
def identify_topic_jieba(self, text: str) -> Optional[str]:
"""使用jieba识别主题"""
@@ -239,33 +223,12 @@ class TopicIdentifier:
filtered_words = []
for word in words:
if word not in stop_words and not word.strip() in {
"",
"",
"",
"",
"",
"",
"",
'"',
'"',
""", """,
"",
"",
"",
"",
"",
"",
"",
"",
"·",
"",
"~",
"",
"+",
"=",
"-",
"[",
"]",
'', '', '', '', '', '', '', '"', '"', ''', ''',
'', '', '', '', '', '', '', '', '·', '', '~',
'', '+', '=', '-', '/', '\\', '|', '*', '#', '@', '$', '%',
'^', '&', '[', ']', '{', '}', '<', '>', '`', '_', '.', ',',
';', ':', '\'', '"', '(', ')', '?', '!', '±', '×', '÷', '',
'', '', '', '', '', '', '', '', '', '', ''
}:
filtered_words.append(word)
@@ -280,5 +243,25 @@ class TopicIdentifier:
return top_words if top_words else None
def identify_topic_snownlp(self, text: str) -> Optional[List[str]]:
"""使用 SnowNLP 进行主题识别
Args:
text (str): 需要识别主题的文本
Returns:
Optional[List[str]]: 返回识别出的主题关键词列表,如果无法识别则返回 None
"""
if not text or len(text.strip()) == 0:
return None
try:
s = SnowNLP(text)
# 提取前3个关键词作为主题
keywords = s.keywords(3)
return keywords if keywords else None
except Exception as e:
print(f"\033[1;31m[错误]\033[0m SnowNLP 处理失败: {str(e)}")
return None
topic_identifier = TopicIdentifier()

View File

@@ -10,6 +10,7 @@ from typing import Dict
from collections import Counter
import math
from nonebot import get_driver
from ..models.utils_model import LLM_request
driver = get_driver()
config = driver.config
@@ -64,25 +65,9 @@ def is_mentioned_bot_in_txt(message: str) -> bool:
return False
def get_embedding(text):
url = "https://api.siliconflow.cn/v1/embeddings"
payload = {
"model": "BAAI/bge-m3",
"input": text,
"encoding_format": "float"
}
headers = {
"Authorization": f"Bearer {config.siliconflow_key}",
"Content-Type": "application/json"
}
response = requests.request("POST", url, json=payload, headers=headers)
if response.status_code != 200:
print(f"API请求失败: {response.status_code}")
print(f"错误信息: {response.text}")
return None
return response.json()['data'][0]['embedding']
"""获取文本的embedding向量"""
llm = LLM_request(model=global_config.embedding)
return llm.get_embedding_sync(text)
def cosine_similarity(v1, v2):
dot_product = np.dot(v1, v2)
@@ -142,14 +127,14 @@ def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list:
# 从数据库获取最近消息
recent_messages = list(db.db.messages.find(
{"group_id": group_id},
{
"time": 1,
"user_id": 1,
"user_nickname": 1,
"message_id": 1,
"raw_message": 1,
"processed_text": 1
}
# {
# "time": 1,
# "user_id": 1,
# "user_nickname": 1,
# "message_id": 1,
# "raw_message": 1,
# "processed_text": 1
# }
).sort("time", -1).limit(limit))
if not recent_messages:
@@ -159,16 +144,20 @@ def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list:
from .message import Message
message_objects = []
for msg_data in recent_messages:
msg = Message(
time=msg_data["time"],
user_id=msg_data["user_id"],
user_nickname=msg_data.get("user_nickname", ""),
message_id=msg_data["message_id"],
raw_message=msg_data["raw_message"],
processed_plain_text=msg_data.get("processed_text", ""),
group_id=group_id
)
message_objects.append(msg)
try:
msg = Message(
time=msg_data["time"],
user_id=msg_data["user_id"],
user_nickname=msg_data.get("user_nickname", ""),
message_id=msg_data["message_id"],
raw_message=msg_data["raw_message"],
processed_plain_text=msg_data.get("processed_text", ""),
group_id=group_id
)
message_objects.append(msg)
except KeyError:
print("[WARNING] 数据库中存在无效的消息")
continue
# 按时间正序排列
message_objects.reverse()
@@ -181,7 +170,6 @@ def get_recent_group_detailed_plain_text(db, group_id: int, limit: int = 12,comb
"time": 1, # 返回时间字段
"user_id": 1, # 返回用户ID字段
"user_nickname": 1, # 返回用户昵称字段
"user_cardname": 1, #返回用户群昵称
"message_id": 1, # 返回消息ID字段
"detailed_plain_text": 1 # 返回处理后的文本字段
}
@@ -193,6 +181,8 @@ def get_recent_group_detailed_plain_text(db, group_id: int, limit: int = 12,comb
message_detailed_plain_text = ''
message_detailed_plain_text_list = []
# 反转消息列表,使最新的消息在最后
recent_messages.reverse()
if combine:
for msg_db_data in recent_messages:

View File

@@ -6,8 +6,12 @@ def get_user_nickname(user_id: int) -> str:
return global_config.BOT_NICKNAME
# print(user_id)
return relationship_manager.get_name(user_id)
def get_user_cardname(user_id: int) -> str:
if int(user_id) == int(global_config.BOT_QQ):
return global_config.BOT_NICKNAME
# print(user_id)
return ''
return ''
def get_groupname(group_id: int) -> str:
return f"{group_id}"

View File

@@ -17,12 +17,12 @@ from src.plugins.chat.config import llm_config
# 直接配置数据库连接信息
Database.initialize(
host= config.mongodb_host,
port= int(config.mongodb_port),
db_name= config.database_name,
username= config.mongodb_username,
password= config.mongodb_password,
auth_source=config.mongodb_auth_source
host= config.MONGODB_HOST,
port= int(config.MONGODB_PORT),
db_name= config.DATABASE_NAME,
username= config.MONGODB_USERNAME,
password= config.MONGODB_PASSWORD,
auth_source=config.MONGODB_AUTH_SOURCE
)
class KnowledgeLibrary:

View File

@@ -2,7 +2,6 @@
import os
import sys
import jieba
from llm_module import LLMModel
import networkx as nx
import matplotlib.pyplot as plt
import math
@@ -10,10 +9,76 @@ from collections import Counter
import datetime
import random
import time
# from chat.config import global_config
from dotenv import load_dotenv
import sys
import asyncio
import aiohttp
from typing import Tuple
sys.path.append("C:/GitHub/MaiMBot") # 添加项目根目录到 Python 路径
from src.common.database import Database # 使用正确的导入语法
# 加载.env.dev文件
env_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))), '.env.dev')
load_dotenv(env_path)
class LLMModel:
def __init__(self, model_name=os.getenv("SILICONFLOW_MODEL_V3"), **kwargs):
self.model_name = model_name
self.params = kwargs
self.api_key = os.getenv("SILICONFLOW_KEY")
self.base_url = os.getenv("SILICONFLOW_BASE_URL")
async def generate_response(self, prompt: str) -> Tuple[str, str]:
"""根据输入的提示生成模型的响应"""
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
# 构建请求体
data = {
"model": self.model_name,
"messages": [{"role": "user", "content": prompt}],
"temperature": 0.5,
**self.params
}
# 发送请求到完整的chat/completions端点
api_url = f"{self.base_url.rstrip('/')}/chat/completions"
max_retries = 3
base_wait_time = 15
for retry in range(max_retries):
try:
async with aiohttp.ClientSession() as session:
async with session.post(api_url, headers=headers, json=data) as response:
if response.status == 429:
wait_time = base_wait_time * (2 ** retry) # 指数退避
print(f"遇到请求限制(429),等待{wait_time}秒后重试...")
await asyncio.sleep(wait_time)
continue
response.raise_for_status() # 检查其他响应状态
result = await response.json()
if "choices" in result and len(result["choices"]) > 0:
content = result["choices"][0]["message"]["content"]
reasoning_content = result["choices"][0]["message"].get("reasoning_content", "")
return content, reasoning_content
return "没有返回结果", ""
except Exception as e:
if retry < max_retries - 1: # 如果还有重试机会
wait_time = base_wait_time * (2 ** retry)
print(f"请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
await asyncio.sleep(wait_time)
else:
return f"请求失败: {str(e)}", ""
return "达到最大重试次数,请求仍然失败", ""
class Memory_graph:
def __init__(self):
@@ -158,12 +223,12 @@ class Memory_graph:
def main():
# 初始化数据库
Database.initialize(
host= os.getenv("MONGODB_HOST"),
port= int(os.getenv("MONGODB_PORT")),
db_name= os.getenv("DATABASE_NAME"),
username= os.getenv("MONGODB_USERNAME"),
password= os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE")
host=os.getenv("MONGODB_HOST", "127.0.0.1"),
port=int(os.getenv("MONGODB_PORT", "27017")),
db_name=os.getenv("DATABASE_NAME", "MegBot"),
username=os.getenv("MONGODB_USERNAME", ""),
password=os.getenv("MONGODB_PASSWORD", ""),
auth_source=os.getenv("MONGODB_AUTH_SOURCE", "")
)
memory_graph = Memory_graph()
@@ -185,11 +250,14 @@ def main():
query = input("请输入新的查询概念(输入'退出'以结束):")
if query.lower() == '退出':
break
items_list = memory_graph.get_related_item(query)
if items_list:
# print(items_list)
for memory_item in items_list:
print(memory_item)
first_layer_items, second_layer_items = memory_graph.get_related_item(query)
if first_layer_items or second_layer_items:
print("\n第一层记忆:")
for item in first_layer_items:
print(item)
print("\n第二层记忆:")
for item in second_layer_items:
print(item)
else:
print("未找到相关记忆。")

View File

@@ -66,7 +66,7 @@ class LLMModel:
except Exception as e:
if retry < max_retries - 1: # 如果还有重试机会
wait_time = base_wait_time * (2 ** retry)
print(f"请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
print(f"[回复]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
await asyncio.sleep(wait_time)
else:
return f"请求失败: {str(e)}", ""

View File

@@ -259,12 +259,12 @@ config = driver.config
start_time = time.time()
Database.initialize(
host= config.mongodb_host,
port= int(config.mongodb_port),
db_name= config.database_name,
username= config.mongodb_username,
password= config.mongodb_password,
auth_source=config.mongodb_auth_source
host= config.MONGODB_HOST,
port= config.MONGODB_PORT,
db_name= config.DATABASE_NAME,
username= config.MONGODB_USERNAME,
password= config.MONGODB_PASSWORD,
auth_source=config.MONGODB_AUTH_SOURCE
)
#创建记忆图
memory_graph = Memory_graph()

View File

@@ -9,7 +9,7 @@ driver = get_driver()
config = driver.config
class LLM_request:
def __init__(self, model = global_config.llm_normal,**kwargs):
def __init__(self, model ,**kwargs):
# 将大写的配置键转换为小写并从config中获取实际值
try:
self.api_key = getattr(config, model["key"])
@@ -61,7 +61,7 @@ class LLM_request:
except Exception as e:
if retry < max_retries - 1: # 如果还有重试机会
wait_time = base_wait_time * (2 ** retry)
print(f"请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
print(f"[回复]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
await asyncio.sleep(wait_time)
else:
return f"请求失败: {str(e)}", ""
@@ -126,7 +126,7 @@ class LLM_request:
except Exception as e:
if retry < max_retries - 1: # 如果还有重试机会
wait_time = base_wait_time * (2 ** retry)
print(f"请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
print(f"[image回复]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
await asyncio.sleep(wait_time)
else:
return f"请求失败: {str(e)}", ""
@@ -166,8 +166,8 @@ class LLM_request:
# 发送请求到完整的chat/completions端点
api_url = f"{self.base_url.rstrip('/')}/chat/completions"
max_retries = 3
base_wait_time = 15
max_retries = 2
base_wait_time = 6
for retry in range(max_retries):
try:
@@ -191,9 +191,119 @@ class LLM_request:
except Exception as e:
if retry < max_retries - 1: # 如果还有重试机会
wait_time = base_wait_time * (2 ** retry)
print(f"请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
print(f"[image_sync回复]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
time.sleep(wait_time)
else:
return f"请求失败: {str(e)}", ""
return "达到最大重试次数,请求仍然失败", ""
def get_embedding_sync(self, text: str, model: str = "BAAI/bge-m3") -> Union[list, None]:
"""同步方法获取文本的embedding向量
Args:
text: 需要获取embedding的文本
model: 使用的模型名称,默认为"BAAI/bge-m3"
Returns:
list: embedding向量如果失败则返回None
"""
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
data = {
"model": model,
"input": text,
"encoding_format": "float"
}
api_url = f"{self.base_url.rstrip('/')}/embeddings"
max_retries = 2
base_wait_time = 6
for retry in range(max_retries):
try:
response = requests.post(api_url, headers=headers, json=data, timeout=30)
if response.status_code == 429:
wait_time = base_wait_time * (2 ** retry)
print(f"遇到请求限制(429),等待{wait_time}秒后重试...")
time.sleep(wait_time)
continue
response.raise_for_status()
result = response.json()
if 'data' in result and len(result['data']) > 0:
return result['data'][0]['embedding']
return None
except Exception as e:
if retry < max_retries - 1:
wait_time = base_wait_time * (2 ** retry)
print(f"[embedding_sync]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
time.sleep(wait_time)
else:
print(f"embedding请求失败: {str(e)}")
return None
print("达到最大重试次数embedding请求仍然失败")
return None
async def get_embedding(self, text: str, model: str = "BAAI/bge-m3") -> Union[list, None]:
"""异步方法获取文本的embedding向量
Args:
text: 需要获取embedding的文本
model: 使用的模型名称,默认为"BAAI/bge-m3"
Returns:
list: embedding向量如果失败则返回None
"""
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
data = {
"model": model,
"input": text,
"encoding_format": "float"
}
api_url = f"{self.base_url.rstrip('/')}/embeddings"
max_retries = 3
base_wait_time = 15
for retry in range(max_retries):
try:
async with aiohttp.ClientSession() as session:
async with session.post(api_url, headers=headers, json=data) as response:
if response.status == 429:
wait_time = base_wait_time * (2 ** retry)
print(f"遇到请求限制(429),等待{wait_time}秒后重试...")
await asyncio.sleep(wait_time)
continue
response.raise_for_status()
result = await response.json()
if 'data' in result and len(result['data']) > 0:
return result['data'][0]['embedding']
return None
except Exception as e:
if retry < max_retries - 1:
wait_time = base_wait_time * (2 ** retry)
print(f"[embedding]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
await asyncio.sleep(wait_time)
else:
print(f"embedding请求失败: {str(e)}")
return None
print("达到最大重试次数embedding请求仍然失败")
return None

View File

@@ -11,12 +11,12 @@ config = driver.config
Database.initialize(
host= config.mongodb_host,
port= int(config.mongodb_port),
db_name= config.database_name,
username= config.mongodb_username,
password= config.mongodb_password,
auth_source=config.mongodb_auth_source
host= config.MONGODB_HOST,
port= int(config.MONGODB_PORT),
db_name= config.DATABASE_NAME,
username= config.MONGODB_USERNAME,
password= config.MONGODB_PASSWORD,
auth_source=config.MONGODB_AUTH_SOURCE
)
class ScheduleGenerator:
@@ -128,6 +128,10 @@ class ScheduleGenerator:
def _time_diff(self, time1: str, time2: str) -> int:
"""计算两个时间字符串之间的分钟差"""
if time1=="24:00":
time1="23:59"
if time2=="24:00":
time2="23:59"
t1 = datetime.datetime.strptime(time1, "%H:%M")
t2 = datetime.datetime.strptime(time2, "%H:%M")
diff = int((t2 - t1).total_seconds() / 60)
@@ -165,4 +169,4 @@ class ScheduleGenerator:
# if __name__ == "__main__":
# main()
bot_schedule = ScheduleGenerator()
bot_schedule = ScheduleGenerator()

View File

@@ -0,0 +1,53 @@
from snownlp import SnowNLP
def analyze_emotion_snownlp(text):
"""
使用SnowNLP进行中文情感分析
:param text: 输入文本
:return: 情感得分(0-1之间越接近1越积极)
"""
try:
s = SnowNLP(text)
sentiment_score = s.sentiments
# 获取文本的关键词
keywords = s.keywords(3)
return {
'sentiment_score': sentiment_score,
'keywords': keywords,
'summary': s.summary(1) # 生成文本摘要
}
except Exception as e:
print(f"分析过程中出现错误: {str(e)}")
return None
def get_emotion_description_snownlp(score):
"""
将情感得分转换为描述性文字
"""
if score is None:
return "无法分析情感"
if score > 0.8:
return "非常积极"
elif score > 0.6:
return "较为积极"
elif score > 0.4:
return "中性偏积极"
elif score > 0.2:
return "中性偏消极"
else:
return "消极"
if __name__ == "__main__":
# 测试样例
test_text = "我们学校有免费的gpt4用"
result = analyze_emotion_snownlp(test_text)
if result:
print(f"测试文本: {test_text}")
print(f"情感得分: {result['sentiment_score']:.2f}")
print(f"情感倾向: {get_emotion_description_snownlp(result['sentiment_score'])}")
print(f"关键词: {', '.join(result['keywords'])}")
print(f"文本摘要: {result['summary'][0]}")

54
src/test/snownlp_demo.py Normal file
View File

@@ -0,0 +1,54 @@
from snownlp import SnowNLP
def demo_snownlp_features(text):
"""
展示SnowNLP的主要功能
:param text: 输入文本
"""
print(f"\n=== SnowNLP功能演示 ===")
print(f"输入文本: {text}")
# 创建SnowNLP对象
s = SnowNLP(text)
# 1. 分词
print(f"\n1. 分词结果:")
print(f" {' | '.join(s.words)}")
# 2. 情感分析
print(f"\n2. 情感分析:")
sentiment = s.sentiments
print(f" 情感得分: {sentiment:.2f}")
print(f" 情感倾向: {'积极' if sentiment > 0.5 else '消极' if sentiment < 0.5 else '中性'}")
# 3. 关键词提取
print(f"\n3. 关键词提取:")
print(f" {', '.join(s.keywords(3))}")
# 4. 词性标注
print(f"\n4. 词性标注:")
print(f" {' '.join([f'{word}/{tag}' for word, tag in s.tags])}")
# 5. 拼音转换
print(f"\n5. 拼音:")
print(f" {' '.join(s.pinyin)}")
# 6. 文本摘要
if len(text) > 100: # 只对较长文本生成摘要
print(f"\n6. 文本摘要:")
print(f" {' '.join(s.summary(3))}")
if __name__ == "__main__":
# 测试用例
test_texts = [
"这家新开的餐厅很不错,菜品种类丰富,味道可口,服务态度也很好,价格实惠,强烈推荐大家来尝试!",
"这部电影剧情混乱,演技浮夸,特效粗糙,配乐难听,完全浪费了我的时间和票价。",
"""人工智能正在改变我们的生活方式。它能够帮助我们完成复杂的计算任务,
提供个性化的服务推荐,优化交通路线,辅助医疗诊断。但同时我们也要警惕
人工智能带来的问题,比如隐私安全、就业变化等。如何正确认识和利用人工智能,
是我们每个人都需要思考的问题。"""
]
for text in test_texts:
demo_snownlp_features(text)
print("\n" + "="*50)

488
src/test/typo.py Normal file
View File

@@ -0,0 +1,488 @@
"""
错别字生成器 - 流程说明
整体替换逻辑:
1. 数据准备
- 加载字频词典使用jieba词典计算汉字使用频率
- 创建拼音映射:建立拼音到汉字的映射关系
- 加载词频信息从jieba词典获取词语使用频率
2. 分词处理
- 使用jieba将输入句子分词
- 区分单字词和多字词
- 保留标点符号和空格
3. 词语级别替换(针对多字词)
- 触发条件:词长>1 且 随机概率<0.3
- 替换流程:
a. 获取词语拼音
b. 生成所有可能的同音字组合
c. 过滤条件:
- 必须是jieba词典中的有效词
- 词频必须达到原词频的10%以上
- 综合评分(词频70%+字频30%)必须达到阈值
d. 按综合评分排序,选择最合适的替换词
4. 字级别替换(针对单字词或未进行整词替换的多字词)
- 单字替换概率0.3
- 多字词中的单字替换概率0.3 * (0.7 ^ (词长-1))
- 替换流程:
a. 获取字的拼音
b. 声调错误处理20%概率)
c. 获取同音字列表
d. 过滤条件:
- 字频必须达到最小阈值
- 频率差异不能过大(指数衰减计算)
e. 按频率排序选择替换字
5. 频率控制机制
- 字频控制使用归一化的字频0-1000范围
- 词频控制使用jieba词典中的词频
- 频率差异计算:使用指数衰减函数
- 最小频率阈值:确保替换字/词不会太生僻
6. 输出信息
- 原文和错字版本的对照
- 每个替换的详细信息(原字/词、替换后字/词、拼音、频率)
- 替换类型说明(整词替换/声调错误/同音字替换)
- 词语分析和完整拼音
注意事项:
1. 所有替换都必须使用有意义的词语
2. 替换词的使用频率不能过低
3. 多字词优先考虑整词替换
4. 考虑声调变化的情况
5. 保持标点符号和空格不变
"""
from pypinyin import pinyin, Style
from collections import defaultdict
import json
import os
import unicodedata
import jieba
import jieba.posseg as pseg
from pathlib import Path
import random
import math
import time
def load_or_create_char_frequency():
"""
加载或创建汉字频率字典
"""
cache_file = Path("char_frequency.json")
# 如果缓存文件存在,直接加载
if cache_file.exists():
with open(cache_file, 'r', encoding='utf-8') as f:
return json.load(f)
# 使用内置的词频文件
char_freq = defaultdict(int)
dict_path = os.path.join(os.path.dirname(jieba.__file__), 'dict.txt')
# 读取jieba的词典文件
with open(dict_path, 'r', encoding='utf-8') as f:
for line in f:
word, freq = line.strip().split()[:2]
# 对词中的每个字进行频率累加
for char in word:
if is_chinese_char(char):
char_freq[char] += int(freq)
# 归一化频率值
max_freq = max(char_freq.values())
normalized_freq = {char: freq/max_freq * 1000 for char, freq in char_freq.items()}
# 保存到缓存文件
with open(cache_file, 'w', encoding='utf-8') as f:
json.dump(normalized_freq, f, ensure_ascii=False, indent=2)
return normalized_freq
# 创建拼音到汉字的映射字典
def create_pinyin_dict():
"""
创建拼音到汉字的映射字典
"""
# 常用汉字范围
chars = [chr(i) for i in range(0x4e00, 0x9fff)]
pinyin_dict = defaultdict(list)
# 为每个汉字建立拼音映射
for char in chars:
try:
py = pinyin(char, style=Style.TONE3)[0][0]
pinyin_dict[py].append(char)
except Exception:
continue
return pinyin_dict
def is_chinese_char(char):
"""
判断是否为汉字
"""
try:
return '\u4e00' <= char <= '\u9fff'
except:
return False
def get_pinyin(sentence):
"""
将中文句子拆分成单个汉字并获取其拼音
:param sentence: 输入的中文句子
:return: 每个汉字及其拼音的列表
"""
# 将句子拆分成单个字符
characters = list(sentence)
# 获取每个字符的拼音
result = []
for char in characters:
# 跳过空格和非汉字字符
if char.isspace() or not is_chinese_char(char):
continue
# 获取拼音(数字声调)
py = pinyin(char, style=Style.TONE3)[0][0]
result.append((char, py))
return result
def get_homophone(char, py, pinyin_dict, char_frequency, min_freq=5):
"""
获取同音字,按照使用频率排序
"""
homophones = pinyin_dict[py]
# 移除原字并过滤低频字
if char in homophones:
homophones.remove(char)
# 过滤掉低频字
homophones = [h for h in homophones if char_frequency.get(h, 0) >= min_freq]
# 按照字频排序
sorted_homophones = sorted(homophones,
key=lambda x: char_frequency.get(x, 0),
reverse=True)
# 只返回前10个同音字避免输出过多
return sorted_homophones[:10]
def get_similar_tone_pinyin(py):
"""
获取相似声调的拼音
例如:'ni3' 可能返回 'ni2''ni4'
处理特殊情况:
1. 轻声(如 'de5''le'
2. 非数字结尾的拼音
"""
# 检查拼音是否为空或无效
if not py or len(py) < 1:
return py
# 如果最后一个字符不是数字,说明可能是轻声或其他特殊情况
if not py[-1].isdigit():
# 为非数字结尾的拼音添加数字声调1
return py + '1'
base = py[:-1] # 去掉声调
tone = int(py[-1]) # 获取声调
# 处理轻声通常用5表示或无效声调
if tone not in [1, 2, 3, 4]:
return base + str(random.choice([1, 2, 3, 4]))
# 正常处理声调
possible_tones = [1, 2, 3, 4]
possible_tones.remove(tone) # 移除原声调
new_tone = random.choice(possible_tones) # 随机选择一个新声调
return base + str(new_tone)
def calculate_replacement_probability(orig_freq, target_freq, max_freq_diff=200):
"""
根据频率差计算替换概率
频率差越大,概率越低
:param orig_freq: 原字频率
:param target_freq: 目标字频率
:param max_freq_diff: 最大允许的频率差
:return: 0-1之间的概率值
"""
if target_freq > orig_freq:
return 1.0 # 如果替换字频率更高,保持原有概率
freq_diff = orig_freq - target_freq
if freq_diff > max_freq_diff:
return 0.0 # 频率差太大,不替换
# 使用指数衰减函数计算概率
# 频率差为0时概率为1频率差为max_freq_diff时概率接近0
return math.exp(-3 * freq_diff / max_freq_diff)
def get_similar_frequency_chars(char, py, pinyin_dict, char_frequency, num_candidates=5, min_freq=5, tone_error_rate=0.2):
"""
获取与给定字频率相近的同音字,可能包含声调错误
"""
homophones = []
# 有20%的概率使用错误声调
if random.random() < tone_error_rate:
wrong_tone_py = get_similar_tone_pinyin(py)
homophones.extend(pinyin_dict[wrong_tone_py])
# 添加正确声调的同音字
homophones.extend(pinyin_dict[py])
if not homophones:
return None
# 获取原字的频率
orig_freq = char_frequency.get(char, 0)
# 计算所有同音字与原字的频率差,并过滤掉低频字
freq_diff = [(h, char_frequency.get(h, 0))
for h in homophones
if h != char and char_frequency.get(h, 0) >= min_freq]
if not freq_diff:
return None
# 计算每个候选字的替换概率
candidates_with_prob = []
for h, freq in freq_diff:
prob = calculate_replacement_probability(orig_freq, freq)
if prob > 0: # 只保留有效概率的候选字
candidates_with_prob.append((h, prob))
if not candidates_with_prob:
return None
# 根据概率排序
candidates_with_prob.sort(key=lambda x: x[1], reverse=True)
# 返回概率最高的几个字
return [char for char, _ in candidates_with_prob[:num_candidates]]
def get_word_pinyin(word):
"""
获取词语的拼音列表
"""
return [py[0] for py in pinyin(word, style=Style.TONE3)]
def segment_sentence(sentence):
"""
使用jieba分词返回词语列表
"""
return list(jieba.cut(sentence))
def get_word_homophones(word, pinyin_dict, char_frequency, min_freq=5):
"""
获取整个词的同音词,只返回高频的有意义词语
:param word: 输入词语
:param pinyin_dict: 拼音字典
:param char_frequency: 字频字典
:param min_freq: 最小频率阈值
:return: 同音词列表
"""
if len(word) == 1:
return []
# 获取词的拼音
word_pinyin = get_word_pinyin(word)
word_pinyin_str = ''.join(word_pinyin)
# 创建词语频率字典
word_freq = defaultdict(float)
# 遍历所有可能的同音字组合
candidates = []
for py in word_pinyin:
chars = pinyin_dict.get(py, [])
if not chars:
return []
candidates.append(chars)
# 生成所有可能的组合
import itertools
all_combinations = itertools.product(*candidates)
# 获取jieba词典和词频信息
dict_path = os.path.join(os.path.dirname(jieba.__file__), 'dict.txt')
valid_words = {} # 改用字典存储词语及其频率
with open(dict_path, 'r', encoding='utf-8') as f:
for line in f:
parts = line.strip().split()
if len(parts) >= 2:
word_text = parts[0]
word_freq = float(parts[1]) # 获取词频
valid_words[word_text] = word_freq
# 获取原词的词频作为参考
original_word_freq = valid_words.get(word, 0)
min_word_freq = original_word_freq * 0.1 # 设置最小词频为原词频的10%
# 过滤和计算频率
homophones = []
for combo in all_combinations:
new_word = ''.join(combo)
if new_word != word and new_word in valid_words:
new_word_freq = valid_words[new_word]
# 只保留词频达到阈值的词
if new_word_freq >= min_word_freq:
# 计算词的平均字频(考虑字频和词频)
char_avg_freq = sum(char_frequency.get(c, 0) for c in new_word) / len(new_word)
# 综合评分:结合词频和字频
combined_score = (new_word_freq * 0.7 + char_avg_freq * 0.3)
if combined_score >= min_freq:
homophones.append((new_word, combined_score))
# 按综合分数排序并限制返回数量
sorted_homophones = sorted(homophones, key=lambda x: x[1], reverse=True)
return [word for word, _ in sorted_homophones[:5]] # 限制返回前5个结果
def create_typo_sentence(sentence, pinyin_dict, char_frequency, error_rate=0.5, min_freq=5, tone_error_rate=0.2, word_replace_rate=0.3):
"""
创建包含同音字错误的句子,支持词语级别和字级别的替换
只使用高频的有意义词语进行替换
"""
result = []
typo_info = []
# 分词
words = segment_sentence(sentence)
for word in words:
# 如果是标点符号或空格,直接添加
if all(not is_chinese_char(c) for c in word):
result.append(word)
continue
# 获取词语的拼音
word_pinyin = get_word_pinyin(word)
# 尝试整词替换
if len(word) > 1 and random.random() < word_replace_rate:
word_homophones = get_word_homophones(word, pinyin_dict, char_frequency, min_freq)
if word_homophones:
typo_word = random.choice(word_homophones)
# 计算词的平均频率
orig_freq = sum(char_frequency.get(c, 0) for c in word) / len(word)
typo_freq = sum(char_frequency.get(c, 0) for c in typo_word) / len(typo_word)
# 添加到结果中
result.append(typo_word)
typo_info.append((word, typo_word,
' '.join(word_pinyin),
' '.join(get_word_pinyin(typo_word)),
orig_freq, typo_freq))
continue
# 如果不进行整词替换,则进行单字替换
if len(word) == 1:
char = word
py = word_pinyin[0]
if random.random() < error_rate:
similar_chars = get_similar_frequency_chars(char, py, pinyin_dict, char_frequency,
min_freq=min_freq, tone_error_rate=tone_error_rate)
if similar_chars:
typo_char = random.choice(similar_chars)
typo_freq = char_frequency.get(typo_char, 0)
orig_freq = char_frequency.get(char, 0)
replace_prob = calculate_replacement_probability(orig_freq, typo_freq)
if random.random() < replace_prob:
result.append(typo_char)
typo_py = pinyin(typo_char, style=Style.TONE3)[0][0]
typo_info.append((char, typo_char, py, typo_py, orig_freq, typo_freq))
continue
result.append(char)
else:
# 处理多字词的单字替换
word_result = []
for i, (char, py) in enumerate(zip(word, word_pinyin)):
# 词中的字替换概率降低
word_error_rate = error_rate * (0.7 ** (len(word) - 1))
if random.random() < word_error_rate:
similar_chars = get_similar_frequency_chars(char, py, pinyin_dict, char_frequency,
min_freq=min_freq, tone_error_rate=tone_error_rate)
if similar_chars:
typo_char = random.choice(similar_chars)
typo_freq = char_frequency.get(typo_char, 0)
orig_freq = char_frequency.get(char, 0)
replace_prob = calculate_replacement_probability(orig_freq, typo_freq)
if random.random() < replace_prob:
word_result.append(typo_char)
typo_py = pinyin(typo_char, style=Style.TONE3)[0][0]
typo_info.append((char, typo_char, py, typo_py, orig_freq, typo_freq))
continue
word_result.append(char)
result.append(''.join(word_result))
return ''.join(result), typo_info
def format_frequency(freq):
"""
格式化频率显示
"""
return f"{freq:.2f}"
def main():
# 记录开始时间
start_time = time.time()
# 首先创建拼音字典和加载字频统计
print("正在加载汉字数据库,请稍候...")
pinyin_dict = create_pinyin_dict()
char_frequency = load_or_create_char_frequency()
# 获取用户输入
sentence = input("请输入中文句子:")
# 创建包含错别字的句子
typo_sentence, typo_info = create_typo_sentence(sentence, pinyin_dict, char_frequency,
error_rate=0.3, min_freq=5,
tone_error_rate=0.2, word_replace_rate=0.3)
# 打印结果
print("\n原句:", sentence)
print("错字版:", typo_sentence)
if typo_info:
print("\n错别字信息:")
for orig, typo, orig_py, typo_py, orig_freq, typo_freq in typo_info:
# 判断是否为词语替换
is_word = ' ' in orig_py
if is_word:
error_type = "整词替换"
else:
tone_error = orig_py[:-1] == typo_py[:-1] and orig_py[-1] != typo_py[-1]
error_type = "声调错误" if tone_error else "同音字替换"
print(f"原文:{orig}({orig_py}) [频率:{format_frequency(orig_freq)}] -> "
f"替换:{typo}({typo_py}) [频率:{format_frequency(typo_freq)}] [{error_type}]")
# 获取拼音结果
result = get_pinyin(sentence)
# 打印完整拼音
print("\n完整拼音:")
print(" ".join(py for _, py in result))
# 打印词语分析
print("\n词语分析:")
words = segment_sentence(sentence)
for word in words:
if any(is_chinese_char(c) for c in word):
word_pinyin = get_word_pinyin(word)
print(f"词语:{word}")
print(f"拼音:{' '.join(word_pinyin)}")
print("---")
# 计算并打印总耗时
end_time = time.time()
total_time = end_time - start_time
print(f"\n总耗时:{total_time:.2f}")
if __name__ == "__main__":
main()

488
src/test/typo_creator.py Normal file
View File

@@ -0,0 +1,488 @@
"""
错别字生成器 - 流程说明
整体替换逻辑:
1. 数据准备
- 加载字频词典使用jieba词典计算汉字使用频率
- 创建拼音映射:建立拼音到汉字的映射关系
- 加载词频信息从jieba词典获取词语使用频率
2. 分词处理
- 使用jieba将输入句子分词
- 区分单字词和多字词
- 保留标点符号和空格
3. 词语级别替换(针对多字词)
- 触发条件:词长>1 且 随机概率<0.3
- 替换流程:
a. 获取词语拼音
b. 生成所有可能的同音字组合
c. 过滤条件:
- 必须是jieba词典中的有效词
- 词频必须达到原词频的10%以上
- 综合评分(词频70%+字频30%)必须达到阈值
d. 按综合评分排序,选择最合适的替换词
4. 字级别替换(针对单字词或未进行整词替换的多字词)
- 单字替换概率0.3
- 多字词中的单字替换概率0.3 * (0.7 ^ (词长-1))
- 替换流程:
a. 获取字的拼音
b. 声调错误处理20%概率)
c. 获取同音字列表
d. 过滤条件:
- 字频必须达到最小阈值
- 频率差异不能过大(指数衰减计算)
e. 按频率排序选择替换字
5. 频率控制机制
- 字频控制使用归一化的字频0-1000范围
- 词频控制使用jieba词典中的词频
- 频率差异计算:使用指数衰减函数
- 最小频率阈值:确保替换字/词不会太生僻
6. 输出信息
- 原文和错字版本的对照
- 每个替换的详细信息(原字/词、替换后字/词、拼音、频率)
- 替换类型说明(整词替换/声调错误/同音字替换)
- 词语分析和完整拼音
注意事项:
1. 所有替换都必须使用有意义的词语
2. 替换词的使用频率不能过低
3. 多字词优先考虑整词替换
4. 考虑声调变化的情况
5. 保持标点符号和空格不变
"""
from pypinyin import pinyin, Style
from collections import defaultdict
import json
import os
import unicodedata
import jieba
import jieba.posseg as pseg
from pathlib import Path
import random
import math
import time
def load_or_create_char_frequency():
"""
加载或创建汉字频率字典
"""
cache_file = Path("char_frequency.json")
# 如果缓存文件存在,直接加载
if cache_file.exists():
with open(cache_file, 'r', encoding='utf-8') as f:
return json.load(f)
# 使用内置的词频文件
char_freq = defaultdict(int)
dict_path = os.path.join(os.path.dirname(jieba.__file__), 'dict.txt')
# 读取jieba的词典文件
with open(dict_path, 'r', encoding='utf-8') as f:
for line in f:
word, freq = line.strip().split()[:2]
# 对词中的每个字进行频率累加
for char in word:
if is_chinese_char(char):
char_freq[char] += int(freq)
# 归一化频率值
max_freq = max(char_freq.values())
normalized_freq = {char: freq/max_freq * 1000 for char, freq in char_freq.items()}
# 保存到缓存文件
with open(cache_file, 'w', encoding='utf-8') as f:
json.dump(normalized_freq, f, ensure_ascii=False, indent=2)
return normalized_freq
# 创建拼音到汉字的映射字典
def create_pinyin_dict():
"""
创建拼音到汉字的映射字典
"""
# 常用汉字范围
chars = [chr(i) for i in range(0x4e00, 0x9fff)]
pinyin_dict = defaultdict(list)
# 为每个汉字建立拼音映射
for char in chars:
try:
py = pinyin(char, style=Style.TONE3)[0][0]
pinyin_dict[py].append(char)
except Exception:
continue
return pinyin_dict
def is_chinese_char(char):
"""
判断是否为汉字
"""
try:
return '\u4e00' <= char <= '\u9fff'
except:
return False
def get_pinyin(sentence):
"""
将中文句子拆分成单个汉字并获取其拼音
:param sentence: 输入的中文句子
:return: 每个汉字及其拼音的列表
"""
# 将句子拆分成单个字符
characters = list(sentence)
# 获取每个字符的拼音
result = []
for char in characters:
# 跳过空格和非汉字字符
if char.isspace() or not is_chinese_char(char):
continue
# 获取拼音(数字声调)
py = pinyin(char, style=Style.TONE3)[0][0]
result.append((char, py))
return result
def get_homophone(char, py, pinyin_dict, char_frequency, min_freq=5):
"""
获取同音字,按照使用频率排序
"""
homophones = pinyin_dict[py]
# 移除原字并过滤低频字
if char in homophones:
homophones.remove(char)
# 过滤掉低频字
homophones = [h for h in homophones if char_frequency.get(h, 0) >= min_freq]
# 按照字频排序
sorted_homophones = sorted(homophones,
key=lambda x: char_frequency.get(x, 0),
reverse=True)
# 只返回前10个同音字避免输出过多
return sorted_homophones[:10]
def get_similar_tone_pinyin(py):
"""
获取相似声调的拼音
例如:'ni3' 可能返回 'ni2''ni4'
处理特殊情况:
1. 轻声(如 'de5''le'
2. 非数字结尾的拼音
"""
# 检查拼音是否为空或无效
if not py or len(py) < 1:
return py
# 如果最后一个字符不是数字,说明可能是轻声或其他特殊情况
if not py[-1].isdigit():
# 为非数字结尾的拼音添加数字声调1
return py + '1'
base = py[:-1] # 去掉声调
tone = int(py[-1]) # 获取声调
# 处理轻声通常用5表示或无效声调
if tone not in [1, 2, 3, 4]:
return base + str(random.choice([1, 2, 3, 4]))
# 正常处理声调
possible_tones = [1, 2, 3, 4]
possible_tones.remove(tone) # 移除原声调
new_tone = random.choice(possible_tones) # 随机选择一个新声调
return base + str(new_tone)
def calculate_replacement_probability(orig_freq, target_freq, max_freq_diff=200):
"""
根据频率差计算替换概率
频率差越大,概率越低
:param orig_freq: 原字频率
:param target_freq: 目标字频率
:param max_freq_diff: 最大允许的频率差
:return: 0-1之间的概率值
"""
if target_freq > orig_freq:
return 1.0 # 如果替换字频率更高,保持原有概率
freq_diff = orig_freq - target_freq
if freq_diff > max_freq_diff:
return 0.0 # 频率差太大,不替换
# 使用指数衰减函数计算概率
# 频率差为0时概率为1频率差为max_freq_diff时概率接近0
return math.exp(-3 * freq_diff / max_freq_diff)
def get_similar_frequency_chars(char, py, pinyin_dict, char_frequency, num_candidates=5, min_freq=5, tone_error_rate=0.2):
"""
获取与给定字频率相近的同音字,可能包含声调错误
"""
homophones = []
# 有20%的概率使用错误声调
if random.random() < tone_error_rate:
wrong_tone_py = get_similar_tone_pinyin(py)
homophones.extend(pinyin_dict[wrong_tone_py])
# 添加正确声调的同音字
homophones.extend(pinyin_dict[py])
if not homophones:
return None
# 获取原字的频率
orig_freq = char_frequency.get(char, 0)
# 计算所有同音字与原字的频率差,并过滤掉低频字
freq_diff = [(h, char_frequency.get(h, 0))
for h in homophones
if h != char and char_frequency.get(h, 0) >= min_freq]
if not freq_diff:
return None
# 计算每个候选字的替换概率
candidates_with_prob = []
for h, freq in freq_diff:
prob = calculate_replacement_probability(orig_freq, freq)
if prob > 0: # 只保留有效概率的候选字
candidates_with_prob.append((h, prob))
if not candidates_with_prob:
return None
# 根据概率排序
candidates_with_prob.sort(key=lambda x: x[1], reverse=True)
# 返回概率最高的几个字
return [char for char, _ in candidates_with_prob[:num_candidates]]
def get_word_pinyin(word):
"""
获取词语的拼音列表
"""
return [py[0] for py in pinyin(word, style=Style.TONE3)]
def segment_sentence(sentence):
"""
使用jieba分词返回词语列表
"""
return list(jieba.cut(sentence))
def get_word_homophones(word, pinyin_dict, char_frequency, min_freq=5):
"""
获取整个词的同音词,只返回高频的有意义词语
:param word: 输入词语
:param pinyin_dict: 拼音字典
:param char_frequency: 字频字典
:param min_freq: 最小频率阈值
:return: 同音词列表
"""
if len(word) == 1:
return []
# 获取词的拼音
word_pinyin = get_word_pinyin(word)
word_pinyin_str = ''.join(word_pinyin)
# 创建词语频率字典
word_freq = defaultdict(float)
# 遍历所有可能的同音字组合
candidates = []
for py in word_pinyin:
chars = pinyin_dict.get(py, [])
if not chars:
return []
candidates.append(chars)
# 生成所有可能的组合
import itertools
all_combinations = itertools.product(*candidates)
# 获取jieba词典和词频信息
dict_path = os.path.join(os.path.dirname(jieba.__file__), 'dict.txt')
valid_words = {} # 改用字典存储词语及其频率
with open(dict_path, 'r', encoding='utf-8') as f:
for line in f:
parts = line.strip().split()
if len(parts) >= 2:
word_text = parts[0]
word_freq = float(parts[1]) # 获取词频
valid_words[word_text] = word_freq
# 获取原词的词频作为参考
original_word_freq = valid_words.get(word, 0)
min_word_freq = original_word_freq * 0.1 # 设置最小词频为原词频的10%
# 过滤和计算频率
homophones = []
for combo in all_combinations:
new_word = ''.join(combo)
if new_word != word and new_word in valid_words:
new_word_freq = valid_words[new_word]
# 只保留词频达到阈值的词
if new_word_freq >= min_word_freq:
# 计算词的平均字频(考虑字频和词频)
char_avg_freq = sum(char_frequency.get(c, 0) for c in new_word) / len(new_word)
# 综合评分:结合词频和字频
combined_score = (new_word_freq * 0.7 + char_avg_freq * 0.3)
if combined_score >= min_freq:
homophones.append((new_word, combined_score))
# 按综合分数排序并限制返回数量
sorted_homophones = sorted(homophones, key=lambda x: x[1], reverse=True)
return [word for word, _ in sorted_homophones[:5]] # 限制返回前5个结果
def create_typo_sentence(sentence, pinyin_dict, char_frequency, error_rate=0.5, min_freq=5, tone_error_rate=0.2, word_replace_rate=0.3):
"""
创建包含同音字错误的句子,支持词语级别和字级别的替换
只使用高频的有意义词语进行替换
"""
result = []
typo_info = []
# 分词
words = segment_sentence(sentence)
for word in words:
# 如果是标点符号或空格,直接添加
if all(not is_chinese_char(c) for c in word):
result.append(word)
continue
# 获取词语的拼音
word_pinyin = get_word_pinyin(word)
# 尝试整词替换
if len(word) > 1 and random.random() < word_replace_rate:
word_homophones = get_word_homophones(word, pinyin_dict, char_frequency, min_freq)
if word_homophones:
typo_word = random.choice(word_homophones)
# 计算词的平均频率
orig_freq = sum(char_frequency.get(c, 0) for c in word) / len(word)
typo_freq = sum(char_frequency.get(c, 0) for c in typo_word) / len(typo_word)
# 添加到结果中
result.append(typo_word)
typo_info.append((word, typo_word,
' '.join(word_pinyin),
' '.join(get_word_pinyin(typo_word)),
orig_freq, typo_freq))
continue
# 如果不进行整词替换,则进行单字替换
if len(word) == 1:
char = word
py = word_pinyin[0]
if random.random() < error_rate:
similar_chars = get_similar_frequency_chars(char, py, pinyin_dict, char_frequency,
min_freq=min_freq, tone_error_rate=tone_error_rate)
if similar_chars:
typo_char = random.choice(similar_chars)
typo_freq = char_frequency.get(typo_char, 0)
orig_freq = char_frequency.get(char, 0)
replace_prob = calculate_replacement_probability(orig_freq, typo_freq)
if random.random() < replace_prob:
result.append(typo_char)
typo_py = pinyin(typo_char, style=Style.TONE3)[0][0]
typo_info.append((char, typo_char, py, typo_py, orig_freq, typo_freq))
continue
result.append(char)
else:
# 处理多字词的单字替换
word_result = []
for i, (char, py) in enumerate(zip(word, word_pinyin)):
# 词中的字替换概率降低
word_error_rate = error_rate * (0.7 ** (len(word) - 1))
if random.random() < word_error_rate:
similar_chars = get_similar_frequency_chars(char, py, pinyin_dict, char_frequency,
min_freq=min_freq, tone_error_rate=tone_error_rate)
if similar_chars:
typo_char = random.choice(similar_chars)
typo_freq = char_frequency.get(typo_char, 0)
orig_freq = char_frequency.get(char, 0)
replace_prob = calculate_replacement_probability(orig_freq, typo_freq)
if random.random() < replace_prob:
word_result.append(typo_char)
typo_py = pinyin(typo_char, style=Style.TONE3)[0][0]
typo_info.append((char, typo_char, py, typo_py, orig_freq, typo_freq))
continue
word_result.append(char)
result.append(''.join(word_result))
return ''.join(result), typo_info
def format_frequency(freq):
"""
格式化频率显示
"""
return f"{freq:.2f}"
def main():
# 记录开始时间
start_time = time.time()
# 首先创建拼音字典和加载字频统计
print("正在加载汉字数据库,请稍候...")
pinyin_dict = create_pinyin_dict()
char_frequency = load_or_create_char_frequency()
# 获取用户输入
sentence = input("请输入中文句子:")
# 创建包含错别字的句子
typo_sentence, typo_info = create_typo_sentence(sentence, pinyin_dict, char_frequency,
error_rate=0.3, min_freq=5,
tone_error_rate=0.2, word_replace_rate=0.3)
# 打印结果
print("\n原句:", sentence)
print("错字版:", typo_sentence)
if typo_info:
print("\n错别字信息:")
for orig, typo, orig_py, typo_py, orig_freq, typo_freq in typo_info:
# 判断是否为词语替换
is_word = ' ' in orig_py
if is_word:
error_type = "整词替换"
else:
tone_error = orig_py[:-1] == typo_py[:-1] and orig_py[-1] != typo_py[-1]
error_type = "声调错误" if tone_error else "同音字替换"
print(f"原文:{orig}({orig_py}) [频率:{format_frequency(orig_freq)}] -> "
f"替换:{typo}({typo_py}) [频率:{format_frequency(typo_freq)}] [{error_type}]")
# 获取拼音结果
result = get_pinyin(sentence)
# 打印完整拼音
print("\n完整拼音:")
print(" ".join(py for _, py in result))
# 打印词语分析
print("\n词语分析:")
words = segment_sentence(sentence)
for word in words:
if any(is_chinese_char(c) for c in word):
word_pinyin = get_word_pinyin(word)
print(f"词语:{word}")
print(f"拼音:{' '.join(word_pinyin)}")
print("---")
# 计算并打印总耗时
end_time = time.time()
total_time = end_time - start_time
print(f"\n总耗时:{total_time:.2f}")
if __name__ == "__main__":
main()