- 添加详细的SubAgent使用指南(README.md) - 创建完整的Pydantic模型示例(example_models.py) - 实现基础使用示例,展示核心功能(basic_example.py) - 构建复杂文本分析应用示例(text_analysis_example.py) - 提供数字提取实验运行器作为参考示例 - 包含多Agent协作、批量处理、性能监控等高级功能 - 支持交互式演示和完整的错误处理机制
377 lines
13 KiB
Python
377 lines
13 KiB
Python
"""
|
||
SubAgent示例Pydantic模型定义
|
||
|
||
提供各种场景下的结构化输出模型,展示SubAgent系统的灵活性
|
||
"""
|
||
|
||
from typing import List, Dict, Any, Optional, Literal, Union
|
||
from pydantic import BaseModel, Field, validator
|
||
from datetime import datetime
|
||
|
||
|
||
class BasicResponse(BaseModel):
|
||
"""基础响应模型"""
|
||
|
||
message: str = Field(description="响应消息")
|
||
success: bool = Field(description="处理是否成功", default=True)
|
||
timestamp: datetime = Field(default_factory=datetime.now, description="响应时间")
|
||
|
||
|
||
class SentimentAnalysis(BaseModel):
|
||
"""情感分析结果模型"""
|
||
|
||
sentiment: Literal["positive", "negative", "neutral"] = Field(description="情感倾向")
|
||
confidence: float = Field(description="置信度", ge=0.0, le=1.0)
|
||
explanation: str = Field(description="分析说明")
|
||
keywords: List[str] = Field(description="关键词列表", default_factory=list)
|
||
|
||
@validator('confidence')
|
||
def validate_confidence(cls, v):
|
||
"""验证置信度范围"""
|
||
if not 0.0 <= v <= 1.0:
|
||
raise ValueError('置信度必须在0.0到1.0之间')
|
||
return round(v, 3)
|
||
|
||
|
||
class KeywordExtraction(BaseModel):
|
||
"""关键词提取项"""
|
||
|
||
keyword: str = Field(description="关键词")
|
||
frequency: int = Field(description="出现频次", ge=1)
|
||
importance: float = Field(description="重要性评分", ge=0.0, le=1.0)
|
||
category: str = Field(description="关键词分类", default="general")
|
||
|
||
class Config:
|
||
json_encoders = {
|
||
float: lambda v: round(v, 3) if v is not None else None
|
||
}
|
||
|
||
|
||
class TextAnalysisResult(BaseModel):
|
||
"""文本分析完整结果"""
|
||
|
||
# 基本信息
|
||
text_length: int = Field(description="文本长度(字符数)", ge=0)
|
||
word_count: int = Field(description="词汇数量", ge=0)
|
||
language: str = Field(description="检测到的语言", default="zh")
|
||
|
||
# 分析结果
|
||
summary: str = Field(description="文本摘要")
|
||
sentiment: SentimentAnalysis = Field(description="情感分析结果")
|
||
keywords: List[KeywordExtraction] = Field(description="关键词提取结果")
|
||
|
||
# 质量评估
|
||
readability: Literal["high", "medium", "low"] = Field(description="可读性评估")
|
||
complexity: float = Field(description="复杂度评分", ge=0.0, le=1.0)
|
||
|
||
@validator('text_length')
|
||
def validate_text_length(cls, v):
|
||
"""验证文本长度"""
|
||
if v < 0:
|
||
raise ValueError('文本长度不能为负数')
|
||
return v
|
||
|
||
@validator('keywords')
|
||
def validate_keywords(cls, v):
|
||
"""验证关键词列表"""
|
||
if len(v) > 20:
|
||
# 只保留前20个最重要的关键词
|
||
v = sorted(v, key=lambda x: x.importance, reverse=True)[:20]
|
||
return v
|
||
|
||
|
||
class CategoryClassification(BaseModel):
|
||
"""分类结果项"""
|
||
|
||
category: str = Field(description="分类名称")
|
||
confidence: float = Field(description="分类置信度", ge=0.0, le=1.0)
|
||
probability: float = Field(description="分类概率", ge=0.0, le=1.0)
|
||
|
||
@validator('confidence', 'probability')
|
||
def round_float_values(cls, v):
|
||
"""保留3位小数"""
|
||
return round(v, 3)
|
||
|
||
|
||
class DocumentClassificationResult(BaseModel):
|
||
"""文档分类结果"""
|
||
|
||
primary_category: str = Field(description="主要分类")
|
||
confidence: float = Field(description="主分类置信度", ge=0.0, le=1.0)
|
||
|
||
all_categories: List[CategoryClassification] = Field(
|
||
description="所有分类结果(按置信度排序)"
|
||
)
|
||
|
||
features_used: List[str] = Field(
|
||
description="使用的特征列表",
|
||
default_factory=list
|
||
)
|
||
|
||
processing_time: Optional[float] = Field(
|
||
description="处理时间(秒)",
|
||
default=None
|
||
)
|
||
|
||
@validator('all_categories')
|
||
def sort_categories(cls, v):
|
||
"""按置信度降序排列"""
|
||
return sorted(v, key=lambda x: x.confidence, reverse=True)
|
||
|
||
|
||
class DataExtractionItem(BaseModel):
|
||
"""数据提取项"""
|
||
|
||
field_name: str = Field(description="字段名称")
|
||
field_value: Union[str, int, float, bool, None] = Field(description="字段值")
|
||
confidence: float = Field(description="提取置信度", ge=0.0, le=1.0)
|
||
source_text: str = Field(description="来源文本片段")
|
||
extraction_method: str = Field(description="提取方法", default="llm")
|
||
|
||
|
||
class StructuredDataExtraction(BaseModel):
|
||
"""结构化数据提取结果"""
|
||
|
||
extracted_data: Dict[str, Any] = Field(description="提取的结构化数据")
|
||
extraction_items: List[DataExtractionItem] = Field(description="详细提取项目")
|
||
|
||
# 质量评估
|
||
extraction_quality: Literal["excellent", "good", "fair", "poor"] = Field(
|
||
description="提取质量评估"
|
||
)
|
||
completeness: float = Field(
|
||
description="完整性评分",
|
||
ge=0.0,
|
||
le=1.0
|
||
)
|
||
accuracy: float = Field(
|
||
description="准确性评分",
|
||
ge=0.0,
|
||
le=1.0
|
||
)
|
||
|
||
# 统计信息
|
||
total_fields: int = Field(description="总字段数", ge=0)
|
||
extracted_fields: int = Field(description="成功提取字段数", ge=0)
|
||
failed_fields: int = Field(description="提取失败字段数", ge=0)
|
||
|
||
@validator('extracted_fields', 'failed_fields')
|
||
def validate_field_counts(cls, v, values):
|
||
"""验证字段计数"""
|
||
total = values.get('total_fields', 0)
|
||
if v > total:
|
||
raise ValueError('提取字段数不能超过总字段数')
|
||
return v
|
||
|
||
|
||
class TaskExecutionResult(BaseModel):
|
||
"""任务执行结果"""
|
||
|
||
# 任务信息
|
||
task_id: str = Field(description="任务ID")
|
||
task_type: str = Field(description="任务类型")
|
||
status: Literal["completed", "failed", "partial"] = Field(description="执行状态")
|
||
|
||
# 执行详情
|
||
result_data: Optional[Dict[str, Any]] = Field(description="结果数据", default=None)
|
||
error_message: Optional[str] = Field(description="错误信息", default=None)
|
||
warnings: List[str] = Field(description="警告信息", default_factory=list)
|
||
|
||
# 性能指标
|
||
execution_time: float = Field(description="执行时间(秒)", ge=0.0)
|
||
memory_usage: Optional[float] = Field(description="内存使用量(MB)", default=None)
|
||
|
||
# 质量评估
|
||
success_rate: float = Field(description="成功率", ge=0.0, le=1.0)
|
||
quality_score: float = Field(description="质量评分", ge=0.0, le=1.0)
|
||
|
||
@validator('success_rate', 'quality_score')
|
||
def round_scores(cls, v):
|
||
"""保留3位小数"""
|
||
return round(v, 3)
|
||
|
||
|
||
class ComprehensiveAnalysisResult(BaseModel):
|
||
"""综合分析结果(组合多个分析)"""
|
||
|
||
# 基本信息
|
||
analysis_id: str = Field(description="分析ID")
|
||
input_summary: str = Field(description="输入数据摘要")
|
||
analysis_timestamp: datetime = Field(default_factory=datetime.now)
|
||
|
||
# 各项分析结果
|
||
text_analysis: Optional[TextAnalysisResult] = Field(
|
||
description="文本分析结果",
|
||
default=None
|
||
)
|
||
classification: Optional[DocumentClassificationResult] = Field(
|
||
description="分类结果",
|
||
default=None
|
||
)
|
||
data_extraction: Optional[StructuredDataExtraction] = Field(
|
||
description="数据提取结果",
|
||
default=None
|
||
)
|
||
|
||
# 综合评估
|
||
overall_quality: Literal["excellent", "good", "fair", "poor"] = Field(
|
||
description="整体质量评估"
|
||
)
|
||
confidence_level: float = Field(
|
||
description="整体置信度",
|
||
ge=0.0,
|
||
le=1.0
|
||
)
|
||
|
||
# 处理统计
|
||
total_processing_time: float = Field(description="总处理时间(秒)", ge=0.0)
|
||
components_completed: int = Field(description="完成的组件数量", ge=0)
|
||
components_failed: int = Field(description="失败的组件数量", ge=0)
|
||
|
||
recommendations: List[str] = Field(
|
||
description="改进建议",
|
||
default_factory=list
|
||
)
|
||
|
||
@validator('confidence_level')
|
||
def validate_confidence(cls, v):
|
||
"""验证并格式化置信度"""
|
||
return round(v, 3)
|
||
|
||
|
||
# 测试模型定义
|
||
def test_models():
|
||
"""测试所有模型定义"""
|
||
print("正在测试示例模型定义...")
|
||
|
||
try:
|
||
# 测试基础响应模型
|
||
basic = BasicResponse(message="测试消息")
|
||
print(f"✅ BasicResponse模型测试成功: {basic.message}")
|
||
|
||
# 测试情感分析模型
|
||
sentiment = SentimentAnalysis(
|
||
sentiment="positive",
|
||
confidence=0.95,
|
||
explanation="积极的情感表达",
|
||
keywords=["好", "满意", "推荐"]
|
||
)
|
||
print(f"✅ SentimentAnalysis模型测试成功: {sentiment.sentiment}")
|
||
|
||
# 测试关键词提取模型
|
||
keyword = KeywordExtraction(
|
||
keyword="人工智能",
|
||
frequency=5,
|
||
importance=0.8,
|
||
category="technology"
|
||
)
|
||
print(f"✅ KeywordExtraction模型测试成功: {keyword.keyword}")
|
||
|
||
# 测试文本分析结果模型
|
||
text_result = TextAnalysisResult(
|
||
text_length=150,
|
||
word_count=25,
|
||
language="zh",
|
||
summary="这是一个测试文本摘要",
|
||
sentiment=sentiment,
|
||
keywords=[keyword],
|
||
readability="high",
|
||
complexity=0.3
|
||
)
|
||
print(f"✅ TextAnalysisResult模型测试成功: {text_result.summary}")
|
||
|
||
# 测试分类结果模型
|
||
classification = DocumentClassificationResult(
|
||
primary_category="技术文档",
|
||
confidence=0.92,
|
||
all_categories=[
|
||
CategoryClassification(
|
||
category="技术文档",
|
||
confidence=0.92,
|
||
probability=0.87
|
||
)
|
||
]
|
||
)
|
||
print(f"✅ DocumentClassificationResult模型测试成功: {classification.primary_category}")
|
||
|
||
# 测试数据提取模型
|
||
extraction_item = DataExtractionItem(
|
||
field_name="标题",
|
||
field_value="SubAgent系统指南",
|
||
confidence=0.98,
|
||
source_text="# SubAgent系统指南",
|
||
extraction_method="pattern_matching"
|
||
)
|
||
|
||
data_extraction = StructuredDataExtraction(
|
||
extracted_data={"title": "SubAgent系统指南"},
|
||
extraction_items=[extraction_item],
|
||
extraction_quality="excellent",
|
||
completeness=1.0,
|
||
accuracy=0.95,
|
||
total_fields=1,
|
||
extracted_fields=1,
|
||
failed_fields=0
|
||
)
|
||
print(f"✅ StructuredDataExtraction模型测试成功: {data_extraction.extraction_quality}")
|
||
|
||
# 测试任务执行结果模型
|
||
task_result = TaskExecutionResult(
|
||
task_id="task_001",
|
||
task_type="text_analysis",
|
||
status="completed",
|
||
result_data={"status": "success"},
|
||
execution_time=2.5,
|
||
success_rate=1.0,
|
||
quality_score=0.95
|
||
)
|
||
print(f"✅ TaskExecutionResult模型测试成功: {task_result.status}")
|
||
|
||
# 测试综合分析结果模型
|
||
comprehensive = ComprehensiveAnalysisResult(
|
||
analysis_id="analysis_001",
|
||
input_summary="测试输入摘要",
|
||
text_analysis=text_result,
|
||
classification=classification,
|
||
data_extraction=data_extraction,
|
||
overall_quality="excellent",
|
||
confidence_level=0.93,
|
||
total_processing_time=5.2,
|
||
components_completed=3,
|
||
components_failed=0,
|
||
recommendations=["继续保持高质量输出"]
|
||
)
|
||
print(f"✅ ComprehensiveAnalysisResult模型测试成功: {comprehensive.overall_quality}")
|
||
|
||
# 测试JSON序列化
|
||
json_str = comprehensive.model_dump_json(indent=2)
|
||
print(f"✅ JSON序列化测试成功,长度: {len(json_str)}字符")
|
||
|
||
# 列出所有模型字段
|
||
print("\n📋 模型字段信息:")
|
||
for model_name, model_class in [
|
||
("BasicResponse", BasicResponse),
|
||
("SentimentAnalysis", SentimentAnalysis),
|
||
("TextAnalysisResult", TextAnalysisResult),
|
||
("DocumentClassificationResult", DocumentClassificationResult),
|
||
("StructuredDataExtraction", StructuredDataExtraction),
|
||
("ComprehensiveAnalysisResult", ComprehensiveAnalysisResult),
|
||
]:
|
||
fields = list(model_class.model_fields.keys())
|
||
print(f" {model_name}: {len(fields)} 个字段 - {fields[:3]}{'...' if len(fields) > 3 else ''}")
|
||
|
||
return True
|
||
|
||
except Exception as e:
|
||
print(f"❌ 模型测试失败: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
return False
|
||
|
||
|
||
if __name__ == "__main__":
|
||
success = test_models()
|
||
if success:
|
||
print("\n🎉 所有示例模型测试通过!")
|
||
else:
|
||
print("\n💥 模型测试失败,请检查定义") |