78 lines
2.5 KiB
Python
78 lines
2.5 KiB
Python
|
|
"""
|
|||
|
|
数字提取Pydantic模型定义
|
|||
|
|
|
|||
|
|
用于结构化解析从文本中提取的数字和相关解释信息
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
from typing import List, Union
|
|||
|
|
from pydantic import BaseModel, Field
|
|||
|
|
|
|||
|
|
|
|||
|
|
class NumberExtraction(BaseModel):
|
|||
|
|
"""单个数字提取项"""
|
|||
|
|
|
|||
|
|
number: Union[int, float] = Field(description="提取的数字值")
|
|||
|
|
explanation: str = Field(description="对该数字的解释说明")
|
|||
|
|
context: str = Field(description="数字出现的上下文片段")
|
|||
|
|
unit: str = Field(default="", description="数字的单位(如果有)")
|
|||
|
|
|
|||
|
|
class Config:
|
|||
|
|
json_encoders = {
|
|||
|
|
# 确保浮点数正确序列化
|
|||
|
|
float: lambda v: round(v, 6) if v is not None else None
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
|
|||
|
|
class NumberExtractionResult(BaseModel):
|
|||
|
|
"""数字提取完整结果"""
|
|||
|
|
|
|||
|
|
extractions: List[NumberExtraction] = Field(description="提取的数字项列表")
|
|||
|
|
summary: str = Field(description="对整个文本中数字的总结")
|
|||
|
|
total_count: int = Field(description="提取的数字总数", ge=0)
|
|||
|
|
|
|||
|
|
def __post_init__(self):
|
|||
|
|
"""确保total_count与extractions长度一致"""
|
|||
|
|
if self.total_count != len(self.extractions):
|
|||
|
|
object.__setattr__(self, 'total_count', len(self.extractions))
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 测试模型定义
|
|||
|
|
def test_models():
|
|||
|
|
"""测试Pydantic模型定义"""
|
|||
|
|
print("正在测试数字提取模型...")
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
# 测试单个数字提取项
|
|||
|
|
extraction = NumberExtraction(
|
|||
|
|
number=95.2,
|
|||
|
|
explanation="模型准确率",
|
|||
|
|
context="模型在测试集上达到了95.2%的准确率",
|
|||
|
|
unit="%"
|
|||
|
|
)
|
|||
|
|
print(f"✅ NumberExtraction模型测试成功: {extraction}")
|
|||
|
|
|
|||
|
|
# 测试完整结果
|
|||
|
|
result = NumberExtractionResult(
|
|||
|
|
extractions=[extraction],
|
|||
|
|
summary="发现1个准确率数值",
|
|||
|
|
total_count=1
|
|||
|
|
)
|
|||
|
|
print(f"✅ NumberExtractionResult模型测试成功")
|
|||
|
|
|
|||
|
|
# 测试JSON序列化
|
|||
|
|
json_str = result.model_dump_json(indent=2)
|
|||
|
|
print(f"✅ JSON序列化测试成功,长度: {len(json_str)}字符")
|
|||
|
|
|
|||
|
|
# 测试模型字段
|
|||
|
|
print(f"✅ NumberExtraction字段: {list(NumberExtraction.model_fields.keys())}")
|
|||
|
|
print(f"✅ NumberExtractionResult字段: {list(NumberExtractionResult.model_fields.keys())}")
|
|||
|
|
|
|||
|
|
return True
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"❌ 模型测试失败: {e}")
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
test_models()
|