Source code for ucp.core

# Copyright (c) 2019-2021, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020       UT-Battelle, LLC. All rights reserved.
# See file LICENSE for terms.

import array
import asyncio
import gc
import logging
import os
import re
import struct
import weakref
from functools import partial
from os import close as close_fd

from . import comm
from ._libs import ucx_api
from ._libs.arr import Array
from .continuous_ucx_progress import BlockingMode, NonBlockingMode
from .exceptions import UCXCanceled, UCXCloseError, UCXError
from .utils import get_event_loop, hash64bits

logger = logging.getLogger("ucx")

# The module should only instantiate one instance of the application context
# However, the init of CUDA must happen after all process forks thus we delay
# the instantiation of the application context to the first use of the API.
_ctx = None


def _get_ctx():
    global _ctx
    if _ctx is None:
        _ctx = ApplicationContext()
    return _ctx


async def exchange_peer_info(endpoint, msg_tag, ctrl_tag, listener):
    """Help function that exchange endpoint information"""

    # Pack peer information incl. a checksum
    fmt = "QQQ"
    my_info = struct.pack(fmt, msg_tag, ctrl_tag, hash64bits(msg_tag, ctrl_tag))
    peer_info = bytearray(len(my_info))
    my_info_arr = Array(my_info)
    peer_info_arr = Array(peer_info)

    # Send/recv peer information. Notice, we force an `await` between the two
    # streaming calls (see <https://github.com/rapidsai/ucx-py/pull/509>)
    if listener is True:
        await comm.stream_send(endpoint, my_info_arr, my_info_arr.nbytes)
        await comm.stream_recv(endpoint, peer_info_arr, peer_info_arr.nbytes)
    else:
        await comm.stream_recv(endpoint, peer_info_arr, peer_info_arr.nbytes)
        await comm.stream_send(endpoint, my_info_arr, my_info_arr.nbytes)

    # Unpacking and sanity check of the peer information
    ret = {}
    (ret["msg_tag"], ret["ctrl_tag"], ret["checksum"]) = struct.unpack(fmt, peer_info)

    expected_checksum = hash64bits(ret["msg_tag"], ret["ctrl_tag"])

    if expected_checksum != ret["checksum"]:
        raise RuntimeError(
            f'Checksum invalid! {hex(expected_checksum)} != {hex(ret["checksum"])}'
        )

    return ret


class CtrlMsg:
    """Implementation of control messages

    For now we have one opcode `1` which means shutdown.
    The opcode takes `close_after_n_recv`, which is the number of
    messages to receive before the worker should close.
    """

    fmt = "QQ"
    nbytes = struct.calcsize(fmt)

    @staticmethod
    def serialize(opcode, close_after_n_recv):
        return struct.pack(CtrlMsg.fmt, int(opcode), int(close_after_n_recv))

    @staticmethod
    def deserialize(serialized_bytes):
        return struct.unpack(CtrlMsg.fmt, serialized_bytes)

    @staticmethod
    def handle_ctrl_msg(ep_weakref, log, msg, future):
        """Function that is called when receiving the control message"""
        try:
            future.result()
        except UCXCanceled:
            return  # The ctrl signal was canceled
        logger.debug(log)
        ep = ep_weakref()
        if ep is None or ep.closed():
            if ep is not None:
                ep.abort()
            return  # The endpoint is closed

        opcode, close_after_n_recv = CtrlMsg.deserialize(msg)
        if opcode == 1:
            ep.close_after_n_recv(close_after_n_recv, count_from_ep_creation=True)
        else:
            raise UCXError("Received unknown control opcode: %s" % opcode)

    @staticmethod
    def setup_ctrl_recv(ep):
        """Help function to setup the receive of the control message"""
        log = "[Recv shutdown] ep: %s, tag: %s" % (
            hex(ep.uid),
            hex(ep._tags["ctrl_recv"]),
        )
        msg = bytearray(CtrlMsg.nbytes)
        msg_arr = Array(msg)
        shutdown_fut = comm.tag_recv(
            ep._ep, msg_arr, msg_arr.nbytes, ep._tags["ctrl_recv"], name=log
        )

        shutdown_fut.add_done_callback(
            partial(CtrlMsg.handle_ctrl_msg, weakref.ref(ep), log, msg)
        )


