""" SubAgent示例Pydantic模型定义 提供各种场景下的结构化输出模型,展示SubAgent系统的灵活性 """ from typing import List, Dict, Any, Optional, Literal, Union from pydantic import BaseModel, Field, validator from datetime import datetime class BasicResponse(BaseModel): """基础响应模型""" message: str = Field(description="响应消息") success: bool = Field(description="处理是否成功", default=True) timestamp: datetime = Field(default_factory=datetime.now, description="响应时间") class SentimentAnalysis(BaseModel): """情感分析结果模型""" sentiment: Literal["positive", "negative", "neutral"] = Field(description="情感倾向") confidence: float = Field(description="置信度", ge=0.0, le=1.0) explanation: str = Field(description="分析说明") keywords: List[str] = Field(description="关键词列表", default_factory=list) @validator('confidence') def validate_confidence(cls, v): """验证置信度范围""" if not 0.0 <= v <= 1.0: raise ValueError('置信度必须在0.0到1.0之间') return round(v, 3) class KeywordExtraction(BaseModel): """关键词提取项""" keyword: str = Field(description="关键词") frequency: int = Field(description="出现频次", ge=1) importance: float = Field(description="重要性评分", ge=0.0, le=1.0) category: str = Field(description="关键词分类", default="general") class Config: json_encoders = { float: lambda v: round(v, 3) if v is not None else None } class TextAnalysisResult(BaseModel): """文本分析完整结果""" # 基本信息 text_length: int = Field(description="文本长度(字符数)", ge=0) word_count: int = Field(description="词汇数量", ge=0) language: str = Field(description="检测到的语言", default="zh") # 分析结果 summary: str = Field(description="文本摘要") sentiment: SentimentAnalysis = Field(description="情感分析结果") keywords: List[KeywordExtraction] = Field(description="关键词提取结果") # 质量评估 readability: Literal["high", "medium", "low"] = Field(description="可读性评估") complexity: float = Field(description="复杂度评分", ge=0.0, le=1.0) @validator('text_length') def validate_text_length(cls, v): """验证文本长度""" if v < 0: raise ValueError('文本长度不能为负数') return v @validator('keywords') def validate_keywords(cls, v): """验证关键词列表""" if len(v) > 20: # 只保留前20个最重要的关键词 v = sorted(v, key=lambda x: x.importance, reverse=True)[:20] return v class CategoryClassification(BaseModel): """分类结果项""" category: str = Field(description="分类名称") confidence: float = Field(description="分类置信度", ge=0.0, le=1.0) probability: float = Field(description="分类概率", ge=0.0, le=1.0) @validator('confidence', 'probability') def round_float_values(cls, v): """保留3位小数""" return round(v, 3) class DocumentClassificationResult(BaseModel): """文档分类结果""" primary_category: str = Field(description="主要分类") confidence: float = Field(description="主分类置信度", ge=0.0, le=1.0) all_categories: List[CategoryClassification] = Field( description="所有分类结果(按置信度排序)" ) features_used: List[str] = Field( description="使用的特征列表", default_factory=list ) processing_time: Optional[float] = Field( description="处理时间(秒)", default=None ) @validator('all_categories') def sort_categories(cls, v): """按置信度降序排列""" return sorted(v, key=lambda x: x.confidence, reverse=True) class DataExtractionItem(BaseModel): """数据提取项""" field_name: str = Field(description="字段名称") field_value: Union[str, int, float, bool, None] = Field(description="字段值") confidence: float = Field(description="提取置信度", ge=0.0, le=1.0) source_text: str = Field(description="来源文本片段") extraction_method: str = Field(description="提取方法", default="llm") class StructuredDataExtraction(BaseModel): """结构化数据提取结果""" extracted_data: Dict[str, Any] = Field(description="提取的结构化数据") extraction_items: List[DataExtractionItem] = Field(description="详细提取项目") # 质量评估 extraction_quality: Literal["excellent", "good", "fair", "poor"] = Field( description="提取质量评估" ) completeness: float = Field( description="完整性评分", ge=0.0, le=1.0 ) accuracy: float = Field( description="准确性评分", ge=0.0, le=1.0 ) # 统计信息 total_fields: int = Field(description="总字段数", ge=0) extracted_fields: int = Field(description="成功提取字段数", ge=0) failed_fields: int = Field(description="提取失败字段数", ge=0) @validator('extracted_fields', 'failed_fields') def validate_field_counts(cls, v, values): """验证字段计数""" total = values.get('total_fields', 0) if v > total: raise ValueError('提取字段数不能超过总字段数') return v class TaskExecutionResult(BaseModel): """任务执行结果""" # 任务信息 task_id: str = Field(description="任务ID") task_type: str = Field(description="任务类型") status: Literal["completed", "failed", "partial"] = Field(description="执行状态") # 执行详情 result_data: Optional[Dict[str, Any]] = Field(description="结果数据", default=None) error_message: Optional[str] = Field(description="错误信息", default=None) warnings: List[str] = Field(description="警告信息", default_factory=list) # 性能指标 execution_time: float = Field(description="执行时间(秒)", ge=0.0) memory_usage: Optional[float] = Field(description="内存使用量(MB)", default=None) # 质量评估 success_rate: float = Field(description="成功率", ge=0.0, le=1.0) quality_score: float = Field(description="质量评分", ge=0.0, le=1.0) @validator('success_rate', 'quality_score') def round_scores(cls, v): """保留3位小数""" return round(v, 3) class ComprehensiveAnalysisResult(BaseModel): """综合分析结果(组合多个分析)""" # 基本信息 analysis_id: str = Field(description="分析ID") input_summary: str = Field(description="输入数据摘要") analysis_timestamp: datetime = Field(default_factory=datetime.now) # 各项分析结果 text_analysis: Optional[TextAnalysisResult] = Field( description="文本分析结果", default=None ) classification: Optional[DocumentClassificationResult] = Field( description="分类结果", default=None ) data_extraction: Optional[StructuredDataExtraction] = Field( description="数据提取结果", default=None ) # 综合评估 overall_quality: Literal["excellent", "good", "fair", "poor"] = Field( description="整体质量评估" ) confidence_level: float = Field( description="整体置信度", ge=0.0, le=1.0 ) # 处理统计 total_processing_time: float = Field(description="总处理时间(秒)", ge=0.0) components_completed: int = Field(description="完成的组件数量", ge=0) components_failed: int = Field(description="失败的组件数量", ge=0) recommendations: List[str] = Field( description="改进建议", default_factory=list ) @validator('confidence_level') def validate_confidence(cls, v): """验证并格式化置信度""" return round(v, 3) # 测试模型定义 def test_models(): """测试所有模型定义""" print("正在测试示例模型定义...") try: # 测试基础响应模型 basic = BasicResponse(message="测试消息") print(f"✅ BasicResponse模型测试成功: {basic.message}") # 测试情感分析模型 sentiment = SentimentAnalysis( sentiment="positive", confidence=0.95, explanation="积极的情感表达", keywords=["好", "满意", "推荐"] ) print(f"✅ SentimentAnalysis模型测试成功: {sentiment.sentiment}") # 测试关键词提取模型 keyword = KeywordExtraction( keyword="人工智能", frequency=5, importance=0.8, category="technology" ) print(f"✅ KeywordExtraction模型测试成功: {keyword.keyword}") # 测试文本分析结果模型 text_result = TextAnalysisResult( text_length=150, word_count=25, language="zh", summary="这是一个测试文本摘要", sentiment=sentiment, keywords=[keyword], readability="high", complexity=0.3 ) print(f"✅ TextAnalysisResult模型测试成功: {text_result.summary}") # 测试分类结果模型 classification = DocumentClassificationResult( primary_category="技术文档", confidence=0.92, all_categories=[ CategoryClassification( category="技术文档", confidence=0.92, probability=0.87 ) ] ) print(f"✅ DocumentClassificationResult模型测试成功: {classification.primary_category}") # 测试数据提取模型 extraction_item = DataExtractionItem( field_name="标题", field_value="SubAgent系统指南", confidence=0.98, source_text="# SubAgent系统指南", extraction_method="pattern_matching" ) data_extraction = StructuredDataExtraction( extracted_data={"title": "SubAgent系统指南"}, extraction_items=[extraction_item], extraction_quality="excellent", completeness=1.0, accuracy=0.95, total_fields=1, extracted_fields=1, failed_fields=0 ) print(f"✅ StructuredDataExtraction模型测试成功: {data_extraction.extraction_quality}") # 测试任务执行结果模型 task_result = TaskExecutionResult( task_id="task_001", task_type="text_analysis", status="completed", result_data={"status": "success"}, execution_time=2.5, success_rate=1.0, quality_score=0.95 ) print(f"✅ TaskExecutionResult模型测试成功: {task_result.status}") # 测试综合分析结果模型 comprehensive = ComprehensiveAnalysisResult( analysis_id="analysis_001", input_summary="测试输入摘要", text_analysis=text_result, classification=classification, data_extraction=data_extraction, overall_quality="excellent", confidence_level=0.93, total_processing_time=5.2, components_completed=3, components_failed=0, recommendations=["继续保持高质量输出"] ) print(f"✅ ComprehensiveAnalysisResult模型测试成功: {comprehensive.overall_quality}") # 测试JSON序列化 json_str = comprehensive.model_dump_json(indent=2) print(f"✅ JSON序列化测试成功,长度: {len(json_str)}字符") # 列出所有模型字段 print("\n📋 模型字段信息:") for model_name, model_class in [ ("BasicResponse", BasicResponse), ("SentimentAnalysis", SentimentAnalysis), ("TextAnalysisResult", TextAnalysisResult), ("DocumentClassificationResult", DocumentClassificationResult), ("StructuredDataExtraction", StructuredDataExtraction), ("ComprehensiveAnalysisResult", ComprehensiveAnalysisResult), ]: fields = list(model_class.model_fields.keys()) print(f" {model_name}: {len(fields)} 个字段 - {fields[:3]}{'...' if len(fields) > 3 else ''}") return True except Exception as e: print(f"❌ 模型测试失败: {e}") import traceback traceback.print_exc() return False if __name__ == "__main__": success = test_models() if success: print("\n🎉 所有示例模型测试通过!") else: print("\n💥 模型测试失败,请检查定义")