diff --git a/python/README.md b/python/README.md index 7aa11e1a..0b8a5289 100644 --- a/python/README.md +++ b/python/README.md @@ -14,6 +14,45 @@ uv pip install -e ".[dev]" ## Quick Start +### Using Context Managers (Recommended) + +The SDK supports Python's async context manager protocol for automatic resource cleanup: + +```python +import asyncio +from copilot import CopilotClient + +async def main(): + # Client automatically starts on enter and cleans up on exit + async with CopilotClient() as client: + # Create a session with automatic cleanup + async with await client.create_session({"model": "gpt-4"}) as session: + # Wait for response using session.idle event + done = asyncio.Event() + + def on_event(event): + if event.type.value == "assistant.message": + print(event.data.content) + elif event.type.value == "session.idle": + done.set() + + session.on(on_event) + + # Send a message and wait for completion + await session.send({"prompt": "What is 2+2?"}) + await done.wait() + + # Session automatically destroyed here + + # Client automatically stopped here + +asyncio.run(main()) +``` + +### Manual Resource Management + +You can also manage resources manually: + ```python import asyncio from copilot import CopilotClient @@ -56,6 +95,7 @@ asyncio.run(main()) - ✅ Session history with `get_messages()` - ✅ Type hints throughout - ✅ Async/await native +- ✅ Async context manager support for automatic resource cleanup ## API Reference @@ -140,6 +180,44 @@ unsubscribe() - `session.foreground` - A session became the foreground session in TUI - `session.background` - A session is no longer the foreground session +### Context Manager Support + +Both `CopilotClient` and `CopilotSession` support Python's async context manager protocol for automatic resource cleanup. This is the recommended pattern as it ensures resources are properly cleaned up even if exceptions occur. + +**CopilotClient Context Manager:** + +```python +async with CopilotClient() as client: + # Client automatically starts on enter + session = await client.create_session() + await session.send({"prompt": "Hello!"}) + # Client automatically stops on exit, cleaning up all sessions +``` + +**CopilotSession Context Manager:** + +```python +async with await client.create_session() as session: + await session.send({"prompt": "Hello!"}) + # Session automatically destroyed on exit +``` + +**Nested Context Managers:** + +```python +async with CopilotClient() as client: + async with await client.create_session() as session: + await session.send({"prompt": "Hello!"}) + # Session destroyed here +# Client stopped here +``` + +**Benefits:** +- Prevents resource leaks by ensuring cleanup even if exceptions occur +- Eliminates the need to manually call `stop()` and `destroy()` +- Follows Python best practices for resource management +- Particularly useful in batch operations and evaluations to prevent process accumulation + ### Tools Define tools with automatic JSON schema generation using the `@define_tool` decorator and Pydantic models: diff --git a/python/copilot/client.py b/python/copilot/client.py index 11669ddc..6574913d 100644 --- a/python/copilot/client.py +++ b/python/copilot/client.py @@ -14,6 +14,7 @@ import asyncio import inspect +import logging import os import re import subprocess @@ -21,6 +22,7 @@ import threading from dataclasses import asdict, is_dataclass from pathlib import Path +from types import TracebackType from typing import Any, Callable, Optional, cast from .generated.rpc import ServerRpc @@ -206,6 +208,56 @@ def __init__(self, options: Optional[CopilotClientOptions] = None): self._lifecycle_handlers_lock = threading.Lock() self._rpc: Optional[ServerRpc] = None + async def __aenter__(self) -> "CopilotClient": + """ + Enter the async context manager. + + Automatically starts the CLI server and establishes a connection if not + already connected. + + Returns: + The CopilotClient instance. + + Example: + >>> async with CopilotClient() as client: + ... session = await client.create_session() + ... await session.send({"prompt": "Hello!"}) + """ + await self.start() + return self + + async def __aexit__( + self, + exc_type: Optional[type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> bool: + """ + Exit the async context manager. + + Performs graceful cleanup by destroying all active sessions and stopping + the CLI server. If cleanup errors occur, they are logged but do not + prevent the context from exiting. + + Args: + exc_type: The type of exception that occurred, if any. + exc_val: The exception instance that occurred, if any. + exc_tb: The traceback of the exception that occurred, if any. + + Returns: + False to propagate any exception that occurred in the context. + """ + try: + stop_errors = await self.stop() + # Log any session destruction errors + if stop_errors: + for error in stop_errors: + logging.warning("Error during session cleanup in CopilotClient: %s", error) + except Exception: + # Log the error but don't raise - we want cleanup to always complete + logging.warning("Error during CopilotClient cleanup", exc_info=True) + return False + @property def rpc(self) -> ServerRpc: """Typed server-scoped RPC methods.""" diff --git a/python/copilot/session.py b/python/copilot/session.py index d7bd1a3f..1734da9c 100644 --- a/python/copilot/session.py +++ b/python/copilot/session.py @@ -7,7 +7,9 @@ import asyncio import inspect +import logging import threading +from types import TracebackType from typing import Any, Callable, Optional from .generated.rpc import SessionRpc @@ -70,6 +72,7 @@ def __init__(self, session_id: str, client: Any, workspace_path: Optional[str] = self.session_id = session_id self._client = client self._workspace_path = workspace_path + self._destroyed = False self._event_handlers: set[Callable[[SessionEvent], None]] = set() self._event_handlers_lock = threading.Lock() self._tool_handlers: dict[str, ToolHandler] = {} @@ -82,6 +85,50 @@ def __init__(self, session_id: str, client: Any, workspace_path: Optional[str] = self._hooks_lock = threading.Lock() self._rpc: Optional[SessionRpc] = None + async def __aenter__(self) -> "CopilotSession": + """ + Enter the async context manager. + + Returns the session instance, ready for use. The session must already be + created (via CopilotClient.create_session or resume_session). + + Returns: + The CopilotSession instance. + + Example: + >>> async with await client.create_session() as session: + ... await session.send({"prompt": "Hello!"}) + """ + return self + + async def __aexit__( + self, + exc_type: Optional[type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> bool: + """ + Exit the async context manager. + + Automatically destroys the session and releases all associated resources. + If an error occurs during cleanup, it is logged but does not prevent the + context from exiting. + + Args: + exc_type: The type of exception that occurred, if any. + exc_val: The exception instance that occurred, if any. + exc_tb: The traceback of the exception that occurred, if any. + + Returns: + False to propagate any exception that occurred in the context. + """ + try: + await self.destroy() + except Exception: + # Log the error but don't raise - we want cleanup to always complete + logging.warning("Error during CopilotSession cleanup", exc_info=True) + return False + @property def rpc(self) -> SessionRpc: """Typed session-scoped RPC methods.""" @@ -479,20 +526,33 @@ async def destroy(self) -> None: handlers and tool handlers are cleared. To continue the conversation, use :meth:`CopilotClient.resume_session` with the session ID. + This method is idempotent—calling it multiple times is safe and will + not raise an error if the session is already destroyed. + Raises: - Exception: If the connection fails. + Exception: If the connection fails (on first destroy call). Example: >>> # Clean up when done >>> await session.destroy() """ - await self._client.request("session.destroy", {"sessionId": self.session_id}) + # Ensure that the check and update of _destroyed are atomic so that + # only the first caller proceeds to send the destroy RPC. with self._event_handlers_lock: - self._event_handlers.clear() - with self._tool_handlers_lock: - self._tool_handlers.clear() - with self._permission_handler_lock: - self._permission_handler = None + if self._destroyed: + return + self._destroyed = True + + try: + await self._client.request("session.destroy", {"sessionId": self.session_id}) + finally: + # Clear handlers even if the request fails + with self._event_handlers_lock: + self._event_handlers.clear() + with self._tool_handlers_lock: + self._tool_handlers.clear() + with self._permission_handler_lock: + self._permission_handler = None async def abort(self) -> None: """ diff --git a/python/e2e/test_context_managers.py b/python/e2e/test_context_managers.py new file mode 100644 index 00000000..4a798f00 --- /dev/null +++ b/python/e2e/test_context_managers.py @@ -0,0 +1,156 @@ +"""E2E Context Manager Tests""" + +import os + +import pytest + +from copilot import CopilotClient + +from .testharness import CLI_PATH + +pytestmark = pytest.mark.asyncio(loop_scope="module") + + +def create_test_client(ctx) -> CopilotClient: + """Create a fresh CopilotClient configured with the test harness proxy.""" + github_token = "fake-token-for-e2e-tests" if os.environ.get("CI") == "true" else None + return CopilotClient( + { + "cli_path": ctx.cli_path, + "cwd": ctx.work_dir, + "env": ctx.get_env(), + "github_token": github_token, + } + ) + + +class TestCopilotClientContextManager: + async def test_should_auto_start_and_cleanup_with_context_manager(self, ctx): + """Test that CopilotClient context manager auto-starts and cleans up.""" + client = create_test_client(ctx) + async with client: + assert client.get_state() == "connected" + # Verify we can use the client + pong = await client.ping("test") + assert pong.message == "pong: test" + + # After exiting context, client should be disconnected + assert client.get_state() == "disconnected" + + async def test_should_create_session_in_context(self, ctx): + """Test creating and using a session within client context.""" + client = create_test_client(ctx) + async with client: + session = await client.create_session({"model": "fake-test-model"}) + assert session.session_id + + # Verify session is usable + messages = await session.get_messages() + assert len(messages) > 0 + assert messages[0].type.value == "session.start" + + # After exiting context, verify cleanup happened + assert client.get_state() == "disconnected" + + async def test_should_cleanup_multiple_sessions(self, ctx): + """Test that all sessions are cleaned up when client context exits.""" + client = create_test_client(ctx) + async with client: + session1 = await client.create_session() + session2 = await client.create_session() + session3 = await client.create_session() + + assert session1.session_id + assert session2.session_id + assert session3.session_id + + # All sessions should be cleaned up + assert client.get_state() == "disconnected" + + async def test_should_propagate_exceptions(self, ctx): + """Test that exceptions within context are propagated.""" + client = create_test_client(ctx) + with pytest.raises(ValueError, match="test error"): + async with client: + assert client.get_state() == "connected" + raise ValueError("test error") + + # Client should still be cleaned up even after exception + assert client.get_state() == "disconnected" + + async def test_should_handle_cleanup_errors_gracefully(self, ctx): + """Test that cleanup errors don't prevent context from exiting.""" + client = create_test_client(ctx) + async with client: + await client.create_session() + + # Kill the process to force cleanup to fail + if client._process: + client._process.kill() + + # Context should still exit successfully despite cleanup errors + assert client.get_state() == "disconnected" + + +class TestCopilotSessionContextManager: + async def test_should_cleanup_session_with_context_manager(self, ctx): + """Test that CopilotSession context manager cleans up session.""" + client = create_test_client(ctx) + async with client: + async with await client.create_session() as session: + assert session.session_id + # Send a message to verify session is working + await session.send({"prompt": "Hello!"}) + + # After exiting context, session should be destroyed + with pytest.raises(Exception, match="Session not found"): + await session.get_messages() + + async def test_should_propagate_exceptions_in_session_context(self, ctx): + """Test that exceptions within session context are propagated.""" + client = create_test_client(ctx) + async with client: + with pytest.raises(ValueError, match="test session error"): + async with await client.create_session() as session: + assert session.session_id + raise ValueError("test session error") + + # Session should still be cleaned up after exception + with pytest.raises(Exception, match="Session not found"): + await session.get_messages() + + async def test_nested_context_managers(self, ctx): + """Test using nested context managers for client and session.""" + client = create_test_client(ctx) + async with client: + async with await client.create_session() as session: + assert session.session_id + await session.send({"prompt": "Test message"}) + + # Session should be cleaned up + with pytest.raises(Exception, match="Session not found"): + await session.get_messages() + + # Client should be cleaned up + assert client.get_state() == "disconnected" + + async def test_multiple_sequential_session_contexts(self, ctx): + """Test creating multiple sessions sequentially with context managers.""" + client = create_test_client(ctx) + async with client: + # First session + async with await client.create_session() as session1: + session1_id = session1.session_id + await session1.send({"prompt": "First session"}) + + # Second session (after first is cleaned up) + async with await client.create_session() as session2: + session2_id = session2.session_id + await session2.send({"prompt": "Second session"}) + + # Both sessions should be different + assert session1_id != session2_id + + # First session should be destroyed + with pytest.raises(Exception, match="Session not found"): + await session1.get_messages() diff --git a/python/test_client.py b/python/test_client.py index 7b4af8c0..17c585ae 100644 --- a/python/test_client.py +++ b/python/test_client.py @@ -147,3 +147,46 @@ def test_use_logged_in_user_with_cli_url_raises(self): CopilotClient( {"cli_url": "localhost:8080", "use_logged_in_user": False, "log_level": "error"} ) + + +class TestContextManager: + @pytest.mark.asyncio + async def test_client_context_manager_returns_self(self): + """Test that __aenter__ returns the client instance.""" + client = CopilotClient({"cli_path": CLI_PATH}) + returned_client = await client.__aenter__() + assert returned_client is client + await client.force_stop() + + @pytest.mark.asyncio + async def test_client_aexit_returns_false(self): + """Test that __aexit__ returns False to propagate exceptions.""" + client = CopilotClient({"cli_path": CLI_PATH}) + await client.start() + result = await client.__aexit__(None, None, None) + assert result is False + + @pytest.mark.asyncio + async def test_session_context_manager_returns_self(self): + """Test that session __aenter__ returns the session instance.""" + client = CopilotClient({"cli_path": CLI_PATH}) + await client.start() + try: + session = await client.create_session() + returned_session = await session.__aenter__() + assert returned_session is session + finally: + await client.force_stop() + + @pytest.mark.asyncio + async def test_session_aexit_returns_false(self): + """Test that session __aexit__ returns False to propagate exceptions.""" + client = CopilotClient({"cli_path": CLI_PATH}) + await client.start() + try: + session = await client.create_session() + result = await session.__aexit__(None, None, None) + assert result is False + finally: + await client.force_stop() +