diff --git a/sqlmodel/ext/asyncio/session.py b/sqlmodel/ext/asyncio/session.py index f2b989e2fa..64fbc51d39 100644 --- a/sqlmodel/ext/asyncio/session.py +++ b/sqlmodel/ext/asyncio/session.py @@ -6,7 +6,7 @@ overload, ) -from sqlalchemy import util +from sqlalchemy import TextClause, util from sqlalchemy.engine.cursor import CursorResult from sqlalchemy.engine.interfaces import _CoreAnyExecuteParams from sqlalchemy.engine.result import Result, ScalarResult, TupleResult @@ -66,12 +66,25 @@ async def exec( _add_event: Any | None = None, ) -> CursorResult[Any]: ... + @overload + async def exec( + self, + statement: TextClause, + *, + params: Mapping[str, Any] | Sequence[Mapping[str, Any]] | None = None, + execution_options: Mapping[str, Any] = util.EMPTY_DICT, + bind_arguments: dict[str, Any] | None = None, + _parent_execute_state: Any | None = None, + _add_event: Any | None = None, + ) -> CursorResult[Any]: ... + async def exec( self, statement: Select[_TSelectParam] | SelectOfScalar[_TSelectParam] | Executable[_TSelectParam] - | UpdateBase, + | UpdateBase + | TextClause, *, params: Mapping[str, Any] | Sequence[Mapping[str, Any]] | None = None, execution_options: Mapping[str, Any] = util.EMPTY_DICT, diff --git a/sqlmodel/orm/session.py b/sqlmodel/orm/session.py index 5c721ae0da..83a3495d79 100644 --- a/sqlmodel/orm/session.py +++ b/sqlmodel/orm/session.py @@ -5,7 +5,7 @@ overload, ) -from sqlalchemy import util +from sqlalchemy import TextClause, util from sqlalchemy.engine.cursor import CursorResult from sqlalchemy.engine.interfaces import _CoreAnyExecuteParams from sqlalchemy.engine.result import Result, ScalarResult, TupleResult @@ -59,12 +59,25 @@ def exec( _add_event: Any | None = None, ) -> CursorResult[Any]: ... + @overload + def exec( + self, + statement: TextClause, + *, + params: Mapping[str, Any] | Sequence[Mapping[str, Any]] | None = None, + execution_options: Mapping[str, Any] = util.EMPTY_DICT, + bind_arguments: dict[str, Any] | None = None, + _parent_execute_state: Any | None = None, + _add_event: Any | None = None, + ) -> CursorResult[Any]: ... + def exec( self, statement: Select[_TSelectParam] | SelectOfScalar[_TSelectParam] | Executable[_TSelectParam] - | UpdateBase, + | UpdateBase + | TextClause, *, params: Mapping[str, Any] | Sequence[Mapping[str, Any]] | None = None, execution_options: Mapping[str, Any] = util.EMPTY_DICT, diff --git a/tests/test_exec_text.py b/tests/test_exec_text.py new file mode 100644 index 0000000000..b69cde8824 --- /dev/null +++ b/tests/test_exec_text.py @@ -0,0 +1,47 @@ +from sqlmodel import Field, Session, SQLModel, create_engine, text + + +def test_select_using_text_statement(clear_sqlmodel): + class Hero(SQLModel, table=True): + id: int | None = Field(default=None, primary_key=True) + name: str + secret_name: str + age: int | None = None + + hero_1 = Hero(name="Deadpond", secret_name="Dive Wilson") + + engine = create_engine("sqlite://") + + SQLModel.metadata.create_all(engine) + + with Session(engine) as session: + session.add(hero_1) + session.commit() + session.refresh(hero_1) + + with Session(engine) as session: + res = session.exec(text("SELECT * FROM hero")).all() + assert len(res) == 1 + assert res[0] == (1, "Deadpond", "Dive Wilson", None) + + +def test_insert_using_text_statement(clear_sqlmodel): + class Hero(SQLModel, table=True): + id: int | None = Field(default=None, primary_key=True) + name: str + secret_name: str + age: int | None = None + + engine = create_engine("sqlite://") + + SQLModel.metadata.create_all(engine) + + with Session(engine) as session: + res = session.exec( + text( + "INSERT INTO hero (name, secret_name) VALUES ('Deadpond', 'Dive Wilson')" + ) + ) + session.commit() + + assert res.rowcount == 1