async def _listener_handler_coroutine(conn_request, ctx, func, endpoint_error_handling):
    # We create the Endpoint in five steps:
    #  1) Create endpoint from conn_request
    #  2) Generate unique IDs to use as tags
    #  3) Exchange endpoint info such as tags
    #  4) Setup control receive callback
    #  5) Execute the listener's callback function
    endpoint = ucx_api.UCXEndpoint.create_from_conn_request(
        ctx.worker, conn_request, endpoint_error_handling
    )

    seed = os.urandom(16)
    msg_tag = hash64bits("msg_tag", seed, endpoint.handle)
    ctrl_tag = hash64bits("ctrl_tag", seed, endpoint.handle)

    peer_info = await exchange_peer_info(
        endpoint=endpoint,
        msg_tag=msg_tag,
        ctrl_tag=ctrl_tag,
        listener=True,
    )
    tags = {
        "msg_send": peer_info["msg_tag"],
        "msg_recv": msg_tag,
        "ctrl_send": peer_info["ctrl_tag"],
        "ctrl_recv": ctrl_tag,
    }
    ep = Endpoint(endpoint=endpoint, ctx=ctx, tags=tags)

    logger.debug(
        "_listener_handler() server: %s, error handling: %s, msg-tag-send: %s, "
        "msg-tag-recv: %s, ctrl-tag-send: %s, ctrl-tag-recv: %s"
        % (
            hex(endpoint.handle),
            endpoint_error_handling,
            hex(ep._tags["msg_send"]),
            hex(ep._tags["msg_recv"]),
            hex(ep._tags["ctrl_send"]),
            hex(ep._tags["ctrl_recv"]),
        )
    )

    # Setup the control receive
    CtrlMsg.setup_ctrl_recv(ep)

    # Removing references here to avoid delayed clean up
    del ctx

    # Finally, we call `func`
    if asyncio.iscoroutinefunction(func):
        await func(ep)
    else:
        func(ep)


def _listener_handler(conn_request, callback_func, ctx, endpoint_error_handling):
    asyncio.ensure_future(
        _listener_handler_coroutine(
            conn_request,
            ctx,
            callback_func,
            endpoint_error_handling,
        )
    )


def _epoll_fd_finalizer(epoll_fd, progress_tasks):
    assert epoll_fd >= 0
    # Notice, progress_tasks must be cleared before we close
    # epoll_fd
    progress_tasks.clear()
    close_fd(epoll_fd)


