377 lines
13 KiB
Python
377 lines
13 KiB
Python
|
|
"""
|
|||
|
|
SubAgent示例Pydantic模型定义
|
|||
|
|
|
|||
|
|
提供各种场景下的结构化输出模型,展示SubAgent系统的灵活性
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
from typing import List, Dict, Any, Optional, Literal, Union
|
|||
|
|
from pydantic import BaseModel, Field, validator
|
|||
|
|
from datetime import datetime
|
|||
|
|
|
|||
|
|
|
|||
|
|
class BasicResponse(BaseModel):
|
|||
|
|
"""基础响应模型"""
|
|||
|
|
|
|||
|
|
message: str = Field(description="响应消息")
|
|||
|
|
success: bool = Field(description="处理是否成功", default=True)
|
|||
|
|
timestamp: datetime = Field(default_factory=datetime.now, description="响应时间")
|
|||
|
|
|
|||
|
|
|
|||
|
|
class SentimentAnalysis(BaseModel):
|
|||
|
|
"""情感分析结果模型"""
|
|||
|
|
|
|||
|
|
sentiment: Literal["positive", "negative", "neutral"] = Field(description="情感倾向")
|
|||
|
|
confidence: float = Field(description="置信度", ge=0.0, le=1.0)
|
|||
|
|
explanation: str = Field(description="分析说明")
|
|||
|
|
keywords: List[str] = Field(description="关键词列表", default_factory=list)
|
|||
|
|
|
|||
|
|
@validator('confidence')
|
|||
|
|
def validate_confidence(cls, v):
|
|||
|
|
"""验证置信度范围"""
|
|||
|
|
if not 0.0 <= v <= 1.0:
|
|||
|
|
raise ValueError('置信度必须在0.0到1.0之间')
|
|||
|
|
return round(v, 3)
|
|||
|
|
|
|||
|
|
|
|||
|
|
class KeywordExtraction(BaseModel):
|
|||
|
|
"""关键词提取项"""
|
|||
|
|
|
|||
|
|
keyword: str = Field(description="关键词")
|
|||
|
|
frequency: int = Field(description="出现频次", ge=1)
|
|||
|
|
importance: float = Field(description="重要性评分", ge=0.0, le=1.0)
|
|||
|
|
category: str = Field(description="关键词分类", default="general")
|
|||
|
|
|
|||
|
|
class Config:
|
|||
|
|
json_encoders = {
|
|||
|
|
float: lambda v: round(v, 3) if v is not None else None
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
|
|||
|
|
class TextAnalysisResult(BaseModel):
|
|||
|
|
"""文本分析完整结果"""
|
|||
|
|
|
|||
|
|
# 基本信息
|
|||
|
|
text_length: int = Field(description="文本长度(字符数)", ge=0)
|
|||
|
|
word_count: int = Field(description="词汇数量", ge=0)
|
|||
|
|
language: str = Field(description="检测到的语言", default="zh")
|
|||
|
|
|
|||
|
|
# 分析结果
|
|||
|
|
summary: str = Field(description="文本摘要")
|
|||
|
|
sentiment: SentimentAnalysis = Field(description="情感分析结果")
|
|||
|
|
keywords: List[KeywordExtraction] = Field(description="关键词提取结果")
|
|||
|
|
|
|||
|
|
# 质量评估
|
|||
|
|
readability: Literal["high", "medium", "low"] = Field(description="可读性评估")
|
|||
|
|
complexity: float = Field(description="复杂度评分", ge=0.0, le=1.0)
|
|||
|
|
|
|||
|
|
@validator('text_length')
|
|||
|
|
def validate_text_length(cls, v):
|
|||
|
|
"""验证文本长度"""
|
|||
|
|
if v < 0:
|
|||
|
|
raise ValueError('文本长度不能为负数')
|
|||
|
|
return v
|
|||
|
|
|
|||
|
|
@validator('keywords')
|
|||
|
|
def validate_keywords(cls, v):
|
|||
|
|
"""验证关键词列表"""
|
|||
|
|
if len(v) > 20:
|
|||
|
|
# 只保留前20个最重要的关键词
|
|||
|
|
v = sorted(v, key=lambda x: x.importance, reverse=True)[:20]
|
|||
|
|
return v
|
|||
|
|
|
|||
|
|
|
|||
|
|
class CategoryClassification(BaseModel):
|
|||
|
|
"""分类结果项"""
|
|||
|
|
|
|||
|
|
category: str = Field(description="分类名称")
|
|||
|
|
confidence: float = Field(description="分类置信度", ge=0.0, le=1.0)
|
|||
|
|
probability: float = Field(description="分类概率", ge=0.0, le=1.0)
|
|||
|
|
|
|||
|
|
@validator('confidence', 'probability')
|
|||
|
|
def round_float_values(cls, v):
|
|||
|
|
"""保留3位小数"""
|
|||
|
|
return round(v, 3)
|
|||
|
|
|
|||
|
|
|
|||
|
|
class DocumentClassificationResult(BaseModel):
|
|||
|
|
"""文档分类结果"""
|
|||
|
|
|
|||
|
|
primary_category: str = Field(description="主要分类")
|
|||
|
|
confidence: float = Field(description="主分类置信度", ge=0.0, le=1.0)
|
|||
|
|
|
|||
|
|
all_categories: List[CategoryClassification] = Field(
|
|||
|
|
description="所有分类结果(按置信度排序)"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
features_used: List[str] = Field(
|
|||
|
|
description="使用的特征列表",
|
|||
|
|
default_factory=list
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
processing_time: Optional[float] = Field(
|
|||
|
|
description="处理时间(秒)",
|
|||
|
|
default=None
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
@validator('all_categories')
|
|||
|
|
def sort_categories(cls, v):
|
|||
|
|
"""按置信度降序排列"""
|
|||
|
|
return sorted(v, key=lambda x: x.confidence, reverse=True)
|
|||
|
|
|
|||
|
|
|
|||
|
|
class DataExtractionItem(BaseModel):
|
|||
|
|
"""数据提取项"""
|
|||
|
|
|
|||
|
|
field_name: str = Field(description="字段名称")
|
|||
|
|
field_value: Union[str, int, float, bool, None] = Field(description="字段值")
|
|||
|
|
confidence: float = Field(description="提取置信度", ge=0.0, le=1.0)
|
|||
|
|
source_text: str = Field(description="来源文本片段")
|
|||
|
|
extraction_method: str = Field(description="提取方法", default="llm")
|
|||
|
|
|
|||
|
|
|
|||
|
|
class StructuredDataExtraction(BaseModel):
|
|||
|
|
"""结构化数据提取结果"""
|
|||
|
|
|
|||
|
|
extracted_data: Dict[str, Any] = Field(description="提取的结构化数据")
|
|||
|
|
extraction_items: List[DataExtractionItem] = Field(description="详细提取项目")
|
|||
|
|
|
|||
|
|
# 质量评估
|
|||
|
|
extraction_quality: Literal["excellent", "good", "fair", "poor"] = Field(
|
|||
|
|
description="提取质量评估"
|
|||
|
|
)
|
|||
|
|
completeness: float = Field(
|
|||
|
|
description="完整性评分",
|
|||
|
|
ge=0.0,
|
|||
|
|
le=1.0
|
|||
|
|
)
|
|||
|
|
accuracy: float = Field(
|
|||
|
|
description="准确性评分",
|
|||
|
|
ge=0.0,
|
|||
|
|
le=1.0
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 统计信息
|
|||
|
|
total_fields: int = Field(description="总字段数", ge=0)
|
|||
|
|
extracted_fields: int = Field(description="成功提取字段数", ge=0)
|
|||
|
|
failed_fields: int = Field(description="提取失败字段数", ge=0)
|
|||
|
|
|
|||
|
|
@validator('extracted_fields', 'failed_fields')
|
|||
|
|
def validate_field_counts(cls, v, values):
|
|||
|
|
"""验证字段计数"""
|
|||
|
|
total = values.get('total_fields', 0)
|
|||
|
|
if v > total:
|
|||
|
|
raise ValueError('提取字段数不能超过总字段数')
|
|||
|
|
return v
|
|||
|
|
|
|||
|
|
|
|||
|
|
class TaskExecutionResult(BaseModel):
|
|||
|
|
"""任务执行结果"""
|
|||
|
|
|
|||
|
|
# 任务信息
|
|||
|
|
task_id: str = Field(description="任务ID")
|
|||
|
|
task_type: str = Field(description="任务类型")
|
|||
|
|
status: Literal["completed", "failed", "partial"] = Field(description="执行状态")
|
|||
|
|
|
|||
|
|
# 执行详情
|
|||
|
|
result_data: Optional[Dict[str, Any]] = Field(description="结果数据", default=None)
|
|||
|
|
error_message: Optional[str] = Field(description="错误信息", default=None)
|
|||
|
|
warnings: List[str] = Field(description="警告信息", default_factory=list)
|
|||
|
|
|
|||
|
|
# 性能指标
|
|||
|
|
execution_time: float = Field(description="执行时间(秒)", ge=0.0)
|
|||
|
|
memory_usage: Optional[float] = Field(description="内存使用量(MB)", default=None)
|
|||
|
|
|
|||
|
|
# 质量评估
|
|||
|
|
success_rate: float = Field(description="成功率", ge=0.0, le=1.0)
|
|||
|
|
quality_score: float = Field(description="质量评分", ge=0.0, le=1.0)
|
|||
|
|
|
|||
|
|
@validator('success_rate', 'quality_score')
|
|||
|
|
def round_scores(cls, v):
|
|||
|
|
"""保留3位小数"""
|
|||
|
|
return round(v, 3)
|
|||
|
|
|
|||
|
|
|
|||
|
|
class ComprehensiveAnalysisResult(BaseModel):
|
|||
|
|
"""综合分析结果(组合多个分析)"""
|
|||
|
|
|
|||
|
|
# 基本信息
|
|||
|
|
analysis_id: str = Field(description="分析ID")
|
|||
|
|
input_summary: str = Field(description="输入数据摘要")
|
|||
|
|
analysis_timestamp: datetime = Field(default_factory=datetime.now)
|
|||
|
|
|
|||
|
|
# 各项分析结果
|
|||
|
|
text_analysis: Optional[TextAnalysisResult] = Field(
|
|||
|
|
description="文本分析结果",
|
|||
|
|
default=None
|
|||
|
|
)
|
|||
|
|
classification: Optional[DocumentClassificationResult] = Field(
|
|||
|
|
description="分类结果",
|
|||
|
|
default=None
|
|||
|
|
)
|
|||
|
|
data_extraction: Optional[StructuredDataExtraction] = Field(
|
|||
|
|
description="数据提取结果",
|
|||
|
|
default=None
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 综合评估
|
|||
|
|
overall_quality: Literal["excellent", "good", "fair", "poor"] = Field(
|
|||
|
|
description="整体质量评估"
|
|||
|
|
)
|
|||
|
|
confidence_level: float = Field(
|
|||
|
|
description="整体置信度",
|
|||
|
|
ge=0.0,
|
|||
|
|
le=1.0
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 处理统计
|
|||
|
|
total_processing_time: float = Field(description="总处理时间(秒)", ge=0.0)
|
|||
|
|
components_completed: int = Field(description="完成的组件数量", ge=0)
|
|||
|
|
components_failed: int = Field(description="失败的组件数量", ge=0)
|
|||
|
|
|
|||
|
|
recommendations: List[str] = Field(
|
|||
|
|
description="改进建议",
|
|||
|
|
default_factory=list
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
@validator('confidence_level')
|
|||
|
|
def validate_confidence(cls, v):
|
|||
|
|
"""验证并格式化置信度"""
|
|||
|
|
return round(v, 3)
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 测试模型定义
|
|||
|
|
def test_models():
|
|||
|
|
"""测试所有模型定义"""
|
|||
|
|
print("正在测试示例模型定义...")
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
# 测试基础响应模型
|
|||
|
|
basic = BasicResponse(message="测试消息")
|
|||
|
|
print(f"✅ BasicResponse模型测试成功: {basic.message}")
|
|||
|
|
|
|||
|
|
# 测试情感分析模型
|
|||
|
|
sentiment = SentimentAnalysis(
|
|||
|
|
sentiment="positive",
|
|||
|
|
confidence=0.95,
|
|||
|
|
explanation="积极的情感表达",
|
|||
|
|
keywords=["好", "满意", "推荐"]
|
|||
|
|
)
|
|||
|
|
print(f"✅ SentimentAnalysis模型测试成功: {sentiment.sentiment}")
|
|||
|
|
|
|||
|
|
# 测试关键词提取模型
|
|||
|
|
keyword = KeywordExtraction(
|
|||
|
|
keyword="人工智能",
|
|||
|
|
frequency=5,
|
|||
|
|
importance=0.8,
|
|||
|
|
category="technology"
|
|||
|
|
)
|
|||
|
|
print(f"✅ KeywordExtraction模型测试成功: {keyword.keyword}")
|
|||
|
|
|
|||
|
|
# 测试文本分析结果模型
|
|||
|
|
text_result = TextAnalysisResult(
|
|||
|
|
text_length=150,
|
|||
|
|
word_count=25,
|
|||
|
|
language="zh",
|
|||
|
|
summary="这是一个测试文本摘要",
|
|||
|
|
sentiment=sentiment,
|
|||
|
|
keywords=[keyword],
|
|||
|
|
readability="high",
|
|||
|
|
complexity=0.3
|
|||
|
|
)
|
|||
|
|
print(f"✅ TextAnalysisResult模型测试成功: {text_result.summary}")
|
|||
|
|
|
|||
|
|
# 测试分类结果模型
|
|||
|
|
classification = DocumentClassificationResult(
|
|||
|
|
primary_category="技术文档",
|
|||
|
|
confidence=0.92,
|
|||
|
|
all_categories=[
|
|||
|
|
CategoryClassification(
|
|||
|
|
category="技术文档",
|
|||
|
|
confidence=0.92,
|
|||
|
|
probability=0.87
|
|||
|
|
)
|
|||
|
|
]
|
|||
|
|
)
|
|||
|
|
print(f"✅ DocumentClassificationResult模型测试成功: {classification.primary_category}")
|
|||
|
|
|
|||
|
|
# 测试数据提取模型
|
|||
|
|
extraction_item = DataExtractionItem(
|
|||
|
|
field_name="标题",
|
|||
|
|
field_value="SubAgent系统指南",
|
|||
|
|
confidence=0.98,
|
|||
|
|
source_text="# SubAgent系统指南",
|
|||
|
|
extraction_method="pattern_matching"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
data_extraction = StructuredDataExtraction(
|
|||
|
|
extracted_data={"title": "SubAgent系统指南"},
|
|||
|
|
extraction_items=[extraction_item],
|
|||
|
|
extraction_quality="excellent",
|
|||
|
|
completeness=1.0,
|
|||
|
|
accuracy=0.95,
|
|||
|
|
total_fields=1,
|
|||
|
|
extracted_fields=1,
|
|||
|
|
failed_fields=0
|
|||
|
|
)
|
|||
|
|
print(f"✅ StructuredDataExtraction模型测试成功: {data_extraction.extraction_quality}")
|
|||
|
|
|
|||
|
|
# 测试任务执行结果模型
|
|||
|
|
task_result = TaskExecutionResult(
|
|||
|
|
task_id="task_001",
|
|||
|
|
task_type="text_analysis",
|
|||
|
|
status="completed",
|
|||
|
|
result_data={"status": "success"},
|
|||
|
|
execution_time=2.5,
|
|||
|
|
success_rate=1.0,
|
|||
|
|
quality_score=0.95
|
|||
|
|
)
|
|||
|
|
print(f"✅ TaskExecutionResult模型测试成功: {task_result.status}")
|
|||
|
|
|
|||
|
|
# 测试综合分析结果模型
|
|||
|
|
comprehensive = ComprehensiveAnalysisResult(
|
|||
|
|
analysis_id="analysis_001",
|
|||
|
|
input_summary="测试输入摘要",
|
|||
|
|
text_analysis=text_result,
|
|||
|
|
classification=classification,
|
|||
|
|
data_extraction=data_extraction,
|
|||
|
|
overall_quality="excellent",
|
|||
|
|
confidence_level=0.93,
|
|||
|
|
total_processing_time=5.2,
|
|||
|
|
components_completed=3,
|
|||
|
|
components_failed=0,
|
|||
|
|
recommendations=["继续保持高质量输出"]
|
|||
|
|
)
|
|||
|
|
print(f"✅ ComprehensiveAnalysisResult模型测试成功: {comprehensive.overall_quality}")
|
|||
|
|
|
|||
|
|
# 测试JSON序列化
|
|||
|
|
json_str = comprehensive.model_dump_json(indent=2)
|
|||
|
|
print(f"✅ JSON序列化测试成功,长度: {len(json_str)}字符")
|
|||
|
|
|
|||
|
|
# 列出所有模型字段
|
|||
|
|
print("\n📋 模型字段信息:")
|
|||
|
|
for model_name, model_class in [
|
|||
|
|
("BasicResponse", BasicResponse),
|
|||
|
|
("SentimentAnalysis", SentimentAnalysis),
|
|||
|
|
("TextAnalysisResult", TextAnalysisResult),
|
|||
|
|
("DocumentClassificationResult", DocumentClassificationResult),
|
|||
|
|
("StructuredDataExtraction", StructuredDataExtraction),
|
|||
|
|
("ComprehensiveAnalysisResult", ComprehensiveAnalysisResult),
|
|||
|
|
]:
|
|||
|
|
fields = list(model_class.model_fields.keys())
|
|||
|
|
print(f" {model_name}: {len(fields)} 个字段 - {fields[:3]}{'...' if len(fields) > 3 else ''}")
|
|||
|
|
|
|||
|
|
return True
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"❌ 模型测试失败: {e}")
|
|||
|
|
import traceback
|
|||
|
|
traceback.print_exc()
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
success = test_models()
|
|||
|
|
if success:
|
|||
|
|
print("\n🎉 所有示例模型测试通过!")
|
|||
|
|
else:
|
|||
|
|
print("\n💥 模型测试失败,请检查定义")
|