Source code for grpc.aio._channel

# Copyright 2019 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Invocation-side implementation of gRPC Asyncio Python."""

import asyncio
import sys
from typing import Any, Iterable, List, Optional, Sequence

import grpc
from grpc import _common
from grpc import _compression
from grpc import _grpcio_metadata
from grpc._cython import cygrpc

from . import _base_call
from . import _base_channel
from ._call import StreamStreamCall
from ._call import StreamUnaryCall
from ._call import UnaryStreamCall
from ._call import UnaryUnaryCall
from ._interceptor import ClientInterceptor
from ._interceptor import InterceptedStreamStreamCall
from ._interceptor import InterceptedStreamUnaryCall
from ._interceptor import InterceptedUnaryStreamCall
from ._interceptor import InterceptedUnaryUnaryCall
from ._interceptor import StreamStreamClientInterceptor
from ._interceptor import StreamUnaryClientInterceptor
from ._interceptor import UnaryStreamClientInterceptor
from ._interceptor import UnaryUnaryClientInterceptor
from ._metadata import Metadata
from ._typing import ChannelArgumentType
from ._typing import DeserializingFunction
from ._typing import MetadataType
from ._typing import RequestIterableType
from ._typing import RequestType
from ._typing import ResponseType
from ._typing import SerializingFunction
from ._utils import _timeout_to_deadline

_USER_AGENT = "grpc-python-asyncio/{}".format(_grpcio_metadata.__version__)

if sys.version_info[1] < 7:

    def _all_tasks() -> Iterable[asyncio.Task]:
        return asyncio.Task.all_tasks()  # pylint: disable=no-member

else:

    def _all_tasks() -> Iterable[asyncio.Task]:
        return asyncio.all_tasks()


def _augment_channel_arguments(
    base_options: ChannelArgumentType, compression: Optional[grpc.Compression]
):
    compression_channel_argument = _compression.create_channel_option(
        compression
    )
    user_agent_channel_argument = (
        (
            cygrpc.ChannelArgKey.primary_user_agent_string,
            _USER_AGENT,
        ),
    )
    return (
        tuple(base_options)
        + compression_channel_argument
        + user_agent_channel_argument
    )


class _BaseMultiCallable:
    """Base class of all multi callable objects.

    Handles the initialization logic and stores common attributes.
    """

    _loop: asyncio.AbstractEventLoop
    _channel: cygrpc.AioChannel
    _method: bytes
    _request_serializer: SerializingFunction
    _response_deserializer: DeserializingFunction
    _interceptors: Optional[Sequence[ClientInterceptor]]
    _references: List[Any]
    _loop: asyncio.AbstractEventLoop

    # pylint: disable=too-many-arguments
    def __init__(
        self,
        channel: cygrpc.AioChannel,
        method: bytes,
        request_serializer: SerializingFunction,
        response_deserializer: DeserializingFunction,
        interceptors: Optional[Sequence[ClientInterceptor]],
        references: List[Any],
        loop: asyncio.AbstractEventLoop,
    ) -> None:
        self._loop = loop
        self._channel = channel
        self._method = method
        self._request_serializer = request_serializer
        self._response_deserializer = response_deserializer
        self._interceptors = interceptors
        self._references = references

    @staticmethod
    def _init_metadata(
        metadata: Optional[MetadataType] = None,
        compression: Optional[grpc.Compression] = None,
    ) -> Metadata:
        """Based on the provided values for <metadata> or <compression> initialise the final
        metadata, as it should be used for the current call.
        """
        metadata = metadata or Metadata()
        if not isinstance(metadata, Metadata) and isinstance(metadata, tuple):
            metadata = Metadata.from_tuple(metadata)
        if compression:
            metadata = Metadata(
                *_compression.augment_metadata(metadata, compression)
            )
        return metadata


