Source code for cerebras.pytorch.backend

# Copyright 2016-2023 Cerebras Systems
# SPDX-License-Identifier: BSD-3-Clause

"""Directory containing the implementations of the various API backends."""
import inspect
from enum import Enum, auto
from typing import Optional
from warnings import warn

import torch


class BackendType(Enum):
    """
    The enum class used to distinguish which Cerebras backend to use.
    """

    CPU = auto()
    GPU = auto()
    # synonyms
    CSX = auto()
    WSE = auto()  # deprecated

    @property
    def is_cpu(self):
        """Returns True if the backend is for the CPU."""
        return self == BackendType.CPU

    @property
    def is_gpu(self):
        """Returns True if the backend is for the GPU."""
        return self == BackendType.GPU

    @property
    def is_csx(self):
        """Returns True if the backend is for the Cerebras wafer scaler cluster."""
        return self in (BackendType.CSX, BackendType.WSE)

    @staticmethod
    def from_str(backend_type: str):
        assert isinstance(backend_type, str)
        backend_type = backend_type.upper()
        if backend_type not in BackendType.__members__:
            raise ValueError(
                f"Invalid Cerebras PyTorch backend type specified. "
                f"Expected one of {list(BackendType.__members__)}. "
                f"Got {backend_type}. "
            )
        return BackendType[backend_type]


class BackendMeta(type):
    """
    The metaclass for Backend to ensure only one backend class is ever
    instantiated.
    """

    instance = {}

    def __call__(cls, *args, **kwargs):
        if cls not in cls.instance:
            cls.instance[cls] = super(BackendMeta, cls).__call__(
                *args, **kwargs
            )
        else:
            raise RuntimeError(
                f"Cannot instantiate multiple backends. "
                f"A backend with type {cls.instance[cls].backend_type.name} "
                f"has already been instantiated.\n"
                f"Use cstorch.backend() to access the existing backend, "
                f"or use cstorch.backend().torch_device to access the "
                f"current backend's torch device."
            )
        return cls.instance[cls]


class Backend(metaclass=BackendMeta):
    """Externally facing Cerebras backend class."""

    # Only if True, initialize the backend implementation
    _init_impl: bool = True

    def __init__(self, backend_type: BackendType, *args, **kwargs):
        assert isinstance(backend_type, BackendType)
        self.backend_type = backend_type

        if not self._init_impl:
            return

        if self.backend_type == BackendType.CSX:
            from .ltc_backend import PyTorchLtcBackendImpl

            self._impl = PyTorchLtcBackendImpl(
                self.backend_type, *args, **kwargs
            )

        elif self.backend_type == BackendType.CPU:
            from .cpu_backend import CpuBackendImpl

            self._impl = CpuBackendImpl(self.backend_type, *args, **kwargs)

        elif self.backend_type == BackendType.GPU:
            from .gpu_backend import GpuBackendImpl

            self._impl = GpuBackendImpl(self.backend_type, *args, **kwargs)

        else:
            raise ValueError(
                f"{self.backend_type.name} backend not yet supported. "
                f"Supported backends include: CSX, CPU, GPU"
            )

    @property
    def artifact_dir(self):
        """Returns the artifact directory being used by the backend."""
        return self._impl.artifact_dir

    @artifact_dir.setter
    def artifact_dir(self, value):
        """Sets the artifact directory for the backend."""
        self._impl.artifact_dir = value

    @property
    def device(self):
        """Returns the Cerebras device being used by the backend."""
        return self._impl.device

    @property
    def torch_device(self):
        """Returns the underlying PyTorch device being used by the backend."""
        return self._impl.device.torch_device

    @property
    def is_tracing(self):
        """Returns True if the backend is currently tracing a model."""
        return self._impl.is_tracing

    @property
    def is_e2e_execution(self):
        """Returns True if the backend is currently tracing a model"""
        return self._impl.is_e2e_execution

    @property
    def cluster_config(self):
        """Returns the cluster config if the backend is a CSX backend, otherwise None."""
        return getattr(self._impl, "cluster_config", None)

    # alias properties from backend type
    is_cpu = property(lambda self: self.backend_type.is_cpu)
    is_gpu = property(lambda self: self.backend_type.is_gpu)
    is_csx = property(lambda self: self.backend_type.is_csx)


