kaolin.rep.TensorContainerBase

API

class kaolin.rep.TensorContainerBase

Bases: ABC

Abstract base class for dealing with containers of tensors.

Subclasses describe their attributes by implementing class_tensor_attributes(), class_other_attributes(), and check_tensor_attribute(); in exchange they get many useful utilities out of the box (see below). Tensor attributes may either be a single torch.Tensor or a dict of tensors.

General utility methods

These are inherited and operate on any subclass without further work.

  • to() - move tensor attributes to device or dtype

  • cuda(), cpu() - move tensor attributes to cuda/CPU devices

  • detach() - detach all tensor attributes

  • to_string(print_stats=True)() - easy inspection, also allows print(obj) to work

  • as_dict() - saves all attributes to dict, should be compatible with constructor

  • get_attributes() - return all non-None attribute names

  • assert_supported() - raises if attribute name is not supported

  • check_sanity() - checks all tensors for sanity (using subclass hooks)

Inheriting from TensorContainerBase

Inherit from TensorContainerBase whenever you want to manage any number of tensor attributes (and optionally non-tensor attributes) and reuse the utilities above. To inherit, simply define (e.g. see PointSamples):

import torch
from kaolin.rep import TensorContainerBase

class MyContainer(TensorContainerBase):
    @classmethod
    def class_tensor_attributes(cls):
        return ["data", "extras"]

    @classmethod
    def class_other_attributes(cls):
        return ["label"]

    def __init__(self, data, extras=None, label="default"):
        self.data = data
        self.extras = extras
        self.label = label

    def check_tensor_attribute(self, attr, log_error=False):
        value = getattr(self, attr)
        if attr == "data":
            return torch.is_tensor(value)
        if attr == "extras":
            return value is None or (
                isinstance(value, dict) and all(torch.is_tensor(v) for v in value.values()))
        return False

    # Optional: override to validate non-tensor attributes
    def check_other_attribute(self, attr, log_error=False):
        if attr == "label":
            return isinstance(self.label, str)
        return False

instance = MyContainer(data=torch.randn(4, 3), extras={"a": torch.randn(4, 2)})
print(instance)                          # uses to_string()
moved = instance.to(dtype=torch.float64) # all tensor attributes converted
assert instance.check_sanity()
as_dict(only_tensors=False) Dict[str, Any]

Return all non-None attributes as a {name: value} dict.

classmethod assert_supported(attr)

Raises an exception if class does not support provided attribute name.

check_other_attribute(attr, log_error=False)

Checks a non-tensor attribute validity; returns True if ok.

check_sanity(log_error=True)

Validates that all tensor attributes are correct; implement abstract methods.

Parameters

log_error (bool) – If True, logs each failed check via logger.error.

Returns

True if all checks pass, False otherwise.

Return type

bool

abstract check_tensor_attribute(attr, log_error=False)

Checks tensor attribute validity; returns True if ok.

abstract classmethod class_other_attributes()

Returns attribute names that are not PyTorch tensors or dicts thereof.

abstract classmethod class_tensor_attributes()

Returns attribute names that are PyTorch tensors or dicts thereof.

cpu(attributes: Optional[Sequence[str]] = None)

Calls cpu on all or selected tensor attributes; returns a shallow copy.

cuda(device: Optional[Union[int, device, str]] = None, attributes: Optional[Sequence[str]] = None)

Calls cuda on all or selected tensor attributes; returns a shallow copy.

describe_attribute(attr, print_stats=False, detailed=False)

Outputs an informative string about an attribute; the same method used for all attributes in to_string.

Parameters
  • attr (str) – attribute name

  • print_stats (bool) – if to print statistics about values in each tensor

  • detailed (bool) – if to include additional information about each tensor

Returns

multi-line string with attribute information

Return type

(str)

Raises

ValueError if attr is not supported

detach(attributes: Optional[Sequence[str]] = None)

Detaches all or selected tensor attributes; returns a shallow copy.

get_attributes(only_tensors=False)

Returns names of all attributes that are currently set to non-None value in this class instance.

Parameters

only_tensors – if true, will only include tensor attributes

Returns

list of string names

Return type

(list)

to(*args: Any, attributes: Optional[Sequence[str]] = None, **kwargs: Any)

Moves or casts tensors like torch.Tensor.to() / torch.nn.Module.to().

Parameters
  • *args – forwarded to tensor.to(*args)

  • attributes (list of str, optional) – if set, only these tensor attributes are updated

  • **kwargs – forwarded to tensor.to(**kwargs)

Returns

shallow copy with converted tensors

Return type

PointSamples

to_string(print_stats=False, detailed=False)

Returns information about tensor attributes currently contained in the object.

Parameters
  • print_stats (bool) – if to print statistics about values in each tensor

  • detailed (bool) – if to include additional information about each tensor

Returns

multi-line string with attribute information

Return type

(str)