class UnaryUnaryMultiCallable(
    _BaseMultiCallable, _base_channel.UnaryUnaryMultiCallable
):
    def __call__(
        self,
        request: RequestType,
        *,
        timeout: Optional[float] = None,
        metadata: Optional[MetadataType] = None,
        credentials: Optional[grpc.CallCredentials] = None,
        wait_for_ready: Optional[bool] = None,
        compression: Optional[grpc.Compression] = None,
    ) -> _base_call.UnaryUnaryCall[RequestType, ResponseType]:
        metadata = self._init_metadata(metadata, compression)
        if not self._interceptors:
            call = UnaryUnaryCall(
                request,
                _timeout_to_deadline(timeout),
                metadata,
                credentials,
                wait_for_ready,
                self._channel,
                self._method,
                self._request_serializer,
                self._response_deserializer,
                self._loop,
            )
        else:
            call = InterceptedUnaryUnaryCall(
                self._interceptors,
                request,
                timeout,
                metadata,
                credentials,
                wait_for_ready,
                self._channel,
                self._method,
                self._request_serializer,
                self._response_deserializer,
                self._loop,
            )

        return call


class UnaryStreamMultiCallable(
    _BaseMultiCallable, _base_channel.UnaryStreamMultiCallable
):
    def __call__(
        self,
        request: RequestType,
        *,
        timeout: Optional[float] = None,
        metadata: Optional[MetadataType] = None,
        credentials: Optional[grpc.CallCredentials] = None,
        wait_for_ready: Optional[bool] = None,
        compression: Optional[grpc.Compression] = None,
    ) -> _base_call.UnaryStreamCall[RequestType, ResponseType]:
        metadata = self._init_metadata(metadata, compression)

        if not self._interceptors:
            call = UnaryStreamCall(
                request,
                _timeout_to_deadline(timeout),
                metadata,
                credentials,
                wait_for_ready,
                self._channel,
                self._method,
                self._request_serializer,
                self._response_deserializer,
                self._loop,
            )
        else:
            call = InterceptedUnaryStreamCall(
                self._interceptors,
                request,
                timeout,
                metadata,
                credentials,
                wait_for_ready,
                self._channel,
                self._method,
                self._request_serializer,
                self._response_deserializer,
                self._loop,
            )

        return call


class StreamUnaryMultiCallable(
    _BaseMultiCallable, _base_channel.StreamUnaryMultiCallable
):
    def __call__(
        self,
        request_iterator: Optional[RequestIterableType] = None,
        timeout: Optional[float] = None,
        metadata: Optional[MetadataType] = None,
        credentials: Optional[grpc.CallCredentials] = None,
        wait_for_ready: Optional[bool] = None,
        compression: Optional[grpc.Compression] = None,
    ) -> _base_call.StreamUnaryCall:
        metadata = self._init_metadata(metadata, compression)

        if not self._interceptors:
            call = StreamUnaryCall(
                request_iterator,
                _timeout_to_deadline(timeout),
                metadata,
                credentials,
                wait_for_ready,
                self._channel,
                self._method,
                self._request_serializer,
                self._response_deserializer,
                self._loop,
            )
        else:
            call = InterceptedStreamUnaryCall(
                self._interceptors,
                request_iterator,
                timeout,
                metadata,
                credentials,
                wait_for_ready,
                self._channel,
                self._method,
                self._request_serializer,
                self._response_deserializer,
                self._loop,
            )

        return call


class StreamStreamMultiCallable(
    _BaseMultiCallable, _base_channel.StreamStreamMultiCallable
):
    def __call__(
        self,
        request_iterator: Optional[RequestIterableType] = None,
        timeout: Optional[float] = None,
        metadata: Optional[MetadataType] = None,
        credentials: Optional[grpc.CallCredentials] = None,
        wait_for_ready: Optional[bool] = None,
        compression: Optional[grpc.Compression] = None,
    ) -> _base_call.StreamStreamCall:
        metadata = self._init_metadata(metadata, compression)

        if not self._interceptors:
            call = StreamStreamCall(
                request_iterator,
                _timeout_to_deadline(timeout),
                metadata,
                credentials,
                wait_for_ready,
                self._channel,
                self._method,
                self._request_serializer,
                self._response_deserializer,
                self._loop,
            )
        else:
            call = InterceptedStreamStreamCall(
                self._interceptors,
                request_iterator,
                timeout,
                metadata,
                credentials,
                wait_for_ready,
                self._channel,
                self._method,
                self._request_serializer,
                self._response_deserializer,
                self._loop,
            )

        return call


