删除无用文档和测试文件
This commit is contained in:
@@ -1,306 +0,0 @@
|
||||
import asyncio
|
||||
import time
|
||||
import os
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
|
||||
# Benchmark the distribution manager's run_chat_stream/conversation_loop behavior
|
||||
# by wiring a lightweight dummy manager and contexts. This avoids touching real DB or chat subsystems.
|
||||
|
||||
# Ensure project root is on sys.path when running as a script
|
||||
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
if ROOT_DIR not in sys.path:
|
||||
sys.path.insert(0, ROOT_DIR)
|
||||
|
||||
# Avoid importing the whole 'src.chat' package to prevent heavy deps (e.g., redis)
|
||||
# Local minimal implementation of loop and manager to isolate benchmark from heavy deps.
|
||||
from collections.abc import AsyncIterator, Awaitable, Callable
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConversationTick:
|
||||
stream_id: str
|
||||
tick_time: float = field(default_factory=time.time)
|
||||
force_dispatch: bool = False
|
||||
tick_count: int = 0
|
||||
|
||||
|
||||
async def conversation_loop(
|
||||
stream_id: str,
|
||||
get_context_func: Callable[[str], Awaitable["DummyContext | None"]],
|
||||
calculate_interval_func: Callable[[str, bool], Awaitable[float]],
|
||||
flush_cache_func: Callable[[str], Awaitable[list[Any]]],
|
||||
check_force_dispatch_func: Callable[["DummyContext", int], bool],
|
||||
is_running_func: Callable[[], bool],
|
||||
) -> AsyncIterator[ConversationTick]:
|
||||
tick_count = 0
|
||||
while is_running_func():
|
||||
ctx = await get_context_func(stream_id)
|
||||
if not ctx:
|
||||
await asyncio.sleep(0.1)
|
||||
continue
|
||||
await flush_cache_func(stream_id)
|
||||
unread = ctx.get_unread_messages()
|
||||
ucnt = len(unread)
|
||||
force = check_force_dispatch_func(ctx, ucnt)
|
||||
if ucnt > 0 or force:
|
||||
tick_count += 1
|
||||
yield ConversationTick(stream_id=stream_id, force_dispatch=force, tick_count=tick_count)
|
||||
interval = await calculate_interval_func(stream_id, ucnt > 0)
|
||||
await asyncio.sleep(interval)
|
||||
|
||||
|
||||
class StreamLoopManager:
|
||||
def __init__(self, max_concurrent_streams: Optional[int] = None):
|
||||
self.stats: dict[str, Any] = {
|
||||
"active_streams": 0,
|
||||
"total_loops": 0,
|
||||
"total_process_cycles": 0,
|
||||
"total_failures": 0,
|
||||
"start_time": time.time(),
|
||||
}
|
||||
self.max_concurrent_streams = max_concurrent_streams or 100
|
||||
self.force_dispatch_unread_threshold = 20
|
||||
self.chatter_manager = DummyChatterManager()
|
||||
self.is_running = False
|
||||
self._stream_start_locks: dict[str, asyncio.Lock] = {}
|
||||
self._processing_semaphore = asyncio.Semaphore(self.max_concurrent_streams)
|
||||
self._chat_manager: Optional[DummyChatManager] = None
|
||||
|
||||
def set_chat_manager(self, chat_manager: "DummyChatManager") -> None:
|
||||
self._chat_manager = chat_manager
|
||||
|
||||
async def start(self):
|
||||
self.is_running = True
|
||||
|
||||
async def stop(self):
|
||||
self.is_running = False
|
||||
|
||||
async def _get_stream_context(self, stream_id: str):
|
||||
assert self._chat_manager is not None
|
||||
stream = await self._chat_manager.get_stream(stream_id)
|
||||
return stream.context if stream else None
|
||||
|
||||
async def _flush_cached_messages_to_unread(self, stream_id: str):
|
||||
ctx = await self._get_stream_context(stream_id)
|
||||
return ctx.flush_cached_messages() if ctx else []
|
||||
|
||||
def _needs_force_dispatch_for_context(self, context: "DummyContext", unread_count: int) -> bool:
|
||||
return unread_count > (self.force_dispatch_unread_threshold or 20)
|
||||
|
||||
async def _process_stream_messages(self, stream_id: str, context: "DummyContext") -> bool:
|
||||
res = await self.chatter_manager.process_stream_context(stream_id, context) # type: ignore[attr-defined]
|
||||
return bool(res.get("success", False))
|
||||
|
||||
async def _update_stream_energy(self, stream_id: str, context: "DummyContext") -> None:
|
||||
pass
|
||||
|
||||
async def _calculate_interval(self, stream_id: str, has_messages: bool) -> float:
|
||||
return 0.005 if has_messages else 0.02
|
||||
|
||||
async def start_stream_loop(self, stream_id: str, force: bool = False) -> bool:
|
||||
ctx = await self._get_stream_context(stream_id)
|
||||
if not ctx:
|
||||
return False
|
||||
# create driver
|
||||
loop_task = asyncio.create_task(run_chat_stream(stream_id, self))
|
||||
ctx.stream_loop_task = loop_task
|
||||
self.stats["active_streams"] += 1
|
||||
self.stats["total_loops"] += 1
|
||||
return True
|
||||
|
||||
|
||||
async def run_chat_stream(stream_id: str, manager: StreamLoopManager) -> None:
|
||||
try:
|
||||
gen = conversation_loop(
|
||||
stream_id=stream_id,
|
||||
get_context_func=manager._get_stream_context,
|
||||
calculate_interval_func=manager._calculate_interval,
|
||||
flush_cache_func=manager._flush_cached_messages_to_unread,
|
||||
check_force_dispatch_func=manager._needs_force_dispatch_for_context,
|
||||
is_running_func=lambda: manager.is_running,
|
||||
)
|
||||
async for tick in gen:
|
||||
ctx = await manager._get_stream_context(stream_id)
|
||||
if not ctx:
|
||||
continue
|
||||
if ctx.is_chatter_processing:
|
||||
continue
|
||||
try:
|
||||
async with manager._processing_semaphore:
|
||||
ok = await manager._process_stream_messages(stream_id, ctx)
|
||||
except Exception:
|
||||
ok = False
|
||||
manager.stats["total_process_cycles"] += 1
|
||||
if not ok:
|
||||
manager.stats["total_failures"] += 1
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class DummyMessage:
|
||||
time: float
|
||||
processed_plain_text: str = ""
|
||||
display_message: str = ""
|
||||
is_at: bool = False
|
||||
is_mentioned: bool = False
|
||||
|
||||
|
||||
class DummyContext:
|
||||
def __init__(self, stream_id: str, initial_unread: int):
|
||||
self.stream_id = stream_id
|
||||
self.unread_messages = [DummyMessage(time=time.time()) for _ in range(initial_unread)]
|
||||
self.history_messages: list[DummyMessage] = []
|
||||
self.is_chatter_processing: bool = False
|
||||
self.processing_task: Optional[asyncio.Task] = None
|
||||
self.stream_loop_task: Optional[asyncio.Task] = None
|
||||
self.triggering_user_id: Optional[str] = None
|
||||
|
||||
def get_unread_messages(self) -> list[DummyMessage]:
|
||||
return list(self.unread_messages)
|
||||
|
||||
def flush_cached_messages(self) -> list[DummyMessage]:
|
||||
return []
|
||||
|
||||
def get_last_message(self) -> Optional[DummyMessage]:
|
||||
return self.unread_messages[-1] if self.unread_messages else None
|
||||
|
||||
def get_history_messages(self, limit: int = 50) -> list[DummyMessage]:
|
||||
return self.history_messages[-limit:]
|
||||
|
||||
|
||||
class DummyStream:
|
||||
def __init__(self, stream_id: str, ctx: DummyContext):
|
||||
self.stream_id = stream_id
|
||||
self.context = ctx
|
||||
self.group_info = None # treat as private chat to accelerate
|
||||
self._focus_energy = 0.5
|
||||
|
||||
|
||||
class DummyChatManager:
|
||||
def __init__(self, streams: dict[str, DummyStream]):
|
||||
self._streams = streams
|
||||
|
||||
async def get_stream(self, stream_id: str) -> Optional[DummyStream]:
|
||||
return self._streams.get(stream_id)
|
||||
|
||||
def get_all_streams(self) -> dict[str, DummyStream]:
|
||||
return self._streams
|
||||
|
||||
|
||||
class DummyChatterManager:
|
||||
async def process_stream_context(self, stream_id: str, context: DummyContext) -> dict[str, Any]:
|
||||
# Simulate some processing latency and consume one unread message
|
||||
await asyncio.sleep(0.01)
|
||||
if context.unread_messages:
|
||||
context.unread_messages.pop(0)
|
||||
return {"success": True}
|
||||
|
||||
|
||||
class BenchStreamLoopManager(StreamLoopManager):
|
||||
def __init__(self, chat_manager: DummyChatManager, max_concurrent_streams: int | None = None):
|
||||
super().__init__(max_concurrent_streams=max_concurrent_streams)
|
||||
self._chat_manager = chat_manager
|
||||
self.chatter_manager = DummyChatterManager()
|
||||
|
||||
async def _get_stream_context(self, stream_id: str): # type: ignore[override]
|
||||
stream = await self._chat_manager.get_stream(stream_id)
|
||||
return stream.context if stream else None
|
||||
|
||||
async def _flush_cached_messages_to_unread(self, stream_id: str): # type: ignore[override]
|
||||
ctx = await self._get_stream_context(stream_id)
|
||||
return ctx.flush_cached_messages() if ctx else []
|
||||
|
||||
def _needs_force_dispatch_for_context(self, context, unread_count: int) -> bool: # type: ignore[override]
|
||||
# force when unread exceeds threshold
|
||||
return unread_count > (self.force_dispatch_unread_threshold or 20)
|
||||
|
||||
async def _process_stream_messages(self, stream_id: str, context): # type: ignore[override]
|
||||
# delegate to chatter manager
|
||||
res = await self.chatter_manager.process_stream_context(stream_id, context) # type: ignore[attr-defined]
|
||||
return bool(res.get("success", False))
|
||||
|
||||
async def _should_skip_for_mute_group(self, stream_id: str, unread_messages: list) -> bool:
|
||||
return False
|
||||
|
||||
async def _update_stream_energy(self, stream_id: str, context): # type: ignore[override]
|
||||
# lightweight: compute based on unread size
|
||||
focus = min(1.0, 0.1 + 0.02 * len(context.get_unread_messages()))
|
||||
# set for compatibility
|
||||
stream = await self._chat_manager.get_stream(stream_id)
|
||||
if stream:
|
||||
stream._focus_energy = focus
|
||||
|
||||
|
||||
def make_streams(n_streams: int, initial_unread: int) -> dict[str, DummyStream]:
|
||||
streams: dict[str, DummyStream] = {}
|
||||
for i in range(n_streams):
|
||||
sid = f"s{i:04d}"
|
||||
ctx = DummyContext(sid, initial_unread)
|
||||
streams[sid] = DummyStream(sid, ctx)
|
||||
return streams
|
||||
|
||||
|
||||
async def run_benchmark(n_streams: int, initial_unread: int, max_concurrent: Optional[int]) -> dict[str, Any]:
|
||||
streams = make_streams(n_streams, initial_unread)
|
||||
chat_mgr = DummyChatManager(streams)
|
||||
mgr = BenchStreamLoopManager(chat_mgr, max_concurrent_streams=max_concurrent)
|
||||
await mgr.start()
|
||||
|
||||
# start loops for all streams
|
||||
start_ts = time.time()
|
||||
for sid in list(streams.keys()):
|
||||
await mgr.start_stream_loop(sid, force=True)
|
||||
|
||||
# run until all unread consumed or timeout
|
||||
timeout = 5.0
|
||||
end_deadline = start_ts + timeout
|
||||
while time.time() < end_deadline:
|
||||
remaining = sum(len(s.context.get_unread_messages()) for s in streams.values())
|
||||
if remaining == 0:
|
||||
break
|
||||
await asyncio.sleep(0.02)
|
||||
|
||||
duration = time.time() - start_ts
|
||||
total_cycles = mgr.stats.get("total_process_cycles", 0)
|
||||
total_failures = mgr.stats.get("total_failures", 0)
|
||||
remaining = sum(len(s.context.get_unread_messages()) for s in streams.values())
|
||||
|
||||
# stop all
|
||||
await mgr.stop()
|
||||
|
||||
return {
|
||||
"n_streams": n_streams,
|
||||
"initial_unread": initial_unread,
|
||||
"max_concurrent": max_concurrent,
|
||||
"duration_sec": duration,
|
||||
"total_cycles": total_cycles,
|
||||
"total_failures": total_failures,
|
||||
"remaining_unread": remaining,
|
||||
"throughput_msgs_per_sec": (n_streams * initial_unread - remaining) / max(0.001, duration),
|
||||
}
|
||||
|
||||
|
||||
async def main():
|
||||
cases = [
|
||||
(50, 5, None), # baseline using configured default
|
||||
(50, 5, 5), # constrained concurrency
|
||||
(50, 5, 10), # moderate concurrency
|
||||
(100, 3, 10), # scale streams
|
||||
]
|
||||
|
||||
print("Running distribution manager benchmark...\n")
|
||||
for n_streams, initial_unread, max_concurrent in cases:
|
||||
res = await run_benchmark(n_streams, initial_unread, max_concurrent)
|
||||
print(
|
||||
f"streams={res['n_streams']} unread={res['initial_unread']} max_conc={res['max_concurrent']} | "
|
||||
f"dur={res['duration_sec']:.3f}s cycles={res['total_cycles']} fail={res['total_failures']} rem={res['remaining_unread']} "+
|
||||
f"thr={res['throughput_msgs_per_sec']:.1f}/s"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -1,276 +0,0 @@
|
||||
"""
|
||||
统一记忆管理器性能基准测试
|
||||
|
||||
对优化前后的关键操作进行性能对比测试
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
|
||||
class PerformanceBenchmark:
|
||||
"""性能基准测试工具"""
|
||||
|
||||
def __init__(self):
|
||||
self.results = {}
|
||||
|
||||
async def benchmark_query_deduplication(self):
|
||||
"""测试查询去重性能"""
|
||||
# 这里需要导入实际的管理器
|
||||
# from src.memory_graph.unified_manager import UnifiedMemoryManager
|
||||
|
||||
test_cases = [
|
||||
{
|
||||
"name": "small_queries",
|
||||
"queries": ["hello", "world"],
|
||||
},
|
||||
{
|
||||
"name": "medium_queries",
|
||||
"queries": ["q" + str(i % 5) for i in range(50)], # 10 个唯一
|
||||
},
|
||||
{
|
||||
"name": "large_queries",
|
||||
"queries": ["q" + str(i % 100) for i in range(1000)], # 100 个唯一
|
||||
},
|
||||
{
|
||||
"name": "many_duplicates",
|
||||
"queries": ["duplicate"] * 500, # 500 个重复
|
||||
},
|
||||
]
|
||||
|
||||
# 模拟旧算法
|
||||
def old_build_manual_queries(queries):
|
||||
deduplicated = []
|
||||
seen = set()
|
||||
for raw in queries:
|
||||
text = (raw or "").strip()
|
||||
if not text or text in seen:
|
||||
continue
|
||||
deduplicated.append(text)
|
||||
seen.add(text)
|
||||
|
||||
if len(deduplicated) <= 1:
|
||||
return []
|
||||
|
||||
manual_queries = []
|
||||
decay = 0.15
|
||||
for idx, text in enumerate(deduplicated):
|
||||
weight = max(0.3, 1.0 - idx * decay)
|
||||
manual_queries.append({"text": text, "weight": round(weight, 2)})
|
||||
|
||||
return manual_queries
|
||||
|
||||
# 新算法
|
||||
def new_build_manual_queries(queries):
|
||||
seen = set()
|
||||
decay = 0.15
|
||||
manual_queries = []
|
||||
|
||||
for raw in queries:
|
||||
text = (raw or "").strip()
|
||||
if text and text not in seen:
|
||||
seen.add(text)
|
||||
weight = max(0.3, 1.0 - len(manual_queries) * decay)
|
||||
manual_queries.append({"text": text, "weight": round(weight, 2)})
|
||||
|
||||
return manual_queries if len(manual_queries) > 1 else []
|
||||
|
||||
print("\n" + "=" * 70)
|
||||
print("查询去重性能基准测试")
|
||||
print("=" * 70)
|
||||
print(f"{'测试用例':<20} {'旧算法(μs)':<15} {'新算法(μs)':<15} {'提升比例':<15}")
|
||||
print("-" * 70)
|
||||
|
||||
for test_case in test_cases:
|
||||
name = test_case["name"]
|
||||
queries = test_case["queries"]
|
||||
|
||||
# 测试旧算法
|
||||
start = time.perf_counter()
|
||||
for _ in range(100):
|
||||
old_build_manual_queries(queries)
|
||||
old_time = (time.perf_counter() - start) / 100 * 1e6
|
||||
|
||||
# 测试新算法
|
||||
start = time.perf_counter()
|
||||
for _ in range(100):
|
||||
new_build_manual_queries(queries)
|
||||
new_time = (time.perf_counter() - start) / 100 * 1e6
|
||||
|
||||
improvement = (old_time - new_time) / old_time * 100
|
||||
print(
|
||||
f"{name:<20} {old_time:>14.2f} {new_time:>14.2f} {improvement:>13.1f}%"
|
||||
)
|
||||
|
||||
print()
|
||||
|
||||
async def benchmark_transfer_parallelization(self):
|
||||
"""测试块转移并行化性能"""
|
||||
print("\n" + "=" * 70)
|
||||
print("块转移并行化性能基准测试")
|
||||
print("=" * 70)
|
||||
|
||||
# 模拟旧算法(串行)
|
||||
async def old_transfer_logic(num_blocks: int):
|
||||
async def mock_operation():
|
||||
await asyncio.sleep(0.001) # 模拟 1ms 操作
|
||||
return True
|
||||
|
||||
results = []
|
||||
for _ in range(num_blocks):
|
||||
result = await mock_operation()
|
||||
results.append(result)
|
||||
return results
|
||||
|
||||
# 新算法(并行)
|
||||
async def new_transfer_logic(num_blocks: int):
|
||||
async def mock_operation():
|
||||
await asyncio.sleep(0.001) # 模拟 1ms 操作
|
||||
return True
|
||||
|
||||
results = await asyncio.gather(*[mock_operation() for _ in range(num_blocks)])
|
||||
return results
|
||||
|
||||
block_counts = [1, 5, 10, 20, 50]
|
||||
|
||||
print(f"{'块数':<10} {'串行(ms)':<15} {'并行(ms)':<15} {'加速比':<15}")
|
||||
print("-" * 70)
|
||||
|
||||
for num_blocks in block_counts:
|
||||
# 测试串行
|
||||
start = time.perf_counter()
|
||||
for _ in range(10):
|
||||
await old_transfer_logic(num_blocks)
|
||||
serial_time = (time.perf_counter() - start) / 10 * 1000
|
||||
|
||||
# 测试并行
|
||||
start = time.perf_counter()
|
||||
for _ in range(10):
|
||||
await new_transfer_logic(num_blocks)
|
||||
parallel_time = (time.perf_counter() - start) / 10 * 1000
|
||||
|
||||
speedup = serial_time / parallel_time
|
||||
print(
|
||||
f"{num_blocks:<10} {serial_time:>14.2f} {parallel_time:>14.2f} {speedup:>14.2f}x"
|
||||
)
|
||||
|
||||
print()
|
||||
|
||||
async def benchmark_deduplication_memory(self):
|
||||
"""测试内存去重性能"""
|
||||
print("\n" + "=" * 70)
|
||||
print("内存去重性能基准测试")
|
||||
print("=" * 70)
|
||||
|
||||
# 创建模拟对象
|
||||
class MockMemory:
|
||||
def __init__(self, mem_id: str):
|
||||
self.id = mem_id
|
||||
|
||||
# 旧算法
|
||||
def old_deduplicate(memories):
|
||||
seen_ids = set()
|
||||
unique_memories = []
|
||||
for mem in memories:
|
||||
mem_id = getattr(mem, "id", None)
|
||||
if mem_id and mem_id in seen_ids:
|
||||
continue
|
||||
unique_memories.append(mem)
|
||||
if mem_id:
|
||||
seen_ids.add(mem_id)
|
||||
return unique_memories
|
||||
|
||||
# 新算法
|
||||
def new_deduplicate(memories):
|
||||
seen_ids = set()
|
||||
unique_memories = []
|
||||
for mem in memories:
|
||||
mem_id = None
|
||||
if isinstance(mem, dict):
|
||||
mem_id = mem.get("id")
|
||||
else:
|
||||
mem_id = getattr(mem, "id", None)
|
||||
|
||||
if mem_id and mem_id in seen_ids:
|
||||
continue
|
||||
unique_memories.append(mem)
|
||||
if mem_id:
|
||||
seen_ids.add(mem_id)
|
||||
return unique_memories
|
||||
|
||||
test_cases = [
|
||||
{
|
||||
"name": "objects_100",
|
||||
"data": [MockMemory(f"id_{i % 50}") for i in range(100)],
|
||||
},
|
||||
{
|
||||
"name": "objects_1000",
|
||||
"data": [MockMemory(f"id_{i % 500}") for i in range(1000)],
|
||||
},
|
||||
{
|
||||
"name": "dicts_100",
|
||||
"data": [{"id": f"id_{i % 50}"} for i in range(100)],
|
||||
},
|
||||
{
|
||||
"name": "dicts_1000",
|
||||
"data": [{"id": f"id_{i % 500}"} for i in range(1000)],
|
||||
},
|
||||
]
|
||||
|
||||
print(f"{'测试用例':<20} {'旧算法(μs)':<15} {'新算法(μs)':<15} {'提升比例':<15}")
|
||||
print("-" * 70)
|
||||
|
||||
for test_case in test_cases:
|
||||
name = test_case["name"]
|
||||
data = test_case["data"]
|
||||
|
||||
# 测试旧算法
|
||||
start = time.perf_counter()
|
||||
for _ in range(100):
|
||||
old_deduplicate(data)
|
||||
old_time = (time.perf_counter() - start) / 100 * 1e6
|
||||
|
||||
# 测试新算法
|
||||
start = time.perf_counter()
|
||||
for _ in range(100):
|
||||
new_deduplicate(data)
|
||||
new_time = (time.perf_counter() - start) / 100 * 1e6
|
||||
|
||||
improvement = (old_time - new_time) / old_time * 100
|
||||
print(
|
||||
f"{name:<20} {old_time:>14.2f} {new_time:>14.2f} {improvement:>13.1f}%"
|
||||
)
|
||||
|
||||
print()
|
||||
|
||||
|
||||
async def run_all_benchmarks():
|
||||
"""运行所有基准测试"""
|
||||
benchmark = PerformanceBenchmark()
|
||||
|
||||
print("\n" + "╔" + "=" * 68 + "╗")
|
||||
print("║" + " " * 68 + "║")
|
||||
print("║" + "统一记忆管理器优化性能基准测试".center(68) + "║")
|
||||
print("║" + " " * 68 + "║")
|
||||
print("╚" + "=" * 68 + "╝")
|
||||
|
||||
await benchmark.benchmark_query_deduplication()
|
||||
await benchmark.benchmark_transfer_parallelization()
|
||||
await benchmark.benchmark_deduplication_memory()
|
||||
|
||||
print("\n" + "=" * 70)
|
||||
print("性能基准测试完成")
|
||||
print("=" * 70)
|
||||
print("\n📊 关键发现:")
|
||||
print(" 1. 查询去重:新算法在大规模查询时快 5-15%")
|
||||
print(" 2. 块转移:并行化在 ≥5 块时有 2-10 倍加速")
|
||||
print(" 3. 内存去重:新算法支持混合类型,性能相当或更优")
|
||||
print("\n💡 建议:")
|
||||
print(" • 定期运行此基准测试监控性能")
|
||||
print(" • 在生产环境观察实际内存管理的转移块数")
|
||||
print(" • 考虑对高频操作进行更深度的优化")
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(run_all_benchmarks())
|
||||
@@ -1,204 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
AWS Bedrock 客户端测试脚本
|
||||
测试 BedrockClient 的基本功能
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# 添加项目根目录到 Python 路径
|
||||
project_root = Path(__file__).parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
from src.config.api_ada_configs import APIProvider, ModelInfo
|
||||
from src.llm_models.model_client.bedrock_client import BedrockClient
|
||||
from src.llm_models.payload_content.message import MessageBuilder
|
||||
|
||||
|
||||
async def test_basic_conversation():
|
||||
"""测试基本对话功能"""
|
||||
print("=" * 60)
|
||||
print("测试 1: 基本对话功能")
|
||||
print("=" * 60)
|
||||
|
||||
# 配置 API Provider(请替换为你的真实凭证)
|
||||
provider = APIProvider(
|
||||
name="bedrock_test",
|
||||
base_url="", # Bedrock 不需要
|
||||
api_key="YOUR_AWS_ACCESS_KEY_ID", # 替换为你的 AWS Access Key
|
||||
client_type="bedrock",
|
||||
max_retry=2,
|
||||
timeout=60,
|
||||
retry_interval=10,
|
||||
extra_params={
|
||||
"aws_secret_key": "YOUR_AWS_SECRET_ACCESS_KEY", # 替换为你的 AWS Secret Key
|
||||
"region": "us-east-1",
|
||||
},
|
||||
)
|
||||
|
||||
# 配置模型信息
|
||||
model = ModelInfo(
|
||||
model_identifier="us.anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||
name="claude-3.5-sonnet-bedrock",
|
||||
api_provider="bedrock_test",
|
||||
price_in=3.0,
|
||||
price_out=15.0,
|
||||
force_stream_mode=False,
|
||||
)
|
||||
|
||||
# 创建客户端
|
||||
client = BedrockClient(provider)
|
||||
|
||||
# 构建消息
|
||||
builder = MessageBuilder()
|
||||
builder.add_user_message("你好!请用一句话介绍 AWS Bedrock。")
|
||||
|
||||
try:
|
||||
# 发送请求
|
||||
response = await client.get_response(
|
||||
model_info=model, message_list=[builder.build()], max_tokens=200, temperature=0.7
|
||||
)
|
||||
|
||||
print(f"✅ 响应内容: {response.content}")
|
||||
if response.usage:
|
||||
print(
|
||||
f"📊 Token 使用: 输入={response.usage.prompt_tokens}, "
|
||||
f"输出={response.usage.completion_tokens}, "
|
||||
f"总计={response.usage.total_tokens}"
|
||||
)
|
||||
print("\n测试通过!✅\n")
|
||||
except Exception as e:
|
||||
print(f"❌ 测试失败: {e!s}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
async def test_streaming():
|
||||
"""测试流式输出功能"""
|
||||
print("=" * 60)
|
||||
print("测试 2: 流式输出功能")
|
||||
print("=" * 60)
|
||||
|
||||
provider = APIProvider(
|
||||
name="bedrock_test",
|
||||
base_url="",
|
||||
api_key="YOUR_AWS_ACCESS_KEY_ID",
|
||||
client_type="bedrock",
|
||||
max_retry=2,
|
||||
timeout=60,
|
||||
extra_params={
|
||||
"aws_secret_key": "YOUR_AWS_SECRET_ACCESS_KEY",
|
||||
"region": "us-east-1",
|
||||
},
|
||||
)
|
||||
|
||||
model = ModelInfo(
|
||||
model_identifier="us.anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||
name="claude-3.5-sonnet-bedrock",
|
||||
api_provider="bedrock_test",
|
||||
price_in=3.0,
|
||||
price_out=15.0,
|
||||
force_stream_mode=True, # 启用流式模式
|
||||
)
|
||||
|
||||
client = BedrockClient(provider)
|
||||
builder = MessageBuilder()
|
||||
builder.add_user_message("写一个关于人工智能的三行诗。")
|
||||
|
||||
try:
|
||||
print("🔄 流式响应中...")
|
||||
response = await client.get_response(
|
||||
model_info=model, message_list=[builder.build()], max_tokens=100, temperature=0.7
|
||||
)
|
||||
|
||||
print(f"✅ 完整响应: {response.content}")
|
||||
print("\n测试通过!✅\n")
|
||||
except Exception as e:
|
||||
print(f"❌ 测试失败: {e!s}")
|
||||
|
||||
|
||||
async def test_multimodal():
|
||||
"""测试多模态(图片输入)功能"""
|
||||
print("=" * 60)
|
||||
print("测试 3: 多模态功能(需要准备图片)")
|
||||
print("=" * 60)
|
||||
print("⏭️ 跳过(需要实际图片文件)\n")
|
||||
|
||||
|
||||
async def test_tool_calling():
|
||||
"""测试工具调用功能"""
|
||||
print("=" * 60)
|
||||
print("测试 4: 工具调用功能")
|
||||
print("=" * 60)
|
||||
|
||||
from src.llm_models.payload_content.tool_option import ToolOptionBuilder, ToolParamType
|
||||
|
||||
provider = APIProvider(
|
||||
name="bedrock_test",
|
||||
base_url="",
|
||||
api_key="YOUR_AWS_ACCESS_KEY_ID",
|
||||
client_type="bedrock",
|
||||
extra_params={
|
||||
"aws_secret_key": "YOUR_AWS_SECRET_ACCESS_KEY",
|
||||
"region": "us-east-1",
|
||||
},
|
||||
)
|
||||
|
||||
model = ModelInfo(
|
||||
model_identifier="us.anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||
name="claude-3.5-sonnet-bedrock",
|
||||
api_provider="bedrock_test",
|
||||
)
|
||||
|
||||
# 定义工具
|
||||
tool_builder = ToolOptionBuilder()
|
||||
tool_builder.set_name("get_weather").set_description("获取指定城市的天气信息").add_param(
|
||||
name="city", param_type=ToolParamType.STRING, description="城市名称", required=True
|
||||
)
|
||||
|
||||
tool = tool_builder.build()
|
||||
|
||||
client = BedrockClient(provider)
|
||||
builder = MessageBuilder()
|
||||
builder.add_user_message("北京今天天气怎么样?")
|
||||
|
||||
try:
|
||||
response = await client.get_response(
|
||||
model_info=model, message_list=[builder.build()], tool_options=[tool], max_tokens=200
|
||||
)
|
||||
|
||||
if response.tool_calls:
|
||||
print("✅ 模型调用了工具:")
|
||||
for call in response.tool_calls:
|
||||
print(f" - 工具名: {call.func_name}")
|
||||
print(f" - 参数: {call.args}")
|
||||
else:
|
||||
print(f"⚠️ 模型没有调用工具,而是直接回复: {response.content}")
|
||||
|
||||
print("\n测试通过!✅\n")
|
||||
except Exception as e:
|
||||
print(f"❌ 测试失败: {e!s}")
|
||||
|
||||
|
||||
async def main():
|
||||
"""主测试函数"""
|
||||
print("\n🚀 AWS Bedrock 客户端测试开始\n")
|
||||
print("⚠️ 请确保已配置 AWS 凭证!")
|
||||
print("⚠️ 修改脚本中的 'YOUR_AWS_ACCESS_KEY_ID' 和 'YOUR_AWS_SECRET_ACCESS_KEY'\n")
|
||||
|
||||
# 运行测试
|
||||
await test_basic_conversation()
|
||||
# await test_streaming()
|
||||
# await test_multimodal()
|
||||
# await test_tool_calling()
|
||||
|
||||
print("=" * 60)
|
||||
print("🎉 所有测试完成!")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
Reference in New Issue
Block a user