Source code for morphocut.torch

from contextlib import ExitStack
from typing import TYPE_CHECKING, Callable, List, Optional, Tuple, Union

from morphocut import Node, Output, RawOrVariable, ReturnOutputs
from morphocut._optional import UnavailableObject
from morphocut.batch import Batch


if TYPE_CHECKING:  # pragma: no cover
    import torch
else:
    try:
        import torch
    except ImportError:  # pragma: no cover
        torch = UnavailableObject("torch")


def _stack_pin(tensors: Union[Tuple["torch.Tensor"], List["torch.Tensor"]]):
    n_tensors = len(tensors)
    first = tensors[0]
    size = (n_tensors,) + first.size()
    out = torch.empty(size, dtype=first.dtype, pin_memory=True)
    torch.stack(tensors, out=out)
    return out


[docs]@ReturnOutputs @Output("output") class PyTorch(Node): """ Apply a PyTorch module to the input. Args: module (torch.nn.Module): PyTorch module. input (input, Variable): Input. device (str or torch.device, optional): Device. is_batch (bool, optional): Assume that input is a batch. output_key (optional): If the module has multiple outputs, output_key selects one of them. pin_memory (bool, optional): Use pinned memory for faster CPU-GPU transfer. Only applicable for CUDA devices. If None, enabled by default for CUDA devices. pre_transform (callable, optional): Transformation to apply to the individual input values. autocast (bool, optional): Enable automatic mixed precision inference to improve performance. Example: .. code-block:: python module = ... with Pipeline() as pipeline: input = ... output = PyTorch(module, input) """ def __init__( self, module: "torch.nn.Module", input: RawOrVariable, device: Union[None, str, "torch.device"] = None, is_batch=None, output_key=None, pin_memory=None, pre_transform: Optional[Callable] = None, autocast=False, ): super().__init__() if device is not None: device = torch.device(device) if pin_memory is None and device is not None: pin_memory = device.type == "cuda" self.device = device module = module.to(device) # Enable evaluation mode module.eval() self.model = module self.input = input self.is_batch = is_batch self.output_key = output_key self.pin_memory = pin_memory self.pre_transform = pre_transform if autocast and device is None: # pragma: no cover raise ValueError("Supply a device when using autocast.") self.autocast = autocast def transform(self, input): with ExitStack() as stack: stack.enter_context(torch.no_grad()) if self.autocast: stack.enter_context(torch.autocast(self.device.type)) # type: ignore is_batch = ( isinstance(input, Batch) if self.is_batch is None else self.is_batch ) # Assemble batch if is_batch: if self.pre_transform is not None: input = [self.pre_transform(inp) for inp in input] input = [torch.as_tensor(inp) for inp in input] input = _stack_pin(input) if self.pin_memory else torch.stack(input) is_batch = True else: if self.pre_transform is not None: input = self.pre_transform(input) input = torch.as_tensor(input) if not is_batch: # Add batch dimension input = input.unsqueeze(0) if self.device is not None: input = input.to(self.device) # type: ignore output = self.model(input) if self.output_key is not None: output = output[self.output_key] output = output.cpu().numpy() if not is_batch: # Remove batch dimension output = output[0] return output