From 2d2d31ed811480370d378f4b665d5406181c125d Mon Sep 17 00:00:00 2001 From: sokoliva Date: Wed, 18 Feb 2026 16:16:19 +0000 Subject: [PATCH 01/29] feat(server): implement `resource scoping` for `tasks` and `push notifications` Introduces caller indentity isolation to ensure clients only access authorized resources, as mandated by the A2A spec. - Add 'owner' field to `TaskMixin` and `PushNotificationConfig` database models. - Add 'last_updated' field to `TaskMixin` for optimized sorting and indexing. - Update `DatabaseTaskStore`, `InMemoryTaskStore` and `DatabasePushNotificationConfigStore` to use `OwnerResolver`. - Add relevant Unit tests. - Add Alembic configuration to enable users to update their own databases with non-optional `owner` field in `tasks` table. --- alembic.ini | 45 ++++++++ alembic/README | 56 ++++++++++ alembic/env.py | 85 +++++++++++++++ alembic/script.py.mako | 28 +++++ .../6419d2d130f6_add_owner_to_task.py | 38 +++++++ pyproject.toml | 86 +++++++++++++++ src/a2a/server/models.py | 18 +++- src/a2a/server/owner_resolver.py | 18 ++++ ...database_push_notification_config_store.py | 62 +++++++---- src/a2a/server/tasks/database_task_store.py | 100 +++++++++++++----- src/a2a/server/tasks/inmemory_task_store.py | 85 +++++++++++---- .../tasks/push_notification_config_store.py | 17 ++- .../test_default_request_handler.py | 14 ++- ...database_push_notification_config_store.py | 89 ++++++++++++++++ .../server/tasks/test_database_task_store.py | 66 ++++++++++++ .../server/tasks/test_inmemory_task_store.py | 65 ++++++++++++ tests/server/test_owner_resolver.py | 31 ++++++ 17 files changed, 823 insertions(+), 80 deletions(-) create mode 100644 alembic.ini create mode 100644 alembic/README create mode 100644 alembic/env.py create mode 100644 alembic/script.py.mako create mode 100644 alembic/versions/6419d2d130f6_add_owner_to_task.py create mode 100644 src/a2a/server/owner_resolver.py create mode 100644 tests/server/test_owner_resolver.py diff --git a/alembic.ini b/alembic.ini new file mode 100644 index 000000000..58249b073 --- /dev/null +++ b/alembic.ini @@ -0,0 +1,45 @@ +# A generic, single database configuration. + +[alembic] + +# database URL. This is consumed by the user-maintained env.py script only. +# other means of configuring database URLs may be customized within the env.py +# file. +# IMPORTANT: This is a placeholder and an example, and should be replaced with your actual database URL. +sqlalchemy.url = sqlite+aiosqlite:///./test.db + + +# Logging configuration +[loggers] +keys = root,sqlalchemy,alembic + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARNING +handlers = console +qualname = + +[logger_sqlalchemy] +level = WARNING +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S diff --git a/alembic/README b/alembic/README new file mode 100644 index 000000000..0c6d7dba1 --- /dev/null +++ b/alembic/README @@ -0,0 +1,56 @@ +# Database Migrations with Alembic + +This directory contains database migration scripts for the A2A SDK, managed by [Alembic](https://alembic.sqlalchemy.org/). + +## Configuration + +- `alembic.ini`: Global configuration for Alembic, including the database URL. +- `env.py`: Python script that runs when the Alembic environment is invoked. It configures the SQLAlchemy engine and connects it to the migration context. +- `versions/`: Directory containing individual migration scripts. + +## Common Commands + +All commands should be run from the project root using `uv run`. + +### Viewing Status +```bash +# View current migration version of the database +uv run alembic current + +# View migration history +uv run alembic history --verbose +``` + +### Running Migrations +```bash +# Upgrade to the latest version +uv run alembic upgrade head + +# Downgrade by one version +uv run alembic downgrade base +``` + +### Creating Migrations +```bash +# Create a new migration manually +uv run alembic revision -m "description of changes" + +# Create a new migration automatically (detects changes in models.py) +uv run alembic revision --autogenerate -m "description of changes" +``` + +## Troubleshooting + +### "duplicate column name" error +If you see an error like `sqlalchemy.exc.OperationalError: (sqlite3.OperationalError) duplicate column name: owner`, it usually means the column was already created (perhaps by `Base.metadata.create_all()` in tests or development) but Alembic doesn't know about it yet. + +To fix this, "stamp" the database to tell Alembic it is already at the latest version: +```bash +uv run alembic stamp head +``` + +## How to add a new migration +1. Modify the models in `src/a2a/server/models.py`. +2. Run `uv run alembic revision --autogenerate -m "Add new field to Task"`. +3. Review the generated script in `alembic/versions/`. +4. Apply the migration with `uv run alembic upgrade head`. diff --git a/alembic/env.py b/alembic/env.py new file mode 100644 index 000000000..d541fe140 --- /dev/null +++ b/alembic/env.py @@ -0,0 +1,85 @@ +import asyncio + +from logging.config import fileConfig + +from sqlalchemy import pool +from sqlalchemy.ext.asyncio import async_engine_from_config + +from a2a.server.models import Base +from alembic import context + + +# this is the Alembic Config object, which provides +# access to the values within the .ini file in use. +config = context.config + +# Interpret the config file for Python logging. +# This line sets up loggers basically. +if config.config_file_name is not None: + fileConfig(config.config_file_name) + +# add your model's MetaData object here for 'autogenerate' support +target_metadata = Base.metadata + +# other values from the config, defined by the needs of env.py, +# can be acquired: +# my_important_option = config.get_main_option("my_important_option") +# ... etc. + + +def run_migrations_offline() -> None: + """Run migrations in 'offline' mode. + + This configures the context with just a URL + and not an Engine, though an Engine is acceptable + here as well. By skipping the Engine creation + we don't even need a DBAPI to be available. + + Calls to context.execute() here emit the given string to the + script output. + + """ + url = config.get_main_option('sqlalchemy.url') + context.configure( + url=url, + target_metadata=target_metadata, + literal_binds=True, + dialect_opts={'paramstyle': 'named'}, + ) + + with context.begin_transaction(): + context.run_migrations() + + +def do_run_migrations(connection): + context.configure(connection=connection, target_metadata=target_metadata) + + with context.begin_transaction(): + context.run_migrations() + + +async def run_async_migrations(): + """In this scenario we need to create an Engine + and associate a connection with the context. + """ + connectable = async_engine_from_config( + config.get_section(config.config_ini_section), + prefix='sqlalchemy.', + poolclass=pool.NullPool, + ) + + async with connectable.connect() as connection: + await connection.run_sync(do_run_migrations) + + await connectable.dispose() + + +def run_migrations_online(): + """Run migrations in 'online' mode.""" + asyncio.run(run_async_migrations()) + + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() diff --git a/alembic/script.py.mako b/alembic/script.py.mako new file mode 100644 index 000000000..11016301e --- /dev/null +++ b/alembic/script.py.mako @@ -0,0 +1,28 @@ +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision: str = ${repr(up_revision)} +down_revision: Union[str, Sequence[str], None] = ${repr(down_revision)} +branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)} +depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)} + + +def upgrade() -> None: + """Upgrade schema.""" + ${upgrades if upgrades else "pass"} + + +def downgrade() -> None: + """Downgrade schema.""" + ${downgrades if downgrades else "pass"} diff --git a/alembic/versions/6419d2d130f6_add_owner_to_task.py b/alembic/versions/6419d2d130f6_add_owner_to_task.py new file mode 100644 index 000000000..3b96a5c9e --- /dev/null +++ b/alembic/versions/6419d2d130f6_add_owner_to_task.py @@ -0,0 +1,38 @@ +"""add_owner_to_task + +Revision ID: 6419d2d130f6 +Revises: +Create Date: 2026-02-17 09:23:06.758085 + +""" + +from collections.abc import Sequence + +import sqlalchemy as sa + +from alembic import op + + +# revision identifiers, used by Alembic. +revision: str = '6419d2d130f6' +down_revision: str | Sequence[str] | None = None +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + """Upgrade schema.""" + op.add_column( + 'tasks', + sa.Column( + 'owner', + sa.String(255), + nullable=False, + server_default='unknown', # Set your desired default value here + ), + ) + + +def downgrade() -> None: + """Downgrade schema.""" + op.drop_column('tasks', 'owner') diff --git a/pyproject.toml b/pyproject.toml index 1a8f0af68..0d580f745 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -323,3 +323,89 @@ docstring-code-format = true docstring-code-line-length = "dynamic" quote-style = "single" indent-style = "space" + + +[tool.alembic] + +# path to migration scripts. +# this is typically a path given in POSIX (e.g. forward slashes) +# format, relative to the token %(here)s which refers to the location of this +# ini file +script_location = "%(here)s/alembic" + +# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s +# Uncomment the line below if you want the files to be prepended with date and time +# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file +# for all available tokens +# file_template = "%%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s" +# Or organize into date-based subdirectories (requires recursive_version_locations = true) +# file_template = "%%(year)d/%%(month).2d/%%(day).2d_%%(hour).2d%%(minute).2d_%%(second).2d_%%(rev)s_%%(slug)s" + +# additional paths to be prepended to sys.path. defaults to the current working directory. +prepend_sys_path = [ + "." +] + +# timezone to use when rendering the date within the migration file +# as well as the filename. +# If specified, requires the tzdata library which can be installed by adding +# `alembic[tz]` to the pip requirements. +# string value is passed to ZoneInfo() +# leave blank for localtime +# timezone = + +# max length of characters to apply to the "slug" field +# truncate_slug_length = 40 + +# set to 'true' to run the environment during +# the 'revision' command, regardless of autogenerate +# revision_environment = false + +# set to 'true' to allow .pyc and .pyo files without +# a source .py file to be detected as revisions in the +# versions/ directory +# sourceless = false + +# version location specification; This defaults +# to /versions. When using multiple version +# directories, initial revisions must be specified with --version-path. +# version_locations = [ +# "%(here)s/alembic/versions", +# "%(here)s/foo/bar" +# ] + + +# set to 'true' to search source files recursively +# in each "version_locations" directory +# new in Alembic version 1.10 +# recursive_version_locations = false + +# the output encoding used when revision files +# are written from script.py.mako +# output_encoding = "utf-8" + +# This section defines scripts or Python functions that are run +# on newly generated revision scripts. See the documentation for further +# detail and examples +# [[tool.alembic.post_write_hooks]] +# format using "black" - use the console_scripts runner, +# against the "black" entrypoint +# name = "black" +# type = "console_scripts" +# entrypoint = "black" +# options = "-l 79 REVISION_SCRIPT_FILENAME" +# +# [[tool.alembic.post_write_hooks]] +# lint with attempts to fix using "ruff" - use the module runner, against the "ruff" module +# name = "ruff" +# type = "module" +# module = "ruff" +# options = "check --fix REVISION_SCRIPT_FILENAME" +# +# [[tool.alembic.post_write_hooks]] +# Alternatively, use the exec runner to execute a binary found on your PATH +# name = "ruff" +# type = "exec" +# executable = "ruff" +# options = "check --fix REVISION_SCRIPT_FILENAME" + diff --git a/src/a2a/server/models.py b/src/a2a/server/models.py index 4b0f7504c..636efedcb 100644 --- a/src/a2a/server/models.py +++ b/src/a2a/server/models.py @@ -1,3 +1,5 @@ +import datetime + from typing import TYPE_CHECKING, Any, Generic, TypeVar @@ -16,7 +18,7 @@ def override(func): # noqa: ANN001, ANN201 try: - from sqlalchemy import JSON, Dialect, LargeBinary, String + from sqlalchemy import JSON, Dialect, Index, LargeBinary, String from sqlalchemy.orm import ( DeclarativeBase, Mapped, @@ -127,6 +129,8 @@ class TaskMixin: kind: Mapped[str] = mapped_column( String(16), nullable=False, default='task' ) + owner: Mapped[str] = mapped_column(String(255), nullable=False, index=True) + last_updated: Mapped[datetime] = mapped_column(String(22), nullable=True) # Properly typed Pydantic fields with automatic serialization status: Mapped[TaskStatus] = mapped_column(PydanticType(TaskStatus)) @@ -152,6 +156,17 @@ def __repr__(self) -> str: f'context_id="{self.context_id}", status="{self.status}")>' ) + @declared_attr + @classmethod + def __table_args__(cls) -> tuple[Any, ...]: + """Define a unique index (owner, last_updated) for each table that uses the mixin.""" + tablename = getattr(cls, '__tablename__', 'tasks') + return ( + Index( + f'idx_{tablename}_owner_last_updated', 'owner', 'last_updated' + ), + ) + def create_task_model( table_name: str = 'tasks', base: type[DeclarativeBase] = Base @@ -212,6 +227,7 @@ class PushNotificationConfigMixin: task_id: Mapped[str] = mapped_column(String(36), primary_key=True) config_id: Mapped[str] = mapped_column(String(255), primary_key=True) config_data: Mapped[bytes] = mapped_column(LargeBinary, nullable=False) + owner: Mapped[str] = mapped_column(String(255), nullable=False, index=True) @override def __repr__(self) -> str: diff --git a/src/a2a/server/owner_resolver.py b/src/a2a/server/owner_resolver.py new file mode 100644 index 000000000..7c2756075 --- /dev/null +++ b/src/a2a/server/owner_resolver.py @@ -0,0 +1,18 @@ +from collections.abc import Callable + +from a2a.server.context import ServerCallContext + + +# Definition +OwnerResolver = Callable[[ServerCallContext], str] + + +# Example Default Implementation +def resolve_user_scope(context: ServerCallContext) -> str: + """Resolves the owner scope based on the user in the context.""" + if not context: + return 'unknown' + if not context.user: + raise ValueError('User not found in context.') + # Example: Basic user name. Adapt as needed for your user model. + return context.user.user_name diff --git a/src/a2a/server/tasks/database_push_notification_config_store.py b/src/a2a/server/tasks/database_push_notification_config_store.py index e125f22a1..b1a30157f 100644 --- a/src/a2a/server/tasks/database_push_notification_config_store.py +++ b/src/a2a/server/tasks/database_push_notification_config_store.py @@ -8,11 +8,7 @@ try: - from sqlalchemy import ( - Table, - delete, - select, - ) + from sqlalchemy import Table, and_, delete, select from sqlalchemy.ext.asyncio import ( AsyncEngine, AsyncSession, @@ -29,11 +25,13 @@ "or 'pip install a2a-sdk[sql]'" ) from e +from a2a.server.context import ServerCallContext from a2a.server.models import ( Base, PushNotificationConfigModel, create_push_notification_config_model, ) +from a2a.server.owner_resolver import OwnerResolver, resolve_user_scope from a2a.server.tasks.push_notification_config_store import ( PushNotificationConfigStore, ) @@ -59,6 +57,7 @@ class DatabasePushNotificationConfigStore(PushNotificationConfigStore): _initialized: bool config_model: type[PushNotificationConfigModel] _fernet: 'Fernet | None' + owner_resolver: OwnerResolver def __init__( self, @@ -66,6 +65,7 @@ def __init__( create_table: bool = True, table_name: str = 'push_notification_configs', encryption_key: str | bytes | None = None, + owner_resolver: OwnerResolver = resolve_user_scope, ) -> None: """Initializes the DatabasePushNotificationConfigStore. @@ -76,6 +76,7 @@ def __init__( encryption_key: A key for encrypting sensitive configuration data. If provided, `config_data` will be encrypted in the database. The key must be a URL-safe base64-encoded 32-byte key. + owner_resolver: Function to resolve the owner from the context. """ logger.debug( 'Initializing DatabasePushNotificationConfigStore with existing engine, table: %s', @@ -87,6 +88,7 @@ def __init__( ) self.create_table = create_table self._initialized = False + self.owner_resolver = owner_resolver self.config_model = ( PushNotificationConfigModel if table_name == 'push_notification_configs' @@ -139,7 +141,7 @@ async def _ensure_initialized(self) -> None: await self.initialize() def _to_orm( - self, task_id: str, config: PushNotificationConfig + self, task_id: str, config: PushNotificationConfig, owner: str ) -> PushNotificationConfigModel: """Maps a Pydantic PushNotificationConfig to a SQLAlchemy model instance. @@ -155,6 +157,7 @@ def _to_orm( return self.config_model( task_id=task_id, config_id=config.id, + owner=owner, config_data=data_to_store, ) @@ -223,30 +226,43 @@ def _from_orm( ) from e async def set_info( - self, task_id: str, notification_config: PushNotificationConfig + self, + task_id: str, + notification_config: PushNotificationConfig, + context: ServerCallContext | None = None, ) -> None: """Sets or updates the push notification configuration for a task.""" await self._ensure_initialized() + owner = self.owner_resolver(context) config_to_save = notification_config.model_copy() if config_to_save.id is None: config_to_save.id = task_id - db_config = self._to_orm(task_id, config_to_save) + db_config = self._to_orm(task_id, config_to_save, owner) async with self.async_session_maker.begin() as session: await session.merge(db_config) logger.debug( - 'Push notification config for task %s with config id %s saved/updated.', + 'Push notification config for task %s with config id %s for owner %s saved/updated.', task_id, config_to_save.id, + owner, ) - async def get_info(self, task_id: str) -> list[PushNotificationConfig]: - """Retrieves all push notification configurations for a task.""" + async def get_info( + self, + task_id: str, + context: ServerCallContext | None = None, + ) -> list[PushNotificationConfig]: + """Retrieves all push notification configurations for a task, for the given owner.""" await self._ensure_initialized() + owner = self.owner_resolver(context) async with self.async_session_maker() as session: stmt = select(self.config_model).where( - self.config_model.task_id == task_id + and_( + self.config_model.task_id == task_id, + self.config_model.owner == owner, + ) ) result = await session.execute(stmt) models = result.scalars().all() @@ -257,24 +273,32 @@ async def get_info(self, task_id: str) -> list[PushNotificationConfig]: configs.append(self._from_orm(model)) except ValueError: # noqa: PERF203 logger.exception( - 'Could not deserialize push notification config for task %s, config %s', + 'Could not deserialize push notification config for task %s, config %s, owner %s', model.task_id, model.config_id, + owner, ) return configs async def delete_info( - self, task_id: str, config_id: str | None = None + self, + task_id: str, + config_id: str | None = None, + context: ServerCallContext | None = None, ) -> None: """Deletes push notification configurations for a task. If config_id is provided, only that specific configuration is deleted. - If config_id is None, all configurations for the task are deleted. + If config_id is None, all configurations for the task for the owner are deleted. """ await self._ensure_initialized() + owner = self.owner_resolver(context) async with self.async_session_maker.begin() as session: stmt = delete(self.config_model).where( - self.config_model.task_id == task_id + and_( + self.config_model.task_id == task_id, + self.config_model.owner == owner, + ) ) if config_id is not None: stmt = stmt.where(self.config_model.config_id == config_id) @@ -283,13 +307,15 @@ async def delete_info( if result.rowcount > 0: logger.info( - 'Deleted %s push notification config(s) for task %s.', + 'Deleted %s push notification config(s) for task %s, owner %s.', result.rowcount, task_id, + owner, ) else: logger.warning( - 'Attempted to delete push notification config for task %s with config_id: %s that does not exist.', + 'Attempted to delete push notification config for task %s, owner %s with config_id: %s that does not exist.', task_id, + owner, config_id, ) diff --git a/src/a2a/server/tasks/database_task_store.py b/src/a2a/server/tasks/database_task_store.py index 1605c601a..503be64d2 100644 --- a/src/a2a/server/tasks/database_task_store.py +++ b/src/a2a/server/tasks/database_task_store.py @@ -30,6 +30,7 @@ from a2a.server.context import ServerCallContext from a2a.server.models import Base, TaskModel, create_task_model +from a2a.server.owner_resolver import OwnerResolver, resolve_user_scope from a2a.server.tasks.task_store import TaskStore, TasksPage from a2a.types import ListTasksParams, Task from a2a.utils.constants import DEFAULT_LIST_TASKS_PAGE_SIZE @@ -50,12 +51,14 @@ class DatabaseTaskStore(TaskStore): create_table: bool _initialized: bool task_model: type[TaskModel] + owner_resolver: OwnerResolver def __init__( self, engine: AsyncEngine, create_table: bool = True, table_name: str = 'tasks', + owner_resolver: OwnerResolver = resolve_user_scope, ) -> None: """Initializes the DatabaseTaskStore. @@ -63,6 +66,7 @@ def __init__( engine: An existing SQLAlchemy AsyncEngine to be used by Task Store create_table: If true, create tasks table on initialization. table_name: Name of the database table. Defaults to 'tasks'. + owner_resolver: Function to resolve the owner from the context. """ logger.debug( 'Initializing DatabaseTaskStore with existing engine, table: %s', @@ -74,6 +78,7 @@ def __init__( ) self.create_table = create_table self._initialized = False + self.owner_resolver = owner_resolver self.task_model = ( TaskModel @@ -104,12 +109,14 @@ async def _ensure_initialized(self) -> None: if not self._initialized: await self.initialize() - def _to_orm(self, task: Task) -> TaskModel: + def _to_orm(self, task: Task, owner: str) -> TaskModel: """Maps a Pydantic Task to a SQLAlchemy TaskModel instance.""" return self.task_model( id=task.id, context_id=task.context_id, kind=task.kind, + owner=owner, + last_updated=task.status.timestamp, status=task.status, artifacts=task.artifacts, history=task.history, @@ -123,6 +130,7 @@ def _from_orm(self, task_model: TaskModel) -> Task: 'id': task_model.id, 'context_id': task_model.context_id, 'kind': task_model.kind, + 'owner': task_model.owner, 'status': task_model.status, 'artifacts': task_model.artifacts, 'history': task_model.history, @@ -134,38 +142,60 @@ def _from_orm(self, task_model: TaskModel) -> Task: async def save( self, task: Task, context: ServerCallContext | None = None ) -> None: - """Saves or updates a task in the database.""" + """Saves or updates a task in the database for the resolved owner.""" await self._ensure_initialized() - db_task = self._to_orm(task) + owner = self.owner_resolver(context) + db_task = self._to_orm(task, owner) async with self.async_session_maker.begin() as session: await session.merge(db_task) - logger.debug('Task %s saved/updated successfully.', task.id) + logger.debug( + 'Task %s for owner %s saved/updated successfully.', + task.id, + owner, + ) async def get( self, task_id: str, context: ServerCallContext | None = None ) -> Task | None: - """Retrieves a task from the database by ID.""" + """Retrieves a task from the database by ID, for the given owner.""" await self._ensure_initialized() + owner = self.owner_resolver(context) async with self.async_session_maker() as session: - stmt = select(self.task_model).where(self.task_model.id == task_id) + stmt = select(self.task_model).where( + and_( + self.task_model.id == task_id, + self.task_model.owner == owner, + ) + ) result = await session.execute(stmt) task_model = result.scalar_one_or_none() if task_model: task = self._from_orm(task_model) - logger.debug('Task %s retrieved successfully.', task_id) + logger.debug( + 'Task %s retrieved successfully for owner %s.', + task_id, + owner, + ) return task - logger.debug('Task %s not found in store.', task_id) + logger.debug( + 'Task %s not found in store for owner %s.', task_id, owner + ) return None async def list( self, params: ListTasksParams, context: ServerCallContext | None = None ) -> TasksPage: - """Retrieves all tasks from the database.""" + """Retrieves tasks from the database based on provided parameters, for the given owner.""" await self._ensure_initialized() + owner = self.owner_resolver(context) + logger.debug('Listing tasks for owner %s with params %s', owner, params) + async with self.async_session_maker() as session: - timestamp_col = self.task_model.status['timestamp'].as_string() - base_stmt = select(self.task_model) + timestamp_col = self.task_model.last_updated + base_stmt = select(self.task_model).where( + self.task_model.owner == owner + ) # Add filters if params.context_id: @@ -202,30 +232,36 @@ async def list( start_task = ( await session.execute( select(self.task_model).where( - self.task_model.id == start_task_id + and_( + self.task_model.id == start_task_id, + self.task_model.owner == owner, + ) ) ) ).scalar_one_or_none() if not start_task: raise ValueError(f'Invalid page token: {params.page_token}') - if start_task.status.timestamp: - stmt = stmt.where( - or_( - and_( - timestamp_col == start_task.status.timestamp, - self.task_model.id <= start_task.id, - ), - timestamp_col < start_task.status.timestamp, - timestamp_col.is_(None), + + start_task_timestamp = start_task.status.timestamp + where_clauses = [] + if start_task_timestamp: + where_clauses.append( + and_( + timestamp_col == start_task_timestamp, + self.task_model.id <= start_task_id, ) ) + where_clauses.append(timestamp_col < start_task_timestamp) + where_clauses.append(timestamp_col.is_(None)) else: - stmt = stmt.where( + where_clauses.append( and_( timestamp_col.is_(None), - self.task_model.id <= start_task.id, + self.task_model.id <= start_task_id, ) ) + stmt = stmt.where(or_(*where_clauses)) + page_size = params.page_size or DEFAULT_LIST_TASKS_PAGE_SIZE stmt = stmt.limit(page_size + 1) # Add 1 for next page token @@ -248,17 +284,27 @@ async def list( async def delete( self, task_id: str, context: ServerCallContext | None = None ) -> None: - """Deletes a task from the database by ID.""" + """Deletes a task from the database by ID, for the given owner.""" await self._ensure_initialized() + owner = self.owner_resolver(context) async with self.async_session_maker.begin() as session: - stmt = delete(self.task_model).where(self.task_model.id == task_id) + stmt = delete(self.task_model).where( + and_( + self.task_model.id == task_id, + self.task_model.owner == owner, + ) + ) result = await session.execute(stmt) # Commit is automatic when using session.begin() if result.rowcount > 0: - logger.info('Task %s deleted successfully.', task_id) + logger.info( + 'Task %s deleted successfully for owner %s.', task_id, owner + ) else: logger.warning( - 'Attempted to delete nonexistent task with id: %s', task_id + 'Attempted to delete nonexistent task with id: %s and owner %s', + task_id, + owner, ) diff --git a/src/a2a/server/tasks/inmemory_task_store.py b/src/a2a/server/tasks/inmemory_task_store.py index 31d42a310..246282650 100644 --- a/src/a2a/server/tasks/inmemory_task_store.py +++ b/src/a2a/server/tasks/inmemory_task_store.py @@ -1,9 +1,11 @@ import asyncio import logging +from collections import defaultdict from datetime import datetime, timezone from a2a.server.context import ServerCallContext +from a2a.server.owner_resolver import OwnerResolver, resolve_user_scope from a2a.server.tasks.task_store import TaskStore, TasksPage from a2a.types import ListTasksParams, Task from a2a.utils.constants import DEFAULT_LIST_TASKS_PAGE_SIZE @@ -16,45 +18,70 @@ class InMemoryTaskStore(TaskStore): """In-memory implementation of TaskStore. - Stores task objects in a dictionary in memory. Task data is lost when the - server process stops. + Stores task objects in a nested dictionary in memory, keyed by owner then task_id. + Task data is lost when the server process stops. """ - def __init__(self) -> None: + def __init__( + self, + owner_resolver: OwnerResolver = resolve_user_scope, + ) -> None: """Initializes the InMemoryTaskStore.""" logger.debug('Initializing InMemoryTaskStore') - self.tasks: dict[str, Task] = {} + self.tasks: dict[str, dict[str, Task]] = defaultdict(dict) self.lock = asyncio.Lock() + self.owner_resolver = owner_resolver async def save( self, task: Task, context: ServerCallContext | None = None ) -> None: - """Saves or updates a task in the in-memory store.""" + """Saves or updates a task in the in-memory store for the resolved owner.""" + owner = self.owner_resolver(context) + async with self.lock: - self.tasks[task.id] = task - logger.debug('Task %s saved successfully.', task.id) + self.tasks[owner][task.id] = task + logger.debug( + 'Task %s for owner %s saved successfully.', task.id, owner + ) async def get( self, task_id: str, context: ServerCallContext | None = None ) -> Task | None: - """Retrieves a task from the in-memory store by ID.""" + """Retrieves a task from the in-memory store by ID, for the given owner.""" + owner = self.owner_resolver(context) async with self.lock: - logger.debug('Attempting to get task with id: %s', task_id) - task = self.tasks.get(task_id) - if task: - logger.debug('Task %s retrieved successfully.', task_id) - else: - logger.debug('Task %s not found in store.', task_id) - return task + logger.debug( + 'Attempting to get task with id: %s for owner: %s', + task_id, + owner, + ) + owner_tasks = self.tasks.get(owner) + if owner_tasks: + task = owner_tasks.get(task_id) + if task: + logger.debug( + 'Task %s retrieved successfully for owner %s.', + task_id, + owner, + ) + return task + logger.debug( + 'Task %s not found in store for owner %s.', task_id, owner + ) + return None async def list( self, params: ListTasksParams, context: ServerCallContext | None = None, ) -> TasksPage: - """Retrieves a list of tasks from the store.""" + """Retrieves a list of tasks from the store, for the given owner.""" + owner = self.owner_resolver(context) + logger.debug('Listing tasks for owner %s with params %s', owner, params) + async with self.lock: - tasks = list(self.tasks.values()) + owner_tasks = self.tasks.get(owner, {}) + tasks = list(owner_tasks.values()) # Filter tasks if params.context_id: @@ -118,13 +145,25 @@ async def list( async def delete( self, task_id: str, context: ServerCallContext | None = None ) -> None: - """Deletes a task from the in-memory store by ID.""" + """Deletes a task from the in-memory store by ID, for the given owner.""" + owner = self.owner_resolver(context) async with self.lock: - logger.debug('Attempting to delete task with id: %s', task_id) - if task_id in self.tasks: - del self.tasks[task_id] - logger.debug('Task %s deleted successfully.', task_id) + logger.debug( + 'Attempting to delete task with id: %s for owner %s', + task_id, + owner, + ) + if owner in self.tasks and task_id in self.tasks[owner]: + del self.tasks[owner][task_id] + logger.debug( + 'Task %s deleted successfully for owner %s.', task_id, owner + ) + if not self.tasks[owner]: + del self.tasks[owner] + logger.debug('Removed empty owner %s from store.', owner) else: logger.warning( - 'Attempted to delete nonexistent task with id: %s', task_id + 'Attempted to delete nonexistent task with id: %s for owner %s', + task_id, + owner, ) diff --git a/src/a2a/server/tasks/push_notification_config_store.py b/src/a2a/server/tasks/push_notification_config_store.py index efe46b40a..388d86c1e 100644 --- a/src/a2a/server/tasks/push_notification_config_store.py +++ b/src/a2a/server/tasks/push_notification_config_store.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod +from a2a.server.context import ServerCallContext from a2a.types import PushNotificationConfig @@ -8,16 +9,26 @@ class PushNotificationConfigStore(ABC): @abstractmethod async def set_info( - self, task_id: str, notification_config: PushNotificationConfig + self, + task_id: str, + notification_config: PushNotificationConfig, + context: ServerCallContext | None = None, ) -> None: """Sets or updates the push notification configuration for a task.""" @abstractmethod - async def get_info(self, task_id: str) -> list[PushNotificationConfig]: + async def get_info( + self, + task_id: str, + context: ServerCallContext | None = None, + ) -> list[PushNotificationConfig]: """Retrieves the push notification configuration for a task.""" @abstractmethod async def delete_info( - self, task_id: str, config_id: str | None = None + self, + task_id: str, + config_id: str | None = None, + context: ServerCallContext | None = None, ) -> None: """Deletes the push notification configuration for a task.""" diff --git a/tests/server/request_handlers/test_default_request_handler.py b/tests/server/request_handlers/test_default_request_handler.py index daeba947f..16c85d400 100644 --- a/tests/server/request_handlers/test_default_request_handler.py +++ b/tests/server/request_handlers/test_default_request_handler.py @@ -915,9 +915,8 @@ async def test_on_message_send_non_blocking(): ), ) - result = await request_handler.on_message_send( - params, create_server_call_context() - ) + context = create_server_call_context() + result = await request_handler.on_message_send(params, context) assert result is not None assert isinstance(result, Task) @@ -927,7 +926,7 @@ async def test_on_message_send_non_blocking(): task: Task | None = None for _ in range(5): await asyncio.sleep(0.1) - task = await task_store.get(result.id) + task = await task_store.get(result.id, context) assert task is not None if task.status.state == TaskState.completed: break @@ -964,9 +963,8 @@ async def test_on_message_send_limit_history(): ), ) - result = await request_handler.on_message_send( - params, create_server_call_context() - ) + context = create_server_call_context() + result = await request_handler.on_message_send(params, context) # verify that history_length is honored assert result is not None @@ -975,7 +973,7 @@ async def test_on_message_send_limit_history(): assert result.status.state == TaskState.completed # verify that history is still persisted to the store - task = await task_store.get(result.id) + task = await task_store.get(result.id, context) assert task is not None assert task.history is not None and len(task.history) > 1 diff --git a/tests/server/tasks/test_database_push_notification_config_store.py b/tests/server/tasks/test_database_push_notification_config_store.py index 0c3bd4683..26d968912 100644 --- a/tests/server/tasks/test_database_push_notification_config_store.py +++ b/tests/server/tasks/test_database_push_notification_config_store.py @@ -3,6 +3,8 @@ from collections.abc import AsyncGenerator import pytest +from a2a.server.context import ServerCallContext +from a2a.auth.user import User # Skip entire test module if SQLAlchemy is not installed @@ -94,6 +96,21 @@ ) +class TestUser(User): + """A test implementation of the User interface.""" + + def __init__(self, user_name: str): + self._user_name = user_name + + @property + def is_authenticated(self) -> bool: + return True + + @property + def user_name(self) -> str: + return self._user_name + + @pytest_asyncio.fixture(params=DB_CONFIGS) async def db_store_parameterized( request, @@ -547,6 +564,7 @@ async def test_parsing_error_after_successful_decryption( task_id=task_id, config_id=config_id, config_data=encrypted_data, + owner='test-owner', ) session.add(db_model) await session.commit() @@ -563,3 +581,74 @@ async def test_parsing_error_after_successful_decryption( with pytest.raises(ValueError): db_store_parameterized._from_orm(db_model_retrieved) # type: ignore + + +@pytest.mark.asyncio +async def test_owner_resource_scoping( + db_store_parameterized: DatabasePushNotificationConfigStore, +) -> None: + """Test that operations are scoped to the correct owner.""" + config_store = db_store_parameterized + + context_user1 = ServerCallContext(user=TestUser(user_name='user1')) + context_user2 = ServerCallContext(user=TestUser(user_name='user2')) + + # Create configs for different owners + task1_u1_config1 = PushNotificationConfig( + id='t1-u1-c1', url='http://u1.com/1' + ) + task1_u1_config2 = PushNotificationConfig( + id='t1-u1-c2', url='http://u1.com/2' + ) + task1_u2_config1 = PushNotificationConfig( + id='t1-u2-c1', url='http://u2.com/1' + ) + task2_u1_config1 = PushNotificationConfig( + id='t2-u1-c1', url='http://u1.com/3' + ) + + await config_store.set_info('task1', task1_u1_config1, context_user1) + await config_store.set_info('task1', task1_u1_config2, context_user1) + await config_store.set_info('task1', task1_u2_config1, context_user2) + await config_store.set_info('task2', task2_u1_config1, context_user1) + + # Test GET_INFO + # User 1 should get only their configs for task1 + u1_task1_configs = await config_store.get_info('task1', context_user1) + assert len(u1_task1_configs) == 2 + assert {c.id for c in u1_task1_configs} == {'t1-u1-c1', 't1-u1-c2'} + + # User 2 should get only their configs for task1 + u2_task1_configs = await config_store.get_info('task1', context_user2) + assert len(u2_task1_configs) == 1 + assert u2_task1_configs[0].id == 't1-u2-c1' + + # User 2 should get no configs for task2 + u2_task2_configs = await config_store.get_info('task2', context_user2) + assert len(u2_task2_configs) == 0 + + # User 1 should get their config for task2 + u1_task2_configs = await config_store.get_info('task2', context_user1) + assert len(u1_task2_configs) == 1 + assert u1_task2_configs[0].id == 't2-u1-c1' + + # Test DELETE_INFO + # User 2 deleting User 1's config should not work + await config_store.delete_info('task1', 't1-u1-c1', context_user2) + u1_task1_configs = await config_store.get_info('task1', context_user1) + assert len(u1_task1_configs) == 2 + + # User 1 deleting their own config + await config_store.delete_info('task1', 't1-u1-c1', context_user1) + u1_task1_configs = await config_store.get_info('task1', context_user1) + assert len(u1_task1_configs) == 1 + assert u1_task1_configs[0].id == 't1-u1-c2' + + # User 1 deleting all configs for task2 + await config_store.delete_info('task2', context=context_user1) + u1_task2_configs = await config_store.get_info('task2', context_user1) + assert len(u1_task2_configs) == 0 + + # Cleanup remaining + await config_store.delete_info('task1', context=context_user1) + await config_store.delete_info('task1', context=context_user2) diff --git a/tests/server/tasks/test_database_task_store.py b/tests/server/tasks/test_database_task_store.py index 495d2e4fd..5c35b391a 100644 --- a/tests/server/tasks/test_database_task_store.py +++ b/tests/server/tasks/test_database_task_store.py @@ -28,6 +28,23 @@ TaskStatus, TextPart, ) +from a2a.auth.user import User +from a2a.server.context import ServerCallContext + + +class TestUser(User): + """A test implementation of the User interface.""" + + def __init__(self, user_name: str): + self._user_name = user_name + + @property + def is_authenticated(self) -> bool: + return True + + @property + def user_name(self) -> str: + return self._user_name # DSNs for different databases @@ -608,4 +625,53 @@ async def test_metadata_field_mapping( await db_store_parameterized.delete('task-metadata-test-4') +@pytest.mark.asyncio +async def test_owner_resource_scoping( + db_store_parameterized: DatabaseTaskStore, +) -> None: + """Test that operations are scoped to the correct owner.""" + task_store = db_store_parameterized + + context_user1 = ServerCallContext(user=TestUser(user_name='user1')) + context_user2 = ServerCallContext(user=TestUser(user_name='user2')) + + # Create tasks for different owners + task1_user1 = MINIMAL_TASK_OBJ.model_copy(update={'id': 'u1-task1'}) + task2_user1 = MINIMAL_TASK_OBJ.model_copy(update={'id': 'u1-task2'}) + task1_user2 = MINIMAL_TASK_OBJ.model_copy(update={'id': 'u2-task1'}) + + await task_store.save(task1_user1, context_user1) + await task_store.save(task2_user1, context_user1) + await task_store.save(task1_user2, context_user2) + + # Test GET + assert await task_store.get('u1-task1', context_user1) is not None + assert await task_store.get('u1-task1', context_user2) is None + assert await task_store.get('u2-task1', context_user1) is None + assert await task_store.get('u2-task1', context_user2) is not None + + # Test LIST + params = ListTasksParams() + page_user1 = await task_store.list(params, context_user1) + assert len(page_user1.tasks) == 2 + assert {t.id for t in page_user1.tasks} == {'u1-task1', 'u1-task2'} + assert page_user1.total_size == 2 + + page_user2 = await task_store.list(params, context_user2) + assert len(page_user2.tasks) == 1 + assert {t.id for t in page_user2.tasks} == {'u2-task1'} + assert page_user2.total_size == 1 + + # Test DELETE + await task_store.delete('u1-task1', context_user2) # Should not delete + assert await task_store.get('u1-task1', context_user1) is not None + + await task_store.delete('u1-task1', context_user1) # Should delete + assert await task_store.get('u1-task1', context_user1) is None + + # Cleanup remaining tasks + await task_store.delete('u1-task2', context_user1) + await task_store.delete('u2-task1', context_user2) + + # Ensure aiosqlite, asyncpg, and aiomysql are installed in the test environment (added to pyproject.toml). diff --git a/tests/server/tasks/test_inmemory_task_store.py b/tests/server/tasks/test_inmemory_task_store.py index ee91b9261..2fd77e0b0 100644 --- a/tests/server/tasks/test_inmemory_task_store.py +++ b/tests/server/tasks/test_inmemory_task_store.py @@ -1,9 +1,26 @@ from typing import Any +from a2a.server.context import ServerCallContext import pytest from a2a.server.tasks import InMemoryTaskStore from a2a.types import ListTasksParams, Task, TaskState, TaskStatus +from a2a.auth.user import User + + +class TestUser(User): + """A test implementation of the User interface.""" + + def __init__(self, user_name: str): + self._user_name = user_name + + @property + def is_authenticated(self) -> bool: + return True + + @property + def user_name(self) -> str: + return self._user_name MINIMAL_TASK: dict[str, Any] = { @@ -259,3 +276,51 @@ async def test_in_memory_task_store_delete_nonexistent() -> None: """Test deleting a nonexistent task.""" store = InMemoryTaskStore() await store.delete('nonexistent') + + +@pytest.mark.asyncio +async def test_owner_resource_scoping() -> None: + """Test that operations are scoped to the correct owner.""" + store = InMemoryTaskStore() + task = Task(**MINIMAL_TASK) + + context_user1 = ServerCallContext(user=TestUser(user_name='user1')) + context_user2 = ServerCallContext(user=TestUser(user_name='user2')) + + # Create tasks for different owners + task1_user1 = task.model_copy(update={'id': 'u1-task1'}) + task2_user1 = task.model_copy(update={'id': 'u1-task2'}) + task1_user2 = task.model_copy(update={'id': 'u2-task1'}) + + await store.save(task1_user1, context_user1) + await store.save(task2_user1, context_user1) + await store.save(task1_user2, context_user2) + + # Test GET + assert await store.get('u1-task1', context_user1) is not None + assert await store.get('u1-task1', context_user2) is None + assert await store.get('u2-task1', context_user1) is None + assert await store.get('u2-task1', context_user2) is not None + + # Test LIST + params = ListTasksParams() + page_user1 = await store.list(params, context_user1) + assert len(page_user1.tasks) == 2 + assert {t.id for t in page_user1.tasks} == {'u1-task1', 'u1-task2'} + assert page_user1.total_size == 2 + + page_user2 = await store.list(params, context_user2) + assert len(page_user2.tasks) == 1 + assert {t.id for t in page_user2.tasks} == {'u2-task1'} + assert page_user2.total_size == 1 + + # Test DELETE + await store.delete('u1-task1', context_user2) # Should not delete + assert await store.get('u1-task1', context_user1) is not None + + await store.delete('u1-task1', context_user1) # Should delete + assert await store.get('u1-task1', context_user1) is None + + # Cleanup remaining tasks + await store.delete('u1-task2', context_user1) + await store.delete('u2-task1', context_user2) diff --git a/tests/server/test_owner_resolver.py b/tests/server/test_owner_resolver.py new file mode 100644 index 000000000..8a0686865 --- /dev/null +++ b/tests/server/test_owner_resolver.py @@ -0,0 +1,31 @@ +from a2a.auth.user import User + +from a2a.server.context import ServerCallContext +from a2a.server.owner_resolver import resolve_user_scope + + +class TestUser(User): + """A test implementation of the User interface.""" + + def __init__(self, user_name: str): + self._user_name = user_name + + @property + def is_authenticated(self) -> bool: + return True + + @property + def user_name(self) -> str: + return self._user_name + + +def test_resolve_user_scope_valid_user(): + """Test resolve_user_scope with a valid user in the context.""" + user = TestUser(user_name='testuser') + context = ServerCallContext(user=user) + assert resolve_user_scope(context) == 'testuser' + + +def test_resolve_user_scope_no_context(): + """Test resolve_user_scope when the context is None.""" + assert resolve_user_scope(None) == 'unknown' From 6093f7f197e2a1ab0fce09c32befe6c9dd24583c Mon Sep 17 00:00:00 2001 From: sokoliva Date: Wed, 18 Feb 2026 16:26:29 +0000 Subject: [PATCH 02/29] fix: add poolclass to allow.txt --- .github/actions/spelling/allow.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/actions/spelling/allow.txt b/.github/actions/spelling/allow.txt index 8d0b13c8c..cda0a4b3d 100644 --- a/.github/actions/spelling/allow.txt +++ b/.github/actions/spelling/allow.txt @@ -69,6 +69,7 @@ oauthoidc oidc opensource otherurl +poolclass postgres POSTGRES postgresql From 6600b4719076e3a96b43a113c29f874006adc80b Mon Sep 17 00:00:00 2001 From: sokoliva Date: Thu, 19 Feb 2026 09:58:44 +0000 Subject: [PATCH 03/29] fix: test_inmemory_task_store.py merge caused error --- tests/server/tasks/test_inmemory_task_store.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/tests/server/tasks/test_inmemory_task_store.py b/tests/server/tasks/test_inmemory_task_store.py index ed30f2356..97befb755 100644 --- a/tests/server/tasks/test_inmemory_task_store.py +++ b/tests/server/tasks/test_inmemory_task_store.py @@ -269,15 +269,23 @@ async def test_in_memory_task_store_delete_nonexistent() -> None: async def test_owner_resource_scoping() -> None: """Test that operations are scoped to the correct owner.""" store = InMemoryTaskStore() - task = Task(**MINIMAL_TASK) + task = create_minimal_task() context_user1 = ServerCallContext(user=TestUser(user_name='user1')) context_user2 = ServerCallContext(user=TestUser(user_name='user2')) # Create tasks for different owners - task1_user1 = task.model_copy(update={'id': 'u1-task1'}) - task2_user1 = task.model_copy(update={'id': 'u1-task2'}) - task1_user2 = task.model_copy(update={'id': 'u2-task1'}) + task1_user1 = Task() + task1_user1.CopyFrom(task) + task1_user1.id = 'u1-task1' + + task2_user1 = Task() + task2_user1.CopyFrom(task) + task2_user1.id = 'u1-task2' + + task1_user2 = Task() + task1_user2.CopyFrom(task) + task1_user2.id = 'u2-task1' await store.save(task1_user1, context_user1) await store.save(task2_user1, context_user1) From ea89bbbe6296ee180fa2bfa43262d074b13a88e6 Mon Sep 17 00:00:00 2001 From: sokoliva Date: Thu, 19 Feb 2026 10:47:12 +0000 Subject: [PATCH 04/29] fix: - add alembic to dev field in pyproject.toml - fix elmbic README.md error - make ServerCallContext optional in OwnerResolver --- alembic/README | 3 +++ pyproject.toml | 1 + src/a2a/server/owner_resolver.py | 2 +- 3 files changed, 5 insertions(+), 1 deletion(-) diff --git a/alembic/README b/alembic/README index 0c6d7dba1..06ec9e9a8 100644 --- a/alembic/README +++ b/alembic/README @@ -27,6 +27,9 @@ uv run alembic history --verbose uv run alembic upgrade head # Downgrade by one version +uv run alembic downgrade -1 + +# Revert all migrations uv run alembic downgrade base ``` diff --git a/pyproject.toml b/pyproject.toml index 6be567814..7e3b6a2f7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -106,6 +106,7 @@ style = "pep440" [dependency-groups] dev = [ + "alembic>=1.14.0", "mypy>=1.15.0", "PyJWT>=2.0.0", "pytest>=8.3.5", diff --git a/src/a2a/server/owner_resolver.py b/src/a2a/server/owner_resolver.py index 6c50cd79f..4fa310b92 100644 --- a/src/a2a/server/owner_resolver.py +++ b/src/a2a/server/owner_resolver.py @@ -4,7 +4,7 @@ # Definition -OwnerResolver = Callable[[ServerCallContext], str] +OwnerResolver = Callable[[ServerCallContext | None], str] # Example Default Implementation From 9301b8c77b4fea37b506deb487164007d9029501 Mon Sep 17 00:00:00 2001 From: sokoliva Date: Thu, 19 Feb 2026 10:55:14 +0000 Subject: [PATCH 05/29] fix: update uv.lock --- uv.lock | 37 +++++++++++++++++++++++++++++++++---- 1 file changed, 33 insertions(+), 4 deletions(-) diff --git a/uv.lock b/uv.lock index 2cecfc177..748ef3ee6 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.10" resolution-markers = [ "python_full_version >= '3.14'", @@ -70,6 +70,7 @@ telemetry = [ [package.dev-dependencies] dev = [ { name = "a2a-sdk", extra = ["all"] }, + { name = "alembic" }, { name = "autoflake" }, { name = "mypy" }, { name = "no-implicit-optional" }, @@ -135,6 +136,7 @@ provides-extras = ["all", "encryption", "grpc", "http-server", "mysql", "postgre [package.metadata.requires-dev] dev = [ { name = "a2a-sdk", extras = ["all"], editable = "." }, + { name = "alembic", specifier = ">=1.14.0" }, { name = "autoflake" }, { name = "mypy", specifier = ">=1.15.0" }, { name = "no-implicit-optional" }, @@ -177,6 +179,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/00/b7/e3bf5133d697a08128598c8d0abc5e16377b51465a33756de24fa7dee953/aiosqlite-0.22.1-py3-none-any.whl", hash = "sha256:21c002eb13823fad740196c5a2e9d8e62f6243bd9e7e4a1f87fb5e44ecb4fceb", size = 17405, upload-time = "2025-12-23T19:25:42.139Z" }, ] +[[package]] +name = "alembic" +version = "1.18.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "mako" }, + { name = "sqlalchemy" }, + { name = "tomli", marker = "python_full_version < '3.11'" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/94/13/8b084e0f2efb0275a1d534838844926f798bd766566b1375174e2448cd31/alembic-1.18.4.tar.gz", hash = "sha256:cb6e1fd84b6174ab8dbb2329f86d631ba9559dd78df550b57804d607672cedbc", size = 2056725, upload-time = "2026-02-10T16:00:47.195Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d2/29/6533c317b74f707ea28f8d633734dbda2119bbadfc61b2f3640ba835d0f7/alembic-1.18.4-py3-none-any.whl", hash = "sha256:a5ed4adcf6d8a4cb575f3d759f071b03cd6e5c7618eb796cb52497be25bfe19a", size = 263893, upload-time = "2026-02-10T16:00:49.997Z" }, +] + [[package]] name = "annotated-doc" version = "0.0.4" @@ -1277,6 +1294,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/94/d1/433b3c06e78f23486fe4fdd19bc134657eb30997d2054b0dbf52bbf3382e/librt-0.8.0-cp314-cp314t-win_arm64.whl", hash = "sha256:92249938ab744a5890580d3cb2b22042f0dce71cdaa7c1369823df62bedf7cbc", size = 48753, upload-time = "2026-02-12T14:53:38.539Z" }, ] +[[package]] +name = "mako" +version = "1.3.10" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "markupsafe" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/9e/38/bd5b78a920a64d708fe6bc8e0a2c075e1389d53bef8413725c63ba041535/mako-1.3.10.tar.gz", hash = "sha256:99579a6f39583fa7e5630a28c3c1f440e4e97a414b80372649c0ce338da2ea28", size = 392474, upload-time = "2025-04-10T12:44:31.16Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/87/fb/99f81ac72ae23375f22b7afdb7642aba97c00a713c217124420147681a2f/mako-1.3.10-py3-none-any.whl", hash = "sha256:baef24a52fc4fc514a0887ac600f9f1cff3d82c61d4d700a1fa84d597b88db59", size = 78509, upload-time = "2025-04-10T12:50:53.297Z" }, +] + [[package]] name = "markupsafe" version = "3.0.3" @@ -2323,7 +2352,7 @@ wheels = [ [[package]] name = "virtualenv" -version = "20.37.0" +version = "20.38.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "distlib" }, @@ -2331,9 +2360,9 @@ dependencies = [ { name = "platformdirs" }, { name = "typing-extensions", marker = "python_full_version < '3.11'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/c1/ef/d9d4ce633df789bf3430bd81fb0d8b9d9465dfc1d1f0deb3fb62cd80f5c2/virtualenv-20.37.0.tar.gz", hash = "sha256:6f7e2064ed470aa7418874e70b6369d53b66bcd9e9fd5389763e96b6c94ccb7c", size = 5864710, upload-time = "2026-02-16T16:17:59.42Z" } +sdist = { url = "https://files.pythonhosted.org/packages/d2/03/a94d404ca09a89a7301a7008467aed525d4cdeb9186d262154dd23208709/virtualenv-20.38.0.tar.gz", hash = "sha256:94f39b1abaea5185bf7ea5a46702b56f1d0c9aa2f41a6c2b8b0af4ddc74c10a7", size = 5864558, upload-time = "2026-02-19T07:48:02.385Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/42/4b/6cf85b485be7ec29db837ec2a1d8cd68bc1147b1abf23d8636c5bd65b3cc/virtualenv-20.37.0-py3-none-any.whl", hash = "sha256:5d3951c32d57232ae3569d4de4cc256c439e045135ebf43518131175d9be435d", size = 5837480, upload-time = "2026-02-16T16:17:57.341Z" }, + { url = "https://files.pythonhosted.org/packages/42/d7/394801755d4c8684b655d35c665aea7836ec68320304f62ab3c94395b442/virtualenv-20.38.0-py3-none-any.whl", hash = "sha256:d6e78e5889de3a4742df2d3d44e779366325a90cf356f15621fddace82431794", size = 5837778, upload-time = "2026-02-19T07:47:59.778Z" }, ] [[package]] From 62dce316c78fba89d6a5bc920372799528b1c60f Mon Sep 17 00:00:00 2001 From: sokoliva Date: Thu, 19 Feb 2026 14:36:51 +0000 Subject: [PATCH 06/29] fix: - remove redundant 'index=True' in owner field declaration - add owner resource scoping to `InMemoryPushNotificationConfigStore` and a related unit test --- src/a2a/server/models.py | 6 +- ...inmemory_push_notification_config_store.py | 117 +++++++++++--- .../server/tasks/test_database_task_store.py | 2 +- .../tasks/test_inmemory_push_notifications.py | 151 ++++++++++++++---- .../server/tasks/test_inmemory_task_store.py | 6 +- 5 files changed, 222 insertions(+), 60 deletions(-) diff --git a/src/a2a/server/models.py b/src/a2a/server/models.py index b1e013e6b..a7e80d81c 100644 --- a/src/a2a/server/models.py +++ b/src/a2a/server/models.py @@ -148,7 +148,7 @@ class TaskMixin: kind: Mapped[str] = mapped_column( String(16), nullable=False, default='task' ) - owner: Mapped[str] = mapped_column(String(255), nullable=False, index=True) + owner: Mapped[str] = mapped_column(String(255), nullable=False) last_updated: Mapped[str] = mapped_column(String(22), nullable=True) # Properly typed Pydantic fields with automatic serialization @@ -175,10 +175,10 @@ def __repr__(self) -> str: f'context_id="{self.context_id}", status="{self.status}")>' ) - @declared_attr + @declared_attr.directive @classmethod def __table_args__(cls) -> tuple[Any, ...]: - """Define a unique index (owner, last_updated) for each table that uses the mixin.""" + """Define a composite index (owner, last_updated) for each table that uses the mixin.""" tablename = getattr(cls, '__tablename__', 'tasks') return ( Index( diff --git a/src/a2a/server/tasks/inmemory_push_notification_config_store.py b/src/a2a/server/tasks/inmemory_push_notification_config_store.py index 707156593..54d6e1894 100644 --- a/src/a2a/server/tasks/inmemory_push_notification_config_store.py +++ b/src/a2a/server/tasks/inmemory_push_notification_config_store.py @@ -1,6 +1,10 @@ import asyncio import logging +from collections import defaultdict + +from a2a.server.context import ServerCallContext +from a2a.server.owner_resolver import OwnerResolver, resolve_user_scope from a2a.server.tasks.push_notification_config_store import ( PushNotificationConfigStore, ) @@ -13,56 +17,117 @@ class InMemoryPushNotificationConfigStore(PushNotificationConfigStore): """In-memory implementation of PushNotificationConfigStore interface. - Stores push notification configurations in memory + Stores push notification configurations in a nested dictionary in memory, + keyed by owner then task_id. """ - def __init__(self) -> None: + def __init__( + self, + owner_resolver: OwnerResolver = resolve_user_scope, + ) -> None: """Initializes the InMemoryPushNotificationConfigStore.""" self.lock = asyncio.Lock() self._push_notification_infos: dict[ - str, list[PushNotificationConfig] - ] = {} + str, dict[str, list[PushNotificationConfig]] + ] = defaultdict(dict) + self.owner_resolver = owner_resolver async def set_info( - self, task_id: str, notification_config: PushNotificationConfig + self, + task_id: str, + notification_config: PushNotificationConfig, + context: ServerCallContext | None = None, ) -> None: """Sets or updates the push notification configuration for a task in memory.""" + owner = self.owner_resolver(context) async with self.lock: - if task_id not in self._push_notification_infos: - self._push_notification_infos[task_id] = [] + owner_infos = self._push_notification_infos[owner] + if task_id not in owner_infos: + owner_infos[task_id] = [] if not notification_config.id: notification_config.id = task_id - for config in self._push_notification_infos[task_id]: + # Remove existing config with the same ID + for config in owner_infos[task_id]: if config.id == notification_config.id: - self._push_notification_infos[task_id].remove(config) + owner_infos[task_id].remove(config) break - self._push_notification_infos[task_id].append(notification_config) - - async def get_info(self, task_id: str) -> list[PushNotificationConfig]: - """Retrieves the push notification configuration for a task from memory.""" + owner_infos[task_id].append(notification_config) + logger.debug( + 'Push notification config for task %s with config id %s for owner %s saved/updated.', + task_id, + notification_config.id, + owner, + ) + + async def get_info( + self, + task_id: str, + context: ServerCallContext | None = None, + ) -> list[PushNotificationConfig]: + """Retrieves all push notification configurations for a task from memory, for the given owner.""" + owner = self.owner_resolver(context) async with self.lock: - return self._push_notification_infos.get(task_id) or [] + owner_infos = self._push_notification_infos.get(owner) + if owner_infos: + return list(owner_infos.get(task_id, [])) + return [] async def delete_info( - self, task_id: str, config_id: str | None = None + self, + task_id: str, + config_id: str | None = None, + context: ServerCallContext | None = None, ) -> None: - """Deletes the push notification configuration for a task from memory.""" - async with self.lock: - if config_id is None: - config_id = task_id + """Deletes push notification configurations for a task from memory. - if task_id in self._push_notification_infos: - configurations = self._push_notification_infos[task_id] - if not configurations: - return + If config_id is provided, only that specific configuration is deleted. + If config_id is None, all configurations for the task for the owner are deleted. + """ + owner = self.owner_resolver(context) + async with self.lock: + owner_infos = self._push_notification_infos.get(owner) + if not owner_infos or task_id not in owner_infos: + logger.warning( + 'Attempted to delete push notification config for task %s, owner %s that does not exist.', + task_id, + owner, + ) + return + if config_id is None: + del owner_infos[task_id] + logger.info( + 'Deleted all push notification configs for task %s, owner %s.', + task_id, + owner, + ) + else: + configurations = owner_infos[task_id] + found = False for config in configurations: if config.id == config_id: configurations.remove(config) + found = True break - - if len(configurations) == 0: - del self._push_notification_infos[task_id] + if found: + logger.info( + 'Deleted push notification config %s for task %s, owner %s.', + config_id, + task_id, + owner, + ) + if len(configurations) == 0: + del owner_infos[task_id] + else: + logger.warning( + 'Attempted to delete push notification config %s for task %s, owner %s that does not exist.', + config_id, + task_id, + owner, + ) + + if not owner_infos: + del self._push_notification_infos[owner] diff --git a/tests/server/tasks/test_database_task_store.py b/tests/server/tasks/test_database_task_store.py index e1396d082..e6b67701c 100644 --- a/tests/server/tasks/test_database_task_store.py +++ b/tests/server/tasks/test_database_task_store.py @@ -33,6 +33,7 @@ ) from a2a.auth.user import User from a2a.server.context import ServerCallContext +from a2a.utils.constants import DEFAULT_LIST_TASKS_PAGE_SIZE class TestUser(User): @@ -48,7 +49,6 @@ def is_authenticated(self) -> bool: @property def user_name(self) -> str: return self._user_name -from a2a.utils.constants import DEFAULT_LIST_TASKS_PAGE_SIZE # DSNs for different databases diff --git a/tests/server/tasks/test_inmemory_push_notifications.py b/tests/server/tasks/test_inmemory_push_notifications.py index bbb01de2c..f1de00782 100644 --- a/tests/server/tasks/test_inmemory_push_notifications.py +++ b/tests/server/tasks/test_inmemory_push_notifications.py @@ -5,6 +5,8 @@ import httpx from google.protobuf.json_format import MessageToDict +from a2a.auth.user import User +from a2a.server.context import ServerCallContext from a2a.server.tasks.base_push_notification_sender import ( BasePushNotificationSender, ) @@ -43,6 +45,21 @@ def create_sample_push_config( return PushNotificationConfig(id=config_id, url=url, token=token) +class SampleUser(User): + """A test implementation of the User interface.""" + + def __init__(self, user_name: str): + self._user_name = user_name + + @property + def is_authenticated(self) -> bool: + return True + + @property + def user_name(self) -> str: + return self._user_name + + class TestInMemoryPushNotifier(unittest.IsolatedAsyncioTestCase): def setUp(self) -> None: self.mock_httpx_client = AsyncMock(spec=httpx.AsyncClient) @@ -60,10 +77,8 @@ async def test_set_info_adds_new_config(self) -> None: await self.config_store.set_info(task_id, config) - self.assertIn(task_id, self.config_store._push_notification_infos) - self.assertEqual( - self.config_store._push_notification_infos[task_id], [config] - ) + retrieved = await self.config_store.get_info(task_id) + self.assertEqual(retrieved, [config]) async def test_set_info_appends_to_existing_config(self) -> None: task_id = 'task_update' @@ -77,15 +92,10 @@ async def test_set_info_appends_to_existing_config(self) -> None: ) await self.config_store.set_info(task_id, updated_config) - self.assertIn(task_id, self.config_store._push_notification_infos) - self.assertEqual( - self.config_store._push_notification_infos[task_id][0], - initial_config, - ) - self.assertEqual( - self.config_store._push_notification_infos[task_id][1], - updated_config, - ) + retrieved = await self.config_store.get_info(task_id) + self.assertEqual(len(retrieved), 2) + self.assertEqual(retrieved[0], initial_config) + self.assertEqual(retrieved[1], updated_config) async def test_set_info_without_config_id(self) -> None: task_id = 'task1' @@ -94,21 +104,17 @@ async def test_set_info_without_config_id(self) -> None: ) await self.config_store.set_info(task_id, initial_config) - assert ( - self.config_store._push_notification_infos[task_id][0].id == task_id - ) + retrieved = await self.config_store.get_info(task_id) + assert retrieved[0].id == task_id updated_config = PushNotificationConfig( url='http://initial.url/callback_new' ) await self.config_store.set_info(task_id, updated_config) - self.assertIn(task_id, self.config_store._push_notification_infos) - assert len(self.config_store._push_notification_infos[task_id]) == 1 - self.assertEqual( - self.config_store._push_notification_infos[task_id][0].url, - updated_config.url, - ) + retrieved = await self.config_store.get_info(task_id) + assert len(retrieved) == 1 + self.assertEqual(retrieved[0].url, updated_config.url) async def test_get_info_existing_config(self) -> None: task_id = 'task_get_exist' @@ -128,9 +134,12 @@ async def test_delete_info_existing_config(self) -> None: config = create_sample_push_config(url='http://delete.this/callback') await self.config_store.set_info(task_id, config) - self.assertIn(task_id, self.config_store._push_notification_infos) + retrieved = await self.config_store.get_info(task_id) + self.assertEqual(len(retrieved), 1) + await self.config_store.delete_info(task_id, config_id=config.id) - self.assertNotIn(task_id, self.config_store._push_notification_infos) + retrieved = await self.config_store.get_info(task_id) + self.assertEqual(len(retrieved), 0) async def test_delete_info_non_existent_config(self) -> None: task_id = 'task_delete_non_exist' @@ -141,9 +150,8 @@ async def test_delete_info_non_existent_config(self) -> None: self.fail( f'delete_info raised {e} unexpectedly for nonexistent task_id' ) - self.assertNotIn( - task_id, self.config_store._push_notification_infos - ) # Should still not be there + retrieved = await self.config_store.get_info(task_id) + self.assertEqual(len(retrieved), 0) async def test_send_notification_success(self) -> None: task_id = 'task_send_success' @@ -295,6 +303,95 @@ async def test_send_notification_with_auth( ) # auth is not passed by current implementation mock_response.raise_for_status.assert_called_once() + async def test_owner_resource_scoping(self) -> None: + """Test that operations are scoped to the correct owner.""" + context_user1 = ServerCallContext(user=SampleUser(user_name='user1')) + context_user2 = ServerCallContext(user=SampleUser(user_name='user2')) + + # Create configs for different owners + task1_u1_config1 = PushNotificationConfig( + id='t1-u1-c1', url='http://u1.com/1' + ) + task1_u1_config2 = PushNotificationConfig( + id='t1-u1-c2', url='http://u1.com/2' + ) + task1_u2_config1 = PushNotificationConfig( + id='t1-u2-c1', url='http://u2.com/1' + ) + task2_u1_config1 = PushNotificationConfig( + id='t2-u1-c1', url='http://u1.com/3' + ) + + await self.config_store.set_info( + 'task1', task1_u1_config1, context_user1 + ) + await self.config_store.set_info( + 'task1', task1_u1_config2, context_user1 + ) + await self.config_store.set_info( + 'task1', task1_u2_config1, context_user2 + ) + await self.config_store.set_info( + 'task2', task2_u1_config1, context_user1 + ) + + # Test GET_INFO + # User 1 should get only their configs for task1 + u1_task1_configs = await self.config_store.get_info( + 'task1', context_user1 + ) + self.assertEqual(len(u1_task1_configs), 2) + self.assertEqual( + {c.id for c in u1_task1_configs}, {'t1-u1-c1', 't1-u1-c2'} + ) + + # User 2 should get only their configs for task1 + u2_task1_configs = await self.config_store.get_info( + 'task1', context_user2 + ) + self.assertEqual(len(u2_task1_configs), 1) + self.assertEqual(u2_task1_configs[0].id, 't1-u2-c1') + + # User 2 should get no configs for task2 + u2_task2_configs = await self.config_store.get_info( + 'task2', context_user2 + ) + self.assertEqual(len(u2_task2_configs), 0) + + # User 1 should get their config for task2 + u1_task2_configs = await self.config_store.get_info( + 'task2', context_user1 + ) + self.assertEqual(len(u1_task2_configs), 1) + self.assertEqual(u1_task2_configs[0].id, 't2-u1-c1') + + # Test DELETE_INFO + # User 2 deleting User 1's config should not work + await self.config_store.delete_info('task1', 't1-u1-c1', context_user2) + u1_task1_configs = await self.config_store.get_info( + 'task1', context_user1 + ) + self.assertEqual(len(u1_task1_configs), 2) + + # User 1 deleting their own config + await self.config_store.delete_info('task1', 't1-u1-c1', context_user1) + u1_task1_configs = await self.config_store.get_info( + 'task1', context_user1 + ) + self.assertEqual(len(u1_task1_configs), 1) + self.assertEqual(u1_task1_configs[0].id, 't1-u1-c2') + + # User 1 deleting all configs for task2 + await self.config_store.delete_info('task2', context=context_user1) + u1_task2_configs = await self.config_store.get_info( + 'task2', context_user1 + ) + self.assertEqual(len(u1_task2_configs), 0) + + # Cleanup remaining + await self.config_store.delete_info('task1', context=context_user1) + await self.config_store.delete_info('task1', context=context_user2) + if __name__ == '__main__': unittest.main() diff --git a/tests/server/tasks/test_inmemory_task_store.py b/tests/server/tasks/test_inmemory_task_store.py index 8f6849e7a..f6093b64e 100644 --- a/tests/server/tasks/test_inmemory_task_store.py +++ b/tests/server/tasks/test_inmemory_task_store.py @@ -9,7 +9,7 @@ from a2a.auth.user import User -class TestUser(User): +class SampleUser(User): """A test implementation of the User interface.""" def __init__(self, user_name: str): @@ -273,8 +273,8 @@ async def test_owner_resource_scoping() -> None: store = InMemoryTaskStore() task = create_minimal_task() - context_user1 = ServerCallContext(user=TestUser(user_name='user1')) - context_user2 = ServerCallContext(user=TestUser(user_name='user2')) + context_user1 = ServerCallContext(user=SampleUser(user_name='user1')) + context_user2 = ServerCallContext(user=SampleUser(user_name='user2')) # Create tasks for different owners task1_user1 = Task() From 99cf89ff51704214b8e6928f44af2d2d777e3485 Mon Sep 17 00:00:00 2001 From: sokoliva Date: Thu, 19 Feb 2026 15:34:46 +0000 Subject: [PATCH 07/29] fix: fix linter issues --- alembic/__init__.py | 1 + alembic/env.py | 17 +++++++++++++---- .../versions/6419d2d130f6_add_owner_to_task.py | 2 +- 3 files changed, 15 insertions(+), 5 deletions(-) create mode 100644 alembic/__init__.py diff --git a/alembic/__init__.py b/alembic/__init__.py new file mode 100644 index 000000000..7b55fb93e --- /dev/null +++ b/alembic/__init__.py @@ -0,0 +1 @@ +"Alembic database migration package." diff --git a/alembic/env.py b/alembic/env.py index d541fe140..dcc644655 100644 --- a/alembic/env.py +++ b/alembic/env.py @@ -23,7 +23,7 @@ # other values from the config, defined by the needs of env.py, # can be acquired: -# my_important_option = config.get_main_option("my_important_option") +# my_important_option = config.get_main_option("my_important_option") # noqa: ERA001 # ... etc. @@ -51,14 +51,23 @@ def run_migrations_offline() -> None: context.run_migrations() -def do_run_migrations(connection): +def do_run_migrations(connection) -> None: + """Run migrations in 'online' mode. + + This function is called within a synchronous context (via run_sync) + to configure the migration context with the provided connection + and target metadata, then execute the migrations within a transaction. + + Args: + connection: The SQLAlchemy connection to use for the migrations. + """ context.configure(connection=connection, target_metadata=target_metadata) with context.begin_transaction(): context.run_migrations() -async def run_async_migrations(): +async def run_async_migrations() -> None: """In this scenario we need to create an Engine and associate a connection with the context. """ @@ -74,7 +83,7 @@ async def run_async_migrations(): await connectable.dispose() -def run_migrations_online(): +def run_migrations_online() -> None: """Run migrations in 'online' mode.""" asyncio.run(run_async_migrations()) diff --git a/alembic/versions/6419d2d130f6_add_owner_to_task.py b/alembic/versions/6419d2d130f6_add_owner_to_task.py index 3b96a5c9e..6e2ede603 100644 --- a/alembic/versions/6419d2d130f6_add_owner_to_task.py +++ b/alembic/versions/6419d2d130f6_add_owner_to_task.py @@ -1,4 +1,4 @@ -"""add_owner_to_task +"""add_owner_to_task. Revision ID: 6419d2d130f6 Revises: From feb5033bfc7af52bf21b7d64053af43b5b2f11cc Mon Sep 17 00:00:00 2001 From: sokoliva Date: Thu, 19 Feb 2026 15:44:09 +0000 Subject: [PATCH 08/29] Fix: fix some more linter errors --- alembic/env.py | 8 +++++--- alembic/versions/__init__.py | 1 + 2 files changed, 6 insertions(+), 3 deletions(-) create mode 100644 alembic/versions/__init__.py diff --git a/alembic/env.py b/alembic/env.py index dcc644655..f516c886c 100644 --- a/alembic/env.py +++ b/alembic/env.py @@ -2,7 +2,7 @@ from logging.config import fileConfig -from sqlalchemy import pool +from sqlalchemy import pool, Connection from sqlalchemy.ext.asyncio import async_engine_from_config from a2a.server.models import Base @@ -51,7 +51,7 @@ def run_migrations_offline() -> None: context.run_migrations() -def do_run_migrations(connection) -> None: +def do_run_migrations(connection: Connection) -> None: """Run migrations in 'online' mode. This function is called within a synchronous context (via run_sync) @@ -68,7 +68,9 @@ def do_run_migrations(connection) -> None: async def run_async_migrations() -> None: - """In this scenario we need to create an Engine + """Run migrations using an Engine. + + In this scenario we need to create an Engine and associate a connection with the context. """ connectable = async_engine_from_config( diff --git a/alembic/versions/__init__.py b/alembic/versions/__init__.py new file mode 100644 index 000000000..23a018c29 --- /dev/null +++ b/alembic/versions/__init__.py @@ -0,0 +1 @@ +"""Alembic versioned migrations for the A2A project.""" From 212ad37c73fbcb27dd0b4c1623d9a20b1e482215 Mon Sep 17 00:00:00 2001 From: sokoliva Date: Thu, 19 Feb 2026 15:48:58 +0000 Subject: [PATCH 09/29] fix: more linter errors fixed --- alembic/env.py | 2 +- alembic/versions/__init__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/alembic/env.py b/alembic/env.py index f516c886c..07864de4d 100644 --- a/alembic/env.py +++ b/alembic/env.py @@ -2,7 +2,7 @@ from logging.config import fileConfig -from sqlalchemy import pool, Connection +from sqlalchemy import Connection, pool from sqlalchemy.ext.asyncio import async_engine_from_config from a2a.server.models import Base diff --git a/alembic/versions/__init__.py b/alembic/versions/__init__.py index 23a018c29..574828c67 100644 --- a/alembic/versions/__init__.py +++ b/alembic/versions/__init__.py @@ -1 +1 @@ -"""Alembic versioned migrations for the A2A project.""" +"""Alembic migrations scripts for the A2A project.""" From f7b5c1cc1e3787c079ef7a5c6ba3824879cbc874 Mon Sep 17 00:00:00 2001 From: sokoliva Date: Fri, 20 Feb 2026 16:53:06 +0000 Subject: [PATCH 10/29] fix: make parameter `ServerCallContext` non-optional in `PushNotificationConfigStore` methods. --- .../default_request_handler.py | 11 +- .../tasks/base_push_notification_sender.py | 8 +- ...database_push_notification_config_store.py | 6 +- ...inmemory_push_notification_config_store.py | 6 +- .../tasks/push_notification_config_store.py | 6 +- .../test_default_request_handler.py | 54 ++++---- .../request_handlers/test_jsonrpc_handler.py | 4 +- ...database_push_notification_config_store.py | 127 ++++++++++++------ .../tasks/test_inmemory_push_notifications.py | 123 +++++++++++------ .../tasks/test_push_notification_sender.py | 65 ++++++--- 10 files changed, 267 insertions(+), 143 deletions(-) diff --git a/src/a2a/server/request_handlers/default_request_handler.py b/src/a2a/server/request_handlers/default_request_handler.py index 63d0fdc74..9860d96e2 100644 --- a/src/a2a/server/request_handlers/default_request_handler.py +++ b/src/a2a/server/request_handlers/default_request_handler.py @@ -228,7 +228,7 @@ async def _run_event_stream( async def _setup_message_execution( self, params: SendMessageRequest, - context: ServerCallContext | None = None, + context: ServerCallContext, ) -> tuple[TaskManager, str, EventQueue, ResultAggregator, asyncio.Task]: """Common setup logic for both streaming and non-streaming message handling. @@ -284,7 +284,7 @@ async def _setup_message_execution( and params.configuration.push_notification_config ): await self._push_config_store.set_info( - task_id, params.configuration.push_notification_config + task_id, params.configuration.push_notification_config, context ) queue = await self._queue_manager.create_or_tap(task_id) @@ -498,6 +498,7 @@ async def on_create_task_push_notification_config( await self._push_config_store.set_info( task_id, params.config, + context, ) return TaskPushNotificationConfig( @@ -524,7 +525,7 @@ async def on_get_task_push_notification_config( raise ServerError(error=TaskNotFoundError()) push_notification_configs: list[PushNotificationConfig] = ( - await self._push_config_store.get_info(task_id) or [] + await self._push_config_store.get_info(task_id, context) or [] ) for config in push_notification_configs: @@ -596,7 +597,7 @@ async def on_list_task_push_notification_configs( raise ServerError(error=TaskNotFoundError()) push_notification_config_list = await self._push_config_store.get_info( - task_id + task_id, context ) return ListTaskPushNotificationConfigsResponse( @@ -627,4 +628,4 @@ async def on_delete_task_push_notification_config( if not task: raise ServerError(error=TaskNotFoundError()) - await self._push_config_store.delete_info(task_id, config_id) + await self._push_config_store.delete_info(task_id, context, config_id) diff --git a/src/a2a/server/tasks/base_push_notification_sender.py b/src/a2a/server/tasks/base_push_notification_sender.py index 4e4444923..84f544f5e 100644 --- a/src/a2a/server/tasks/base_push_notification_sender.py +++ b/src/a2a/server/tasks/base_push_notification_sender.py @@ -5,6 +5,7 @@ from google.protobuf.json_format import MessageToDict +from a2a.server.context import ServerCallContext from a2a.server.tasks.push_notification_config_store import ( PushNotificationConfigStore, ) @@ -22,19 +23,24 @@ def __init__( self, httpx_client: httpx.AsyncClient, config_store: PushNotificationConfigStore, + context: ServerCallContext, ) -> None: """Initializes the BasePushNotificationSender. Args: httpx_client: An async HTTP client instance to send notifications. config_store: A PushNotificationConfigStore instance to retrieve configurations. + context: The `ServerCallContext` that this push notification is produced under. """ self._client = httpx_client self._config_store = config_store + self._call_context: ServerCallContext = context async def send_notification(self, task: Task) -> None: """Sends a push notification for a task if configuration exists.""" - push_configs = await self._config_store.get_info(task.id) + push_configs = await self._config_store.get_info( + task.id, self._call_context + ) if not push_configs: return diff --git a/src/a2a/server/tasks/database_push_notification_config_store.py b/src/a2a/server/tasks/database_push_notification_config_store.py index 32dd47fd8..be8f16121 100644 --- a/src/a2a/server/tasks/database_push_notification_config_store.py +++ b/src/a2a/server/tasks/database_push_notification_config_store.py @@ -241,7 +241,7 @@ async def set_info( self, task_id: str, notification_config: PushNotificationConfig, - context: ServerCallContext | None = None, + context: ServerCallContext, ) -> None: """Sets or updates the push notification configuration for a task.""" await self._ensure_initialized() @@ -266,7 +266,7 @@ async def set_info( async def get_info( self, task_id: str, - context: ServerCallContext | None = None, + context: ServerCallContext, ) -> list[PushNotificationConfig]: """Retrieves all push notification configurations for a task, for the given owner.""" await self._ensure_initialized() @@ -297,8 +297,8 @@ async def get_info( async def delete_info( self, task_id: str, + context: ServerCallContext, config_id: str | None = None, - context: ServerCallContext | None = None, ) -> None: """Deletes push notification configurations for a task. diff --git a/src/a2a/server/tasks/inmemory_push_notification_config_store.py b/src/a2a/server/tasks/inmemory_push_notification_config_store.py index 54d6e1894..4de8b82fa 100644 --- a/src/a2a/server/tasks/inmemory_push_notification_config_store.py +++ b/src/a2a/server/tasks/inmemory_push_notification_config_store.py @@ -36,7 +36,7 @@ async def set_info( self, task_id: str, notification_config: PushNotificationConfig, - context: ServerCallContext | None = None, + context: ServerCallContext, ) -> None: """Sets or updates the push notification configuration for a task in memory.""" owner = self.owner_resolver(context) @@ -65,7 +65,7 @@ async def set_info( async def get_info( self, task_id: str, - context: ServerCallContext | None = None, + context: ServerCallContext, ) -> list[PushNotificationConfig]: """Retrieves all push notification configurations for a task from memory, for the given owner.""" owner = self.owner_resolver(context) @@ -78,8 +78,8 @@ async def get_info( async def delete_info( self, task_id: str, + context: ServerCallContext, config_id: str | None = None, - context: ServerCallContext | None = None, ) -> None: """Deletes push notification configurations for a task from memory. diff --git a/src/a2a/server/tasks/push_notification_config_store.py b/src/a2a/server/tasks/push_notification_config_store.py index e47060d7d..f1db64664 100644 --- a/src/a2a/server/tasks/push_notification_config_store.py +++ b/src/a2a/server/tasks/push_notification_config_store.py @@ -12,7 +12,7 @@ async def set_info( self, task_id: str, notification_config: PushNotificationConfig, - context: ServerCallContext | None = None, + context: ServerCallContext, ) -> None: """Sets or updates the push notification configuration for a task.""" @@ -20,7 +20,7 @@ async def set_info( async def get_info( self, task_id: str, - context: ServerCallContext | None = None, + context: ServerCallContext, ) -> list[PushNotificationConfig]: """Retrieves the push notification configuration for a task.""" @@ -28,7 +28,7 @@ async def get_info( async def delete_info( self, task_id: str, + context: ServerCallContext, config_id: str | None = None, - context: ServerCallContext | None = None, ) -> None: """Deletes the push notification configuration for a task.""" diff --git a/tests/server/request_handlers/test_default_request_handler.py b/tests/server/request_handlers/test_default_request_handler.py index 691731c94..410c2d21f 100644 --- a/tests/server/request_handlers/test_default_request_handler.py +++ b/tests/server/request_handlers/test_default_request_handler.py @@ -546,6 +546,7 @@ async def mock_current_result(): lambda self: mock_current_result() ) + context = create_server_call_context() with ( patch( 'a2a.server.request_handlers.default_request_handler.ResultAggregator', @@ -560,12 +561,10 @@ async def mock_current_result(): return_value=sample_initial_task, ), ): # Ensure task object is returned - await request_handler.on_message_send( - params, create_server_call_context() - ) + await request_handler.on_message_send(params, context) mock_push_notification_store.set_info.assert_awaited_once_with( - task_id, push_config + task_id, push_config, context ) # Other assertions for full flow if needed (e.g., agent execution) mock_agent_executor.execute.assert_awaited_once() @@ -665,6 +664,7 @@ async def mock_consume_and_break_on_interrupt( mock_consume_and_break_on_interrupt ) + context = create_server_call_context() with ( patch( 'a2a.server.request_handlers.default_request_handler.ResultAggregator', @@ -680,9 +680,7 @@ async def mock_consume_and_break_on_interrupt( ), ): # Execute the non-blocking request - result = await request_handler.on_message_send( - params, create_server_call_context() - ) + result = await request_handler.on_message_send(params, context) # Verify the result is the initial task (non-blocking behavior) assert result == initial_task @@ -700,7 +698,7 @@ async def mock_consume_and_break_on_interrupt( # Verify that the push notification config was stored mock_push_notification_store.set_info.assert_awaited_once_with( - task_id, push_config + task_id, push_config, context ) @@ -763,6 +761,7 @@ async def mock_current_result(): lambda self: mock_current_result() ) + context = create_server_call_context() with ( patch( 'a2a.server.request_handlers.default_request_handler.ResultAggregator', @@ -773,12 +772,10 @@ async def mock_current_result(): return_value=None, ), ): - await request_handler.on_message_send( - params, create_server_call_context() - ) + await request_handler.on_message_send(params, context) mock_push_notification_store.set_info.assert_awaited_once_with( - task_id, push_config + task_id, push_config, context ) # Other assertions for full flow if needed (e.g., agent execution) mock_agent_executor.execute.assert_awaited_once() @@ -1382,6 +1379,7 @@ def sync_get_event_stream_gen_for_prop_test(*args, **kwargs): side_effect=[get_current_result_coro1(), get_current_result_coro2()] ) + context = create_server_call_context() with ( patch( 'a2a.server.request_handlers.default_request_handler.ResultAggregator', @@ -1397,16 +1395,16 @@ def sync_get_event_stream_gen_for_prop_test(*args, **kwargs): ), ): # Consume the stream - async for _ in request_handler.on_message_send_stream( - params, create_server_call_context() - ): + async for _ in request_handler.on_message_send_stream(params, context): pass await asyncio.wait_for(execute_called.wait(), timeout=0.1) # Assertions # 1. set_info called once at the beginning if task exists (or after task is created from message) - mock_push_config_store.set_info.assert_any_call(task_id, push_config) + mock_push_config_store.set_info.assert_any_call( + task_id, push_config, context + ) # 2. send_notification called for each task event yielded by aggregator assert mock_push_sender.send_notification.await_count == 2 @@ -2082,7 +2080,9 @@ async def test_get_task_push_notification_config_info_not_found(): exc_info.value.error, InternalError ) # Current code raises InternalError mock_task_store.get.assert_awaited_once_with('non_existent_task', context) - mock_push_store.get_info.assert_awaited_once_with('non_existent_task') + mock_push_store.get_info.assert_awaited_once_with( + 'non_existent_task', context + ) @pytest.mark.asyncio @@ -2236,7 +2236,7 @@ async def test_on_message_send_stream(): async def consume_stream(): events = [] async for event in request_handler.on_message_send_stream( - message_params + message_params, create_server_call_context() ): events.append(event) if len(events) >= 3: @@ -2340,8 +2340,9 @@ async def test_list_task_push_notification_config_info_with_config(): ) push_store = InMemoryPushNotificationConfigStore() - await push_store.set_info('task_1', push_config1) - await push_store.set_info('task_1', push_config2) + context = create_server_call_context() + await push_store.set_info('task_1', push_config1, context) + await push_store.set_info('task_1', push_config2, context) request_handler = DefaultRequestHandler( agent_executor=MockAgentExecutor(), @@ -2467,6 +2468,7 @@ async def test_delete_no_task_push_notification_config_info(): await push_store.set_info( 'task_2', PushNotificationConfig(id='config_1', url='http://example.com'), + create_server_call_context(), ) request_handler = DefaultRequestHandler( @@ -2509,9 +2511,10 @@ async def test_delete_task_push_notification_config_info_with_config(): ) push_store = InMemoryPushNotificationConfigStore() - await push_store.set_info('task_1', push_config1) - await push_store.set_info('task_1', push_config2) - await push_store.set_info('task_2', push_config1) + context = create_server_call_context() + await push_store.set_info('task_1', push_config1, context) + await push_store.set_info('task_1', push_config2, context) + await push_store.set_info('task_2', push_config1, context) request_handler = DefaultRequestHandler( agent_executor=MockAgentExecutor(), @@ -2550,8 +2553,9 @@ async def test_delete_task_push_notification_config_info_with_config_and_no_id() # insertion without id should replace the existing config push_store = InMemoryPushNotificationConfigStore() - await push_store.set_info('task_1', push_config) - await push_store.set_info('task_1', push_config) + context = create_server_call_context() + await push_store.set_info('task_1', push_config, context) + await push_store.set_info('task_1', push_config, context) request_handler = DefaultRequestHandler( agent_executor=MockAgentExecutor(), diff --git a/tests/server/request_handlers/test_jsonrpc_handler.py b/tests/server/request_handlers/test_jsonrpc_handler.py index fca1175af..90b7be1c8 100644 --- a/tests/server/request_handlers/test_jsonrpc_handler.py +++ b/tests/server/request_handlers/test_jsonrpc_handler.py @@ -554,7 +554,7 @@ async def test_set_push_notification_success(self) -> None: self.assertIsInstance(response, dict) self.assertTrue(is_success_response(response)) mock_push_notification_store.set_info.assert_called_once_with( - mock_task.id, push_config + mock_task.id, push_config, None ) async def test_get_push_notification_success(self) -> None: @@ -601,7 +601,7 @@ async def test_on_message_stream_new_message_send_push_notification_success( mock_httpx_client = AsyncMock(spec=httpx.AsyncClient) push_notification_store = InMemoryPushNotificationConfigStore() push_notification_sender = BasePushNotificationSender( - mock_httpx_client, push_notification_store + mock_httpx_client, push_notification_store, ServerCallContext() ) request_handler = DefaultRequestHandler( mock_agent_executor, diff --git a/tests/server/tasks/test_database_push_notification_config_store.py b/tests/server/tasks/test_database_push_notification_config_store.py index 9336493a2..042ff8000 100644 --- a/tests/server/tasks/test_database_push_notification_config_store.py +++ b/tests/server/tasks/test_database_push_notification_config_store.py @@ -104,7 +104,7 @@ def _create_timestamp() -> Timestamp: ) -class TestUser(User): +class SampleUser(User): """A test implementation of the User interface.""" def __init__(self, user_name: str): @@ -119,6 +119,9 @@ def user_name(self) -> str: return self._user_name +MINIMAL_CALL_CONTEXT = ServerCallContext(user=SampleUser(user_name='user')) + + @pytest_asyncio.fixture(params=DB_CONFIGS) async def db_store_parameterized( request, @@ -198,8 +201,10 @@ async def test_set_and_get_info_single_config( task_id = 'task-1' config = PushNotificationConfig(id='config-1', url='http://example.com') - await db_store_parameterized.set_info(task_id, config) - retrieved_configs = await db_store_parameterized.get_info(task_id) + await db_store_parameterized.set_info(task_id, config, MINIMAL_CALL_CONTEXT) + retrieved_configs = await db_store_parameterized.get_info( + task_id, MINIMAL_CALL_CONTEXT + ) assert len(retrieved_configs) == 1 assert retrieved_configs[0] == config @@ -215,9 +220,15 @@ async def test_set_and_get_info_multiple_configs( config1 = PushNotificationConfig(id='config-1', url='http://example.com/1') config2 = PushNotificationConfig(id='config-2', url='http://example.com/2') - await db_store_parameterized.set_info(task_id, config1) - await db_store_parameterized.set_info(task_id, config2) - retrieved_configs = await db_store_parameterized.get_info(task_id) + await db_store_parameterized.set_info( + task_id, config1, MINIMAL_CALL_CONTEXT + ) + await db_store_parameterized.set_info( + task_id, config2, MINIMAL_CALL_CONTEXT + ) + retrieved_configs = await db_store_parameterized.get_info( + task_id, MINIMAL_CALL_CONTEXT + ) assert len(retrieved_configs) == 2 assert config1 in retrieved_configs @@ -238,9 +249,15 @@ async def test_set_info_updates_existing_config( id=config_id, url='http://updated.url' ) - await db_store_parameterized.set_info(task_id, initial_config) - await db_store_parameterized.set_info(task_id, updated_config) - retrieved_configs = await db_store_parameterized.get_info(task_id) + await db_store_parameterized.set_info( + task_id, initial_config, MINIMAL_CALL_CONTEXT + ) + await db_store_parameterized.set_info( + task_id, updated_config, MINIMAL_CALL_CONTEXT + ) + retrieved_configs = await db_store_parameterized.get_info( + task_id, MINIMAL_CALL_CONTEXT + ) assert len(retrieved_configs) == 1 assert retrieved_configs[0].url == 'http://updated.url' @@ -254,8 +271,10 @@ async def test_set_info_defaults_config_id_to_task_id( task_id = 'task-1' config = PushNotificationConfig(url='http://example.com') # id is None - await db_store_parameterized.set_info(task_id, config) - retrieved_configs = await db_store_parameterized.get_info(task_id) + await db_store_parameterized.set_info(task_id, config, MINIMAL_CALL_CONTEXT) + retrieved_configs = await db_store_parameterized.get_info( + task_id, MINIMAL_CALL_CONTEXT + ) assert len(retrieved_configs) == 1 assert retrieved_configs[0].id == task_id @@ -267,7 +286,7 @@ async def test_get_info_not_found( ): """Test getting info for a task with no configs returns an empty list.""" retrieved_configs = await db_store_parameterized.get_info( - 'non-existent-task' + 'non-existent-task', MINIMAL_CALL_CONTEXT ) assert retrieved_configs == [] @@ -281,11 +300,19 @@ async def test_delete_info_specific_config( config1 = PushNotificationConfig(id='config-1', url='http://a.com') config2 = PushNotificationConfig(id='config-2', url='http://b.com') - await db_store_parameterized.set_info(task_id, config1) - await db_store_parameterized.set_info(task_id, config2) + await db_store_parameterized.set_info( + task_id, config1, MINIMAL_CALL_CONTEXT + ) + await db_store_parameterized.set_info( + task_id, config2, MINIMAL_CALL_CONTEXT + ) - await db_store_parameterized.delete_info(task_id, 'config-1') - retrieved_configs = await db_store_parameterized.get_info(task_id) + await db_store_parameterized.delete_info( + task_id, MINIMAL_CALL_CONTEXT, 'config-1' + ) + retrieved_configs = await db_store_parameterized.get_info( + task_id, MINIMAL_CALL_CONTEXT + ) assert len(retrieved_configs) == 1 assert retrieved_configs[0] == config2 @@ -301,11 +328,19 @@ async def test_delete_info_all_for_task( config1 = PushNotificationConfig(id='config-1', url='http://a.com') config2 = PushNotificationConfig(id='config-2', url='http://b.com') - await db_store_parameterized.set_info(task_id, config1) - await db_store_parameterized.set_info(task_id, config2) + await db_store_parameterized.set_info( + task_id, config1, MINIMAL_CALL_CONTEXT + ) + await db_store_parameterized.set_info( + task_id, config2, MINIMAL_CALL_CONTEXT + ) - await db_store_parameterized.delete_info(task_id, None) - retrieved_configs = await db_store_parameterized.get_info(task_id) + await db_store_parameterized.delete_info( + task_id, MINIMAL_CALL_CONTEXT, None + ) + retrieved_configs = await db_store_parameterized.get_info( + task_id, MINIMAL_CALL_CONTEXT + ) assert retrieved_configs == [] @@ -316,7 +351,9 @@ async def test_delete_info_not_found( ): """Test that deleting a non-existent config does not raise an error.""" # Should not raise - await db_store_parameterized.delete_info('task-1', 'non-existent-config') + await db_store_parameterized.delete_info( + 'task-1', MINIMAL_CALL_CONTEXT, 'non-existent-config' + ) @pytest.mark.asyncio @@ -330,7 +367,7 @@ async def test_data_is_encrypted_in_db( ) plain_json = MessageToJson(config) - await db_store_parameterized.set_info(task_id, config) + await db_store_parameterized.set_info(task_id, config, MINIMAL_CALL_CONTEXT) # Directly query the database to inspect the raw data async_session = async_sessionmaker( @@ -360,7 +397,7 @@ async def test_decryption_error_with_wrong_key( task_id = 'wrong-key-task' config = PushNotificationConfig(id='config-1', url='http://secret.url') - await db_store_parameterized.set_info(task_id, config) + await db_store_parameterized.set_info(task_id, config, MINIMAL_CALL_CONTEXT) # 2. Try to read with a different key # Directly query the database to inspect the raw data @@ -369,7 +406,7 @@ async def test_decryption_error_with_wrong_key( db_store_parameterized.engine, encryption_key=wrong_key ) - retrieved_configs = await store2.get_info(task_id) + retrieved_configs = await store2.get_info(task_id, MINIMAL_CALL_CONTEXT) assert retrieved_configs == [] # _from_orm should raise a ValueError @@ -394,13 +431,13 @@ async def test_decryption_error_with_no_key( task_id = 'wrong-key-task' config = PushNotificationConfig(id='config-1', url='http://secret.url') - await db_store_parameterized.set_info(task_id, config) + await db_store_parameterized.set_info(task_id, config, MINIMAL_CALL_CONTEXT) # 2. Try to read with no key set # Directly query the database to inspect the raw data store2 = DatabasePushNotificationConfigStore(db_store_parameterized.engine) - retrieved_configs = await store2.get_info(task_id) + retrieved_configs = await store2.get_info(task_id, MINIMAL_CALL_CONTEXT) assert retrieved_configs == [] # _from_orm should raise a ValueError @@ -437,8 +474,10 @@ async def test_custom_table_name( config = PushNotificationConfig(id='config-1', url='http://custom.url') # This will create the table on first use - await custom_store.set_info(task_id, config) - retrieved_configs = await custom_store.get_info(task_id) + await custom_store.set_info(task_id, config, MINIMAL_CALL_CONTEXT) + retrieved_configs = await custom_store.get_info( + task_id, MINIMAL_CALL_CONTEXT + ) assert len(retrieved_configs) == 1 assert retrieved_configs[0] == config @@ -482,9 +521,9 @@ async def test_set_and_get_info_multiple_configs_no_key( config1 = PushNotificationConfig(id='config-1', url='http://example.com/1') config2 = PushNotificationConfig(id='config-2', url='http://example.com/2') - await store.set_info(task_id, config1) - await store.set_info(task_id, config2) - retrieved_configs = await store.get_info(task_id) + await store.set_info(task_id, config1, MINIMAL_CALL_CONTEXT) + await store.set_info(task_id, config2, MINIMAL_CALL_CONTEXT) + retrieved_configs = await store.get_info(task_id, MINIMAL_CALL_CONTEXT) assert len(retrieved_configs) == 2 assert config1 in retrieved_configs @@ -508,7 +547,7 @@ async def test_data_is_not_encrypted_in_db_if_no_key_is_set( config = PushNotificationConfig(id='config-1', url='http://example.com/1') plain_json = MessageToJson(config) - await store.set_info(task_id, config) + await store.set_info(task_id, config, MINIMAL_CALL_CONTEXT) # Directly query the database to inspect the raw data async_session = async_sessionmaker( @@ -539,10 +578,12 @@ async def test_decryption_fallback_for_unencrypted_data( task_id = 'mixed-encryption-task' config = PushNotificationConfig(id='config-1', url='http://plain.url') - await unencrypted_store.set_info(task_id, config) + await unencrypted_store.set_info(task_id, config, MINIMAL_CALL_CONTEXT) # 2. Try to read with the encryption-enabled store from the fixture - retrieved_configs = await db_store_parameterized.get_info(task_id) + retrieved_configs = await db_store_parameterized.get_info( + task_id, MINIMAL_CALL_CONTEXT + ) # Should fall back to parsing as plain JSON and not fail assert len(retrieved_configs) == 1 @@ -572,13 +613,15 @@ async def test_parsing_error_after_successful_decryption( task_id=task_id, config_id=config_id, config_data=encrypted_data, - owner='test-owner', + owner='user', ) session.add(db_model) await session.commit() # 3. get_info should log an error and return an empty list - retrieved_configs = await db_store_parameterized.get_info(task_id) + retrieved_configs = await db_store_parameterized.get_info( + task_id, MINIMAL_CALL_CONTEXT + ) assert retrieved_configs == [] # 4. _from_orm should raise a ValueError @@ -598,8 +641,8 @@ async def test_owner_resource_scoping( """Test that operations are scoped to the correct owner.""" config_store = db_store_parameterized - context_user1 = ServerCallContext(user=TestUser(user_name='user1')) - context_user2 = ServerCallContext(user=TestUser(user_name='user2')) + context_user1 = ServerCallContext(user=SampleUser(user_name='user1')) + context_user2 = ServerCallContext(user=SampleUser(user_name='user2')) # Create configs for different owners task1_u1_config1 = PushNotificationConfig( @@ -642,12 +685,16 @@ async def test_owner_resource_scoping( # Test DELETE_INFO # User 2 deleting User 1's config should not work - await config_store.delete_info('task1', 't1-u1-c1', context_user2) + await config_store.delete_info('task1', context_user2, 't1-u1-c1') u1_task1_configs = await config_store.get_info('task1', context_user1) assert len(u1_task1_configs) == 2 # User 1 deleting their own config - await config_store.delete_info('task1', 't1-u1-c1', context_user1) + await config_store.delete_info( + 'task1', + context_user1, + 't1-u1-c1', + ) u1_task1_configs = await config_store.get_info('task1', context_user1) assert len(u1_task1_configs) == 1 assert u1_task1_configs[0].id == 't1-u1-c2' diff --git a/tests/server/tasks/test_inmemory_push_notifications.py b/tests/server/tasks/test_inmemory_push_notifications.py index f1de00782..0024a95a6 100644 --- a/tests/server/tasks/test_inmemory_push_notifications.py +++ b/tests/server/tasks/test_inmemory_push_notifications.py @@ -26,7 +26,7 @@ # logging.disable(logging.CRITICAL) -def create_sample_task( +def _create_sample_task( task_id: str = 'task123', status_state: TaskState = TaskState.TASK_STATE_COMPLETED, ) -> Task: @@ -37,7 +37,7 @@ def create_sample_task( ) -def create_sample_push_config( +def _create_sample_push_config( url: str = 'http://example.com/callback', config_id: str = 'cfg1', token: str | None = None, @@ -60,12 +60,17 @@ def user_name(self) -> str: return self._user_name +MINIMAL_CALL_CONTEXT = ServerCallContext(user=SampleUser(user_name='user')) + + class TestInMemoryPushNotifier(unittest.IsolatedAsyncioTestCase): def setUp(self) -> None: self.mock_httpx_client = AsyncMock(spec=httpx.AsyncClient) self.config_store = InMemoryPushNotificationConfigStore() self.notifier = BasePushNotificationSender( - httpx_client=self.mock_httpx_client, config_store=self.config_store + httpx_client=self.mock_httpx_client, + config_store=self.config_store, + context=MINIMAL_CALL_CONTEXT, ) # Corrected argument name def test_constructor_stores_client(self) -> None: @@ -73,26 +78,34 @@ def test_constructor_stores_client(self) -> None: async def test_set_info_adds_new_config(self) -> None: task_id = 'task_new' - config = create_sample_push_config(url='http://new.url/callback') + config = _create_sample_push_config(url='http://new.url/callback') - await self.config_store.set_info(task_id, config) + await self.config_store.set_info(task_id, config, MINIMAL_CALL_CONTEXT) - retrieved = await self.config_store.get_info(task_id) + retrieved = await self.config_store.get_info( + task_id, MINIMAL_CALL_CONTEXT + ) self.assertEqual(retrieved, [config]) async def test_set_info_appends_to_existing_config(self) -> None: task_id = 'task_update' - initial_config = create_sample_push_config( + initial_config = _create_sample_push_config( url='http://initial.url/callback', config_id='cfg_initial' ) - await self.config_store.set_info(task_id, initial_config) + await self.config_store.set_info( + task_id, initial_config, MINIMAL_CALL_CONTEXT + ) - updated_config = create_sample_push_config( + updated_config = _create_sample_push_config( url='http://updated.url/callback', config_id='cfg_updated' ) - await self.config_store.set_info(task_id, updated_config) + await self.config_store.set_info( + task_id, updated_config, MINIMAL_CALL_CONTEXT + ) - retrieved = await self.config_store.get_info(task_id) + retrieved = await self.config_store.get_info( + task_id, MINIMAL_CALL_CONTEXT + ) self.assertEqual(len(retrieved), 2) self.assertEqual(retrieved[0], initial_config) self.assertEqual(retrieved[1], updated_config) @@ -102,62 +115,84 @@ async def test_set_info_without_config_id(self) -> None: initial_config = PushNotificationConfig( url='http://initial.url/callback' ) - await self.config_store.set_info(task_id, initial_config) + await self.config_store.set_info( + task_id, initial_config, MINIMAL_CALL_CONTEXT + ) - retrieved = await self.config_store.get_info(task_id) + retrieved = await self.config_store.get_info( + task_id, MINIMAL_CALL_CONTEXT + ) assert retrieved[0].id == task_id updated_config = PushNotificationConfig( url='http://initial.url/callback_new' ) - await self.config_store.set_info(task_id, updated_config) + await self.config_store.set_info( + task_id, updated_config, MINIMAL_CALL_CONTEXT + ) - retrieved = await self.config_store.get_info(task_id) + retrieved = await self.config_store.get_info( + task_id, MINIMAL_CALL_CONTEXT + ) assert len(retrieved) == 1 self.assertEqual(retrieved[0].url, updated_config.url) async def test_get_info_existing_config(self) -> None: task_id = 'task_get_exist' - config = create_sample_push_config(url='http://get.this/callback') - await self.config_store.set_info(task_id, config) + config = _create_sample_push_config(url='http://get.this/callback') + await self.config_store.set_info(task_id, config, MINIMAL_CALL_CONTEXT) - retrieved_config = await self.config_store.get_info(task_id) + retrieved_config = await self.config_store.get_info( + task_id, MINIMAL_CALL_CONTEXT + ) self.assertEqual(retrieved_config, [config]) async def test_get_info_non_existent_config(self) -> None: task_id = 'task_get_non_exist' - retrieved_config = await self.config_store.get_info(task_id) + retrieved_config = await self.config_store.get_info( + task_id, MINIMAL_CALL_CONTEXT + ) assert retrieved_config == [] async def test_delete_info_existing_config(self) -> None: task_id = 'task_delete_exist' - config = create_sample_push_config(url='http://delete.this/callback') - await self.config_store.set_info(task_id, config) + config = _create_sample_push_config(url='http://delete.this/callback') + await self.config_store.set_info(task_id, config, MINIMAL_CALL_CONTEXT) - retrieved = await self.config_store.get_info(task_id) + retrieved = await self.config_store.get_info( + task_id, MINIMAL_CALL_CONTEXT + ) self.assertEqual(len(retrieved), 1) - await self.config_store.delete_info(task_id, config_id=config.id) - retrieved = await self.config_store.get_info(task_id) + await self.config_store.delete_info( + task_id, config_id=config.id, context=MINIMAL_CALL_CONTEXT + ) + retrieved = await self.config_store.get_info( + task_id, MINIMAL_CALL_CONTEXT + ) self.assertEqual(len(retrieved), 0) async def test_delete_info_non_existent_config(self) -> None: task_id = 'task_delete_non_exist' # Ensure it doesn't raise an error try: - await self.config_store.delete_info(task_id) + await self.config_store.delete_info( + task_id, context=MINIMAL_CALL_CONTEXT + ) except Exception as e: self.fail( f'delete_info raised {e} unexpectedly for nonexistent task_id' ) - retrieved = await self.config_store.get_info(task_id) + retrieved = await self.config_store.get_info( + task_id, MINIMAL_CALL_CONTEXT + ) self.assertEqual(len(retrieved), 0) async def test_send_notification_success(self) -> None: task_id = 'task_send_success' - task_data = create_sample_task(task_id=task_id) - config = create_sample_push_config(url='http://notify.me/here') - await self.config_store.set_info(task_id, config) + task_data = _create_sample_task(task_id=task_id) + config = _create_sample_push_config(url='http://notify.me/here') + await self.config_store.set_info(task_id, config, MINIMAL_CALL_CONTEXT) # Mock the post call to simulate success mock_response = AsyncMock(spec=httpx.Response) @@ -180,11 +215,11 @@ async def test_send_notification_success(self) -> None: async def test_send_notification_with_token_success(self) -> None: task_id = 'task_send_success' - task_data = create_sample_task(task_id=task_id) - config = create_sample_push_config( + task_data = _create_sample_task(task_id=task_id) + config = _create_sample_push_config( url='http://notify.me/here', token='unique_token' ) - await self.config_store.set_info(task_id, config) + await self.config_store.set_info(task_id, config, MINIMAL_CALL_CONTEXT) # Mock the post call to simulate success mock_response = AsyncMock(spec=httpx.Response) @@ -211,7 +246,7 @@ async def test_send_notification_with_token_success(self) -> None: async def test_send_notification_no_config(self) -> None: task_id = 'task_send_no_config' - task_data = create_sample_task(task_id=task_id) + task_data = _create_sample_task(task_id=task_id) await self.notifier.send_notification(task_data) # Pass only task_data @@ -222,9 +257,9 @@ async def test_send_notification_http_status_error( self, mock_logger: MagicMock ) -> None: task_id = 'task_send_http_err' - task_data = create_sample_task(task_id=task_id) - config = create_sample_push_config(url='http://notify.me/http_error') - await self.config_store.set_info(task_id, config) + task_data = _create_sample_task(task_id=task_id) + config = _create_sample_push_config(url='http://notify.me/http_error') + await self.config_store.set_info(task_id, config, MINIMAL_CALL_CONTEXT) mock_response = MagicMock( spec=httpx.Response @@ -252,9 +287,9 @@ async def test_send_notification_request_error( self, mock_logger: MagicMock ) -> None: task_id = 'task_send_req_err' - task_data = create_sample_task(task_id=task_id) - config = create_sample_push_config(url='http://notify.me/req_error') - await self.config_store.set_info(task_id, config) + task_data = _create_sample_task(task_id=task_id) + config = _create_sample_push_config(url='http://notify.me/req_error') + await self.config_store.set_info(task_id, config, MINIMAL_CALL_CONTEXT) request_error = httpx.RequestError('Network issue', request=MagicMock()) self.mock_httpx_client.post.side_effect = request_error @@ -279,11 +314,11 @@ async def test_send_notification_with_auth( still works even if the config has an authentication field set. """ task_id = 'task_send_auth' - task_data = create_sample_task(task_id=task_id) - config = create_sample_push_config(url='http://notify.me/auth') + task_data = _create_sample_task(task_id=task_id) + config = _create_sample_push_config(url='http://notify.me/auth') # The current implementation doesn't use the authentication field # It only supports token-based auth via the token field - await self.config_store.set_info(task_id, config) + await self.config_store.set_info(task_id, config, MINIMAL_CALL_CONTEXT) mock_response = AsyncMock(spec=httpx.Response) mock_response.status_code = 200 @@ -367,14 +402,14 @@ async def test_owner_resource_scoping(self) -> None: # Test DELETE_INFO # User 2 deleting User 1's config should not work - await self.config_store.delete_info('task1', 't1-u1-c1', context_user2) + await self.config_store.delete_info('task1', context_user2, 't1-u1-c1') u1_task1_configs = await self.config_store.get_info( 'task1', context_user1 ) self.assertEqual(len(u1_task1_configs), 2) # User 1 deleting their own config - await self.config_store.delete_info('task1', 't1-u1-c1', context_user1) + await self.config_store.delete_info('task1', context_user1, 't1-u1-c1') u1_task1_configs = await self.config_store.get_info( 'task1', context_user1 ) diff --git a/tests/server/tasks/test_push_notification_sender.py b/tests/server/tasks/test_push_notification_sender.py index a7b5f7603..985ae6b7a 100644 --- a/tests/server/tasks/test_push_notification_sender.py +++ b/tests/server/tasks/test_push_notification_sender.py @@ -5,6 +5,8 @@ import httpx from google.protobuf.json_format import MessageToDict +from a2a.auth.user import User +from a2a.server.context import ServerCallContext from a2a.server.tasks.base_push_notification_sender import ( BasePushNotificationSender, ) @@ -17,7 +19,22 @@ ) -def create_sample_task( +class SampleUser(User): + """A test implementation of the User interface.""" + + def __init__(self, user_name: str): + self._user_name = user_name + + @property + def is_authenticated(self) -> bool: + return True + + @property + def user_name(self) -> str: + return self._user_name + + +def _create_sample_task( task_id: str = 'task123', status_state: TaskState = TaskState.TASK_STATE_COMPLETED, ) -> Task: @@ -28,7 +45,7 @@ def create_sample_task( ) -def create_sample_push_config( +def _create_sample_push_config( url: str = 'http://example.com/callback', config_id: str = 'cfg1', token: str | None = None, @@ -36,6 +53,9 @@ def create_sample_push_config( return PushNotificationConfig(id=config_id, url=url, token=token) +MINIMAL_CALL_CONTEXT = ServerCallContext(user=SampleUser(user_name='user')) + + class TestBasePushNotificationSender(unittest.IsolatedAsyncioTestCase): def setUp(self) -> None: self.mock_httpx_client = AsyncMock(spec=httpx.AsyncClient) @@ -43,6 +63,7 @@ def setUp(self) -> None: self.sender = BasePushNotificationSender( httpx_client=self.mock_httpx_client, config_store=self.mock_config_store, + context=MINIMAL_CALL_CONTEXT, ) def test_constructor_stores_client_and_config_store(self) -> None: @@ -51,8 +72,8 @@ def test_constructor_stores_client_and_config_store(self) -> None: async def test_send_notification_success(self) -> None: task_id = 'task_send_success' - task_data = create_sample_task(task_id=task_id) - config = create_sample_push_config(url='http://notify.me/here') + task_data = _create_sample_task(task_id=task_id) + config = _create_sample_push_config(url='http://notify.me/here') self.mock_config_store.get_info.return_value = [config] mock_response = AsyncMock(spec=httpx.Response) @@ -61,7 +82,9 @@ async def test_send_notification_success(self) -> None: await self.sender.send_notification(task_data) - self.mock_config_store.get_info.assert_awaited_once_with + self.mock_config_store.get_info.assert_awaited_once_with( + task_id, MINIMAL_CALL_CONTEXT + ) # assert httpx_client post method got invoked with right parameters self.mock_httpx_client.post.assert_awaited_once_with( @@ -73,8 +96,8 @@ async def test_send_notification_success(self) -> None: async def test_send_notification_with_token_success(self) -> None: task_id = 'task_send_success' - task_data = create_sample_task(task_id=task_id) - config = create_sample_push_config( + task_data = _create_sample_task(task_id=task_id) + config = _create_sample_push_config( url='http://notify.me/here', token='unique_token' ) self.mock_config_store.get_info.return_value = [config] @@ -85,7 +108,9 @@ async def test_send_notification_with_token_success(self) -> None: await self.sender.send_notification(task_data) - self.mock_config_store.get_info.assert_awaited_once_with + self.mock_config_store.get_info.assert_awaited_once_with( + task_id, MINIMAL_CALL_CONTEXT + ) # assert httpx_client post method got invoked with right parameters self.mock_httpx_client.post.assert_awaited_once_with( @@ -97,12 +122,14 @@ async def test_send_notification_with_token_success(self) -> None: async def test_send_notification_no_config(self) -> None: task_id = 'task_send_no_config' - task_data = create_sample_task(task_id=task_id) + task_data = _create_sample_task(task_id=task_id) self.mock_config_store.get_info.return_value = [] await self.sender.send_notification(task_data) - self.mock_config_store.get_info.assert_awaited_once_with(task_id) + self.mock_config_store.get_info.assert_awaited_once_with( + task_id, MINIMAL_CALL_CONTEXT + ) self.mock_httpx_client.post.assert_not_called() @patch('a2a.server.tasks.base_push_notification_sender.logger') @@ -110,8 +137,8 @@ async def test_send_notification_http_status_error( self, mock_logger: MagicMock ) -> None: task_id = 'task_send_http_err' - task_data = create_sample_task(task_id=task_id) - config = create_sample_push_config(url='http://notify.me/http_error') + task_data = _create_sample_task(task_id=task_id) + config = _create_sample_push_config(url='http://notify.me/http_error') self.mock_config_store.get_info.return_value = [config] mock_response = MagicMock(spec=httpx.Response) @@ -124,7 +151,9 @@ async def test_send_notification_http_status_error( await self.sender.send_notification(task_data) - self.mock_config_store.get_info.assert_awaited_once_with(task_id) + self.mock_config_store.get_info.assert_awaited_once_with( + task_id, MINIMAL_CALL_CONTEXT + ) self.mock_httpx_client.post.assert_awaited_once_with( config.url, json=MessageToDict(StreamResponse(task=task_data)), @@ -134,11 +163,11 @@ async def test_send_notification_http_status_error( async def test_send_notification_multiple_configs(self) -> None: task_id = 'task_multiple_configs' - task_data = create_sample_task(task_id=task_id) - config1 = create_sample_push_config( + task_data = _create_sample_task(task_id=task_id) + config1 = _create_sample_push_config( url='http://notify.me/cfg1', config_id='cfg1' ) - config2 = create_sample_push_config( + config2 = _create_sample_push_config( url='http://notify.me/cfg2', config_id='cfg2' ) self.mock_config_store.get_info.return_value = [config1, config2] @@ -149,7 +178,9 @@ async def test_send_notification_multiple_configs(self) -> None: await self.sender.send_notification(task_data) - self.mock_config_store.get_info.assert_awaited_once_with(task_id) + self.mock_config_store.get_info.assert_awaited_once_with( + task_id, MINIMAL_CALL_CONTEXT + ) self.assertEqual(self.mock_httpx_client.post.call_count, 2) # Check calls for config1 From 38d7df6706e61fdd300086bd2ee8d3c96289ade3 Mon Sep 17 00:00:00 2001 From: sokoliva Date: Fri, 20 Feb 2026 17:01:27 +0000 Subject: [PATCH 11/29] fix: add ServerCallContext to tests/e2e/push_notifications/agent_app.py --- tests/e2e/push_notifications/agent_app.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/e2e/push_notifications/agent_app.py b/tests/e2e/push_notifications/agent_app.py index ef8276c4e..dfe71566a 100644 --- a/tests/e2e/push_notifications/agent_app.py +++ b/tests/e2e/push_notifications/agent_app.py @@ -4,6 +4,7 @@ from a2a.server.agent_execution import AgentExecutor, RequestContext from a2a.server.apps import A2ARESTFastAPIApplication +from a2a.server.context import ServerCallContext from a2a.server.events import EventQueue from a2a.server.request_handlers import DefaultRequestHandler from a2a.server.tasks import ( @@ -148,6 +149,7 @@ def create_agent_app( push_sender=BasePushNotificationSender( httpx_client=notification_client, config_store=push_config_store, + context=ServerCallContext(), ), ), ) From c4b282a999317b4a5afa73d035d9692212be5245 Mon Sep 17 00:00:00 2001 From: sokoliva Date: Fri, 20 Feb 2026 17:10:05 +0000 Subject: [PATCH 12/29] fix: small fix --- .../default_request_handler.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/src/a2a/server/request_handlers/default_request_handler.py b/src/a2a/server/request_handlers/default_request_handler.py index 24c4b586e..104b256de 100644 --- a/src/a2a/server/request_handlers/default_request_handler.py +++ b/src/a2a/server/request_handlers/default_request_handler.py @@ -227,7 +227,7 @@ async def _run_event_stream( async def _setup_message_execution( self, params: SendMessageRequest, - context: ServerCallContext, + context: ServerCallContext | None, ) -> tuple[TaskManager, str, EventQueue, ResultAggregator, asyncio.Task]: """Common setup logic for both streaming and non-streaming message handling. @@ -283,7 +283,9 @@ async def _setup_message_execution( and params.configuration.push_notification_config ): await self._push_config_store.set_info( - task_id, params.configuration.push_notification_config, context + task_id, + params.configuration.push_notification_config, + context or ServerCallContext(), ) queue = await self._queue_manager.create_or_tap(task_id) @@ -495,7 +497,7 @@ async def on_create_task_push_notification_config( await self._push_config_store.set_info( task_id, params.config, - context, + context or ServerCallContext(), ) return TaskPushNotificationConfig( @@ -522,7 +524,10 @@ async def on_get_task_push_notification_config( raise ServerError(error=TaskNotFoundError()) push_notification_configs: list[PushNotificationConfig] = ( - await self._push_config_store.get_info(task_id, context) or [] + await self._push_config_store.get_info( + task_id, context or ServerCallContext() + ) + or [] ) for config in push_notification_configs: @@ -598,7 +603,7 @@ async def on_list_task_push_notification_configs( raise ServerError(error=TaskNotFoundError()) push_notification_config_list = await self._push_config_store.get_info( - task_id, context + task_id, context or ServerCallContext() ) return ListTaskPushNotificationConfigsResponse( @@ -629,4 +634,6 @@ async def on_delete_task_push_notification_config( if not task: raise ServerError(error=TaskNotFoundError()) - await self._push_config_store.delete_info(task_id, context, config_id) + await self._push_config_store.delete_info( + task_id, context or ServerCallContext(), config_id + ) From 0090ecc3b73c6eeda52c326af6261068ae7d0e40 Mon Sep 17 00:00:00 2001 From: sokoliva Date: Fri, 20 Feb 2026 17:22:30 +0000 Subject: [PATCH 13/29] fix: fix unit test error --- tests/server/request_handlers/test_jsonrpc_handler.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/server/request_handlers/test_jsonrpc_handler.py b/tests/server/request_handlers/test_jsonrpc_handler.py index 2ab43b44d..aa448f354 100644 --- a/tests/server/request_handlers/test_jsonrpc_handler.py +++ b/tests/server/request_handlers/test_jsonrpc_handler.py @@ -550,11 +550,12 @@ async def test_set_push_notification_success(self) -> None: task_id=mock_task.id, config=push_config, ) - response = await handler.set_push_notification_config(request) + context = ServerCallContext() + response = await handler.set_push_notification_config(request, context) self.assertIsInstance(response, dict) self.assertTrue(is_success_response(response)) mock_push_notification_store.set_info.assert_called_once_with( - mock_task.id, push_config, None + mock_task.id, push_config, context ) async def test_get_push_notification_success(self) -> None: From 0f51ef3832ac2e0394d2da482cf7feecffcd05f8 Mon Sep 17 00:00:00 2001 From: sokoliva Date: Fri, 20 Feb 2026 18:06:29 +0000 Subject: [PATCH 14/29] fi: fix --- src/a2a/server/tasks/inmemory_task_store.py | 19 +++++++++---------- .../server/tasks/test_inmemory_task_store.py | 8 ++++++++ 2 files changed, 17 insertions(+), 10 deletions(-) diff --git a/src/a2a/server/tasks/inmemory_task_store.py b/src/a2a/server/tasks/inmemory_task_store.py index 45d5c5b93..84c6556e8 100644 --- a/src/a2a/server/tasks/inmemory_task_store.py +++ b/src/a2a/server/tasks/inmemory_task_store.py @@ -55,16 +55,15 @@ async def get( task_id, owner, ) - owner_tasks = self.tasks.get(owner) - if owner_tasks: - task = owner_tasks.get(task_id) - if task: - logger.debug( - 'Task %s retrieved successfully for owner %s.', - task_id, - owner, - ) - return task + owner_tasks = self.tasks.get(owner, {}) + task = owner_tasks.get(task_id) + if task: + logger.debug( + 'Task %s retrieved successfully for owner %s.', + task_id, + owner, + ) + return task logger.debug( 'Task %s not found in store for owner %s.', task_id, owner ) diff --git a/tests/server/tasks/test_inmemory_task_store.py b/tests/server/tasks/test_inmemory_task_store.py index f6093b64e..6aa1bb7e5 100644 --- a/tests/server/tasks/test_inmemory_task_store.py +++ b/tests/server/tasks/test_inmemory_task_store.py @@ -275,6 +275,9 @@ async def test_owner_resource_scoping() -> None: context_user1 = ServerCallContext(user=SampleUser(user_name='user1')) context_user2 = ServerCallContext(user=SampleUser(user_name='user2')) + context_user3 = ServerCallContext( + user=SampleUser(user_name='user3') + ) # For testing non-existent user # Create tasks for different owners task1_user1 = Task() @@ -298,6 +301,7 @@ async def test_owner_resource_scoping() -> None: assert await store.get('u1-task1', context_user2) is None assert await store.get('u2-task1', context_user1) is None assert await store.get('u2-task1', context_user2) is not None + assert await store.get('u2-task1', context_user3) is None # Test LIST params = ListTasksRequest() @@ -311,6 +315,10 @@ async def test_owner_resource_scoping() -> None: assert {t.id for t in page_user2.tasks} == {'u2-task1'} assert page_user2.total_size == 1 + page_user3 = await store.list(params, context_user3) + assert len(page_user3.tasks) == 0 + assert page_user3.total_size == 0 + # Test DELETE await store.delete('u1-task1', context_user2) # Should not delete assert await store.get('u1-task1', context_user1) is not None From 00e5eacede1ebf8742795cb739e38e1fa5121c22 Mon Sep 17 00:00:00 2001 From: sokoliva Date: Fri, 20 Feb 2026 18:06:51 +0000 Subject: [PATCH 15/29] fix: fix --- ...inmemory_push_notification_config_store.py | 10 ++++----- src/a2a/server/tasks/inmemory_task_store.py | 21 +++++++++++-------- 2 files changed, 16 insertions(+), 15 deletions(-) diff --git a/src/a2a/server/tasks/inmemory_push_notification_config_store.py b/src/a2a/server/tasks/inmemory_push_notification_config_store.py index 4de8b82fa..eb336e329 100644 --- a/src/a2a/server/tasks/inmemory_push_notification_config_store.py +++ b/src/a2a/server/tasks/inmemory_push_notification_config_store.py @@ -70,10 +70,8 @@ async def get_info( """Retrieves all push notification configurations for a task from memory, for the given owner.""" owner = self.owner_resolver(context) async with self.lock: - owner_infos = self._push_notification_infos.get(owner) - if owner_infos: - return list(owner_infos.get(task_id, [])) - return [] + owner_infos = self._push_notification_infos.get(owner, {}) + return list(owner_infos.get(task_id, [])) async def delete_info( self, @@ -88,8 +86,8 @@ async def delete_info( """ owner = self.owner_resolver(context) async with self.lock: - owner_infos = self._push_notification_infos.get(owner) - if not owner_infos or task_id not in owner_infos: + owner_infos = self._push_notification_infos.get(owner, {}) + if task_id not in owner_infos: logger.warning( 'Attempted to delete push notification config for task %s, owner %s that does not exist.', task_id, diff --git a/src/a2a/server/tasks/inmemory_task_store.py b/src/a2a/server/tasks/inmemory_task_store.py index 84c6556e8..019fd773e 100644 --- a/src/a2a/server/tasks/inmemory_task_store.py +++ b/src/a2a/server/tasks/inmemory_task_store.py @@ -160,17 +160,20 @@ async def delete( task_id, owner, ) - if owner in self.tasks and task_id in self.tasks[owner]: - del self.tasks[owner][task_id] - logger.debug( - 'Task %s deleted successfully for owner %s.', task_id, owner - ) - if not self.tasks[owner]: - del self.tasks[owner] - logger.debug('Removed empty owner %s from store.', owner) - else: + + owner_tasks = self.tasks.get(owner, {}) + if task_id not in owner_tasks: logger.warning( 'Attempted to delete nonexistent task with id: %s for owner %s', task_id, owner, ) + return + + del owner_tasks[task_id] + logger.debug( + 'Task %s deleted successfully for owner %s.', task_id, owner + ) + if not owner_tasks: + del self.tasks[owner] + logger.debug('Removed empty owner %s from store.', owner) From 5d65aa666528d9a6cfc8000901c58efca0a9d270 Mon Sep 17 00:00:00 2001 From: sokoliva Date: Mon, 23 Feb 2026 15:27:44 +0000 Subject: [PATCH 16/29] chore: distribute alembic migration files. Add helper _get_owner_push_notification_infos function --- alembic/README | 59 ------------- pyproject.toml | 80 +----------------- alembic.ini => src/a2a/alembic.ini | 10 --- src/a2a/cli.py | 65 ++++++++++++++ src/a2a/migrations/README | 84 +++++++++++++++++++ {alembic => src/a2a/migrations}/__init__.py | 0 {alembic => src/a2a/migrations}/env.py | 17 ++-- .../a2a/migrations}/script.py.mako | 0 .../6419d2d130f6_add_owner_to_task.py | 10 ++- .../a2a/migrations}/versions/__init__.py | 0 src/a2a/server/models.py | 2 +- ...inmemory_push_notification_config_store.py | 7 +- src/a2a/server/tasks/inmemory_task_store.py | 9 +- 13 files changed, 183 insertions(+), 160 deletions(-) delete mode 100644 alembic/README rename alembic.ini => src/a2a/alembic.ini (59%) create mode 100644 src/a2a/cli.py create mode 100644 src/a2a/migrations/README rename {alembic => src/a2a/migrations}/__init__.py (100%) rename {alembic => src/a2a/migrations}/env.py (86%) rename {alembic => src/a2a/migrations}/script.py.mako (100%) rename {alembic => src/a2a/migrations}/versions/6419d2d130f6_add_owner_to_task.py (74%) rename {alembic => src/a2a/migrations}/versions/__init__.py (100%) diff --git a/alembic/README b/alembic/README deleted file mode 100644 index 06ec9e9a8..000000000 --- a/alembic/README +++ /dev/null @@ -1,59 +0,0 @@ -# Database Migrations with Alembic - -This directory contains database migration scripts for the A2A SDK, managed by [Alembic](https://alembic.sqlalchemy.org/). - -## Configuration - -- `alembic.ini`: Global configuration for Alembic, including the database URL. -- `env.py`: Python script that runs when the Alembic environment is invoked. It configures the SQLAlchemy engine and connects it to the migration context. -- `versions/`: Directory containing individual migration scripts. - -## Common Commands - -All commands should be run from the project root using `uv run`. - -### Viewing Status -```bash -# View current migration version of the database -uv run alembic current - -# View migration history -uv run alembic history --verbose -``` - -### Running Migrations -```bash -# Upgrade to the latest version -uv run alembic upgrade head - -# Downgrade by one version -uv run alembic downgrade -1 - -# Revert all migrations -uv run alembic downgrade base -``` - -### Creating Migrations -```bash -# Create a new migration manually -uv run alembic revision -m "description of changes" - -# Create a new migration automatically (detects changes in models.py) -uv run alembic revision --autogenerate -m "description of changes" -``` - -## Troubleshooting - -### "duplicate column name" error -If you see an error like `sqlalchemy.exc.OperationalError: (sqlite3.OperationalError) duplicate column name: owner`, it usually means the column was already created (perhaps by `Base.metadata.create_all()` in tests or development) but Alembic doesn't know about it yet. - -To fix this, "stamp" the database to tell Alembic it is already at the latest version: -```bash -uv run alembic stamp head -``` - -## How to add a new migration -1. Modify the models in `src/a2a/server/models.py`. -2. Run `uv run alembic revision --autogenerate -m "Add new field to Task"`. -3. Review the generated script in `alembic/versions/`. -4. Apply the migration with `uv run alembic upgrade head`. diff --git a/pyproject.toml b/pyproject.toml index 7e3b6a2f7..8caa7c70d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -353,84 +353,12 @@ indent-style = "space" [tool.alembic] # path to migration scripts. -# this is typically a path given in POSIX (e.g. forward slashes) -# format, relative to the token %(here)s which refers to the location of this -# ini file -script_location = "%(here)s/alembic" - -# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s -# Uncomment the line below if you want the files to be prepended with date and time -# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file -# for all available tokens -# file_template = "%%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s" -# Or organize into date-based subdirectories (requires recursive_version_locations = true) -# file_template = "%%(year)d/%%(month).2d/%%(day).2d_%%(hour).2d%%(minute).2d_%%(second).2d_%%(rev)s_%%(slug)s" +script_location = "src/a2a/migrations" # additional paths to be prepended to sys.path. defaults to the current working directory. prepend_sys_path = [ - "." + "src" ] -# timezone to use when rendering the date within the migration file -# as well as the filename. -# If specified, requires the tzdata library which can be installed by adding -# `alembic[tz]` to the pip requirements. -# string value is passed to ZoneInfo() -# leave blank for localtime -# timezone = - -# max length of characters to apply to the "slug" field -# truncate_slug_length = 40 - -# set to 'true' to run the environment during -# the 'revision' command, regardless of autogenerate -# revision_environment = false - -# set to 'true' to allow .pyc and .pyo files without -# a source .py file to be detected as revisions in the -# versions/ directory -# sourceless = false - -# version location specification; This defaults -# to /versions. When using multiple version -# directories, initial revisions must be specified with --version-path. -# version_locations = [ -# "%(here)s/alembic/versions", -# "%(here)s/foo/bar" -# ] - - -# set to 'true' to search source files recursively -# in each "version_locations" directory -# new in Alembic version 1.10 -# recursive_version_locations = false - -# the output encoding used when revision files -# are written from script.py.mako -# output_encoding = "utf-8" - -# This section defines scripts or Python functions that are run -# on newly generated revision scripts. See the documentation for further -# detail and examples -# [[tool.alembic.post_write_hooks]] -# format using "black" - use the console_scripts runner, -# against the "black" entrypoint -# name = "black" -# type = "console_scripts" -# entrypoint = "black" -# options = "-l 79 REVISION_SCRIPT_FILENAME" -# -# [[tool.alembic.post_write_hooks]] -# lint with attempts to fix using "ruff" - use the module runner, against the "ruff" module -# name = "ruff" -# type = "module" -# module = "ruff" -# options = "check --fix REVISION_SCRIPT_FILENAME" -# -# [[tool.alembic.post_write_hooks]] -# Alternatively, use the exec runner to execute a binary found on your PATH -# name = "ruff" -# type = "exec" -# executable = "ruff" -# options = "check --fix REVISION_SCRIPT_FILENAME" - +[project.scripts] +a2a-db = "a2a.cli:run_migrations" diff --git a/alembic.ini b/src/a2a/alembic.ini similarity index 59% rename from alembic.ini rename to src/a2a/alembic.ini index 58249b073..ed0e60bc2 100644 --- a/alembic.ini +++ b/src/a2a/alembic.ini @@ -1,15 +1,5 @@ # A generic, single database configuration. -[alembic] - -# database URL. This is consumed by the user-maintained env.py script only. -# other means of configuring database URLs may be customized within the env.py -# file. -# IMPORTANT: This is a placeholder and an example, and should be replaced with your actual database URL. -sqlalchemy.url = sqlite+aiosqlite:///./test.db - - -# Logging configuration [loggers] keys = root,sqlalchemy,alembic diff --git a/src/a2a/cli.py b/src/a2a/cli.py new file mode 100644 index 000000000..7b8a3680a --- /dev/null +++ b/src/a2a/cli.py @@ -0,0 +1,65 @@ +import argparse +import os +from alembic.config import Config +from alembic import command +from importlib.resources import files + +def run_migrations(): + """CLI tool to manage database migrations.""" + parser = argparse.ArgumentParser(description="A2A Database Migration Tool") + + # Global options + parser.add_argument("-o", "--owner", help="Value for the 'owner' column (used in specific migrations). If not set defaults to 'unknown'") + parser.add_argument("-u", "--database-url", help="Database URL to use for the migrations. If not set, the DATABASE_URL environment variable will be used.") + + subparsers = parser.add_subparsers(dest="cmd", help="Migration command") + + # Upgrade command + up_parser = subparsers.add_parser("upgrade", help="Upgrade to a later version") + up_parser.add_argument("revision", nargs="?", default="head", help="Revision target (default: head)") + up_parser.add_argument("-o", "--owner", dest="sub_owner", help="Alias for top-level --owner") + up_parser.add_argument("-u", "--database-url", dest="sub_database_url", help="Alias for top-level --database-url") + + # Downgrade command + down_parser = subparsers.add_parser("downgrade", help="Revert to a previous version") + down_parser.add_argument("revision", nargs="?", default="base", help="Revision target (e.g., -1, base or a specific ID)") + down_parser.add_argument("-u", "--database-url", dest="sub_database_url", help="Alias for top-level --database-url") + + args = parser.parse_args() + + # Consolidate owner value + owner = args.owner or getattr(args, "sub_owner", None) + db_url = args.database_url or getattr(args, "sub_database_url", None) + if db_url: + os.environ["DATABASE_URL"] = db_url + + # Default to upgrade head if no command is provided + if not args.cmd: + args.cmd = "upgrade" + args.revision = "head" + + # 1. Locate the bundled alembic.ini + ini_path = files('a2a').joinpath('alembic.ini') + cfg = Config(str(ini_path)) + + # 2. Dynamically set the script location + migrations_path = files('a2a').joinpath('migrations') + cfg.set_main_option("script_location", str(migrations_path)) + + # 3. Pass custom arguments to the migration context + if owner: + if args.cmd == "downgrade": + parser.error("The --owner option is not supported for the 'downgrade' command.") + cfg.set_main_option("owner", owner) + + + # 3. Execute the requested command + if args.cmd == "upgrade": + print(f"Upgrading database to {args.revision}...") + command.upgrade(cfg, args.revision) + elif args.cmd == "downgrade": + print(f"Downgrading database to {args.revision}...") + command.downgrade(cfg, args.revision) + + print("Done.") + diff --git a/src/a2a/migrations/README b/src/a2a/migrations/README new file mode 100644 index 000000000..ebd418b93 --- /dev/null +++ b/src/a2a/migrations/README @@ -0,0 +1,84 @@ +# A2A SDK Database Migrations + +This directory handles the database schema updates for the A2A SDK. It uses a bundled CLI tool to simplify the migration process for both users and developers of the SDK. + +## User Guide (For Integrators) + +When you install the `a2a-sdk`, you get a built-in command `a2a-db` to manage your database schema. + +### 1. Set your Database URL +Migrations require the `DATABASE_URL` environment variable to be set with an **async-compatible** driver or you can use the `-u` flag to specify the database URL for a single command. + +```bash +# SQLite example +export DATABASE_URL="sqlite+aiosqlite://user:pass@host:port/your_database_name" + +# PostgreSQL example +export DATABASE_URL="postgresql+asyncpg://user:pass@localhost/your_database_name" + +# MySQL example +export DATABASE_URL="mysql+aiomysql://user:pass@localhost/your_database_name" +``` + +### 2. Apply Migrations +Always run this command after installing or upgrading the SDK to ensure your database matches the required schema. + +```bash +# Bring the database to the latest version +a2a-db +``` + +### 3. Customizing Defaults +Add owner to tasks migration allows you to pass custom values for the new `owner` column. For example, to set a specific default owner for existing tasks: + +```bash +a2a-db -o "admin_user" +``` + +### 4. Rolling Back +If you need to revert a change: + +```bash +# Step back one version +a2a-db downgrade -1 +``` + +```bash +# Downgrade to a specific revision ID +a2a-db downgrade +``` + +```bash +# Revert all migrations +a2a-db downgrade base +``` + +--- + +## Developer Guide (For SDK Contributors) + +If you are modifying the SDK models and need to generate new migration files, use the following workflow. + +### Creating a New Migration +Developers should use the raw `alembic` command locally to generate migrations. Ensure you are in the project root. + +```bash +# Detect changes in models.py and generate a script +uv run alembic revision --autogenerate -m "describe your changes" +``` + +### Internal Layout +- `env.py`: Configures the migration engine and applies the mandatory `DATABASE_URL` check. +- `versions/`: Contains the migration history. +- `script.py.mako`: The template for all new migration files. + +## Troubleshooting + +### "Duplicate column name" +If your database already has the required tables (e.g., created via `Base.metadata.create_all()` in a legacy script), you may need to "stamp" the database to tell the SDK that it is already up to date: + +```bash +# Stamp the database without running SQL commands +# (Requires raw alembic command for developer use) +uv run alembic stamp head +``` diff --git a/alembic/__init__.py b/src/a2a/migrations/__init__.py similarity index 100% rename from alembic/__init__.py rename to src/a2a/migrations/__init__.py diff --git a/alembic/env.py b/src/a2a/migrations/env.py similarity index 86% rename from alembic/env.py rename to src/a2a/migrations/env.py index 07864de4d..dbc7b8c00 100644 --- a/alembic/env.py +++ b/src/a2a/migrations/env.py @@ -1,4 +1,5 @@ import asyncio +import os from logging.config import fileConfig @@ -13,6 +14,16 @@ # access to the values within the .ini file in use. config = context.config +# Mandatory database configuration +db_url = os.getenv("DATABASE_URL") +if not db_url: + raise RuntimeError( + "DATABASE_URL environment variable is not set. " + "Please set it (e.g., export DATABASE_URL='sqlite+aiosqlite:///./my-database.db') before running migrations" + "or use the --database-url flag." + ) +config.set_main_option("sqlalchemy.url", db_url) + # Interpret the config file for Python logging. # This line sets up loggers basically. if config.config_file_name is not None: @@ -21,12 +32,6 @@ # add your model's MetaData object here for 'autogenerate' support target_metadata = Base.metadata -# other values from the config, defined by the needs of env.py, -# can be acquired: -# my_important_option = config.get_main_option("my_important_option") # noqa: ERA001 -# ... etc. - - def run_migrations_offline() -> None: """Run migrations in 'offline' mode. diff --git a/alembic/script.py.mako b/src/a2a/migrations/script.py.mako similarity index 100% rename from alembic/script.py.mako rename to src/a2a/migrations/script.py.mako diff --git a/alembic/versions/6419d2d130f6_add_owner_to_task.py b/src/a2a/migrations/versions/6419d2d130f6_add_owner_to_task.py similarity index 74% rename from alembic/versions/6419d2d130f6_add_owner_to_task.py rename to src/a2a/migrations/versions/6419d2d130f6_add_owner_to_task.py index 6e2ede603..1539b97ba 100644 --- a/alembic/versions/6419d2d130f6_add_owner_to_task.py +++ b/src/a2a/migrations/versions/6419d2d130f6_add_owner_to_task.py @@ -10,7 +10,7 @@ import sqlalchemy as sa -from alembic import op +from alembic import context, op # revision identifiers, used by Alembic. @@ -22,17 +22,21 @@ def upgrade() -> None: """Upgrade schema.""" + # Get the default value from the config (passed via CLI) + owner = context.config.get_main_option("owner", "unknown") + op.add_column( 'tasks', sa.Column( 'owner', - sa.String(255), + sa.String(128), nullable=False, - server_default='unknown', # Set your desired default value here + server_default=owner, ), ) + def downgrade() -> None: """Downgrade schema.""" op.drop_column('tasks', 'owner') diff --git a/alembic/versions/__init__.py b/src/a2a/migrations/versions/__init__.py similarity index 100% rename from alembic/versions/__init__.py rename to src/a2a/migrations/versions/__init__.py diff --git a/src/a2a/server/models.py b/src/a2a/server/models.py index a7e80d81c..35f47ffff 100644 --- a/src/a2a/server/models.py +++ b/src/a2a/server/models.py @@ -148,7 +148,7 @@ class TaskMixin: kind: Mapped[str] = mapped_column( String(16), nullable=False, default='task' ) - owner: Mapped[str] = mapped_column(String(255), nullable=False) + owner: Mapped[str] = mapped_column(String(128), nullable=False) last_updated: Mapped[str] = mapped_column(String(22), nullable=True) # Properly typed Pydantic fields with automatic serialization diff --git a/src/a2a/server/tasks/inmemory_push_notification_config_store.py b/src/a2a/server/tasks/inmemory_push_notification_config_store.py index eb336e329..d5c0f10a4 100644 --- a/src/a2a/server/tasks/inmemory_push_notification_config_store.py +++ b/src/a2a/server/tasks/inmemory_push_notification_config_store.py @@ -31,6 +31,9 @@ def __init__( str, dict[str, list[PushNotificationConfig]] ] = defaultdict(dict) self.owner_resolver = owner_resolver + + def _get_owner_push_notification_infos(self, owner: str) -> dict[str, list[PushNotificationConfig]]: + return self._push_notification_infos.get(owner, {}) async def set_info( self, @@ -70,7 +73,7 @@ async def get_info( """Retrieves all push notification configurations for a task from memory, for the given owner.""" owner = self.owner_resolver(context) async with self.lock: - owner_infos = self._push_notification_infos.get(owner, {}) + owner_infos = self._get_owner_push_notification_infos(owner) return list(owner_infos.get(task_id, [])) async def delete_info( @@ -86,7 +89,7 @@ async def delete_info( """ owner = self.owner_resolver(context) async with self.lock: - owner_infos = self._push_notification_infos.get(owner, {}) + owner_infos = self._get_owner_push_notification_infos(owner) if task_id not in owner_infos: logger.warning( 'Attempted to delete push notification config for task %s, owner %s that does not exist.', diff --git a/src/a2a/server/tasks/inmemory_task_store.py b/src/a2a/server/tasks/inmemory_task_store.py index 019fd773e..b655805fc 100644 --- a/src/a2a/server/tasks/inmemory_task_store.py +++ b/src/a2a/server/tasks/inmemory_task_store.py @@ -32,6 +32,9 @@ def __init__( self.lock = asyncio.Lock() self.owner_resolver = owner_resolver + def _get_owner_tasks(self, owner: str) -> dict[str, Task]: + return self.tasks.get(owner, {}) + async def save( self, task: Task, context: ServerCallContext | None = None ) -> None: @@ -55,7 +58,7 @@ async def get( task_id, owner, ) - owner_tasks = self.tasks.get(owner, {}) + owner_tasks = self._get_owner_tasks(owner) task = owner_tasks.get(task_id) if task: logger.debug( @@ -79,7 +82,7 @@ async def list( logger.debug('Listing tasks for owner %s with params %s', owner, params) async with self.lock: - owner_tasks = self.tasks.get(owner, {}) + owner_tasks = self._get_owner_tasks(owner) tasks = list(owner_tasks.values()) # Filter tasks @@ -161,7 +164,7 @@ async def delete( owner, ) - owner_tasks = self.tasks.get(owner, {}) + owner_tasks = self._get_owner_tasks(owner) if task_id not in owner_tasks: logger.warning( 'Attempted to delete nonexistent task with id: %s for owner %s', From 44ad799867f05fb86fa63e3ca421671bde4b23d2 Mon Sep 17 00:00:00 2001 From: sokoliva Date: Mon, 23 Feb 2026 17:31:36 +0000 Subject: [PATCH 17/29] fix: linter issues --- src/a2a/cli.py | 101 ++++++++++++------ src/a2a/migrations/env.py | 11 +- ...inmemory_push_notification_config_store.py | 2 +- 3 files changed, 76 insertions(+), 38 deletions(-) diff --git a/src/a2a/cli.py b/src/a2a/cli.py index 7b8a3680a..b5f3f7faf 100644 --- a/src/a2a/cli.py +++ b/src/a2a/cli.py @@ -1,42 +1,79 @@ import argparse import os -from alembic.config import Config -from alembic import command + from importlib.resources import files -def run_migrations(): +from alembic import command +from alembic.config import Config + + +def run_migrations() -> None: """CLI tool to manage database migrations.""" - parser = argparse.ArgumentParser(description="A2A Database Migration Tool") - + parser = argparse.ArgumentParser(description='A2A Database Migration Tool') + # Global options - parser.add_argument("-o", "--owner", help="Value for the 'owner' column (used in specific migrations). If not set defaults to 'unknown'") - parser.add_argument("-u", "--database-url", help="Database URL to use for the migrations. If not set, the DATABASE_URL environment variable will be used.") - - subparsers = parser.add_subparsers(dest="cmd", help="Migration command") + parser.add_argument( + '-o', + '--owner', + help="Value for the 'owner' column (used in specific migrations). If not set defaults to 'unknown'", + ) + parser.add_argument( + '-u', + '--database-url', + help='Database URL to use for the migrations. If not set, the DATABASE_URL environment variable will be used.', + ) + + subparsers = parser.add_subparsers(dest='cmd', help='Migration command') # Upgrade command - up_parser = subparsers.add_parser("upgrade", help="Upgrade to a later version") - up_parser.add_argument("revision", nargs="?", default="head", help="Revision target (default: head)") - up_parser.add_argument("-o", "--owner", dest="sub_owner", help="Alias for top-level --owner") - up_parser.add_argument("-u", "--database-url", dest="sub_database_url", help="Alias for top-level --database-url") + up_parser = subparsers.add_parser( + 'upgrade', help='Upgrade to a later version' + ) + up_parser.add_argument( + 'revision', + nargs='?', + default='head', + help='Revision target (default: head)', + ) + up_parser.add_argument( + '-o', '--owner', dest='sub_owner', help='Alias for top-level --owner' + ) + up_parser.add_argument( + '-u', + '--database-url', + dest='sub_database_url', + help='Alias for top-level --database-url', + ) # Downgrade command - down_parser = subparsers.add_parser("downgrade", help="Revert to a previous version") - down_parser.add_argument("revision", nargs="?", default="base", help="Revision target (e.g., -1, base or a specific ID)") - down_parser.add_argument("-u", "--database-url", dest="sub_database_url", help="Alias for top-level --database-url") + down_parser = subparsers.add_parser( + 'downgrade', help='Revert to a previous version' + ) + down_parser.add_argument( + 'revision', + nargs='?', + default='base', + help='Revision target (e.g., -1, base or a specific ID)', + ) + down_parser.add_argument( + '-u', + '--database-url', + dest='sub_database_url', + help='Alias for top-level --database-url', + ) args = parser.parse_args() # Consolidate owner value - owner = args.owner or getattr(args, "sub_owner", None) - db_url = args.database_url or getattr(args, "sub_database_url", None) + owner = args.owner or getattr(args, 'sub_owner', None) + db_url = args.database_url or getattr(args, 'sub_database_url', None) if db_url: - os.environ["DATABASE_URL"] = db_url + os.environ['DATABASE_URL'] = db_url # Default to upgrade head if no command is provided if not args.cmd: - args.cmd = "upgrade" - args.revision = "head" + args.cmd = 'upgrade' + args.revision = 'head' # 1. Locate the bundled alembic.ini ini_path = files('a2a').joinpath('alembic.ini') @@ -44,22 +81,22 @@ def run_migrations(): # 2. Dynamically set the script location migrations_path = files('a2a').joinpath('migrations') - cfg.set_main_option("script_location", str(migrations_path)) + cfg.set_main_option('script_location', str(migrations_path)) # 3. Pass custom arguments to the migration context if owner: - if args.cmd == "downgrade": - parser.error("The --owner option is not supported for the 'downgrade' command.") - cfg.set_main_option("owner", owner) - + if args.cmd == 'downgrade': + parser.error( + "The --owner option is not supported for the 'downgrade' command." + ) + cfg.set_main_option('owner', owner) # 3. Execute the requested command - if args.cmd == "upgrade": - print(f"Upgrading database to {args.revision}...") + if args.cmd == 'upgrade': + print(f'Upgrading database to {args.revision}...') command.upgrade(cfg, args.revision) - elif args.cmd == "downgrade": - print(f"Downgrading database to {args.revision}...") + elif args.cmd == 'downgrade': + print(f'Downgrading database to {args.revision}...') command.downgrade(cfg, args.revision) - print("Done.") - + print('Done.') diff --git a/src/a2a/migrations/env.py b/src/a2a/migrations/env.py index dbc7b8c00..43bac7cf6 100644 --- a/src/a2a/migrations/env.py +++ b/src/a2a/migrations/env.py @@ -15,14 +15,14 @@ config = context.config # Mandatory database configuration -db_url = os.getenv("DATABASE_URL") +db_url = os.getenv('DATABASE_URL') if not db_url: raise RuntimeError( - "DATABASE_URL environment variable is not set. " + 'DATABASE_URL environment variable is not set. ' "Please set it (e.g., export DATABASE_URL='sqlite+aiosqlite:///./my-database.db') before running migrations" - "or use the --database-url flag." + 'or use the --database-url flag.' ) -config.set_main_option("sqlalchemy.url", db_url) +config.set_main_option('sqlalchemy.url', db_url) # Interpret the config file for Python logging. # This line sets up loggers basically. @@ -32,6 +32,7 @@ # add your model's MetaData object here for 'autogenerate' support target_metadata = Base.metadata + def run_migrations_offline() -> None: """Run migrations in 'offline' mode. @@ -79,7 +80,7 @@ async def run_async_migrations() -> None: and associate a connection with the context. """ connectable = async_engine_from_config( - config.get_section(config.config_ini_section), + config.get_section(config.config_ini_section, {}), prefix='sqlalchemy.', poolclass=pool.NullPool, ) diff --git a/src/a2a/server/tasks/inmemory_push_notification_config_store.py b/src/a2a/server/tasks/inmemory_push_notification_config_store.py index d5c0f10a4..8020f64ca 100644 --- a/src/a2a/server/tasks/inmemory_push_notification_config_store.py +++ b/src/a2a/server/tasks/inmemory_push_notification_config_store.py @@ -31,7 +31,7 @@ def __init__( str, dict[str, list[PushNotificationConfig]] ] = defaultdict(dict) self.owner_resolver = owner_resolver - + def _get_owner_push_notification_infos(self, owner: str) -> dict[str, list[PushNotificationConfig]]: return self._push_notification_infos.get(owner, {}) From 5bfccf9d986c7118916a866cfce4905e257ea418 Mon Sep 17 00:00:00 2001 From: sokoliva Date: Mon, 23 Feb 2026 17:36:59 +0000 Subject: [PATCH 18/29] fix: uv run ruff format --- .../6419d2d130f6_add_owner_to_task.py | 3 +-- .../tasks/base_push_notification_sender.py | 4 +++- ...inmemory_push_notification_config_store.py | 4 +++- .../tasks/test_push_notification_sender.py | 23 ++++++++++++++----- 4 files changed, 24 insertions(+), 10 deletions(-) diff --git a/src/a2a/migrations/versions/6419d2d130f6_add_owner_to_task.py b/src/a2a/migrations/versions/6419d2d130f6_add_owner_to_task.py index 1539b97ba..26d192511 100644 --- a/src/a2a/migrations/versions/6419d2d130f6_add_owner_to_task.py +++ b/src/a2a/migrations/versions/6419d2d130f6_add_owner_to_task.py @@ -23,7 +23,7 @@ def upgrade() -> None: """Upgrade schema.""" # Get the default value from the config (passed via CLI) - owner = context.config.get_main_option("owner", "unknown") + owner = context.config.get_main_option('owner', 'unknown') op.add_column( 'tasks', @@ -36,7 +36,6 @@ def upgrade() -> None: ) - def downgrade() -> None: """Downgrade schema.""" op.drop_column('tasks', 'owner') diff --git a/src/a2a/server/tasks/base_push_notification_sender.py b/src/a2a/server/tasks/base_push_notification_sender.py index dc7ba4979..201169e6e 100644 --- a/src/a2a/server/tasks/base_push_notification_sender.py +++ b/src/a2a/server/tasks/base_push_notification_sender.py @@ -44,7 +44,9 @@ async def send_notification( self, task_id: str, event: PushNotificationEvent ) -> None: """Sends a push notification for an event if configuration exists.""" - push_configs = await self._config_store.get_info(task_id, self._call_context) + push_configs = await self._config_store.get_info( + task_id, self._call_context + ) if not push_configs: return diff --git a/src/a2a/server/tasks/inmemory_push_notification_config_store.py b/src/a2a/server/tasks/inmemory_push_notification_config_store.py index 8020f64ca..4ea93cdfe 100644 --- a/src/a2a/server/tasks/inmemory_push_notification_config_store.py +++ b/src/a2a/server/tasks/inmemory_push_notification_config_store.py @@ -32,7 +32,9 @@ def __init__( ] = defaultdict(dict) self.owner_resolver = owner_resolver - def _get_owner_push_notification_infos(self, owner: str) -> dict[str, list[PushNotificationConfig]]: + def _get_owner_push_notification_infos( + self, owner: str + ) -> dict[str, list[PushNotificationConfig]]: return self._push_notification_infos.get(owner, {}) async def set_info( diff --git a/tests/server/tasks/test_push_notification_sender.py b/tests/server/tasks/test_push_notification_sender.py index 36b12d87b..d0cc7fac5 100644 --- a/tests/server/tasks/test_push_notification_sender.py +++ b/tests/server/tasks/test_push_notification_sender.py @@ -21,6 +21,7 @@ TaskStatusUpdateEvent, ) + class SampleUser(User): """A test implementation of the User interface.""" @@ -35,8 +36,10 @@ def is_authenticated(self) -> bool: def user_name(self) -> str: return self._user_name + MINIMAL_CALL_CONTEXT = ServerCallContext(user=SampleUser(user_name='user')) + def _create_sample_task( task_id: str = 'task123', status_state: TaskState = TaskState.TASK_STATE_COMPLETED, @@ -82,7 +85,9 @@ async def test_send_notification_success(self) -> None: await self.sender.send_notification(task_id, task_data) - self.mock_config_store.get_info.assert_awaited_once_with(task_data.id, MINIMAL_CALL_CONTEXT) + self.mock_config_store.get_info.assert_awaited_once_with( + task_data.id, MINIMAL_CALL_CONTEXT + ) # assert httpx_client post method got invoked with right parameters self.mock_httpx_client.post.assert_awaited_once_with( @@ -106,7 +111,9 @@ async def test_send_notification_with_token_success(self) -> None: await self.sender.send_notification(task_id, task_data) - self.mock_config_store.get_info.assert_awaited_once_with(task_data.id, MINIMAL_CALL_CONTEXT) + self.mock_config_store.get_info.assert_awaited_once_with( + task_data.id, MINIMAL_CALL_CONTEXT + ) # assert httpx_client post method got invoked with right parameters self.mock_httpx_client.post.assert_awaited_once_with( @@ -148,7 +155,7 @@ async def test_send_notification_http_status_error( await self.sender.send_notification(task_id, task_data) self.mock_config_store.get_info.assert_awaited_once_with( - task_id, MINIMAL_CALL_CONTEXT + task_id, MINIMAL_CALL_CONTEXT ) self.mock_httpx_client.post.assert_awaited_once_with( config.url, @@ -175,7 +182,7 @@ async def test_send_notification_multiple_configs(self) -> None: await self.sender.send_notification(task_id, task_data) self.mock_config_store.get_info.assert_awaited_once_with( - task_id, MINIMAL_CALL_CONTEXT + task_id, MINIMAL_CALL_CONTEXT ) self.assertEqual(self.mock_httpx_client.post.call_count, 2) @@ -208,7 +215,9 @@ async def test_send_notification_status_update_event(self) -> None: await self.sender.send_notification(task_id, event) - self.mock_config_store.get_info.assert_awaited_once_with(task_id, MINIMAL_CALL_CONTEXT) + self.mock_config_store.get_info.assert_awaited_once_with( + task_id, MINIMAL_CALL_CONTEXT + ) self.mock_httpx_client.post.assert_awaited_once_with( config.url, json=MessageToDict(StreamResponse(status_update=event)), @@ -230,7 +239,9 @@ async def test_send_notification_artifact_update_event(self) -> None: await self.sender.send_notification(task_id, event) - self.mock_config_store.get_info.assert_awaited_once_with(task_id, MINIMAL_CALL_CONTEXT) + self.mock_config_store.get_info.assert_awaited_once_with( + task_id, MINIMAL_CALL_CONTEXT + ) self.mock_httpx_client.post.assert_awaited_once_with( config.url, json=MessageToDict(StreamResponse(artifact_update=event)), From bde88f14a8c7ae767cb2a6da9f7f77a1e8a4d40d Mon Sep 17 00:00:00 2001 From: sokoliva Date: Tue, 24 Feb 2026 08:57:12 +0000 Subject: [PATCH 19/29] refactor: add an edge case to test_database_task_store --- tests/server/tasks/test_database_task_store.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/server/tasks/test_database_task_store.py b/tests/server/tasks/test_database_task_store.py index e6b67701c..9887c9704 100644 --- a/tests/server/tasks/test_database_task_store.py +++ b/tests/server/tasks/test_database_task_store.py @@ -632,6 +632,7 @@ async def test_owner_resource_scoping( context_user1 = ServerCallContext(user=TestUser(user_name='user1')) context_user2 = ServerCallContext(user=TestUser(user_name='user2')) + context_user3 = ServerCallContext(user=TestUser(user_name='user3')) # user with no tasks # Create tasks for different owners task1_user1, task2_user1, task1_user2 = Task(), Task(), Task() @@ -664,6 +665,10 @@ async def test_owner_resource_scoping( assert {t.id for t in page_user2.tasks} == {'u2-task1'} assert page_user2.total_size == 1 + page_user3 = await task_store.list(params, context_user3) + assert len(page_user3.tasks) == 0 + assert page_user3.total_size == 0 + # Test DELETE await task_store.delete('u1-task1', context_user2) # Should not delete assert await task_store.get('u1-task1', context_user1) is not None From db3d050a8ea460611da2c1f2b5e277bf66a88c00 Mon Sep 17 00:00:00 2001 From: sokoliva Date: Tue, 24 Feb 2026 08:59:16 +0000 Subject: [PATCH 20/29] fix: uv run ruff format --- tests/server/tasks/test_database_task_store.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/server/tasks/test_database_task_store.py b/tests/server/tasks/test_database_task_store.py index 9887c9704..bf912281f 100644 --- a/tests/server/tasks/test_database_task_store.py +++ b/tests/server/tasks/test_database_task_store.py @@ -632,7 +632,9 @@ async def test_owner_resource_scoping( context_user1 = ServerCallContext(user=TestUser(user_name='user1')) context_user2 = ServerCallContext(user=TestUser(user_name='user2')) - context_user3 = ServerCallContext(user=TestUser(user_name='user3')) # user with no tasks + context_user3 = ServerCallContext( + user=TestUser(user_name='user3') + ) # user with no tasks # Create tasks for different owners task1_user1, task2_user1, task1_user2 = Task(), Task(), Task() From a37efe1dcecc04be015b7bf3cf94e013f17f4520 Mon Sep 17 00:00:00 2001 From: sokoliva Date: Tue, 24 Feb 2026 09:13:42 +0000 Subject: [PATCH 21/29] fix: typo --- src/a2a/migrations/env.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/a2a/migrations/env.py b/src/a2a/migrations/env.py index 43bac7cf6..554b58cbc 100644 --- a/src/a2a/migrations/env.py +++ b/src/a2a/migrations/env.py @@ -19,7 +19,7 @@ if not db_url: raise RuntimeError( 'DATABASE_URL environment variable is not set. ' - "Please set it (e.g., export DATABASE_URL='sqlite+aiosqlite:///./my-database.db') before running migrations" + "Please set it (e.g., export DATABASE_URL='sqlite+aiosqlite:///./my-database.db') before running migrations " 'or use the --database-url flag.' ) config.set_main_option('sqlalchemy.url', db_url) From 782fcf7895950c3716ea2a190cd323fd8351a62c Mon Sep 17 00:00:00 2001 From: sokoliva Date: Tue, 24 Feb 2026 11:40:16 +0000 Subject: [PATCH 22/29] chore: add migration for `push_notification_configs` table --- src/a2a/migrations/README | 14 +++++++++----- .../versions/6419d2d130f6_add_owner_to_task.py | 11 +++++++++++ 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/src/a2a/migrations/README b/src/a2a/migrations/README index ebd418b93..6ee439333 100644 --- a/src/a2a/migrations/README +++ b/src/a2a/migrations/README @@ -4,9 +4,13 @@ This directory handles the database schema updates for the A2A SDK. It uses a bu ## User Guide (For Integrators) -When you install the `a2a-sdk`, you get a built-in command `a2a-db` to manage your database schema. +When you install the `a2a-sdk`, you get a built-in command `a2a-db` which updates tables 'tasks' and 'push_notification_configs' in your database to make it compatible with the latest version of the SDK. -### 1. Set your Database URL +### 1. Recommended: Backup your database + +Before running migrations, it is recommended to backup your database. + +### 2. Set your Database URL Migrations require the `DATABASE_URL` environment variable to be set with an **async-compatible** driver or you can use the `-u` flag to specify the database URL for a single command. ```bash @@ -20,7 +24,7 @@ export DATABASE_URL="postgresql+asyncpg://user:pass@localhost/your_database_name export DATABASE_URL="mysql+aiomysql://user:pass@localhost/your_database_name" ``` -### 2. Apply Migrations +### 3. Apply Migrations Always run this command after installing or upgrading the SDK to ensure your database matches the required schema. ```bash @@ -28,14 +32,14 @@ Always run this command after installing or upgrading the SDK to ensure your dat a2a-db ``` -### 3. Customizing Defaults +### 4. Customizing Defaults Add owner to tasks migration allows you to pass custom values for the new `owner` column. For example, to set a specific default owner for existing tasks: ```bash a2a-db -o "admin_user" ``` -### 4. Rolling Back +### 5. Rolling Back If you need to revert a change: ```bash diff --git a/src/a2a/migrations/versions/6419d2d130f6_add_owner_to_task.py b/src/a2a/migrations/versions/6419d2d130f6_add_owner_to_task.py index 26d192511..a9e5a77d0 100644 --- a/src/a2a/migrations/versions/6419d2d130f6_add_owner_to_task.py +++ b/src/a2a/migrations/versions/6419d2d130f6_add_owner_to_task.py @@ -35,7 +35,18 @@ def upgrade() -> None: ), ) + op.add_column( + 'push_notification_configs', + sa.Column( + 'owner', + sa.String(128), + nullable=False, + server_default=owner, + ), + ) + def downgrade() -> None: """Downgrade schema.""" op.drop_column('tasks', 'owner') + op.drop_column('push_notification_configs', 'owner') From 34289566ec8c04310f00d196d0c378b83bfbba08 Mon Sep 17 00:00:00 2001 From: sokoliva Date: Tue, 24 Feb 2026 15:24:57 +0000 Subject: [PATCH 23/29] fix: rename README to README.md --- src/a2a/migrations/{README => README.md} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename src/a2a/migrations/{README => README.md} (100%) diff --git a/src/a2a/migrations/README b/src/a2a/migrations/README.md similarity index 100% rename from src/a2a/migrations/README rename to src/a2a/migrations/README.md From d90b3603c2c738b874dad4576a2340dcc388a9c6 Mon Sep 17 00:00:00 2001 From: sokoliva Date: Tue, 24 Feb 2026 20:47:34 +0000 Subject: [PATCH 24/29] feat: enhance migration CLI with table and verbose options; add owner and last_updated columns to tasks and push_notification_configs --- src/a2a/cli.py | 66 +++++++++++--- src/a2a/migrations/env.py | 4 + ...d2d130f6_add_columns_owner_last_updated.py | 86 +++++++++++++++++++ .../6419d2d130f6_add_owner_to_task.py | 52 ----------- 4 files changed, 146 insertions(+), 62 deletions(-) create mode 100644 src/a2a/migrations/versions/6419d2d130f6_add_columns_owner_last_updated.py delete mode 100644 src/a2a/migrations/versions/6419d2d130f6_add_owner_to_task.py diff --git a/src/a2a/cli.py b/src/a2a/cli.py index b5f3f7faf..afd5b994b 100644 --- a/src/a2a/cli.py +++ b/src/a2a/cli.py @@ -22,6 +22,18 @@ def run_migrations() -> None: '--database-url', help='Database URL to use for the migrations. If not set, the DATABASE_URL environment variable will be used.', ) + parser.add_argument( + '-t', + '--table', + help="Specific table to update. If not set, both 'tasks' and 'push_notification_configs' are updated.", + action='append', + ) + parser.add_argument( + '-v', + '--verbose', + help='Enable verbose output (sets sqlalchemy.engine logging to INFO)', + action='store_true', + ) subparsers = parser.add_subparsers(dest='cmd', help='Migration command') @@ -44,6 +56,20 @@ def run_migrations() -> None: dest='sub_database_url', help='Alias for top-level --database-url', ) + up_parser.add_argument( + '-t', + '--table', + dest='sub_table', + help='Alias for top-level --table', + action='append', + ) + up_parser.add_argument( + '-v', + '--verbose', + dest='sub_verbose', + help='Enable verbose output (sets sqlalchemy.engine logging to INFO)', + action='store_true', + ) # Downgrade command down_parser = subparsers.add_parser( @@ -61,37 +87,57 @@ def run_migrations() -> None: dest='sub_database_url', help='Alias for top-level --database-url', ) + down_parser.add_argument( + '-t', + '--table', + dest='sub_table', + help='Alias for top-level --table', + action='append', + ) + down_parser.add_argument( + '-v', + '--verbose', + dest='sub_verbose', + help='Enable verbose output (sets sqlalchemy.engine logging to INFO)', + action='store_true', + ) args = parser.parse_args() - # Consolidate owner value - owner = args.owner or getattr(args, 'sub_owner', None) - db_url = args.database_url or getattr(args, 'sub_database_url', None) - if db_url: - os.environ['DATABASE_URL'] = db_url - # Default to upgrade head if no command is provided if not args.cmd: args.cmd = 'upgrade' args.revision = 'head' - # 1. Locate the bundled alembic.ini + # Locate the bundled alembic.ini ini_path = files('a2a').joinpath('alembic.ini') cfg = Config(str(ini_path)) - # 2. Dynamically set the script location + # Dynamically set the script location migrations_path = files('a2a').joinpath('migrations') cfg.set_main_option('script_location', str(migrations_path)) - # 3. Pass custom arguments to the migration context + # Consolidate owner, db_url, tables and verbose values + owner = args.owner or getattr(args, 'sub_owner', None) + db_url = args.database_url or getattr(args, 'sub_database_url', None) + tables = args.table or getattr(args, 'sub_table', None) + verbose = args.verbose or getattr(args, 'sub_verbose', False) + + # Pass custom arguments to the migration context if owner: if args.cmd == 'downgrade': parser.error( "The --owner option is not supported for the 'downgrade' command." ) cfg.set_main_option('owner', owner) + if db_url: + os.environ['DATABASE_URL'] = db_url + if tables: + cfg.set_main_option('tables', ','.join(tables)) + if verbose: + cfg.set_main_option('verbose', 'true') - # 3. Execute the requested command + # Execute the requested command if args.cmd == 'upgrade': print(f'Upgrading database to {args.revision}...') command.upgrade(cfg, args.revision) diff --git a/src/a2a/migrations/env.py b/src/a2a/migrations/env.py index 554b58cbc..9d27fb13a 100644 --- a/src/a2a/migrations/env.py +++ b/src/a2a/migrations/env.py @@ -1,4 +1,5 @@ import asyncio +import logging import os from logging.config import fileConfig @@ -29,6 +30,9 @@ if config.config_file_name is not None: fileConfig(config.config_file_name) +if config.get_main_option('verbose') == 'true': + logging.getLogger('sqlalchemy.engine').setLevel(logging.INFO) + # add your model's MetaData object here for 'autogenerate' support target_metadata = Base.metadata diff --git a/src/a2a/migrations/versions/6419d2d130f6_add_columns_owner_last_updated.py b/src/a2a/migrations/versions/6419d2d130f6_add_columns_owner_last_updated.py new file mode 100644 index 000000000..20d93bc66 --- /dev/null +++ b/src/a2a/migrations/versions/6419d2d130f6_add_columns_owner_last_updated.py @@ -0,0 +1,86 @@ +"""add_columns_owner_last_updated. + +Revision ID: 6419d2d130f6 +Revises: +Create Date: 2026-02-17 09:23:06.758085 + +""" + +from collections.abc import Sequence + +import sqlalchemy as sa + +from alembic import context, op + + +# revision identifiers, used by Alembic. +revision: str = '6419d2d130f6' +down_revision: str | Sequence[str] | None = None +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def column_exists(table_name: str, column_name: str) -> bool: + bind = op.get_bind() + inspector = sa.inspect(bind) + columns = [c['name'] for c in inspector.get_columns(table_name)] + return column_name in columns + + +def index_exists(table_name: str, index_name: str) -> bool: + bind = op.get_bind() + inspector = sa.inspect(bind) + indexes = [i['name'] for i in inspector.get_indexes(table_name)] + return index_name in indexes + + +def upgrade() -> None: + """Upgrade schema.""" + # Get the default value from the config (passed via CLI) + owner = context.config.get_main_option('owner', 'unknown') + tables_str = context.config.get_main_option( + 'tables', 'tasks,push_notification_configs' + ) + tables = [t.strip() for t in tables_str.split(',')] + + for table in tables: + if not column_exists(table, 'owner'): + op.add_column( + table, + sa.Column( + 'owner', + sa.String(128), + nullable=False, + server_default=owner, + ), + ) + if column_exists( + table, 'kind' + ): # Check to differentiate between table of tasks and push_notification_configs. Only tasks table should have last_updated column and index. + if not column_exists(table, 'last_updated'): + op.add_column( + table, + sa.Column('last_updated', sa.String(22), nullable=True), + ) + if not index_exists(table, f'idx_{table}_owner_last_updated'): + op.create_index( + f'idx_{table}_owner_last_updated', + table, + ['owner', 'last_updated'], + ) + + +def downgrade() -> None: + """Downgrade schema.""" + tables_str = context.config.get_main_option( + 'tables', 'tasks,push_notification_configs' + ) + tables = [t.strip() for t in tables_str.split(',')] + + for table in tables: + if index_exists(table, f'idx_{table}_owner_last_updated'): + op.drop_index(f'idx_{table}_owner_last_updated', table_name=table) + if column_exists(table, 'owner'): + op.drop_column(table, 'owner') + if column_exists(table, 'last_updated'): + op.drop_column(table, 'last_updated') diff --git a/src/a2a/migrations/versions/6419d2d130f6_add_owner_to_task.py b/src/a2a/migrations/versions/6419d2d130f6_add_owner_to_task.py deleted file mode 100644 index a9e5a77d0..000000000 --- a/src/a2a/migrations/versions/6419d2d130f6_add_owner_to_task.py +++ /dev/null @@ -1,52 +0,0 @@ -"""add_owner_to_task. - -Revision ID: 6419d2d130f6 -Revises: -Create Date: 2026-02-17 09:23:06.758085 - -""" - -from collections.abc import Sequence - -import sqlalchemy as sa - -from alembic import context, op - - -# revision identifiers, used by Alembic. -revision: str = '6419d2d130f6' -down_revision: str | Sequence[str] | None = None -branch_labels: str | Sequence[str] | None = None -depends_on: str | Sequence[str] | None = None - - -def upgrade() -> None: - """Upgrade schema.""" - # Get the default value from the config (passed via CLI) - owner = context.config.get_main_option('owner', 'unknown') - - op.add_column( - 'tasks', - sa.Column( - 'owner', - sa.String(128), - nullable=False, - server_default=owner, - ), - ) - - op.add_column( - 'push_notification_configs', - sa.Column( - 'owner', - sa.String(128), - nullable=False, - server_default=owner, - ), - ) - - -def downgrade() -> None: - """Downgrade schema.""" - op.drop_column('tasks', 'owner') - op.drop_column('push_notification_configs', 'owner') From 28b6cfc1eb8570456bdcbc34d4ecf49802bdfc7f Mon Sep 17 00:00:00 2001 From: sokoliva Date: Wed, 25 Feb 2026 11:33:53 +0000 Subject: [PATCH 25/29] feat: enhance migration CLI with SQL mode and logging; update README and tests --- src/a2a/alembic.ini | 4 +- src/a2a/cli.py | 44 ++++- src/a2a/migrations/README.md | 82 ++++---- src/a2a/migrations/env.py | 2 + ...d2d130f6_add_columns_owner_last_updated.py | 4 + tests/migrations/test_cli.py | 184 ++++++++++++++++++ tests/migrations/test_env.py | 127 ++++++++++++ .../versions/test_migration_6419d2d130f6.py | 176 +++++++++++++++++ .../server/tasks/test_database_task_store.py | 8 +- tests/server/test_owner_resolver.py | 6 +- 10 files changed, 585 insertions(+), 52 deletions(-) create mode 100644 tests/migrations/test_cli.py create mode 100644 tests/migrations/test_env.py create mode 100644 tests/migrations/versions/test_migration_6419d2d130f6.py diff --git a/src/a2a/alembic.ini b/src/a2a/alembic.ini index ed0e60bc2..f46511c00 100644 --- a/src/a2a/alembic.ini +++ b/src/a2a/alembic.ini @@ -10,7 +10,7 @@ keys = console keys = generic [logger_root] -level = WARNING +level = INFO handlers = console qualname = @@ -20,7 +20,7 @@ handlers = qualname = sqlalchemy.engine [logger_alembic] -level = INFO +level = WARNING handlers = qualname = alembic diff --git a/src/a2a/cli.py b/src/a2a/cli.py index afd5b994b..faf04e1b6 100644 --- a/src/a2a/cli.py +++ b/src/a2a/cli.py @@ -1,4 +1,5 @@ import argparse +import logging import os from importlib.resources import files @@ -7,8 +8,8 @@ from alembic.config import Config -def run_migrations() -> None: - """CLI tool to manage database migrations.""" +def create_parser() -> argparse.ArgumentParser: + """Create the argument parser for the migration tool.""" parser = argparse.ArgumentParser(description='A2A Database Migration Tool') # Global options @@ -34,6 +35,11 @@ def run_migrations() -> None: help='Enable verbose output (sets sqlalchemy.engine logging to INFO)', action='store_true', ) + parser.add_argument( + '--sql', + help='Run migrations in sql mode (generate SQL instead of executing)', + action='store_true', + ) subparsers = parser.add_subparsers(dest='cmd', help='Migration command') @@ -70,6 +76,12 @@ def run_migrations() -> None: help='Enable verbose output (sets sqlalchemy.engine logging to INFO)', action='store_true', ) + up_parser.add_argument( + '--sql', + dest='sub_sql', + help='Run migrations in sql mode (generate SQL instead of executing)', + action='store_true', + ) # Downgrade command down_parser = subparsers.add_parser( @@ -101,7 +113,22 @@ def run_migrations() -> None: help='Enable verbose output (sets sqlalchemy.engine logging to INFO)', action='store_true', ) + down_parser.add_argument( + '--sql', + dest='sub_sql', + help='Run migrations in sql mode (generate SQL instead of executing)', + action='store_true', + ) + + return parser + + +def run_migrations() -> None: + """CLI tool to manage database migrations.""" + # Configure logging to show INFO messages + logging.basicConfig(level=logging.INFO, format='%(levelname)s %(message)s') + parser = create_parser() args = parser.parse_args() # Default to upgrade head if no command is provided @@ -117,11 +144,12 @@ def run_migrations() -> None: migrations_path = files('a2a').joinpath('migrations') cfg.set_main_option('script_location', str(migrations_path)) - # Consolidate owner, db_url, tables and verbose values + # Consolidate owner, db_url, tables, verbose and sql values owner = args.owner or getattr(args, 'sub_owner', None) db_url = args.database_url or getattr(args, 'sub_database_url', None) tables = args.table or getattr(args, 'sub_table', None) verbose = args.verbose or getattr(args, 'sub_verbose', False) + sql = args.sql or getattr(args, 'sub_sql', False) # Pass custom arguments to the migration context if owner: @@ -139,10 +167,10 @@ def run_migrations() -> None: # Execute the requested command if args.cmd == 'upgrade': - print(f'Upgrading database to {args.revision}...') - command.upgrade(cfg, args.revision) + logging.info('Upgrading database to %s', args.revision) + command.upgrade(cfg, args.revision, sql=sql) elif args.cmd == 'downgrade': - print(f'Downgrading database to {args.revision}...') - command.downgrade(cfg, args.revision) + logging.info('Downgrading database to %s', args.revision) + command.downgrade(cfg, args.revision, sql=sql) - print('Done.') + logging.info('Done.') diff --git a/src/a2a/migrations/README.md b/src/a2a/migrations/README.md index 6ee439333..5e0f4e4c0 100644 --- a/src/a2a/migrations/README.md +++ b/src/a2a/migrations/README.md @@ -1,65 +1,88 @@ # A2A SDK Database Migrations -This directory handles the database schema updates for the A2A SDK. It uses a bundled CLI tool to simplify the migration process for both users and developers of the SDK. +This directory handles the database schema updates for the A2A SDK. It uses a bundled CLI tool to simplify the migration process. -## User Guide (For Integrators) +## User Guide for Integrators -When you install the `a2a-sdk`, you get a built-in command `a2a-db` which updates tables 'tasks' and 'push_notification_configs' in your database to make it compatible with the latest version of the SDK. +When you install the `a2a-sdk`, you get a built-in command `a2a-db` which updates your database to make it compatible with the latest version of the SDK. -### 1. Recommended: Backup your database +### 1. Recommended: Back up your database -Before running migrations, it is recommended to backup your database. +Before running migrations, it is recommended to back up your database. ### 2. Set your Database URL -Migrations require the `DATABASE_URL` environment variable to be set with an **async-compatible** driver or you can use the `-u` flag to specify the database URL for a single command. +Migrations require the `DATABASE_URL` environment variable to be set with an `async-compatible` driver. +You can set it globally with `export DATABASE_URL`. Examples for SQLite, PostgreSQL and MySQL, respectively: ```bash -# SQLite example export DATABASE_URL="sqlite+aiosqlite://user:pass@host:port/your_database_name" -# PostgreSQL example export DATABASE_URL="postgresql+asyncpg://user:pass@localhost/your_database_name" -# MySQL example export DATABASE_URL="mysql+aiomysql://user:pass@localhost/your_database_name" ``` +Or you can use the `-u` flag to specify the database URL for a single command. + + ### 3. Apply Migrations -Always run this command after installing or upgrading the SDK to ensure your database matches the required schema. +Always run this command after installing or upgrading the SDK to ensure your database matches the required schema. This will upgrade the tables `tasks` and `push_notification_configs` in your database by adding columns `owner` and `last_updated` and an index `(owner, last_updated)` to the `tasks` table and a column `owner` to the `push_notification_configs` table. ```bash -# Bring the database to the latest version -a2a-db +uv run a2a-db ``` -### 4. Customizing Defaults -Add owner to tasks migration allows you to pass custom values for the new `owner` column. For example, to set a specific default owner for existing tasks: +### 4. Customizing Defaults with Flags +#### -o +Allows you to pass custom values for the new `owner` column. If not set, it will default to the value `unknown`. ```bash -a2a-db -o "admin_user" +uv run a2a-db -o "admin_user" ``` +#### -u +You can use the `-u` flag to specify the database URL for a single command. -### 5. Rolling Back -If you need to revert a change: +```bash +uv run a2a-db -u "sqlite+aiosqlite:///my_database.db" +``` +#### -t +By default, tables `tasks` and `push_notification_configs` are updated. Using `-t` flag allows you to choose which tables to update. ```bash -# Step back one version -a2a-db downgrade -1 +uv run a2a-db -t "my_table1" -t "my_table2" ``` +#### -v +Enables verbose output by setting `sqlalchemy.engine` logging to `INFO`. ```bash -# Downgrade to a specific revision ID -a2a-db downgrade +uv run a2a-db -v +``` +#### --sql +Enables running migrations in `offline` mode. Migrations are generated as SQL scripts instead of being run against the database. + +```bash +uv run a2a-db --sql ``` +### 5. Rolling Back +If you need to revert a change: + ```bash +# Step back one version +uv run a2a-db downgrade -1 + +# Downgrade to a specific revision ID +uv run a2a-db downgrade + # Revert all migrations -a2a-db downgrade base +uv run a2a-db downgrade base ``` +All flags except the `-o` flag can be used during rollbacks. + --- -## Developer Guide (For SDK Contributors) +## Developer Guide for SDK Contributors If you are modifying the SDK models and need to generate new migration files, use the following workflow. @@ -68,21 +91,10 @@ Developers should use the raw `alembic` command locally to generate migrations. ```bash # Detect changes in models.py and generate a script -uv run alembic revision --autogenerate -m "describe your changes" +uv run alembic revision --autogenerate -m "description of changes" ``` ### Internal Layout - `env.py`: Configures the migration engine and applies the mandatory `DATABASE_URL` check. - `versions/`: Contains the migration history. - `script.py.mako`: The template for all new migration files. - -## Troubleshooting - -### "Duplicate column name" -If your database already has the required tables (e.g., created via `Base.metadata.create_all()` in a legacy script), you may need to "stamp" the database to tell the SDK that it is already up to date: - -```bash -# Stamp the database without running SQL commands -# (Requires raw alembic command for developer use) -uv run alembic stamp head -``` diff --git a/src/a2a/migrations/env.py b/src/a2a/migrations/env.py index 9d27fb13a..33db469b3 100644 --- a/src/a2a/migrations/env.py +++ b/src/a2a/migrations/env.py @@ -101,6 +101,8 @@ def run_migrations_online() -> None: if context.is_offline_mode(): + logging.info('Running migrations in offline mode.') run_migrations_offline() else: + logging.info('Running migrations in online mode.') run_migrations_online() diff --git a/src/a2a/migrations/versions/6419d2d130f6_add_columns_owner_last_updated.py b/src/a2a/migrations/versions/6419d2d130f6_add_columns_owner_last_updated.py index 20d93bc66..c727aca3f 100644 --- a/src/a2a/migrations/versions/6419d2d130f6_add_columns_owner_last_updated.py +++ b/src/a2a/migrations/versions/6419d2d130f6_add_columns_owner_last_updated.py @@ -21,6 +21,8 @@ def column_exists(table_name: str, column_name: str) -> bool: + if context.is_offline_mode(): + return False bind = op.get_bind() inspector = sa.inspect(bind) columns = [c['name'] for c in inspector.get_columns(table_name)] @@ -28,6 +30,8 @@ def column_exists(table_name: str, column_name: str) -> bool: def index_exists(table_name: str, index_name: str) -> bool: + if context.is_offline_mode(): + return False bind = op.get_bind() inspector = sa.inspect(bind) indexes = [i['name'] for i in inspector.get_indexes(table_name)] diff --git a/tests/migrations/test_cli.py b/tests/migrations/test_cli.py new file mode 100644 index 000000000..c3dec2560 --- /dev/null +++ b/tests/migrations/test_cli.py @@ -0,0 +1,184 @@ +import os +import argparse +from unittest.mock import MagicMock, patch +import pytest +from a2a.cli import run_migrations + + +@pytest.fixture +def mock_alembic_command(): + with ( + patch('alembic.command.upgrade') as mock_upgrade, + patch('alembic.command.downgrade') as mock_downgrade, + ): + yield mock_upgrade, mock_downgrade + + +@pytest.fixture +def mock_alembic_config(): + with patch('a2a.cli.Config') as mock_config: + yield mock_config + + +def test_cli_upgrade_offline(mock_alembic_command, mock_alembic_config): + mock_upgrade, _ = mock_alembic_command + custom_owner = 'test-owner' + target_tables = ['table1', 'table2'] + + # Simulate: a2a-db upgrade head --sql -o test-owner -t table1 -t table2 -v + test_args = [ + 'a2a-db', + 'upgrade', + 'head', + '--sql', + '-o', + custom_owner, + '-t', + target_tables[0], + '-t', + target_tables[1], + '-v', + ] + with patch('sys.argv', test_args): + with patch.dict(os.environ, {'DATABASE_URL': 'sqlite:///test.db'}): + run_migrations() + + # Verify upgrade parameters + args, kwargs = mock_upgrade.call_args + assert kwargs['sql'] is True + assert args[1] == 'head' + + # Verify options were set in config instance + mock_alembic_config.return_value.set_main_option.assert_any_call( + 'owner', custom_owner + ) + mock_alembic_config.return_value.set_main_option.assert_any_call( + 'tables', ','.join(target_tables) + ) + mock_alembic_config.return_value.set_main_option.assert_any_call( + 'verbose', 'true' + ) + + +def test_cli_downgrade_offline(mock_alembic_command, mock_alembic_config): + _, mock_downgrade = mock_alembic_command + target_table = 'only_tasks' + + # Simulate: a2a-db downgrade base --sql -t only_tasks -v + test_args = [ + 'a2a-db', + 'downgrade', + 'base', + '--sql', + '-t', + target_table, + '-v', + ] + with patch('sys.argv', test_args): + with patch.dict(os.environ, {'DATABASE_URL': 'sqlite:///test.db'}): + run_migrations() + + args, kwargs = mock_downgrade.call_args + assert kwargs['sql'] is True + assert args[1] == 'base' + + # Verify options + mock_alembic_config.return_value.set_main_option.assert_any_call( + 'tables', target_table + ) + mock_alembic_config.return_value.set_main_option.assert_any_call( + 'verbose', 'true' + ) + + +def test_cli_global_offline(mock_alembic_command, mock_alembic_config): + mock_upgrade, _ = mock_alembic_command + + # Simulate: a2a-db --sql -v (defaults to upgrade head) + test_args = ['a2a-db', '--sql', '-v'] + with patch('sys.argv', test_args): + with patch.dict(os.environ, {'DATABASE_URL': 'sqlite:///test.db'}): + run_migrations() + + # Verify upgrade was called with sql=True + args, kwargs = mock_upgrade.call_args + assert kwargs['sql'] is True + + # Verify verbose option + mock_alembic_config.return_value.set_main_option.assert_any_call( + 'verbose', 'true' + ) + + +def test_cli_upgrade_online(mock_alembic_command, mock_alembic_config): + mock_upgrade, _ = mock_alembic_command + custom_owner = 'test-owner' + target_table = 'specific_table' + + # Simulate: a2a-db upgrade head -o test-owner -t specific_table -v + test_args = [ + 'a2a-db', + 'upgrade', + 'head', + '-o', + custom_owner, + '-t', + target_table, + '-v', + ] + with patch('sys.argv', test_args): + with patch.dict(os.environ, {'DATABASE_URL': 'sqlite:///test.db'}): + run_migrations() + + # Verify upgrade was called with sql=False + args, kwargs = mock_upgrade.call_args + assert kwargs['sql'] is False + + # Verify options were set in config instance + mock_alembic_config.return_value.set_main_option.assert_any_call( + 'owner', custom_owner + ) + mock_alembic_config.return_value.set_main_option.assert_any_call( + 'tables', target_table + ) + mock_alembic_config.return_value.set_main_option.assert_any_call( + 'verbose', 'true' + ) + + + +def test_cli_downgrade_online(mock_alembic_command, mock_alembic_config): + _, mock_downgrade = mock_alembic_command + target_table = 'other_table' + + # Simulate: a2a-db downgrade base -t other_table + test_args = ['a2a-db', 'downgrade', 'base', '-t', target_table] + with patch('sys.argv', test_args): + with patch.dict(os.environ, {'DATABASE_URL': 'sqlite:///test.db'}): + run_migrations() + + # Verify downgrade was called with sql=False + args, kwargs = mock_downgrade.call_args + assert kwargs['sql'] is False + + # Verify tables option + mock_alembic_config.return_value.set_main_option.assert_any_call( + 'tables', target_table + ) + + + +def test_cli_database_url_flag(mock_alembic_command, mock_alembic_config): + mock_upgrade, _ = mock_alembic_command + custom_db = 'sqlite:///custom_cli.db' + + # Simulate: a2a-db -u sqlite:///custom_cli.db + test_args = ['a2a-db', '-u', custom_db] + with patch('sys.argv', test_args): + # Clear environment to ensure it picks up the CLI flag + with patch.dict(os.environ, {}, clear=True): + run_migrations() + # Verify the CLI tool set the environment variable for env.py + assert os.environ['DATABASE_URL'] == custom_db + + mock_upgrade.assert_called() diff --git a/tests/migrations/test_env.py b/tests/migrations/test_env.py new file mode 100644 index 000000000..13d414aa1 --- /dev/null +++ b/tests/migrations/test_env.py @@ -0,0 +1,127 @@ +import asyncio +import logging +import os +import runpy +from unittest.mock import MagicMock, patch + +import pytest + + +@pytest.fixture +def mock_alembic(): + """Fixture to mock alembic context and config.""" + with patch('alembic.context') as mock_context: + mock_config = MagicMock() + mock_context.config = mock_config + yield mock_context, mock_config + + +def test_env_py_missing_db_url(mock_alembic): + """Test that env.py raises RuntimeError when DATABASE_URL is missing.""" + with patch.dict(os.environ, {}, clear=True): + with pytest.raises( + RuntimeError, match='DATABASE_URL environment variable is not set' + ): + # run_path executes the script in a fresh namespace + runpy.run_path('src/a2a/migrations/env.py') + + +@patch('logging.config.fileConfig') +@patch('a2a.server.models.Base.metadata') +def test_env_py_offline_mode(mock_metadata, mock_file_config, mock_alembic): + """Test env.py logic in offline mode.""" + mock_context, mock_config = mock_alembic + db_url = 'sqlite+aiosqlite:///test.db' + + mock_config.config_file_name = 'alembic.ini' + + # Mock get_main_option to return db_url for 'sqlalchemy.url' + def get_opt(key, default=None): + if key == 'sqlalchemy.url': + return db_url + return default + + mock_config.get_main_option.side_effect = get_opt + + mock_context.is_offline_mode.return_value = True + + with patch.dict(os.environ, {'DATABASE_URL': db_url}): + runpy.run_path('src/a2a/migrations/env.py') + + # Verify sqlalchemy.url was set from env var + mock_config.set_main_option.assert_any_call('sqlalchemy.url', db_url) + + # Verify logging was configured + mock_file_config.assert_called_with('alembic.ini') + + # Verify context.configure was called for offline mode + mock_context.configure.assert_called() + # Check if url was passed to configure + args, kwargs = mock_context.configure.call_args + assert kwargs['url'] == db_url + assert kwargs['target_metadata'] == mock_metadata + + +@patch('logging.config.fileConfig') +@patch('a2a.server.models.Base.metadata') +@patch('alembic.context.run_migrations') +@patch('sqlalchemy.ext.asyncio.async_engine_from_config') +@patch('asyncio.run') +def test_env_py_online_mode( + mock_asyncio_run, + mock_async_engine, + mock_run_migrations, + mock_metadata, + mock_file_config, + mock_alembic, +): + """Test env.py logic in online mode.""" + mock_context, mock_config = mock_alembic + db_url = 'sqlite+aiosqlite:///test.db' + + mock_config.config_file_name = None + mock_context.is_offline_mode.return_value = False + + # Prevent "coroutine never awaited" warning by closing the coro passed to asyncio.run + def close_coro(coro): + if asyncio.iscoroutine(coro): + coro.close() + + mock_asyncio_run.side_effect = close_coro + + with patch.dict(os.environ, {'DATABASE_URL': db_url}): + runpy.run_path('src/a2a/migrations/env.py') + + # Verify sqlalchemy.url was set + mock_config.set_main_option.assert_any_call('sqlalchemy.url', db_url) + + # Verify asyncio.run was called to start online migrations + mock_asyncio_run.assert_called() + + +def test_env_py_verbose_logging(mock_alembic): + """Test that env.py enables verbose logging when 'verbose' option is set.""" + mock_context, mock_config = mock_alembic + db_url = 'sqlite+aiosqlite:///test.db' + + # Use a real side_effect to simulate config.get_main_option + def get_opt(key, default=None): + if key == 'verbose': + return 'true' + return default + + mock_config.get_main_option.side_effect = get_opt + mock_config.config_file_name = None + mock_context.is_offline_mode.return_value = True + + with patch('logging.getLogger') as mock_get_logger: + mock_logger = MagicMock() + mock_get_logger.return_value = mock_logger + + with patch.dict(os.environ, {'DATABASE_URL': db_url}): + with patch('a2a.server.models.Base.metadata'): + runpy.run_path('src/a2a/migrations/env.py') + + # Check if sqlalchemy.engine logger level was set to INFO + mock_get_logger.assert_called_with('sqlalchemy.engine') + mock_logger.setLevel.assert_called_with(logging.INFO) diff --git a/tests/migrations/versions/test_migration_6419d2d130f6.py b/tests/migrations/versions/test_migration_6419d2d130f6.py new file mode 100644 index 000000000..dd815f418 --- /dev/null +++ b/tests/migrations/versions/test_migration_6419d2d130f6.py @@ -0,0 +1,176 @@ +import os +import sqlite3 +import subprocess +import tempfile +from typing import Generator + +import pytest + + +@pytest.fixture +def temp_db() -> Generator[str, None, None]: + """Create a temporary SQLite database for testing.""" + fd, path = tempfile.mkstemp(suffix='.db') + os.close(fd) + yield path + if os.path.exists(path): + os.remove(path) + + +def test_migration_6419d2d130f6_full_cycle(temp_db: str) -> None: + """Test the full upgrade/downgrade cycle for migration 6419d2d130f6.""" + db_url = f'sqlite+aiosqlite:///{temp_db}' + + # 1. Setup initial schema without the new columns + conn = sqlite3.connect(temp_db) + cursor = conn.cursor() + cursor.execute(""" + CREATE TABLE tasks ( + id VARCHAR(36) PRIMARY KEY, + context_id VARCHAR(36) NOT NULL, + kind VARCHAR(16) NOT NULL, + status TEXT, + artifacts TEXT, + history TEXT, + metadata TEXT + ) + """) + cursor.execute(""" + CREATE TABLE push_notification_configs ( + task_id VARCHAR(36), + config_id VARCHAR(255), + config_data BLOB NOT NULL, + PRIMARY KEY (task_id, config_id) + ) + """) + conn.commit() + conn.close() + + # 2. Run Upgrade via CLI with a custom owner + custom_owner = 'test_owner_123' + env = os.environ.copy() + env['DATABASE_URL'] = db_url + + # We use the CLI tool to perform the upgrade + result = subprocess.run( + [ + 'uv', + 'run', + 'a2a-db', + '--owner', + custom_owner, + 'upgrade', + '6419d2d130f6', + ], + capture_output=True, + text=True, + env=env, + check=False, + ) + + assert result.returncode == 0, f"""Upgrade failed: {result.stderr} +{result.stdout}""" + + # 3. Verify columns and index exist + conn = sqlite3.connect(temp_db) + cursor = conn.cursor() + + # Check tasks table + cursor.execute('PRAGMA table_info(tasks)') + tasks_columns = {row[1]: row for row in cursor.fetchall()} + assert 'owner' in tasks_columns + assert 'last_updated' in tasks_columns + + # Check default value for owner in tasks + # row[4] is dflt_value in PRAGMA table_info + assert tasks_columns['owner'][4] == f"'{custom_owner}'" + + # Check index on tasks + cursor.execute('PRAGMA index_list(tasks)') + tasks_indexes = {row[1] for row in cursor.fetchall()} + assert 'idx_tasks_owner_last_updated' in tasks_indexes + + # Check push_notification_configs table + cursor.execute('PRAGMA table_info(push_notification_configs)') + pnc_columns = {row[1]: row for row in cursor.fetchall()} + assert 'owner' in pnc_columns + assert ( + 'last_updated' not in pnc_columns + ) # Only for tables with 'kind' column + + conn.close() + + # 4. Run Downgrade via CLI + result = subprocess.run( + ['uv', 'run', 'a2a-db', 'downgrade', 'base'], + capture_output=True, + text=True, + env=env, + check=False, + ) + + assert result.returncode == 0, f"""Downgrade failed: {result.stderr} +{result.stdout}""" + + # 5. Verify columns are gone + conn = sqlite3.connect(temp_db) + cursor = conn.cursor() + + # Check tasks table + cursor.execute('PRAGMA table_info(tasks)') + tasks_columns_post = {row[1] for row in cursor.fetchall()} + assert 'owner' not in tasks_columns_post + assert 'last_updated' not in tasks_columns_post + + # Check index on tasks + cursor.execute('PRAGMA index_list(tasks)') + tasks_indexes_post = {row[1] for row in cursor.fetchall()} + assert 'idx_tasks_owner_last_updated' not in tasks_indexes_post + + # Check push_notification_configs table + cursor.execute('PRAGMA table_info(push_notification_configs)') + pnc_columns_post = {row[1] for row in cursor.fetchall()} + assert 'owner' not in pnc_columns_post + + conn.close() + + +def test_migration_6419d2d130f6_idempotency(temp_db: str) -> None: + """Test that the migration is idempotent (can be run multiple times).""" + db_url = f'sqlite+aiosqlite:///{temp_db}' + + # 1. Setup initial schema - must include both tables expected by the migration + conn = sqlite3.connect(temp_db) + cursor = conn.cursor() + cursor.execute( + 'CREATE TABLE tasks (id VARCHAR(36) PRIMARY KEY, kind VARCHAR(16))' + ) + cursor.execute( + 'CREATE TABLE push_notification_configs (task_id VARCHAR(36), config_id VARCHAR(255), PRIMARY KEY (task_id, config_id))' + ) + conn.commit() + conn.close() + + env = os.environ.copy() + env['DATABASE_URL'] = db_url + + # 2. Run Upgrade first time + result = subprocess.run( + ['uv', 'run', 'a2a-db', 'upgrade', '6419d2d130f6'], + capture_output=True, + text=True, + env=env, + check=False, + ) + assert result.returncode == 0 + + # 3. Run Upgrade second time - should not fail even if columns already exist + # (The migration script has 'if not column_exists' checks) + result = subprocess.run( + ['uv', 'run', 'a2a-db', 'upgrade', '6419d2d130f6'], + capture_output=True, + text=True, + env=env, + check=False, + ) + assert result.returncode == 0 diff --git a/tests/server/tasks/test_database_task_store.py b/tests/server/tasks/test_database_task_store.py index bf912281f..b71fd709b 100644 --- a/tests/server/tasks/test_database_task_store.py +++ b/tests/server/tasks/test_database_task_store.py @@ -36,7 +36,7 @@ from a2a.utils.constants import DEFAULT_LIST_TASKS_PAGE_SIZE -class TestUser(User): +class SampleUser(User): """A test implementation of the User interface.""" def __init__(self, user_name: str): @@ -630,10 +630,10 @@ async def test_owner_resource_scoping( """Test that operations are scoped to the correct owner.""" task_store = db_store_parameterized - context_user1 = ServerCallContext(user=TestUser(user_name='user1')) - context_user2 = ServerCallContext(user=TestUser(user_name='user2')) + context_user1 = ServerCallContext(user=SampleUser(user_name='user1')) + context_user2 = ServerCallContext(user=SampleUser(user_name='user2')) context_user3 = ServerCallContext( - user=TestUser(user_name='user3') + user=SampleUser(user_name='user3') ) # user with no tasks # Create tasks for different owners diff --git a/tests/server/test_owner_resolver.py b/tests/server/test_owner_resolver.py index 8a0686865..5bac5c605 100644 --- a/tests/server/test_owner_resolver.py +++ b/tests/server/test_owner_resolver.py @@ -4,7 +4,7 @@ from a2a.server.owner_resolver import resolve_user_scope -class TestUser(User): +class SampleUser(User): """A test implementation of the User interface.""" def __init__(self, user_name: str): @@ -21,9 +21,9 @@ def user_name(self) -> str: def test_resolve_user_scope_valid_user(): """Test resolve_user_scope with a valid user in the context.""" - user = TestUser(user_name='testuser') + user = SampleUser(user_name='SampleUser') context = ServerCallContext(user=user) - assert resolve_user_scope(context) == 'testuser' + assert resolve_user_scope(context) == 'SampleUser' def test_resolve_user_scope_no_context(): From 94f129d4d003422da3c14abeea613e716e59b7c7 Mon Sep 17 00:00:00 2001 From: sokoliva Date: Wed, 25 Feb 2026 11:45:50 +0000 Subject: [PATCH 26/29] fix: uv run ruff format --- tests/migrations/test_cli.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/migrations/test_cli.py b/tests/migrations/test_cli.py index c3dec2560..8a4e82be3 100644 --- a/tests/migrations/test_cli.py +++ b/tests/migrations/test_cli.py @@ -146,7 +146,6 @@ def test_cli_upgrade_online(mock_alembic_command, mock_alembic_config): ) - def test_cli_downgrade_online(mock_alembic_command, mock_alembic_config): _, mock_downgrade = mock_alembic_command target_table = 'other_table' @@ -167,7 +166,6 @@ def test_cli_downgrade_online(mock_alembic_command, mock_alembic_config): ) - def test_cli_database_url_flag(mock_alembic_command, mock_alembic_config): mock_upgrade, _ = mock_alembic_command custom_db = 'sqlite:///custom_cli.db' From 7f9e2bc81c9481c685a7bbce6ed903757c677856 Mon Sep 17 00:00:00 2001 From: sokoliva Date: Wed, 25 Feb 2026 12:22:33 +0000 Subject: [PATCH 27/29] fix: remove user existence check in resolve_user_scope function --- src/a2a/server/owner_resolver.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/a2a/server/owner_resolver.py b/src/a2a/server/owner_resolver.py index 4fa310b92..798eb8c9b 100644 --- a/src/a2a/server/owner_resolver.py +++ b/src/a2a/server/owner_resolver.py @@ -12,7 +12,5 @@ def resolve_user_scope(context: ServerCallContext | None) -> str: """Resolves the owner scope based on the user in the context.""" if not context: return 'unknown' - if not context.user: - raise ValueError('User not found in context.') # Example: Basic user name. Adapt as needed for your user model. return context.user.user_name From 59f1fd426a0ca973a00d94df3d0b6b7b3e9595e7 Mon Sep 17 00:00:00 2001 From: sokoliva Date: Wed, 25 Feb 2026 13:35:18 +0000 Subject: [PATCH 28/29] feat: add CLI support for alembic in migration tool and update dependencies --- pyproject.toml | 3 +- src/a2a/cli.py | 84 +++++++++++++------------------------------------- uv.lock | 16 ++++++---- 3 files changed, 34 insertions(+), 69 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 8caa7c70d..12ecef8ec 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,6 +39,7 @@ postgresql = ["sqlalchemy[asyncio,postgresql-asyncpg]>=2.0.0"] mysql = ["sqlalchemy[asyncio,aiomysql]>=2.0.0"] signing = ["PyJWT>=2.0.0"] sqlite = ["sqlalchemy[asyncio,aiosqlite]>=2.0.0"] +cli = ["alembic>=1.14.0"] sql = ["a2a-sdk[postgresql,mysql,sqlite]"] @@ -49,6 +50,7 @@ all = [ "a2a-sdk[grpc]", "a2a-sdk[telemetry]", "a2a-sdk[signing]", + "a2a-sdk[cli]", ] [project.urls] @@ -106,7 +108,6 @@ style = "pep440" [dependency-groups] dev = [ - "alembic>=1.14.0", "mypy>=1.15.0", "PyJWT>=2.0.0", "pytest>=8.3.5", diff --git a/src/a2a/cli.py b/src/a2a/cli.py index faf04e1b6..8fc1389e8 100644 --- a/src/a2a/cli.py +++ b/src/a2a/cli.py @@ -8,39 +8,49 @@ from alembic.config import Config -def create_parser() -> argparse.ArgumentParser: - """Create the argument parser for the migration tool.""" - parser = argparse.ArgumentParser(description='A2A Database Migration Tool') - - # Global options - parser.add_argument( - '-o', - '--owner', - help="Value for the 'owner' column (used in specific migrations). If not set defaults to 'unknown'", - ) +def _add_shared_args(parser: argparse.ArgumentParser, is_sub: bool = False) -> None: + """Add common arguments to the given parser.""" + prefix = 'sub_' if is_sub else '' parser.add_argument( '-u', '--database-url', + dest=f'{prefix}database_url', help='Database URL to use for the migrations. If not set, the DATABASE_URL environment variable will be used.', ) parser.add_argument( '-t', '--table', + dest=f'{prefix}table', help="Specific table to update. If not set, both 'tasks' and 'push_notification_configs' are updated.", action='append', ) parser.add_argument( '-v', '--verbose', + dest=f'{prefix}verbose', help='Enable verbose output (sets sqlalchemy.engine logging to INFO)', action='store_true', ) parser.add_argument( '--sql', + dest=f'{prefix}sql', help='Run migrations in sql mode (generate SQL instead of executing)', action='store_true', ) + +def create_parser() -> argparse.ArgumentParser: + """Create the argument parser for the migration tool.""" + parser = argparse.ArgumentParser(description='A2A Database Migration Tool') + + # Global options + parser.add_argument( + '-o', + '--owner', + help="Value for the 'owner' column (used in specific migrations). If not set defaults to 'unknown'", + ) + _add_shared_args(parser) + subparsers = parser.add_subparsers(dest='cmd', help='Migration command') # Upgrade command @@ -56,32 +66,7 @@ def create_parser() -> argparse.ArgumentParser: up_parser.add_argument( '-o', '--owner', dest='sub_owner', help='Alias for top-level --owner' ) - up_parser.add_argument( - '-u', - '--database-url', - dest='sub_database_url', - help='Alias for top-level --database-url', - ) - up_parser.add_argument( - '-t', - '--table', - dest='sub_table', - help='Alias for top-level --table', - action='append', - ) - up_parser.add_argument( - '-v', - '--verbose', - dest='sub_verbose', - help='Enable verbose output (sets sqlalchemy.engine logging to INFO)', - action='store_true', - ) - up_parser.add_argument( - '--sql', - dest='sub_sql', - help='Run migrations in sql mode (generate SQL instead of executing)', - action='store_true', - ) + _add_shared_args(up_parser, is_sub=True) # Downgrade command down_parser = subparsers.add_parser( @@ -93,32 +78,7 @@ def create_parser() -> argparse.ArgumentParser: default='base', help='Revision target (e.g., -1, base or a specific ID)', ) - down_parser.add_argument( - '-u', - '--database-url', - dest='sub_database_url', - help='Alias for top-level --database-url', - ) - down_parser.add_argument( - '-t', - '--table', - dest='sub_table', - help='Alias for top-level --table', - action='append', - ) - down_parser.add_argument( - '-v', - '--verbose', - dest='sub_verbose', - help='Enable verbose output (sets sqlalchemy.engine logging to INFO)', - action='store_true', - ) - down_parser.add_argument( - '--sql', - dest='sub_sql', - help='Run migrations in sql mode (generate SQL instead of executing)', - action='store_true', - ) + _add_shared_args(down_parser, is_sub=True) return parser diff --git a/uv.lock b/uv.lock index 748ef3ee6..afa911af4 100644 --- a/uv.lock +++ b/uv.lock @@ -22,6 +22,7 @@ dependencies = [ [package.optional-dependencies] all = [ + { name = "alembic" }, { name = "cryptography" }, { name = "fastapi" }, { name = "grpcio" }, @@ -34,6 +35,9 @@ all = [ { name = "sse-starlette" }, { name = "starlette" }, ] +cli = [ + { name = "alembic" }, +] encryption = [ { name = "cryptography" }, ] @@ -70,7 +74,6 @@ telemetry = [ [package.dev-dependencies] dev = [ { name = "a2a-sdk", extra = ["all"] }, - { name = "alembic" }, { name = "autoflake" }, { name = "mypy" }, { name = "no-implicit-optional" }, @@ -94,6 +97,8 @@ dev = [ [package.metadata] requires-dist = [ + { name = "alembic", marker = "extra == 'all'", specifier = ">=1.14.0" }, + { name = "alembic", marker = "extra == 'cli'", specifier = ">=1.14.0" }, { name = "cryptography", marker = "extra == 'all'", specifier = ">=43.0.0" }, { name = "cryptography", marker = "extra == 'encryption'", specifier = ">=43.0.0" }, { name = "fastapi", marker = "extra == 'all'", specifier = ">=0.115.2" }, @@ -131,12 +136,11 @@ requires-dist = [ { name = "starlette", marker = "extra == 'all'" }, { name = "starlette", marker = "extra == 'http-server'" }, ] -provides-extras = ["all", "encryption", "grpc", "http-server", "mysql", "postgresql", "signing", "sql", "sqlite", "telemetry"] +provides-extras = ["all", "cli", "encryption", "grpc", "http-server", "mysql", "postgresql", "signing", "sql", "sqlite", "telemetry"] [package.metadata.requires-dev] dev = [ { name = "a2a-sdk", extras = ["all"], editable = "." }, - { name = "alembic", specifier = ">=1.14.0" }, { name = "autoflake" }, { name = "mypy", specifier = ">=1.15.0" }, { name = "no-implicit-optional" }, @@ -2352,7 +2356,7 @@ wheels = [ [[package]] name = "virtualenv" -version = "20.38.0" +version = "20.39.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "distlib" }, @@ -2360,9 +2364,9 @@ dependencies = [ { name = "platformdirs" }, { name = "typing-extensions", marker = "python_full_version < '3.11'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/d2/03/a94d404ca09a89a7301a7008467aed525d4cdeb9186d262154dd23208709/virtualenv-20.38.0.tar.gz", hash = "sha256:94f39b1abaea5185bf7ea5a46702b56f1d0c9aa2f41a6c2b8b0af4ddc74c10a7", size = 5864558, upload-time = "2026-02-19T07:48:02.385Z" } +sdist = { url = "https://files.pythonhosted.org/packages/ed/54/809199edc537dbace273495ac0884d13df26436e910a5ed4d0ec0a69806b/virtualenv-20.39.0.tar.gz", hash = "sha256:a15f0cebd00d50074fd336a169d53422436a12dfe15149efec7072cfe817df8b", size = 5869141, upload-time = "2026-02-23T18:09:13.349Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/42/d7/394801755d4c8684b655d35c665aea7836ec68320304f62ab3c94395b442/virtualenv-20.38.0-py3-none-any.whl", hash = "sha256:d6e78e5889de3a4742df2d3d44e779366325a90cf356f15621fddace82431794", size = 5837778, upload-time = "2026-02-19T07:47:59.778Z" }, + { url = "https://files.pythonhosted.org/packages/f7/b4/8268da45f26f4fe84f6eae80a6ca1485ffb490a926afecff75fc48f61979/virtualenv-20.39.0-py3-none-any.whl", hash = "sha256:44888bba3775990a152ea1f73f8e5f566d49f11bbd1de61d426fd7732770043e", size = 5839121, upload-time = "2026-02-23T18:09:11.173Z" }, ] [[package]] From 31fd9c104c874dc3fe73831d3bd315383cab73a2 Mon Sep 17 00:00:00 2001 From: sokoliva Date: Wed, 25 Feb 2026 13:44:34 +0000 Subject: [PATCH 29/29] fix: add ImportError handling for Alembic in CLI and migration files --- src/a2a/cli.py | 15 ++++++++++++--- src/a2a/migrations/env.py | 8 +++++++- ...6419d2d130f6_add_columns_owner_last_updated.py | 7 ++++++- 3 files changed, 25 insertions(+), 5 deletions(-) diff --git a/src/a2a/cli.py b/src/a2a/cli.py index 8fc1389e8..47c81fa26 100644 --- a/src/a2a/cli.py +++ b/src/a2a/cli.py @@ -4,11 +4,20 @@ from importlib.resources import files -from alembic import command -from alembic.config import Config +try: + from alembic import command + from alembic.config import Config -def _add_shared_args(parser: argparse.ArgumentParser, is_sub: bool = False) -> None: +except ImportError as e: + raise ImportError( + "CLI requires Alembic. Install with: 'pip install a2a-sdk[cli]'." + ) from e + + +def _add_shared_args( + parser: argparse.ArgumentParser, is_sub: bool = False +) -> None: """Add common arguments to the given parser.""" prefix = 'sub_' if is_sub else '' parser.add_argument( diff --git a/src/a2a/migrations/env.py b/src/a2a/migrations/env.py index 33db469b3..08d0e0d6c 100644 --- a/src/a2a/migrations/env.py +++ b/src/a2a/migrations/env.py @@ -8,7 +8,13 @@ from sqlalchemy.ext.asyncio import async_engine_from_config from a2a.server.models import Base -from alembic import context + +try: + from alembic import context +except ImportError as e: + raise ImportError( + "Migrations require Alembic. Install with: 'pip install a2a-sdk[cli]'." + ) from e # this is the Alembic Config object, which provides diff --git a/src/a2a/migrations/versions/6419d2d130f6_add_columns_owner_last_updated.py b/src/a2a/migrations/versions/6419d2d130f6_add_columns_owner_last_updated.py index c727aca3f..d12c98c09 100644 --- a/src/a2a/migrations/versions/6419d2d130f6_add_columns_owner_last_updated.py +++ b/src/a2a/migrations/versions/6419d2d130f6_add_columns_owner_last_updated.py @@ -10,7 +10,12 @@ import sqlalchemy as sa -from alembic import context, op +try: + from alembic import context, op +except ImportError as e: + raise ImportError( + "Add columns to database tables migration requires Alembic. Install with: 'pip install a2a-sdk[cli]'." + ) from e # revision identifiers, used by Alembic.