diff --git a/src/a2a/client/transports/__init__.py b/src/a2a/client/transports/__init__.py index af7c60f62..359c29be8 100644 --- a/src/a2a/client/transports/__init__.py +++ b/src/a2a/client/transports/__init__.py @@ -3,6 +3,12 @@ from a2a.client.transports.base import ClientTransport from a2a.client.transports.jsonrpc import JsonRpcTransport from a2a.client.transports.rest import RestTransport +from a2a.client.transports.retry import ( + OnRetryCallback, + RetryPredicate, + RetryTransport, + default_retry_predicate, +) try: @@ -15,5 +21,9 @@ 'ClientTransport', 'GrpcTransport', 'JsonRpcTransport', + 'OnRetryCallback', 'RestTransport', + 'RetryPredicate', + 'RetryTransport', + 'default_retry_predicate', ] diff --git a/src/a2a/client/transports/retry.py b/src/a2a/client/transports/retry.py new file mode 100644 index 000000000..5e17c57e7 --- /dev/null +++ b/src/a2a/client/transports/retry.py @@ -0,0 +1,376 @@ +import asyncio +import inspect +import logging +import random + +from collections.abc import AsyncGenerator, Awaitable, Callable +from typing import Any, TypeVar + +import httpx + +from a2a.client.client import ClientCallContext +from a2a.client.errors import A2AClientError, A2AClientTimeoutError +from a2a.client.transports.base import ClientTransport +from a2a.types.a2a_pb2 import ( + AgentCard, + CancelTaskRequest, + DeleteTaskPushNotificationConfigRequest, + GetExtendedAgentCardRequest, + GetTaskPushNotificationConfigRequest, + GetTaskRequest, + ListTaskPushNotificationConfigsRequest, + ListTaskPushNotificationConfigsResponse, + ListTasksRequest, + ListTasksResponse, + SendMessageRequest, + SendMessageResponse, + StreamResponse, + SubscribeToTaskRequest, + Task, + TaskPushNotificationConfig, +) + + +logger = logging.getLogger(__name__) + +T = TypeVar('T') + +RetryPredicate = Callable[[Exception], bool] +OnRetryCallback = Callable[[int, Exception, float], Awaitable[None] | None] + +_RETRYABLE_HTTP_STATUS: frozenset[int] = frozenset({408, 429, 502, 503, 504}) + +# grpc is an optional dependency. +try: + import grpc as _grpc + + _AioRpcError: Any = _grpc.aio.AioRpcError + _RETRYABLE_GRPC_CODES: frozenset[Any] = frozenset( + { + _grpc.StatusCode.UNAVAILABLE, + _grpc.StatusCode.RESOURCE_EXHAUSTED, + } + ) +except ImportError: # pragma: no cover + _AioRpcError = None + _RETRYABLE_GRPC_CODES = frozenset() + + +def default_retry_predicate(error: Exception) -> bool: # noqa: PLR0911 + """Returns True for transient errors, False otherwise. + + Retried: A2AClientTimeoutError; A2AClientError caused by httpx network + errors, HTTP 408/429/502/503/504, or gRPC UNAVAILABLE/RESOURCE_EXHAUSTED. + + Not retried: domain errors (TaskNotFoundError, etc.), HTTP 5xx other than + 502/503/504 (replaying server bugs is not safe), JSON decode / SSE errors. + + The cause is read from ``__cause__`` first (set by ``raise … from e``), + falling back to ``__context__`` for callers that don't chain explicitly. + """ + if isinstance(error, A2AClientTimeoutError): + return True + if not isinstance(error, A2AClientError): + return False + + cause = error.__cause__ or error.__context__ + if cause is None: + return False + if isinstance(cause, httpx.HTTPStatusError): + return cause.response.status_code in _RETRYABLE_HTTP_STATUS + if isinstance(cause, httpx.RequestError): + return True + if _AioRpcError is not None and isinstance(cause, _AioRpcError): + return cause.code() in _RETRYABLE_GRPC_CODES # pyright: ignore[reportAttributeAccessIssue] + return False + + +class RetryTransport(ClientTransport): + """A transport decorator that retries transient failures with exponential backoff. + + Streaming methods only retry before the first event is yielded. + """ + + def __init__( # noqa: PLR0913 + self, + base: ClientTransport, + *, + max_retries: int = 3, + base_delay: float = 1.0, + max_delay: float = 30.0, + jitter: bool = True, + retry_predicate: RetryPredicate | None = None, + on_retry: OnRetryCallback | None = None, + ) -> None: + if max_retries < 0: + raise ValueError('max_retries must be >= 0') + if base_delay <= 0: + raise ValueError('base_delay must be > 0') + if max_delay <= 0: + raise ValueError('max_delay must be > 0') + self._base = base + self._max_retries = max_retries + self._base_delay = base_delay + self._max_delay = max_delay + self._jitter = jitter + self._retry_predicate = retry_predicate or default_retry_predicate + self._on_retry = on_retry + + def _calculate_delay(self, attempt_index: int) -> float: + delay = min(self._base_delay * (2**attempt_index), self._max_delay) + if self._jitter: + delay = random.uniform(0, delay) # noqa: S311 + return delay + + async def _delay_and_notify( + self, + attempt_index: int, + error: Exception, + method_name: str, + ) -> None: + retry_number = attempt_index + 1 + delay = self._calculate_delay(attempt_index) + logger.warning( + 'Retry %d/%d for %s after %.2fs: %s', + retry_number, + self._max_retries, + method_name, + delay, + error, + ) + if self._on_retry is not None: + try: + result: Any = self._on_retry(retry_number, error, delay) + if inspect.isawaitable(result): + await result + except asyncio.CancelledError: + raise + except Exception: + # A buggy callback must not break the retry loop. + logger.exception( + 'on_retry callback raised for %s; continuing retry', + method_name, + ) + await asyncio.sleep(delay) + + @staticmethod + async def _safe_aclose(stream: Any) -> None: + aclose = getattr(stream, 'aclose', None) + if aclose is None: + return + try: + await aclose() + except asyncio.CancelledError: + raise + except Exception: + logger.debug( + 'Ignoring error while closing stream during retry cleanup', + exc_info=True, + ) + + async def _execute_with_retry( + self, + operation: Callable[[], Awaitable[T]], + method_name: str, + ) -> T: + attempt = 0 + while True: + try: + return await operation() + except asyncio.CancelledError: # noqa: PERF203 + raise + except Exception as e: + if attempt >= self._max_retries or not self._retry_predicate(e): + raise + await self._delay_and_notify(attempt, e, method_name) + attempt += 1 + + async def _execute_streaming_with_retry( + self, + operation: Callable[[], AsyncGenerator[StreamResponse]], + method_name: str, + ) -> AsyncGenerator[StreamResponse]: + # Retry only pre-stream failures. The inner finally closes the inner + # generator on every exit path (success, retry, exception, consumer + # break) so transport resources are not leaked. + attempt = 0 + while True: + first = True + stream: Any = None + try: + stream = operation() + try: + async for event in stream: + first = False + yield event + finally: + await self._safe_aclose(stream) + except asyncio.CancelledError: + raise + except Exception as e: + if ( + not first + or attempt >= self._max_retries + or not self._retry_predicate(e) + ): + raise + await self._delay_and_notify(attempt, e, method_name) + attempt += 1 + else: + return + + async def send_message( + self, + request: SendMessageRequest, + *, + context: ClientCallContext | None = None, + ) -> SendMessageResponse: + """Sends a non-streaming message request to the agent.""" + return await self._execute_with_retry( + lambda: self._base.send_message(request, context=context), + 'send_message', + ) + + async def send_message_streaming( + self, + request: SendMessageRequest, + *, + context: ClientCallContext | None = None, + ) -> AsyncGenerator[StreamResponse]: + """Sends a streaming message request to the agent and yields responses as they arrive.""" + inner = self._execute_streaming_with_retry( + lambda: self._base.send_message_streaming(request, context=context), + 'send_message_streaming', + ) + try: + async for event in inner: + yield event + finally: + await inner.aclose() + + async def get_task( + self, + request: GetTaskRequest, + *, + context: ClientCallContext | None = None, + ) -> Task: + """Retrieves the current state and history of a specific task.""" + return await self._execute_with_retry( + lambda: self._base.get_task(request, context=context), + 'get_task', + ) + + async def list_tasks( + self, + request: ListTasksRequest, + *, + context: ClientCallContext | None = None, + ) -> ListTasksResponse: + """Retrieves tasks for an agent.""" + return await self._execute_with_retry( + lambda: self._base.list_tasks(request, context=context), + 'list_tasks', + ) + + async def cancel_task( + self, + request: CancelTaskRequest, + *, + context: ClientCallContext | None = None, + ) -> Task: + """Requests the agent to cancel a specific task.""" + return await self._execute_with_retry( + lambda: self._base.cancel_task(request, context=context), + 'cancel_task', + ) + + async def create_task_push_notification_config( + self, + request: TaskPushNotificationConfig, + *, + context: ClientCallContext | None = None, + ) -> TaskPushNotificationConfig: + """Sets or updates the push notification configuration for a specific task.""" + return await self._execute_with_retry( + lambda: self._base.create_task_push_notification_config( + request, context=context + ), + 'create_task_push_notification_config', + ) + + async def get_task_push_notification_config( + self, + request: GetTaskPushNotificationConfigRequest, + *, + context: ClientCallContext | None = None, + ) -> TaskPushNotificationConfig: + """Retrieves the push notification configuration for a specific task.""" + return await self._execute_with_retry( + lambda: self._base.get_task_push_notification_config( + request, context=context + ), + 'get_task_push_notification_config', + ) + + async def list_task_push_notification_configs( + self, + request: ListTaskPushNotificationConfigsRequest, + *, + context: ClientCallContext | None = None, + ) -> ListTaskPushNotificationConfigsResponse: + """Lists push notification configurations for a specific task.""" + return await self._execute_with_retry( + lambda: self._base.list_task_push_notification_configs( + request, context=context + ), + 'list_task_push_notification_configs', + ) + + async def delete_task_push_notification_config( + self, + request: DeleteTaskPushNotificationConfigRequest, + *, + context: ClientCallContext | None = None, + ) -> None: + """Deletes the push notification configuration for a specific task.""" + await self._execute_with_retry( + lambda: self._base.delete_task_push_notification_config( + request, context=context + ), + 'delete_task_push_notification_config', + ) + + async def subscribe( + self, + request: SubscribeToTaskRequest, + *, + context: ClientCallContext | None = None, + ) -> AsyncGenerator[StreamResponse]: + """Reconnects to get task updates.""" + inner = self._execute_streaming_with_retry( + lambda: self._base.subscribe(request, context=context), + 'subscribe', + ) + try: + async for event in inner: + yield event + finally: + await inner.aclose() + + async def get_extended_agent_card( + self, + request: GetExtendedAgentCardRequest, + *, + context: ClientCallContext | None = None, + ) -> AgentCard: + """Retrieves the Extended AgentCard.""" + return await self._execute_with_retry( + lambda: self._base.get_extended_agent_card( + request, context=context + ), + 'get_extended_agent_card', + ) + + async def close(self) -> None: + """Closes the transport.""" + await self._base.close() diff --git a/tests/client/transports/test_retry.py b/tests/client/transports/test_retry.py new file mode 100644 index 000000000..511df45d6 --- /dev/null +++ b/tests/client/transports/test_retry.py @@ -0,0 +1,786 @@ +import asyncio +import json + +from collections.abc import AsyncGenerator +from unittest.mock import AsyncMock, patch + +import httpx +import pytest + +from a2a.client.client import ClientCallContext +from a2a.client.errors import A2AClientError, A2AClientTimeoutError +from a2a.client.transports.base import ClientTransport +from a2a.client.transports.retry import ( + RetryTransport, + default_retry_predicate, +) +from a2a.types.a2a_pb2 import ( + CancelTaskRequest, + DeleteTaskPushNotificationConfigRequest, + GetExtendedAgentCardRequest, + GetTaskPushNotificationConfigRequest, + GetTaskRequest, + ListTaskPushNotificationConfigsRequest, + ListTasksRequest, + Message, + Part, + SendMessageRequest, + SendMessageResponse, + StreamResponse, + SubscribeToTaskRequest, + Task, + TaskPushNotificationConfig, +) +from a2a.utils.errors import InternalError, TaskNotFoundError + + +@pytest.fixture +def mock_transport() -> AsyncMock: + return AsyncMock(spec=ClientTransport) + + +@pytest.fixture +def retry_transport(mock_transport: AsyncMock) -> RetryTransport: + return RetryTransport( + mock_transport, + max_retries=3, + base_delay=0.01, + max_delay=0.1, + jitter=False, + ) + + +class TestDefaultRetryPredicate: + def test_timeout_error_is_retryable(self) -> None: + error = A2AClientTimeoutError('timeout') + assert default_retry_predicate(error) is True + + def test_network_error_is_retryable(self) -> None: + cause = httpx.ConnectError('connection refused') + error = A2AClientError( + 'Network communication error: connection refused' + ) + error.__cause__ = cause + assert default_retry_predicate(error) is True + + @pytest.mark.parametrize('status_code', [408, 429, 502, 503, 504]) + def test_retryable_http_status_codes(self, status_code: int) -> None: + request = httpx.Request('POST', 'http://example.com') + response = httpx.Response(status_code, request=request) + cause = httpx.HTTPStatusError( + 'error', request=request, response=response + ) + error = A2AClientError(f'HTTP Error {status_code}') + error.__cause__ = cause + assert default_retry_predicate(error) is True + + @pytest.mark.parametrize('status_code', [400, 401, 403, 404, 500]) + def test_non_retryable_http_status_codes(self, status_code: int) -> None: + request = httpx.Request('POST', 'http://example.com') + response = httpx.Response(status_code, request=request) + cause = httpx.HTTPStatusError( + 'error', request=request, response=response + ) + error = A2AClientError(f'HTTP Error {status_code}') + error.__cause__ = cause + assert default_retry_predicate(error) is False + + def test_json_decode_error_is_not_retryable(self) -> None: + cause = json.JSONDecodeError('msg', 'doc', 0) + error = A2AClientError('JSON Decode Error') + error.__cause__ = cause + assert default_retry_predicate(error) is False + + def test_domain_error_is_not_retryable(self) -> None: + error = TaskNotFoundError() + assert default_retry_predicate(error) is False + + def test_internal_error_is_not_retryable(self) -> None: + error = InternalError() + assert default_retry_predicate(error) is False + + def test_client_error_without_cause_is_not_retryable(self) -> None: + error = A2AClientError('some error') + assert default_retry_predicate(error) is False + + def test_non_a2a_error_is_not_retryable(self) -> None: + error = ValueError('not an A2A error') + assert default_retry_predicate(error) is False + + @pytest.mark.parametrize( + 'status_code, expected', + [ + ('UNAVAILABLE', True), + ('RESOURCE_EXHAUSTED', True), + ('NOT_FOUND', False), + ], + ) + def test_grpc_error_retryability( + self, status_code: str, expected: bool + ) -> None: + grpc = pytest.importorskip('grpc') + + class FakeAioRpcError(grpc.aio.AioRpcError): + def __init__(self, code: object) -> None: + self._code = code + + def code(self) -> object: + return self._code + + cause = FakeAioRpcError(getattr(grpc.StatusCode, status_code)) + error = A2AClientError(f'gRPC Error {status_code}') + error.__cause__ = cause + assert default_retry_predicate(error) is expected + + +class TestRetryTransport: + @pytest.mark.parametrize( + 'method_name, request_obj', + [ + ( + 'send_message', + SendMessageRequest(message=Message(parts=[Part(text='hello')])), + ), + ('get_task', GetTaskRequest(id='t1')), + ('list_tasks', ListTasksRequest()), + ('cancel_task', CancelTaskRequest(id='t1')), + ( + 'create_task_push_notification_config', + TaskPushNotificationConfig(task_id='t1'), + ), + ( + 'get_task_push_notification_config', + GetTaskPushNotificationConfigRequest(task_id='t1', id='c1'), + ), + ( + 'list_task_push_notification_configs', + ListTaskPushNotificationConfigsRequest(task_id='t1'), + ), + ( + 'delete_task_push_notification_config', + DeleteTaskPushNotificationConfigRequest(task_id='t1', id='c1'), + ), + ('get_extended_agent_card', GetExtendedAgentCardRequest()), + ], + ) + @pytest.mark.asyncio + async def test_delegates_to_base_transport( + self, + mock_transport: AsyncMock, + retry_transport: RetryTransport, + method_name: str, + request_obj: object, + ) -> None: + await getattr(retry_transport, method_name)(request_obj) + getattr(mock_transport, method_name).assert_called_once_with( + request_obj, context=None + ) + + @pytest.mark.asyncio + async def test_retries_on_network_error( + self, + mock_transport: AsyncMock, + retry_transport: RetryTransport, + ) -> None: + cause = httpx.ConnectError('refused') + error = A2AClientError('Network communication error: refused') + error.__cause__ = cause + + expected = Task() + mock_transport.get_task.side_effect = [error, expected] + result = await retry_transport.get_task(GetTaskRequest(id='t1')) + assert result == expected + assert mock_transport.get_task.call_count == 2 + + @pytest.mark.asyncio + async def test_no_retry_on_domain_error( + self, + mock_transport: AsyncMock, + retry_transport: RetryTransport, + ) -> None: + mock_transport.get_task.side_effect = TaskNotFoundError() + with pytest.raises(TaskNotFoundError): + await retry_transport.get_task(GetTaskRequest(id='t1')) + assert mock_transport.get_task.call_count == 1 + + @pytest.mark.asyncio + async def test_no_retry_on_non_retryable_http_status( + self, + mock_transport: AsyncMock, + retry_transport: RetryTransport, + ) -> None: + request = httpx.Request('POST', 'http://example.com') + response = httpx.Response(400, request=request) + cause = httpx.HTTPStatusError( + 'bad request', request=request, response=response + ) + error = A2AClientError('HTTP Error 400: bad request') + error.__cause__ = cause + + mock_transport.send_message.side_effect = error + with pytest.raises(A2AClientError): + await retry_transport.send_message(SendMessageRequest()) + assert mock_transport.send_message.call_count == 1 + + @pytest.mark.asyncio + async def test_exponential_backoff_timing( + self, mock_transport: AsyncMock + ) -> None: + transport = RetryTransport( + mock_transport, + max_retries=3, + base_delay=1.0, + max_delay=30.0, + jitter=False, + ) + mock_transport.send_message.side_effect = A2AClientTimeoutError( + 'timeout' + ) + + with patch( + 'a2a.client.transports.retry.asyncio.sleep', + new_callable=AsyncMock, + ) as mock_sleep: + with pytest.raises(A2AClientTimeoutError): + await transport.send_message(SendMessageRequest()) + + assert mock_sleep.call_count == 3 + mock_sleep.assert_any_call(1.0) + mock_sleep.assert_any_call(2.0) + mock_sleep.assert_any_call(4.0) + + @pytest.mark.asyncio + async def test_max_delay_cap(self, mock_transport: AsyncMock) -> None: + transport = RetryTransport( + mock_transport, + max_retries=5, + base_delay=10.0, + max_delay=20.0, + jitter=False, + ) + mock_transport.send_message.side_effect = A2AClientTimeoutError( + 'timeout' + ) + + with patch( + 'a2a.client.transports.retry.asyncio.sleep', + new_callable=AsyncMock, + ) as mock_sleep: + with pytest.raises(A2AClientTimeoutError): + await transport.send_message(SendMessageRequest()) + + for call_args in mock_sleep.call_args_list: + assert call_args[0][0] <= 20.0 + + @pytest.mark.asyncio + async def test_jitter_produces_randomized_delays( + self, mock_transport: AsyncMock + ) -> None: + transport = RetryTransport( + mock_transport, + max_retries=3, + base_delay=1.0, + max_delay=30.0, + jitter=True, + ) + mock_transport.send_message.side_effect = A2AClientTimeoutError( + 'timeout' + ) + + with patch( + 'a2a.client.transports.retry.asyncio.sleep', + new_callable=AsyncMock, + ) as mock_sleep: + with pytest.raises(A2AClientTimeoutError): + await transport.send_message(SendMessageRequest()) + + for i, call_args in enumerate(mock_sleep.call_args_list): + delay = call_args[0][0] + max_possible = min(1.0 * (2**i), 30.0) + assert 0 <= delay <= max_possible + + @pytest.mark.asyncio + async def test_streaming_retries_pre_stream_failure( + self, + mock_transport: AsyncMock, + retry_transport: RetryTransport, + ) -> None: + async def success_stream(*args: object, **kwargs: object) -> object: + yield StreamResponse() + yield StreamResponse() + + mock_transport.send_message_streaming.side_effect = [ + A2AClientTimeoutError('timeout'), + success_stream(), + ] + events = [ + event + async for event in retry_transport.send_message_streaming( + SendMessageRequest() + ) + ] + + assert len(events) == 2 + assert mock_transport.send_message_streaming.call_count == 2 + + @pytest.mark.asyncio + async def test_streaming_no_retry_mid_stream( + self, + mock_transport: AsyncMock, + retry_transport: RetryTransport, + ) -> None: + async def failing_mid_stream(*args: object, **kwargs: object) -> object: + yield StreamResponse() + raise A2AClientTimeoutError('mid-stream timeout') + + mock_transport.send_message_streaming.return_value = ( + failing_mid_stream() + ) + + events: list[StreamResponse] = [] + with pytest.raises(A2AClientTimeoutError): + async for event in retry_transport.send_message_streaming( + SendMessageRequest() + ): + events.append(event) # noqa: PERF401 + + assert len(events) == 1 + assert mock_transport.send_message_streaming.call_count == 1 + + @pytest.mark.asyncio + async def test_subscribe_streaming_retries( + self, + mock_transport: AsyncMock, + retry_transport: RetryTransport, + ) -> None: + async def success_stream(*args: object, **kwargs: object) -> object: + yield StreamResponse() + + mock_transport.subscribe.side_effect = [ + A2AClientTimeoutError('timeout'), + success_stream(), + ] + events = [ + event + async for event in retry_transport.subscribe( + SubscribeToTaskRequest(id='t1') + ) + ] + + assert len(events) == 1 + assert mock_transport.subscribe.call_count == 2 + + @pytest.mark.asyncio + async def test_streaming_max_retries_exhausted( + self, + mock_transport: AsyncMock, + retry_transport: RetryTransport, + ) -> None: + mock_transport.send_message_streaming.side_effect = ( + A2AClientTimeoutError('timeout') + ) + with pytest.raises(A2AClientTimeoutError): + async for _ in retry_transport.send_message_streaming( + SendMessageRequest() + ): + pass + assert mock_transport.send_message_streaming.call_count == 4 + + @pytest.mark.asyncio + async def test_custom_retry_predicate( + self, mock_transport: AsyncMock + ) -> None: + transport = RetryTransport( + mock_transport, + max_retries=2, + base_delay=0.01, + jitter=False, + retry_predicate=lambda e: isinstance(e, TaskNotFoundError), + ) + expected = Task() + mock_transport.get_task.side_effect = [ + TaskNotFoundError(), + expected, + ] + result = await transport.get_task(GetTaskRequest(id='t1')) + assert result == expected + assert mock_transport.get_task.call_count == 2 + + @pytest.mark.asyncio + async def test_custom_predicate_rejects_normally_retryable( + self, mock_transport: AsyncMock + ) -> None: + transport = RetryTransport( + mock_transport, + max_retries=3, + base_delay=0.01, + retry_predicate=lambda e: False, + ) + mock_transport.send_message.side_effect = A2AClientTimeoutError( + 'timeout' + ) + with pytest.raises(A2AClientTimeoutError): + await transport.send_message(SendMessageRequest()) + assert mock_transport.send_message.call_count == 1 + + @pytest.mark.asyncio + async def test_on_retry_async_callback( + self, mock_transport: AsyncMock + ) -> None: + on_retry_mock = AsyncMock() + transport = RetryTransport( + mock_transport, + max_retries=2, + base_delay=0.01, + jitter=False, + on_retry=on_retry_mock, + ) + error = A2AClientTimeoutError('timeout') + expected = SendMessageResponse() + mock_transport.send_message.side_effect = [error, expected] + + await transport.send_message(SendMessageRequest()) + + on_retry_mock.assert_called_once_with(1, error, 0.01) + + @pytest.mark.asyncio + async def test_on_retry_sync_callback( + self, mock_transport: AsyncMock + ) -> None: + calls: list[tuple[int, Exception, float]] = [] + + def sync_on_retry(attempt: int, error: Exception, delay: float) -> None: + calls.append((attempt, error, delay)) + + transport = RetryTransport( + mock_transport, + max_retries=2, + base_delay=0.01, + jitter=False, + on_retry=sync_on_retry, + ) + error = A2AClientTimeoutError('timeout') + expected = SendMessageResponse() + mock_transport.send_message.side_effect = [error, expected] + + await transport.send_message(SendMessageRequest()) + + assert len(calls) == 1 + assert calls[0][0] == 1 + + @pytest.mark.asyncio + async def test_close_delegates_without_retry( + self, + mock_transport: AsyncMock, + retry_transport: RetryTransport, + ) -> None: + await retry_transport.close() + mock_transport.close.assert_called_once() + + @pytest.mark.asyncio + async def test_context_passed_through( + self, + mock_transport: AsyncMock, + retry_transport: RetryTransport, + ) -> None: + context = ClientCallContext(timeout=5.0) + request = SendMessageRequest( + message=Message(parts=[Part(text='hello')]) + ) + await retry_transport.send_message(request, context=context) + mock_transport.send_message.assert_called_once_with( + request, context=context + ) + + @pytest.mark.asyncio + async def test_streaming_delegates( + self, + mock_transport: AsyncMock, + retry_transport: RetryTransport, + ) -> None: + async def mock_stream(*args: object, **kwargs: object) -> object: + yield StreamResponse() + + mock_transport.send_message_streaming.return_value = mock_stream() + request = SendMessageRequest() + events = [ + event + async for event in retry_transport.send_message_streaming(request) + ] + + assert len(events) == 1 + mock_transport.send_message_streaming.assert_called_once_with( + request, context=None + ) + + @pytest.mark.asyncio + async def test_end_to_end_retry_within_context_manager( + self, mock_transport: AsyncMock + ) -> None: + expected = SendMessageResponse() + mock_transport.send_message.side_effect = [ + A2AClientTimeoutError('timeout'), + expected, + ] + + async with RetryTransport( + mock_transport, max_retries=2, base_delay=0.01, jitter=False + ) as t: + assert t is not mock_transport + result = await t.send_message( + SendMessageRequest(message=Message(parts=[Part(text='hello')])) + ) + assert result == expected + assert mock_transport.send_message.call_count == 2 + mock_transport.close.assert_not_awaited() + + mock_transport.close.assert_awaited_once() + + @pytest.mark.asyncio + async def test_end_to_end_retry_exhaustion_within_context_manager( + self, mock_transport: AsyncMock + ) -> None: + mock_transport.send_message.side_effect = A2AClientTimeoutError( + 'timeout' + ) + + with pytest.raises(A2AClientTimeoutError): + async with RetryTransport( + mock_transport, max_retries=2, base_delay=0.01, jitter=False + ) as t: + await t.send_message(SendMessageRequest()) + + assert mock_transport.send_message.call_count == 3 + mock_transport.close.assert_awaited_once() + + @pytest.mark.asyncio + async def test_max_retries_zero_disables_retry( + self, mock_transport: AsyncMock + ) -> None: + transport = RetryTransport( + mock_transport, + max_retries=0, + base_delay=0.01, + ) + mock_transport.send_message.side_effect = A2AClientTimeoutError( + 'timeout' + ) + with pytest.raises(A2AClientTimeoutError): + await transport.send_message(SendMessageRequest()) + assert mock_transport.send_message.call_count == 1 + + def test_invalid_max_retries_raises_value_error( + self, mock_transport: AsyncMock + ) -> None: + with pytest.raises(ValueError, match='max_retries must be >= 0'): + RetryTransport(mock_transport, max_retries=-1) + + def test_invalid_base_delay_raises_value_error( + self, mock_transport: AsyncMock + ) -> None: + with pytest.raises(ValueError, match='base_delay must be > 0'): + RetryTransport(mock_transport, base_delay=0) + + def test_invalid_max_delay_raises_value_error( + self, mock_transport: AsyncMock + ) -> None: + with pytest.raises(ValueError, match='max_delay must be > 0'): + RetryTransport(mock_transport, max_delay=-1) + + @pytest.mark.asyncio + async def test_on_retry_exception_does_not_break_retry_loop( + self, mock_transport: AsyncMock + ) -> None: + """A buggy on_retry must not replace the original error or stop retries.""" + + def bad_callback(attempt: int, error: Exception, delay: float) -> None: + raise RuntimeError('on_retry blew up') + + transport = RetryTransport( + mock_transport, + max_retries=2, + base_delay=0.01, + jitter=False, + on_retry=bad_callback, + ) + expected = SendMessageResponse() + mock_transport.send_message.side_effect = [ + A2AClientTimeoutError('timeout'), + expected, + ] + + result = await transport.send_message(SendMessageRequest()) + + assert result is expected + assert mock_transport.send_message.call_count == 2 + + @pytest.mark.asyncio + async def test_cancelled_error_during_sleep_propagates( + self, mock_transport: AsyncMock + ) -> None: + """Cancellation during the retry backoff must not be swallowed.""" + transport = RetryTransport( + mock_transport, + max_retries=3, + base_delay=1.0, + jitter=False, + ) + mock_transport.send_message.side_effect = A2AClientTimeoutError( + 'timeout' + ) + + async def cancelling_sleep(*_args: object, **_kwargs: object) -> None: + raise asyncio.CancelledError + + with ( + patch( + 'a2a.client.transports.retry.asyncio.sleep', + side_effect=cancelling_sleep, + ), + pytest.raises(asyncio.CancelledError), + ): + await transport.send_message(SendMessageRequest()) + + # First attempt ran; cancel hit on the sleep before the second. + assert mock_transport.send_message.call_count == 1 + + @pytest.mark.asyncio + async def test_cancelled_error_from_transport_propagates( + self, mock_transport: AsyncMock + ) -> None: + """CancelledError raised by the inner transport bypasses retry.""" + mock_transport.send_message.side_effect = asyncio.CancelledError + transport = RetryTransport( + mock_transport, max_retries=3, base_delay=0.01, jitter=False + ) + with pytest.raises(asyncio.CancelledError): + await transport.send_message(SendMessageRequest()) + assert mock_transport.send_message.call_count == 1 + + @pytest.mark.asyncio + async def test_cancelled_error_from_streaming_transport_propagates( + self, mock_transport: AsyncMock + ) -> None: + """CancelledError raised by the streaming transport bypasses retry.""" + mock_transport.send_message_streaming.side_effect = ( + asyncio.CancelledError + ) + transport = RetryTransport( + mock_transport, max_retries=3, base_delay=0.01, jitter=False + ) + with pytest.raises(asyncio.CancelledError): + async for _event in transport.send_message_streaming( + SendMessageRequest() + ): + pass + assert mock_transport.send_message_streaming.call_count == 1 + + @pytest.mark.asyncio + async def test_on_retry_cancelled_error_propagates( + self, mock_transport: AsyncMock + ) -> None: + """CancelledError from on_retry must not be swallowed by the catch-all.""" + + async def cancelling_callback( + *_args: object, **_kwargs: object + ) -> None: + raise asyncio.CancelledError + + transport = RetryTransport( + mock_transport, + max_retries=2, + base_delay=0.01, + jitter=False, + on_retry=cancelling_callback, + ) + mock_transport.send_message.side_effect = A2AClientTimeoutError( + 'timeout' + ) + with pytest.raises(asyncio.CancelledError): + await transport.send_message(SendMessageRequest()) + assert mock_transport.send_message.call_count == 1 + + @pytest.mark.asyncio + async def test_streaming_inner_generator_closed_on_consumer_break( + self, mock_transport: AsyncMock + ) -> None: + """A consumer that breaks mid-stream must not leak the inner generator.""" + closed: list[bool] = [] + + async def long_stream() -> AsyncGenerator[StreamResponse]: + try: + yield StreamResponse() + yield StreamResponse() + yield StreamResponse() + finally: + closed.append(True) + + mock_transport.send_message_streaming.return_value = long_stream() + transport = RetryTransport( + mock_transport, + max_retries=3, + base_delay=0.01, + jitter=False, + ) + + outer = transport.send_message_streaming(SendMessageRequest()) + count = 0 + async for _event in outer: + count += 1 + if count == 1: + break + await outer.aclose() + + assert closed == [True] + + @pytest.mark.asyncio + async def test_streaming_inner_generator_closed_on_retry( + self, mock_transport: AsyncMock + ) -> None: + """Pre-stream failures must aclose() the inner generator before retry.""" + closed: list[int] = [] + + def make_failing_stream(idx: int) -> AsyncGenerator[StreamResponse]: + async def gen() -> AsyncGenerator[StreamResponse]: + try: + raise A2AClientTimeoutError(f'pre-stream failure {idx}') + yield StreamResponse() # pragma: no cover + finally: + closed.append(idx) + + return gen() + + async def success() -> AsyncGenerator[StreamResponse]: + yield StreamResponse() + + mock_transport.send_message_streaming.side_effect = [ + make_failing_stream(0), + make_failing_stream(1), + success(), + ] + + transport = RetryTransport( + mock_transport, + max_retries=3, + base_delay=0.001, + jitter=False, + ) + events = [ + event + async for event in transport.send_message_streaming( + SendMessageRequest() + ) + ] + assert len(events) == 1 + assert closed == [0, 1] + + def test_retry_predicate_and_on_retry_callback_aliases_are_exported( + self, + ) -> None: + from a2a.client.transports import ( # noqa: PLC0415 + OnRetryCallback, + RetryPredicate, + ) + + assert RetryPredicate is not None + assert OnRetryCallback is not None diff --git a/tests/integration/test_retry_integration.py b/tests/integration/test_retry_integration.py new file mode 100644 index 000000000..681bed93c --- /dev/null +++ b/tests/integration/test_retry_integration.py @@ -0,0 +1,226 @@ +from unittest.mock import ANY, AsyncMock + +import httpx +import pytest + +from starlette.applications import Starlette + +from a2a.client.base_client import BaseClient +from a2a.client.client import ClientConfig +from a2a.client.client_factory import ClientFactory +from a2a.client.transports.retry import RetryTransport +from a2a.server.request_handlers import RequestHandler +from a2a.server.routes import create_jsonrpc_routes, create_rest_routes +from a2a.types.a2a_pb2 import ( + AgentCapabilities, + AgentCard, + AgentInterface, + GetTaskRequest, + Message, + Part, + Role, + SendMessageRequest, + Task, + TaskState, + TaskStatus, +) +from a2a.utils.constants import TransportProtocol + + +TASK_RESPONSE = Task( + id='task-retry-integration', + context_id='ctx-retry-integration', + status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED), +) + + +def _wrap_with_transient_503(app, fail_count: int = 2): + state = {'count': 0} + + async def middleware(scope, receive, send): + if scope['type'] == 'http' and state['count'] < fail_count: + state['count'] += 1 + await send( + { + 'type': 'http.response.start', + 'status': 503, + 'headers': [[b'content-type', b'text/plain']], + } + ) + await send( + { + 'type': 'http.response.body', + 'body': b'Service Unavailable', + } + ) + return + await app(scope, receive, send) + + return middleware, state + + +@pytest.fixture +def mock_request_handler() -> AsyncMock: + handler = AsyncMock(spec=RequestHandler) + handler.on_get_task.return_value = TASK_RESPONSE + handler.on_message_send.return_value = TASK_RESPONSE + return handler + + +@pytest.fixture +def agent_card() -> AgentCard: + return AgentCard( + name='Retry Integration Agent', + description='Agent for retry integration testing.', + version='1.0.0', + capabilities=AgentCapabilities(streaming=False), + skills=[], + default_input_modes=['text/plain'], + default_output_modes=['text/plain'], + supported_interfaces=[ + AgentInterface( + protocol_binding=TransportProtocol.HTTP_JSON, + url='http://testserver', + ), + AgentInterface( + protocol_binding=TransportProtocol.JSONRPC, + url='http://testserver', + ), + ], + ) + + +@pytest.mark.asyncio +async def test_retry_with_client_factory_rest( + mock_request_handler: AsyncMock, + agent_card: AgentCard, +) -> None: + rest_routes = create_rest_routes(mock_request_handler) + app = Starlette(routes=[*rest_routes]) + failing_app, state = _wrap_with_transient_503(app, fail_count=2) + + httpx_client = httpx.AsyncClient( + transport=httpx.ASGITransport(app=failing_app), + ) + + factory = ClientFactory( + config=ClientConfig( + httpx_client=httpx_client, + supported_protocol_bindings=[TransportProtocol.HTTP_JSON], + ) + ) + client = factory.create(agent_card) + + assert isinstance(client, BaseClient) + original_transport = client._transport + client._transport = RetryTransport( + original_transport, + max_retries=3, + base_delay=0.01, + max_delay=0.1, + jitter=False, + ) + + params = GetTaskRequest(id=TASK_RESPONSE.id) + result = await client.get_task(request=params) + + assert result.id == TASK_RESPONSE.id + assert state['count'] == 2 + mock_request_handler.on_get_task.assert_awaited_once_with(params, ANY) + + await client.close() + + +@pytest.mark.asyncio +async def test_retry_with_client_factory_jsonrpc( + mock_request_handler: AsyncMock, + agent_card: AgentCard, +) -> None: + jsonrpc_routes = create_jsonrpc_routes( + request_handler=mock_request_handler, + rpc_url='/', + ) + app = Starlette(routes=[*jsonrpc_routes]) + failing_app, state = _wrap_with_transient_503(app, fail_count=2) + + httpx_client = httpx.AsyncClient( + transport=httpx.ASGITransport(app=failing_app), + ) + + factory = ClientFactory( + config=ClientConfig( + httpx_client=httpx_client, + supported_protocol_bindings=[TransportProtocol.JSONRPC], + ) + ) + client = factory.create(agent_card) + + assert isinstance(client, BaseClient) + original_transport = client._transport + client._transport = RetryTransport( + original_transport, + max_retries=3, + base_delay=0.01, + max_delay=0.1, + jitter=False, + ) + + params = GetTaskRequest(id=TASK_RESPONSE.id) + result = await client.get_task(request=params) + + assert result.id == TASK_RESPONSE.id + assert state['count'] == 2 + mock_request_handler.on_get_task.assert_awaited_once_with(params, ANY) + + await client.close() + + +@pytest.mark.asyncio +async def test_retry_send_message_blocking( + mock_request_handler: AsyncMock, + agent_card: AgentCard, +) -> None: + rest_routes = create_rest_routes(mock_request_handler) + app = Starlette(routes=[*rest_routes]) + failing_app, state = _wrap_with_transient_503(app, fail_count=1) + + httpx_client = httpx.AsyncClient( + transport=httpx.ASGITransport(app=failing_app), + ) + + factory = ClientFactory( + config=ClientConfig( + httpx_client=httpx_client, + supported_protocol_bindings=[TransportProtocol.HTTP_JSON], + ) + ) + client = factory.create(agent_card) + + assert isinstance(client, BaseClient) + # Disable streaming to force a single non-streaming call. + client._config.streaming = False + original_transport = client._transport + client._transport = RetryTransport( + original_transport, + max_retries=2, + base_delay=0.01, + jitter=False, + ) + + message_to_send = Message( + role=Role.ROLE_USER, + message_id='msg-retry-test', + parts=[Part(text='Hello retry')], + ) + params = SendMessageRequest(message=message_to_send) + + events = [event async for event in client.send_message(request=params)] + + assert len(events) == 1 + stream_response = events[0] + assert stream_response.HasField('task') + assert stream_response.task.id == TASK_RESPONSE.id + assert state['count'] == 1 + mock_request_handler.on_message_send.assert_awaited_once_with(params, ANY) + + await client.close()