377 lines
13 KiB
Python
Raw Normal View History

"""
SubAgent示例Pydantic模型定义
提供各种场景下的结构化输出模型展示SubAgent系统的灵活性
"""
from typing import List, Dict, Any, Optional, Literal, Union
from pydantic import BaseModel, Field, validator
from datetime import datetime
class BasicResponse(BaseModel):
"""基础响应模型"""
message: str = Field(description="响应消息")
success: bool = Field(description="处理是否成功", default=True)
timestamp: datetime = Field(default_factory=datetime.now, description="响应时间")
class SentimentAnalysis(BaseModel):
"""情感分析结果模型"""
sentiment: Literal["positive", "negative", "neutral"] = Field(description="情感倾向")
confidence: float = Field(description="置信度", ge=0.0, le=1.0)
explanation: str = Field(description="分析说明")
keywords: List[str] = Field(description="关键词列表", default_factory=list)
@validator('confidence')
def validate_confidence(cls, v):
"""验证置信度范围"""
if not 0.0 <= v <= 1.0:
raise ValueError('置信度必须在0.0到1.0之间')
return round(v, 3)
class KeywordExtraction(BaseModel):
"""关键词提取项"""
keyword: str = Field(description="关键词")
frequency: int = Field(description="出现频次", ge=1)
importance: float = Field(description="重要性评分", ge=0.0, le=1.0)
category: str = Field(description="关键词分类", default="general")
class Config:
json_encoders = {
float: lambda v: round(v, 3) if v is not None else None
}
class TextAnalysisResult(BaseModel):
"""文本分析完整结果"""
# 基本信息
text_length: int = Field(description="文本长度(字符数)", ge=0)
word_count: int = Field(description="词汇数量", ge=0)
language: str = Field(description="检测到的语言", default="zh")
# 分析结果
summary: str = Field(description="文本摘要")
sentiment: SentimentAnalysis = Field(description="情感分析结果")
keywords: List[KeywordExtraction] = Field(description="关键词提取结果")
# 质量评估
readability: Literal["high", "medium", "low"] = Field(description="可读性评估")
complexity: float = Field(description="复杂度评分", ge=0.0, le=1.0)
@validator('text_length')
def validate_text_length(cls, v):
"""验证文本长度"""
if v < 0:
raise ValueError('文本长度不能为负数')
return v
@validator('keywords')
def validate_keywords(cls, v):
"""验证关键词列表"""
if len(v) > 20:
# 只保留前20个最重要的关键词
v = sorted(v, key=lambda x: x.importance, reverse=True)[:20]
return v
class CategoryClassification(BaseModel):
"""分类结果项"""
category: str = Field(description="分类名称")
confidence: float = Field(description="分类置信度", ge=0.0, le=1.0)
probability: float = Field(description="分类概率", ge=0.0, le=1.0)
@validator('confidence', 'probability')
def round_float_values(cls, v):
"""保留3位小数"""
return round(v, 3)
class DocumentClassificationResult(BaseModel):
"""文档分类结果"""
primary_category: str = Field(description="主要分类")
confidence: float = Field(description="主分类置信度", ge=0.0, le=1.0)
all_categories: List[CategoryClassification] = Field(
description="所有分类结果(按置信度排序)"
)
features_used: List[str] = Field(
description="使用的特征列表",
default_factory=list
)
processing_time: Optional[float] = Field(
description="处理时间(秒)",
default=None
)
@validator('all_categories')
def sort_categories(cls, v):
"""按置信度降序排列"""
return sorted(v, key=lambda x: x.confidence, reverse=True)
class DataExtractionItem(BaseModel):
"""数据提取项"""
field_name: str = Field(description="字段名称")
field_value: Union[str, int, float, bool, None] = Field(description="字段值")
confidence: float = Field(description="提取置信度", ge=0.0, le=1.0)
source_text: str = Field(description="来源文本片段")
extraction_method: str = Field(description="提取方法", default="llm")
class StructuredDataExtraction(BaseModel):
"""结构化数据提取结果"""
extracted_data: Dict[str, Any] = Field(description="提取的结构化数据")
extraction_items: List[DataExtractionItem] = Field(description="详细提取项目")
# 质量评估
extraction_quality: Literal["excellent", "good", "fair", "poor"] = Field(
description="提取质量评估"
)
completeness: float = Field(
description="完整性评分",
ge=0.0,
le=1.0
)
accuracy: float = Field(
description="准确性评分",
ge=0.0,
le=1.0
)
# 统计信息
total_fields: int = Field(description="总字段数", ge=0)
extracted_fields: int = Field(description="成功提取字段数", ge=0)
failed_fields: int = Field(description="提取失败字段数", ge=0)
@validator('extracted_fields', 'failed_fields')
def validate_field_counts(cls, v, values):
"""验证字段计数"""
total = values.get('total_fields', 0)
if v > total:
raise ValueError('提取字段数不能超过总字段数')
return v
class TaskExecutionResult(BaseModel):
"""任务执行结果"""
# 任务信息
task_id: str = Field(description="任务ID")
task_type: str = Field(description="任务类型")
status: Literal["completed", "failed", "partial"] = Field(description="执行状态")
# 执行详情
result_data: Optional[Dict[str, Any]] = Field(description="结果数据", default=None)
error_message: Optional[str] = Field(description="错误信息", default=None)
warnings: List[str] = Field(description="警告信息", default_factory=list)
# 性能指标
execution_time: float = Field(description="执行时间(秒)", ge=0.0)
memory_usage: Optional[float] = Field(description="内存使用量MB", default=None)
# 质量评估
success_rate: float = Field(description="成功率", ge=0.0, le=1.0)
quality_score: float = Field(description="质量评分", ge=0.0, le=1.0)
@validator('success_rate', 'quality_score')
def round_scores(cls, v):
"""保留3位小数"""
return round(v, 3)
class ComprehensiveAnalysisResult(BaseModel):
"""综合分析结果(组合多个分析)"""
# 基本信息
analysis_id: str = Field(description="分析ID")
input_summary: str = Field(description="输入数据摘要")
analysis_timestamp: datetime = Field(default_factory=datetime.now)
# 各项分析结果
text_analysis: Optional[TextAnalysisResult] = Field(
description="文本分析结果",
default=None
)
classification: Optional[DocumentClassificationResult] = Field(
description="分类结果",
default=None
)
data_extraction: Optional[StructuredDataExtraction] = Field(
description="数据提取结果",
default=None
)
# 综合评估
overall_quality: Literal["excellent", "good", "fair", "poor"] = Field(
description="整体质量评估"
)
confidence_level: float = Field(
description="整体置信度",
ge=0.0,
le=1.0
)
# 处理统计
total_processing_time: float = Field(description="总处理时间(秒)", ge=0.0)
components_completed: int = Field(description="完成的组件数量", ge=0)
components_failed: int = Field(description="失败的组件数量", ge=0)
recommendations: List[str] = Field(
description="改进建议",
default_factory=list
)
@validator('confidence_level')
def validate_confidence(cls, v):
"""验证并格式化置信度"""
return round(v, 3)
# 测试模型定义
def test_models():
"""测试所有模型定义"""
print("正在测试示例模型定义...")
try:
# 测试基础响应模型
basic = BasicResponse(message="测试消息")
print(f"✅ BasicResponse模型测试成功: {basic.message}")
# 测试情感分析模型
sentiment = SentimentAnalysis(
sentiment="positive",
confidence=0.95,
explanation="积极的情感表达",
keywords=["", "满意", "推荐"]
)
print(f"✅ SentimentAnalysis模型测试成功: {sentiment.sentiment}")
# 测试关键词提取模型
keyword = KeywordExtraction(
keyword="人工智能",
frequency=5,
importance=0.8,
category="technology"
)
print(f"✅ KeywordExtraction模型测试成功: {keyword.keyword}")
# 测试文本分析结果模型
text_result = TextAnalysisResult(
text_length=150,
word_count=25,
language="zh",
summary="这是一个测试文本摘要",
sentiment=sentiment,
keywords=[keyword],
readability="high",
complexity=0.3
)
print(f"✅ TextAnalysisResult模型测试成功: {text_result.summary}")
# 测试分类结果模型
classification = DocumentClassificationResult(
primary_category="技术文档",
confidence=0.92,
all_categories=[
CategoryClassification(
category="技术文档",
confidence=0.92,
probability=0.87
)
]
)
print(f"✅ DocumentClassificationResult模型测试成功: {classification.primary_category}")
# 测试数据提取模型
extraction_item = DataExtractionItem(
field_name="标题",
field_value="SubAgent系统指南",
confidence=0.98,
source_text="# SubAgent系统指南",
extraction_method="pattern_matching"
)
data_extraction = StructuredDataExtraction(
extracted_data={"title": "SubAgent系统指南"},
extraction_items=[extraction_item],
extraction_quality="excellent",
completeness=1.0,
accuracy=0.95,
total_fields=1,
extracted_fields=1,
failed_fields=0
)
print(f"✅ StructuredDataExtraction模型测试成功: {data_extraction.extraction_quality}")
# 测试任务执行结果模型
task_result = TaskExecutionResult(
task_id="task_001",
task_type="text_analysis",
status="completed",
result_data={"status": "success"},
execution_time=2.5,
success_rate=1.0,
quality_score=0.95
)
print(f"✅ TaskExecutionResult模型测试成功: {task_result.status}")
# 测试综合分析结果模型
comprehensive = ComprehensiveAnalysisResult(
analysis_id="analysis_001",
input_summary="测试输入摘要",
text_analysis=text_result,
classification=classification,
data_extraction=data_extraction,
overall_quality="excellent",
confidence_level=0.93,
total_processing_time=5.2,
components_completed=3,
components_failed=0,
recommendations=["继续保持高质量输出"]
)
print(f"✅ ComprehensiveAnalysisResult模型测试成功: {comprehensive.overall_quality}")
# 测试JSON序列化
json_str = comprehensive.model_dump_json(indent=2)
print(f"✅ JSON序列化测试成功长度: {len(json_str)}字符")
# 列出所有模型字段
print("\n📋 模型字段信息:")
for model_name, model_class in [
("BasicResponse", BasicResponse),
("SentimentAnalysis", SentimentAnalysis),
("TextAnalysisResult", TextAnalysisResult),
("DocumentClassificationResult", DocumentClassificationResult),
("StructuredDataExtraction", StructuredDataExtraction),
("ComprehensiveAnalysisResult", ComprehensiveAnalysisResult),
]:
fields = list(model_class.model_fields.keys())
print(f" {model_name}: {len(fields)} 个字段 - {fields[:3]}{'...' if len(fields) > 3 else ''}")
return True
except Exception as e:
print(f"❌ 模型测试失败: {e}")
import traceback
traceback.print_exc()
return False
if __name__ == "__main__":
success = test_models()
if success:
print("\n🎉 所有示例模型测试通过!")
else:
print("\n💥 模型测试失败,请检查定义")