class ApplicationContext:
    """
    The context of the Asyncio interface of UCX.
    """

    def __init__(self, config_dict={}, blocking_progress_mode=None):
        self.progress_tasks = []

        # For now, a application context only has one worker
        self.context = ucx_api.UCXContext(config_dict)
        self.worker = ucx_api.UCXWorker(self.context)

        if blocking_progress_mode is not None:
            self.blocking_progress_mode = blocking_progress_mode
        elif "UCXPY_NON_BLOCKING_MODE" in os.environ:
            self.blocking_progress_mode = False
        else:
            self.blocking_progress_mode = True

        if self.blocking_progress_mode:
            self.epoll_fd = self.worker.init_blocking_progress_mode()
            weakref.finalize(
                self, _epoll_fd_finalizer, self.epoll_fd, self.progress_tasks
            )

        # Ensure progress even before Endpoints get created, for example to
        # receive messages directly on a worker after a remote endpoint
        # connected with `create_endpoint_from_worker_address`.
        self.continuous_ucx_progress()

    def create_listener(
        self,
        callback_func,
        port=0,
        endpoint_error_handling=True,
    ):
        """Create and start a listener to accept incoming connections

        callback_func is the function or coroutine that takes one
        argument -- the Endpoint connected to the client.

        Notice, the listening is closed when the returned Listener
        goes out of scope thus remember to keep a reference to the object.

        Parameters
        ----------
        callback_func: function or coroutine
            A callback function that gets invoked when an incoming
            connection is accepted
        port: int, optional
            An unused port number for listening, or `0` to let UCX assign
            an unused port.
        endpoint_error_handling: boolean, optional
            If `True` (default) enable endpoint error handling raising
            exceptions when an error occurs, may incur in performance penalties
            but prevents a process from terminating unexpectedly that may
            happen when disabled. If `False` endpoint endpoint error handling
            is disabled.

        Returns
        -------
        Listener
            The new listener. When this object is deleted, the listening stops
        """
        self.continuous_ucx_progress()
        if port is None:
            port = 0

        logger.info("create_listener() - Start listening on port %d" % port)
        ret = Listener(
            ucx_api.UCXListener(
                worker=self.worker,
                port=port,
                cb_func=_listener_handler,
                cb_args=(callback_func, self, endpoint_error_handling),
            )
        )
        return ret

    async def create_endpoint(self, ip_address, port, endpoint_error_handling=True):
        """Create a new endpoint to a server

        Parameters
        ----------
        ip_address: str
            IP address of the server the endpoint should connect to
        port: int
            IP address of the server the endpoint should connect to
        endpoint_error_handling: boolean, optional
            If `True` (default) enable endpoint error handling raising
            exceptions when an error occurs, may incur in performance penalties
            but prevents a process from terminating unexpectedly that may
            happen when disabled. If `False` endpoint endpoint error handling
            is disabled.

        Returns
        -------
        Endpoint
            The new endpoint
        """
        self.continuous_ucx_progress()

        ucx_ep = ucx_api.UCXEndpoint.create(
            self.worker, ip_address, port, endpoint_error_handling
        )
        self.worker.progress()

        # We create the Endpoint in three steps:
        #  1) Generate unique IDs to use as tags
        #  2) Exchange endpoint info such as tags
        #  3) Use the info to create an endpoint
        seed = os.urandom(16)
        msg_tag = hash64bits("msg_tag", seed, ucx_ep.handle)
        ctrl_tag = hash64bits("ctrl_tag", seed, ucx_ep.handle)
        peer_info = await exchange_peer_info(
            endpoint=ucx_ep,
            msg_tag=msg_tag,
            ctrl_tag=ctrl_tag,
            listener=False,
        )
        tags = {
            "msg_send": peer_info["msg_tag"],
            "msg_recv": msg_tag,
            "ctrl_send": peer_info["ctrl_tag"],
            "ctrl_recv": ctrl_tag,
        }
        ep = Endpoint(endpoint=ucx_ep, ctx=self, tags=tags)

        logger.debug(
            "create_endpoint() client: %s, error handling: %s, msg-tag-send: %s, "
            "msg-tag-recv: %s, ctrl-tag-send: %s, ctrl-tag-recv: %s"
            % (
                hex(ep._ep.handle),
                endpoint_error_handling,
                hex(ep._tags["msg_send"]),
                hex(ep._tags["msg_recv"]),
                hex(ep._tags["ctrl_send"]),
                hex(ep._tags["ctrl_recv"]),
            )
        )

        # Setup the control receive
        CtrlMsg.setup_ctrl_recv(ep)
        return ep

    async def create_endpoint_from_worker_address(
        self,
        address,
        endpoint_error_handling=True,
    ):
        """Create a new endpoint to a server

        Parameters
        ----------
        address: UCXAddress
        endpoint_error_handling: boolean, optional
            If `True` (default) enable endpoint error handling raising
            exceptions when an error occurs, may incur in performance penalties
            but prevents a process from terminating unexpectedly that may
            happen when disabled. If `False` endpoint endpoint error handling
            is disabled.

        Returns
        -------
        Endpoint
            The new endpoint
        """
        self.continuous_ucx_progress()

        ucx_ep = ucx_api.UCXEndpoint.create_from_worker_address(
            self.worker,
            address,
            endpoint_error_handling,
        )
        self.worker.progress()

        ep = Endpoint(endpoint=ucx_ep, ctx=self, tags=None)

        logger.debug(
            "create_endpoint() client: %s, error handling: %s"
            % (hex(ep._ep.handle), endpoint_error_handling)
        )

        return ep

    def continuous_ucx_progress(self, event_loop=None):
        """Guarantees continuous UCX progress

        Use this function to associate UCX progress with an event loop.
        Notice, multiple event loops can be associate with UCX progress.

        This function is automatically called when calling
        `create_listener()` or `create_endpoint()`.

        Parameters
        ----------
        event_loop: asyncio.event_loop, optional
            The event loop to evoke UCX progress. If None,
            `ucp.utils.get_event_loop()` is used.
        """
        loop = event_loop or get_event_loop()
        if loop in self.progress_tasks:
            return  # Progress has already been guaranteed for the current event loop

        if self.blocking_progress_mode:
            task = BlockingMode(self.worker, loop, self.epoll_fd)
        else:
            task = NonBlockingMode(self.worker, loop)
        self.progress_tasks.append(task)

    def get_ucp_worker(self):
        """Returns the underlying UCP worker handle (ucp_worker_h)
        as a Python integer.
        """
        return self.worker.handle

    def get_config(self):
        """Returns all UCX configuration options as a dict.

        Returns
        -------
        dict
            The current UCX configuration options
        """
        return self.context.get_config()

    def ucp_context_info(self):
        """Return low-level UCX info about this endpoint as a string"""
        return self.context.info()

    def ucp_worker_info(self):
        """Return low-level UCX info about this endpoint as a string"""
        return self.worker.info()

    def fence(self):
        return self.worker.fence()

    async def flush(self):
        return await comm.flush_worker(self.worker)

    def get_worker_address(self):
        return self.worker.get_address()

    def register_am_allocator(self, allocator, allocator_type):
        """Register an allocator for received Active Messages.

        The allocator registered by this function is always called by the
        active message receive callback when an incoming message is
        available. The appropriate allocator is called depending on whether
        the message received is a host message or CUDA message.
        Note that CUDA messages can only be received via rendezvous, all
        eager messages are received on a host object.

        By default, the host allocator is `bytearray`. There is no default
        CUDA allocator and one must always be registered if CUDA is used.

        Parameters
        ----------
        allocator: callable
            An allocation function accepting exactly one argument, the
            size of the message receives.
        allocator_type: str
            The type of allocator, currently supports "host" and "cuda".
        """
        if allocator_type == "host":
            allocator_type = ucx_api.AllocatorType.HOST
        elif allocator_type == "cuda":
            allocator_type = ucx_api.AllocatorType.CUDA
        else:
            allocator_type = ucx_api.AllocatorType.UNSUPPORTED
        self.worker.register_am_allocator(allocator, allocator_type)

    @ucx_api.nvtx_annotate("UCXPY_WORKER_RECV", color="red", domain="ucxpy")
    async def recv(self, buffer, tag):
        """Receive directly on worker without a local Endpoint into `buffer`.

        Parameters
        ----------
        buffer: exposing the buffer protocol or array/cuda interface
            The buffer to receive into. Raise ValueError if buffer
            is smaller than nbytes or read-only.
        tag: hashable, optional
            Set a tag that must match the received message.
        """
        if not isinstance(buffer, Array):
            buffer = Array(buffer)
        nbytes = buffer.nbytes
        log = "[Worker Recv] worker: %s, tag: %s, nbytes: %d, type: %s" % (
            hex(self.worker.handle),
            hex(tag),
            nbytes,
            type(buffer.obj),
        )
        logger.debug(log)

        return await comm.tag_recv(self.worker, buffer, nbytes, tag, name=log)