class Channel(_base_channel.Channel):
    _loop: asyncio.AbstractEventLoop
    _channel: cygrpc.AioChannel
    _unary_unary_interceptors: List[UnaryUnaryClientInterceptor]
    _unary_stream_interceptors: List[UnaryStreamClientInterceptor]
    _stream_unary_interceptors: List[StreamUnaryClientInterceptor]
    _stream_stream_interceptors: List[StreamStreamClientInterceptor]

    def __init__(
        self,
        target: str,
        options: ChannelArgumentType,
        credentials: Optional[grpc.ChannelCredentials],
        compression: Optional[grpc.Compression],
        interceptors: Optional[Sequence[ClientInterceptor]],
    ):
        """Constructor.

        Args:
          target: The target to which to connect.
          options: Configuration options for the channel.
          credentials: A cygrpc.ChannelCredentials or None.
          compression: An optional value indicating the compression method to be
            used over the lifetime of the channel.
          interceptors: An optional list of interceptors that would be used for
            intercepting any RPC executed with that channel.
        """
        self._unary_unary_interceptors = []
        self._unary_stream_interceptors = []
        self._stream_unary_interceptors = []
        self._stream_stream_interceptors = []

        if interceptors is not None:
            for interceptor in interceptors:
                if isinstance(interceptor, UnaryUnaryClientInterceptor):
                    self._unary_unary_interceptors.append(interceptor)
                elif isinstance(interceptor, UnaryStreamClientInterceptor):
                    self._unary_stream_interceptors.append(interceptor)
                elif isinstance(interceptor, StreamUnaryClientInterceptor):
                    self._stream_unary_interceptors.append(interceptor)
                elif isinstance(interceptor, StreamStreamClientInterceptor):
                    self._stream_stream_interceptors.append(interceptor)
                else:
                    raise ValueError(
                        "Interceptor {} must be ".format(interceptor)
                        + "{} or ".format(UnaryUnaryClientInterceptor.__name__)
                        + "{} or ".format(UnaryStreamClientInterceptor.__name__)
                        + "{} or ".format(StreamUnaryClientInterceptor.__name__)
                        + "{}. ".format(StreamStreamClientInterceptor.__name__)
                    )

        self._loop = cygrpc.get_working_loop()
        self._channel = cygrpc.AioChannel(
            _common.encode(target),
            _augment_channel_arguments(options, compression),
            credentials,
            self._loop,
        )

    async def __aenter__(self):
        return self

    async def __aexit__(self, exc_type, exc_val, exc_tb):
        await self._close(None)

    async def _close(self, grace):  # pylint: disable=too-many-branches
        if self._channel.closed():
            return

        # No new calls will be accepted by the Cython channel.
        self._channel.closing()

        # Iterate through running tasks
        tasks = _all_tasks()
        calls = []
        call_tasks = []
        for task in tasks:
            try:
                stack = task.get_stack(limit=1)
            except AttributeError as attribute_error:
                # NOTE(lidiz) tl;dr: If the Task is created with a CPython
                # object, it will trigger AttributeError.
                #
                # In the global finalizer, the event loop schedules
                # a CPython PyAsyncGenAThrow object.
                # https://github.com/python/cpython/blob/00e45877e33d32bb61aa13a2033e3bba370bda4d/Lib/asyncio/base_events.py#L484
                #
                # However, the PyAsyncGenAThrow object is written in C and
                # failed to include the normal Python frame objects. Hence,
                # this exception is a false negative, and it is safe to ignore
                # the failure. It is fixed by https://github.com/python/cpython/pull/18669,
                # but not available until 3.9 or 3.8.3. So, we have to keep it
                # for a while.
                # TODO(lidiz) drop this hack after 3.8 deprecation
                if "frame" in str(attribute_error):
                    continue
                else:
                    raise

            # If the Task is created by a C-extension, the stack will be empty.
            if not stack:
                continue

            # Locate ones created by `aio.Call`.
            frame = stack[0]
            candidate = frame.f_locals.get("self")
            if candidate:
                if isinstance(candidate, _base_call.Call):
                    if hasattr(candidate, "_channel"):
                        # For intercepted Call object
                        if candidate._channel is not self._channel:
                            continue
                    elif hasattr(candidate, "_cython_call"):
                        # For normal Call object
                        if candidate._cython_call._channel is not self._channel:
                            continue
                    else:
                        # Unidentified Call object
                        raise cygrpc.InternalError(
                            f"Unrecognized call object: {candidate}"
                        )

                    calls.append(candidate)
                    call_tasks.append(task)

        # If needed, try to wait for them to finish.
        # Call objects are not always awaitables.
        if grace and call_tasks:
            await asyncio.wait(call_tasks, timeout=grace)

        # Time to cancel existing calls.
        for call in calls:
            call.cancel()

        # Destroy the channel
        self._channel.close()

    async def close(self, grace: Optional[float] = None):
        await self._close(grace)

    def __del__(self):
        if hasattr(self, "_channel"):
            if not self._channel.closed():
                self._channel.close()

    def get_state(
        self, try_to_connect: bool = False
    ) -> grpc.ChannelConnectivity:
        result = self._channel.check_connectivity_state(try_to_connect)
        return _common.CYGRPC_CONNECTIVITY_STATE_TO_CHANNEL_CONNECTIVITY[result]

    async def wait_for_state_change(
        self,
        last_observed_state: grpc.ChannelConnectivity,
    ) -> None:
        assert await self._channel.watch_connectivity_state(
            last_observed_state.value[0], None
        )

    async def channel_ready(self) -> None:
        state = self.get_state(try_to_connect=True)
        while state != grpc.ChannelConnectivity.READY:
            await self.wait_for_state_change(state)
            state = self.get_state(try_to_connect=True)

    def unary_unary(
        self,
        method: str,
        request_serializer: Optional[SerializingFunction] = None,
        response_deserializer: Optional[DeserializingFunction] = None,
    ) -> UnaryUnaryMultiCallable:
        return UnaryUnaryMultiCallable(
            self._channel,
            _common.encode(method),
            request_serializer,
            response_deserializer,
            self._unary_unary_interceptors,
            [self],
            self._loop,
        )

    def unary_stream(
        self,
        method: str,
        request_serializer: Optional[SerializingFunction] = None,
        response_deserializer: Optional[DeserializingFunction] = None,
    ) -> UnaryStreamMultiCallable:
        return UnaryStreamMultiCallable(
            self._channel,
            _common.encode(method),
            request_serializer,
            response_deserializer,
            self._unary_stream_interceptors,
            [self],
            self._loop,
        )

    def stream_unary(
        self,
        method: str,
        request_serializer: Optional[SerializingFunction] = None,
        response_deserializer: Optional[DeserializingFunction] = None,
    ) -> StreamUnaryMultiCallable:
        return StreamUnaryMultiCallable(
            self._channel,
            _common.encode(method),
            request_serializer,
            response_deserializer,
            self._stream_unary_interceptors,
            [self],
            self._loop,
        )

    def stream_stream(
        self,
        method: str,
        request_serializer: Optional[SerializingFunction] = None,
        response_deserializer: Optional[DeserializingFunction] = None,
    ) -> StreamStreamMultiCallable:
        return StreamStreamMultiCallable(
            self._channel,
            _common.encode(method),
            request_serializer,
            response_deserializer,
            self._stream_stream_interceptors,
            [self],
            self._loop,
        )


