MedResearcher/number_extraction_models.py

78 lines
2.5 KiB
Python
Raw Normal View History

"""
数字提取Pydantic模型定义
用于结构化解析从文本中提取的数字和相关解释信息
"""
from typing import List, Union
from pydantic import BaseModel, Field
class NumberExtraction(BaseModel):
"""单个数字提取项"""
number: Union[int, float] = Field(description="提取的数字值")
explanation: str = Field(description="对该数字的解释说明")
context: str = Field(description="数字出现的上下文片段")
unit: str = Field(default="", description="数字的单位(如果有)")
class Config:
json_encoders = {
# 确保浮点数正确序列化
float: lambda v: round(v, 6) if v is not None else None
}
class NumberExtractionResult(BaseModel):
"""数字提取完整结果"""
extractions: List[NumberExtraction] = Field(description="提取的数字项列表")
summary: str = Field(description="对整个文本中数字的总结")
total_count: int = Field(description="提取的数字总数", ge=0)
def __post_init__(self):
"""确保total_count与extractions长度一致"""
if self.total_count != len(self.extractions):
object.__setattr__(self, 'total_count', len(self.extractions))
# 测试模型定义
def test_models():
"""测试Pydantic模型定义"""
print("正在测试数字提取模型...")
try:
# 测试单个数字提取项
extraction = NumberExtraction(
number=95.2,
explanation="模型准确率",
context="模型在测试集上达到了95.2%的准确率",
unit="%"
)
print(f"✅ NumberExtraction模型测试成功: {extraction}")
# 测试完整结果
result = NumberExtractionResult(
extractions=[extraction],
summary="发现1个准确率数值",
total_count=1
)
print(f"✅ NumberExtractionResult模型测试成功")
# 测试JSON序列化
json_str = result.model_dump_json(indent=2)
print(f"✅ JSON序列化测试成功长度: {len(json_str)}字符")
# 测试模型字段
print(f"✅ NumberExtraction字段: {list(NumberExtraction.model_fields.keys())}")
print(f"✅ NumberExtractionResult字段: {list(NumberExtractionResult.model_fields.keys())}")
return True
except Exception as e:
print(f"❌ 模型测试失败: {e}")
return False
if __name__ == "__main__":
test_models()