diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 300031de8b..4648ce9e01 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -23,6 +23,7 @@ overload, ) +import annotated_types from pydantic import BaseModel, EmailStr from pydantic.fields import FieldInfo as PydanticFieldInfo from sqlalchemy import ( @@ -247,10 +248,10 @@ def Field( exclude: Set[int | str] | Mapping[int | str, Any] | Any = None, include: Set[int | str] | Mapping[int | str, Any] | Any = None, const: bool | None = None, - gt: float | None = None, - ge: float | None = None, - lt: float | None = None, - le: float | None = None, + gt: annotated_types.SupportsGt | None = None, + ge: annotated_types.SupportsGe | None = None, + lt: annotated_types.SupportsLt | None = None, + le: annotated_types.SupportsLe | None = None, multiple_of: float | None = None, max_digits: int | None = None, decimal_places: int | None = None, @@ -290,10 +291,10 @@ def Field( exclude: Set[int | str] | Mapping[int | str, Any] | Any = None, include: Set[int | str] | Mapping[int | str, Any] | Any = None, const: bool | None = None, - gt: float | None = None, - ge: float | None = None, - lt: float | None = None, - le: float | None = None, + gt: annotated_types.SupportsGt | None = None, + ge: annotated_types.SupportsGe | None = None, + lt: annotated_types.SupportsLt | None = None, + le: annotated_types.SupportsLe | None = None, multiple_of: float | None = None, max_digits: int | None = None, decimal_places: int | None = None, @@ -342,10 +343,10 @@ def Field( exclude: Set[int | str] | Mapping[int | str, Any] | Any = None, include: Set[int | str] | Mapping[int | str, Any] | Any = None, const: bool | None = None, - gt: float | None = None, - ge: float | None = None, - lt: float | None = None, - le: float | None = None, + gt: annotated_types.SupportsGt | None = None, + ge: annotated_types.SupportsGe | None = None, + lt: annotated_types.SupportsLt | None = None, + le: annotated_types.SupportsLe | None = None, multiple_of: float | None = None, max_digits: int | None = None, decimal_places: int | None = None, @@ -375,10 +376,10 @@ def Field( exclude: Set[int | str] | Mapping[int | str, Any] | Any = None, include: Set[int | str] | Mapping[int | str, Any] | Any = None, const: bool | None = None, - gt: float | None = None, - ge: float | None = None, - lt: float | None = None, - le: float | None = None, + gt: annotated_types.SupportsGt | None = None, + ge: annotated_types.SupportsGe | None = None, + lt: annotated_types.SupportsLt | None = None, + le: annotated_types.SupportsLe | None = None, multiple_of: float | None = None, max_digits: int | None = None, decimal_places: int | None = None, diff --git a/tests/test_pydantic/test_field.py b/tests/test_pydantic/test_field.py index 11f4150d98..43d47c125b 100644 --- a/tests/test_pydantic/test_field.py +++ b/tests/test_pydantic/test_field.py @@ -54,3 +54,83 @@ class Model(SQLModel): instance = Model(id=123, foo="bar") assert "foo=" not in repr(instance) + + +def test_gt(): + class Model(SQLModel): + int_value: int = Field(gt=10) + tuple_value: tuple[int, int] = Field(gt=(1, 2)) + + Model(int_value=11, tuple_value=(1, 3)) + + with pytest.raises(ValidationError) as exc_info: + Model(int_value=10, tuple_value=(1, 3)) + assert len(exc_info.value.errors()) == 1 + assert exc_info.value.errors()[0]["type"] == "greater_than" + assert exc_info.value.errors()[0]["loc"] == ("int_value",) + + with pytest.raises(ValidationError) as exc_info_2: + Model(int_value=11, tuple_value=(1, 2)) + assert len(exc_info_2.value.errors()) == 1 + assert exc_info_2.value.errors()[0]["type"] == "greater_than" + assert exc_info_2.value.errors()[0]["loc"] == ("tuple_value",) + + +def test_ge(): + class Model(SQLModel): + int_value: int = Field(ge=10) + tuple_value: tuple[int, int] = Field(ge=(1, 2)) + + Model(int_value=10, tuple_value=(1, 2)) + + with pytest.raises(ValidationError) as exc_info: + Model(int_value=9, tuple_value=(1, 2)) + assert len(exc_info.value.errors()) == 1 + assert exc_info.value.errors()[0]["type"] == "greater_than_equal" + assert exc_info.value.errors()[0]["loc"] == ("int_value",) + + with pytest.raises(ValidationError) as exc_info_2: + Model(int_value=10, tuple_value=(1, 1)) + assert len(exc_info_2.value.errors()) == 1 + assert exc_info_2.value.errors()[0]["type"] == "greater_than_equal" + assert exc_info_2.value.errors()[0]["loc"] == ("tuple_value",) + + +def test_lt(): + class Model(SQLModel): + int_value: int = Field(lt=10) + tuple_value: tuple[int, int] = Field(lt=(1, 2)) + + Model(int_value=9, tuple_value=(1, 1)) + + with pytest.raises(ValidationError) as exc_info: + Model(int_value=10, tuple_value=(1, 1)) + assert len(exc_info.value.errors()) == 1 + assert exc_info.value.errors()[0]["type"] == "less_than" + assert exc_info.value.errors()[0]["loc"] == ("int_value",) + + with pytest.raises(ValidationError) as exc_info_2: + Model(int_value=9, tuple_value=(1, 2)) + assert len(exc_info_2.value.errors()) == 1 + assert exc_info_2.value.errors()[0]["type"] == "less_than" + assert exc_info_2.value.errors()[0]["loc"] == ("tuple_value",) + + +def test_le(): + class Model(SQLModel): + int_value: int = Field(le=10) + tuple_value: tuple[int, int] = Field(le=(1, 2)) + + Model(int_value=10, tuple_value=(1, 2)) + + with pytest.raises(ValidationError) as exc_info: + Model(int_value=11, tuple_value=(1, 2)) + assert len(exc_info.value.errors()) == 1 + assert exc_info.value.errors()[0]["type"] == "less_than_equal" + assert exc_info.value.errors()[0]["loc"] == ("int_value",) + + with pytest.raises(ValidationError) as exc_info_2: + Model(int_value=10, tuple_value=(1, 3)) + assert len(exc_info_2.value.errors()) == 1 + assert exc_info_2.value.errors()[0]["type"] == "less_than_equal" + assert exc_info_2.value.errors()[0]["loc"] == ("tuple_value",)