[docs]def insecure_channel( target: str, options: Optional[ChannelArgumentType] = None, compression: Optional[grpc.Compression] = None, interceptors: Optional[Sequence[ClientInterceptor]] = None, ): """Creates an insecure asynchronous Channel to a server. Args: target: The server address options: An optional list of key-value pairs (:term:`channel_arguments` in gRPC Core runtime) to configure the channel. compression: An optional value indicating the compression method to be used over the lifetime of the channel. interceptors: An optional sequence of interceptors that will be executed for any call executed with this channel. Returns: A Channel. """ return Channel( target, () if options is None else options, None, compression, interceptors, )
[docs]def secure_channel( target: str, credentials: grpc.ChannelCredentials, options: Optional[ChannelArgumentType] = None, compression: Optional[grpc.Compression] = None, interceptors: Optional[Sequence[ClientInterceptor]] = None, ): """Creates a secure asynchronous Channel to a server. Args: target: The server address. credentials: A ChannelCredentials instance. options: An optional list of key-value pairs (:term:`channel_arguments` in gRPC Core runtime) to configure the channel. compression: An optional value indicating the compression method to be used over the lifetime of the channel. interceptors: An optional sequence of interceptors that will be executed for any call executed with this channel. Returns: An aio.Channel. """ return Channel( target, () if options is None else options, credentials._credentials, compression, interceptors, )