def get_backend_args(backend_type: str):
    """
    Get the arguments for the backend class with the given backend type.

    Args:
        backend_type: The type of backend to get the arguments for.
            Must be one of "CSX", "CPU", "GPU"
    """
    if isinstance(backend_type, str):
        backend_type = BackendType.from_str(backend_type)

    if backend_type == BackendType.CSX:
        from .ltc_backend import PyTorchLtcBackendImpl as BackendImpl
    elif backend_type == BackendType.CPU:
        from .cpu_backend import CpuBackendImpl as BackendImpl
    elif backend_type == BackendType.GPU:
        from .gpu_backend import GpuBackendImpl as BackendImpl
    else:
        raise ValueError(
            f"{backend_type.name} backend not yet supported. "
            f"Supported backends include: CSX, CPU, GPU"
        )

    return inspect.signature(BackendImpl.__init__).parameters


# backend() queries the current backend, while backend(str, ...) creates a new one
[docs]def backend(backend_type: Optional[str] = None, *args, **kwargs): """ Instantiates a backend with the given type. Args: backend_type: The type of backend to instantiate. One of "CSX", "CPU", "GPU" If no backend_type is provided, returns the current backend if it exists. args: Positional arguments to pass to the backend implementation kwargs: Keyword arguments to pass to the backend implementation """ if backend_type is None: # if other args are given return a type error if len(args) != 0 or len(kwargs) != 0: raise TypeError( "Expected backend_type when constructing a backend. " "Either provide one of \"CSX\", \"CPU\", \"GPU\" to create a new backend, " "or call with no arguments to get the instance of the current backend, " "if it exists." ) return current_backend(raise_warning=False) if isinstance(backend_type, str): backend_type = BackendType.from_str(backend_type) elif not isinstance(backend_type, BackendType): raise TypeError( f"Expected backend_type to be of type BackendType, " "or a string representing the backend type. " f"Got: {type(backend_type)}" ) return Backend(backend_type, *args, **kwargs)
[docs]def current_backend(raise_exception: bool = True, raise_warning: bool = True): """DEPRECATED: Use cstorch.backend() instead. Gets instance of the current backend. Args: raise_exception: If True, raise an exception if no backend has been instantiated. Otherwise return None """ if raise_warning: warn( "cstorch.current_backend() is deprecated and will be removed in a future release. " "Use cstorch.backend() instead to access the current backend.", DeprecationWarning, ) if Backend not in BackendMeta.instance: if raise_exception: raise RuntimeError( "No active Cerebras backend found. Please make sure that " "your model has been prepared for compilation.\n" "You can do this using a call to:\n\n" "\tcompiled_model = cstorch.compile(model, backend=...)\n\n" "Or by explicitly instantiating a backend, e.g.\n\n" "\tbackend = cstorch.backend(...)" ) return None return BackendMeta.instance[Backend]
[docs]def current_torch_device(): """ Gets the torch device of the current backend. Returns torch.device('cpu') if no backend has been initialized yet """ _backend = current_backend(raise_exception=False, raise_warning=False) if _backend is None: return torch.device("cpu") # pylint: disable=protected-access return _backend._impl.torch_device
def current_backend_impl(raise_exception: bool = True): """Returns the implementation of the current backend class. Args: raise_exception: If True, raise an exception if no backend has been instantiated. Returns: The backend implementation if one exists, otherwise None. Raises: RuntimeError: If no backend has been instantiated and `raise_exception` is True. """ _backend = current_backend( raise_exception=raise_exception, raise_warning=False ) if _backend is None: return None # pylint: disable=protected-access return _backend._impl
[docs]def use_cs(): """Returns True if the active device is a CSX device.""" _backend = current_backend(raise_exception=False, raise_warning=False) return _backend is not None and _backend.is_csx