[docs]class Listener: """A handle to the listening service started by `create_listener()` The listening continues as long as this object exist or `.close()` is called. Please use `create_listener()` to create an Listener. """ def __init__(self, backend): assert backend.initialized self._b = backend
[docs] def closed(self): """Is the listener closed?""" return not self._b.initialized
@property def ip(self): """The listening network IP address""" return self._b.ip @property def port(self): """The listening network port""" return self._b.port
[docs] def close(self): """Closing the listener""" self._b.close()
[docs]class Endpoint: """An endpoint represents a connection to a peer Please use `create_listener()` and `create_endpoint()` to create an Endpoint. """ def __init__(self, endpoint, ctx, tags=None): self._ep = endpoint self._ctx = ctx self._send_count = 0 # Number of calls to self.send() self._recv_count = 0 # Number of calls to self.recv() self._finished_recv_count = 0 # Number of returned (finished) self.recv() calls self._shutting_down_peer = False # Told peer to shutdown self._close_after_n_recv = None self._tags = tags @property def uid(self): """The unique ID of the underlying UCX endpoint""" return self._ep.handle
[docs] def closed(self): """Is this endpoint closed?""" return self._ep is None or not self._ep.initialized or not self._ep.is_alive()
[docs] def abort(self): """Close the communication immediately and abruptly. Useful in destructors or generators' ``finally`` blocks. Notice, this functions doesn't signal the connected peer to close. To do that, use `Endpoint.close()` """ if self._ep is not None: logger.debug("Endpoint.abort(): %s" % hex(self.uid)) self._ep.close() self._ep = None self._ctx = None
[docs] async def close(self): """Close the endpoint cleanly. This will attempt to flush outgoing buffers before actually closing the underlying UCX endpoint. """ if self.closed(): self.abort() return try: # Making sure we only tell peer to shutdown once if self._shutting_down_peer: return self._shutting_down_peer = True # Send a shutdown message to the peer msg = CtrlMsg.serialize(opcode=1, close_after_n_recv=self._send_count) msg_arr = Array(msg) log = "[Send shutdown] ep: %s, tag: %s, close_after_n_recv: %d" % ( hex(self.uid), hex(self._tags["ctrl_send"]), self._send_count, ) logger.debug(log) try: await comm.tag_send( self._ep, msg_arr, msg_arr.nbytes, self._tags["ctrl_send"], name=log ) # The peer might already be shutting down thus we can ignore any send errors except UCXError as e: logging.warning( "UCX failed closing worker %s (probably already closed): %s" % (hex(self.uid), repr(e)) ) finally: if not self.closed(): # Give all current outstanding send() calls a chance to return self._ctx.worker.progress() await asyncio.sleep(0) self.abort()
[docs] @ucx_api.nvtx_annotate("UCXPY_SEND", color="green", domain="ucxpy") async def send(self, buffer, tag=None, force_tag=False): """Send `buffer` to connected peer. Parameters ---------- buffer: exposing the buffer protocol or array/cuda interface The buffer to send. Raise ValueError if buffer is smaller than nbytes. tag: hashable, optional tag: hashable, optional Set a tag that the receiver must match. Currently the tag is hashed together with the internal Endpoint tag that is agreed with the remote end at connection time. To enforce using the user tag, make sure to specify `force_tag=True`. force_tag: bool If true, force using `tag` as is, otherwise the value specified with `tag` (if any) will be hashed with the internal Endpoint tag. """ self._ep.raise_on_error() if self.closed(): raise UCXCloseError("Endpoint closed") if not isinstance(buffer, Array): buffer = Array(buffer) if tag is None: tag = self._tags["msg_send"] elif not force_tag: tag = hash64bits(self._tags["msg_send"], hash(tag)) nbytes = buffer.nbytes log = "[Send #%03d] ep: %s, tag: %s, nbytes: %d, type: %s" % ( self._send_count, hex(self.uid), hex(tag), nbytes, type(buffer.obj), ) logger.debug(log) self._send_count += 1 try: return await comm.tag_send(self._ep, buffer, nbytes, tag, name=log) except UCXCanceled as e: # If self._ep has already been closed and destroyed, we reraise the # UCXCanceled exception. if self._ep is None: raise e
[docs] @ucx_api.nvtx_annotate("UCXPY_AM_SEND", color="green", domain="ucxpy") async def am_send(self, buffer): """Send `buffer` to connected peer. Parameters ---------- buffer: exposing the buffer protocol or array/cuda interface The buffer to send. Raise ValueError if buffer is smaller than nbytes. """ if self.closed(): raise UCXCloseError("Endpoint closed") if not isinstance(buffer, Array): buffer = Array(buffer) nbytes = buffer.nbytes log = "[AM Send #%03d] ep: %s, nbytes: %d, type: %s" % ( self._send_count, hex(self.uid), nbytes, type(buffer.obj), ) logger.debug(log) self._send_count += 1 return await comm.am_send(self._ep, buffer, nbytes, name=log)
[docs] @ucx_api.nvtx_annotate("UCXPY_RECV", color="red", domain="ucxpy") async def recv(self, buffer, tag=None, force_tag=False): """Receive from connected peer into `buffer`. Parameters ---------- buffer: exposing the buffer protocol or array/cuda interface The buffer to receive into. Raise ValueError if buffer is smaller than nbytes or read-only. tag: hashable, optional Set a tag that must match the received message. Currently the tag is hashed together with the internal Endpoint tag that is agreed with the remote end at connection time. To enforce using the user tag, make sure to specify `force_tag=True`. force_tag: bool If true, force using `tag` as is, otherwise the value specified with `tag` (if any) will be hashed with the internal Endpoint tag. """ if tag is None: tag = self._tags["msg_recv"] elif not force_tag: tag = hash64bits(self._tags["msg_recv"], hash(tag)) if not self._ctx.worker.tag_probe(tag): self._ep.raise_on_error() if self.closed(): raise UCXCloseError("Endpoint closed") if not isinstance(buffer, Array): buffer = Array(buffer) nbytes = buffer.nbytes log = "[Recv #%03d] ep: %s, tag: %s, nbytes: %d, type: %s" % ( self._recv_count, hex(self.uid), hex(tag), nbytes, type(buffer.obj), ) logger.debug(log) self._recv_count += 1 ret = await comm.tag_recv(self._ep, buffer, nbytes, tag, name=log) self._finished_recv_count += 1 if ( self._close_after_n_recv is not None and self._finished_recv_count >= self._close_after_n_recv ): self.abort() return ret
[docs] @ucx_api.nvtx_annotate("UCXPY_AM_RECV", color="red", domain="ucxpy") async def am_recv(self): """Receive from connected peer.""" if not self._ep.am_probe(): self._ep.raise_on_error() if self.closed(): raise UCXCloseError("Endpoint closed") log = "[AM Recv #%03d] ep: %s" % (self._recv_count, hex(self.uid)) logger.debug(log) self._recv_count += 1 ret = await comm.am_recv(self._ep, name=log) self._finished_recv_count += 1 if ( self._close_after_n_recv is not None and self._finished_recv_count >= self._close_after_n_recv ): self.abort() return ret
[docs] def cuda_support(self): """Return whether UCX is configured with CUDA support or not""" return self._ctx.context.cuda_support
[docs] def get_ucp_worker(self): """Returns the underlying UCP worker handle (ucp_worker_h) as a Python integer. """ return self._ctx.worker.handle
[docs] def get_ucp_endpoint(self): """Returns the underlying UCP endpoint handle (ucp_ep_h) as a Python integer. """ return self._ep.handle
[docs] def ucx_info(self): """Return low-level UCX info about this endpoint as a string""" return self._ep.info()
[docs] def close_after_n_recv(self, n, count_from_ep_creation=False): """Close the endpoint after `n` received messages. Parameters ---------- n: int Number of messages to received before closing the endpoint. count_from_ep_creation: bool, optional Whether to count `n` from this function call (default) or from the creation of the endpoint. """ if not count_from_ep_creation: n += self._finished_recv_count # Make `n` absolute if self._close_after_n_recv is not None: raise UCXError( "close_after_n_recv has already been set to: %d (abs)" % self._close_after_n_recv ) if n == self._finished_recv_count: self.abort() elif n > self._finished_recv_count: self._close_after_n_recv = n else: raise UCXError( "`n` cannot be less than current recv_count: %d (abs) < %d (abs)" % (n, self._finished_recv_count) )
[docs] async def send_obj(self, obj, tag=None): """Send `obj` to connected peer that calls `recv_obj()`. The transfer includes an extra message containing the size of `obj`, which increases the overhead slightly. Parameters ---------- obj: exposing the buffer protocol or array/cuda interface The object to send. tag: hashable, optional Set a tag that the receiver must match. Example ------- >>> await ep.send_obj(pickle.dumps([1,2,3])) """ if not isinstance(obj, Array): obj = Array(obj) nbytes = Array(array.array("Q", [obj.nbytes])) await self.send(nbytes, tag=tag) await self.send(obj, tag=tag)
[docs] async def recv_obj(self, tag=None, allocator=bytearray): """Receive from connected peer that calls `send_obj()`. As opposed to `recv()`, this function returns the received object. Data is received into a buffer allocated by `allocator`. The transfer includes an extra message containing the size of `obj`, which increses the overhead slightly. Parameters ---------- tag: hashable, optional Set a tag that must match the received message. Notice, currently UCX-Py doesn't support a "any tag" thus `tag=None` only matches a send that also sets `tag=None`. allocator: callabale, optional Function to allocate the received object. The function should take the number of bytes to allocate as input and return a new buffer of that size as output. Example ------- >>> await pickle.loads(ep.recv_obj()) """ nbytes = array.array("Q", [0]) await self.recv(nbytes, tag=tag) nbytes = nbytes[0] ret = allocator(nbytes) await self.recv(ret, tag=tag) return ret
async def flush(self): logger.debug("[Flush] ep: %s" % (hex(self.uid))) return await comm.flush_ep(self._ep)
[docs] def set_close_callback(self, callback_func): """Register a user callback function to be called on Endpoint's closing. Allows the user to register a callback function to be called when the Endpoint's error callback is called, or during its finalizer if the error callback is never called. Once the callback is called, it's not possible to send any more messages. However, receiving messages may still be possible, as UCP may still have incoming messages in transit. Parameters ---------- callback_func: callable The callback function to be called when the Endpoint's error callback is called, otherwise called on its finalizer. Example >>> ep.set_close_callback(lambda: print("Executing close callback")) """ self._ep.set_close_callback(callback_func)
# The following functions initialize and use a single ApplicationContext instance
[docs]def init(options={}, env_takes_precedence=False, blocking_progress_mode=None): """Initiate UCX. Usually this is done automatically at the first API call but this function makes it possible to set UCX options programmable. Alternatively, UCX options can be specified through environment variables. Parameters ---------- options: dict, optional UCX options send to the underlying UCX library env_takes_precedence: bool, optional Whether environment variables takes precedence over the `options` specified here. blocking_progress_mode: bool, optional If None, blocking UCX progress mode is used unless the environment variable `UCXPY_NON_BLOCKING_MODE` is defined. Otherwise, if True blocking mode is used and if False non-blocking mode is used. """ global _ctx if _ctx is not None: raise RuntimeError( "UCX is already initiated. Call reset() and init() " "in order to re-initate UCX with new options." ) options = options.copy() for k, v in options.items(): env_k = f"UCX_{k}" env_v = os.environ.get(env_k) if env_v is not None: if env_takes_precedence: options[k] = env_v logger.debug( f"Ignoring option {k}={v}; using environment {env_k}={env_v}" ) else: logger.debug( f"Ignoring environment {env_k}={env_v}; using option {k}={v}" ) _ctx = ApplicationContext(options, blocking_progress_mode=blocking_progress_mode)
[docs]def reset(): """Resets the UCX library by shutting down all of UCX. The library is initiated at next API call. """ global _ctx if _ctx is not None: weakref_ctx = weakref.ref(_ctx) _ctx = None gc.collect() if weakref_ctx() is not None: msg = ( "Trying to reset UCX but not all Endpoints and/or Listeners " "are closed(). The following objects are still referencing " "ApplicationContext: " ) for o in gc.get_referrers(weakref_ctx()): msg += "\n %s" % str(o) raise UCXError(msg)
[docs]def get_ucx_version(): """Return the version of the underlying UCX installation Notice, this function doesn't initialize UCX. Returns ------- tuple The version as a tuple e.g. (1, 7, 0) """ return ucx_api.get_ucx_version()
[docs]def progress(): """Try to progress the communication layer Warning, it is illegal to call this from a call-back function such as the call-back function given to create_listener. """ return _get_ctx().worker.progress()
[docs]def get_config(): """Returns all UCX configuration options as a dict. If UCX is uninitialized, the options returned are the options used if UCX were to be initialized now. Notice, this function doesn't initialize UCX. Returns ------- dict The current UCX configuration options """ if _ctx is None: return ucx_api.get_current_options() else: return _get_ctx().get_config()
def register_am_allocator(allocator, allocator_type): return _get_ctx().register_am_allocator(allocator, allocator_type)
[docs]def create_listener(callback_func, port=None, endpoint_error_handling=True): return _get_ctx().create_listener( callback_func, port, endpoint_error_handling=endpoint_error_handling, )
[docs]async def create_endpoint(ip_address, port, endpoint_error_handling=True): return await _get_ctx().create_endpoint( ip_address, port, endpoint_error_handling=endpoint_error_handling, )
async def create_endpoint_from_worker_address( address, endpoint_error_handling=True, ): return await _get_ctx().create_endpoint_from_worker_address( address, endpoint_error_handling=endpoint_error_handling, ) def continuous_ucx_progress(event_loop=None): _get_ctx().continuous_ucx_progress(event_loop=event_loop)
[docs]def get_ucp_worker(): return _get_ctx().get_ucp_worker()
def get_worker_address(): return _get_ctx().get_worker_address() def get_ucx_address_from_buffer(buffer): return ucx_api.UCXAddress.from_buffer(buffer) async def recv(buffer, tag): return await _get_ctx().recv(buffer, tag=tag) def get_ucp_context_info(): """Gets information on the current UCX context, obtained from `ucp_context_print_info`. """ return _get_ctx().ucp_context_info() def get_ucp_worker_info(): """Gets information on the current UCX worker, obtained from `ucp_worker_print_info`. """ return _get_ctx().ucp_worker_info() def get_active_transports(): """Returns a list of all transports that are available and are currently active in UCX, meaning UCX **may** use them depending on the type of transfers and how it is configured but is not required to do so. """ info = get_ucp_context_info() resources = re.findall("^#.*resource.*md.*dev.*flags.*$", info, re.MULTILINE) return set([r.split()[-1].split("/")[0] for r in resources]) async def flush(): """Flushes outstanding AMO and RMA operations. This ensures that the operations issued on this worker have completed both locally and remotely. This function does not guarantee ordering. """ if _ctx is not None: return await _get_ctx().flush() else: # If ctx is not initialized we still want to do the right thing by asyncio return await asyncio.sleep(0) def fence(): """Ensures ordering of non-blocking communication operations on the UCP worker. This function returns nothing, but will raise an error if it cannot make this guarantee. This function does not ensure any operations have completed. """ if _ctx is not None: _get_ctx().fence() # Setting the __doc__ create_listener.__doc__ = ApplicationContext.create_listener.__doc__ create_endpoint.__doc__ = ApplicationContext.create_endpoint.__doc__ continuous_ucx_progress.__doc__ = ApplicationContext.continuous_ucx_progress.__doc__ get_ucp_worker.__doc__ = ApplicationContext.get_ucp_worker.__doc__