kaolin.rep.TensorContainerBase¶
API¶
- class kaolin.rep.TensorContainerBase¶
Bases:
ABCAbstract base class for dealing with containers of tensors.
Subclasses describe their attributes by implementing
class_tensor_attributes(),class_other_attributes(), andcheck_tensor_attribute(); in exchange they get many useful utilities out of the box (see below). Tensor attributes may either be a singletorch.Tensoror adictof tensors.General utility methods
These are inherited and operate on any subclass without further work.
to()- move tensor attributes todeviceordtypedetach()- detach all tensor attributesto_string(print_stats=True)()- easy inspection, also allowsprint(obj)to workas_dict()- saves all attributes to dict, should be compatible with constructorget_attributes()- return all non-Noneattribute namesassert_supported()- raises if attribute name is not supportedcheck_sanity()- checks all tensors for sanity (using subclass hooks)
Inheriting from TensorContainerBase
Inherit from
TensorContainerBasewhenever you want to manage any number of tensor attributes (and optionally non-tensor attributes) and reuse the utilities above. To inherit, simply define (e.g. seePointSamples):import torch from kaolin.rep import TensorContainerBase class MyContainer(TensorContainerBase): @classmethod def class_tensor_attributes(cls): return ["data", "extras"] @classmethod def class_other_attributes(cls): return ["label"] def __init__(self, data, extras=None, label="default"): self.data = data self.extras = extras self.label = label def check_tensor_attribute(self, attr, log_error=False): value = getattr(self, attr) if attr == "data": return torch.is_tensor(value) if attr == "extras": return value is None or ( isinstance(value, dict) and all(torch.is_tensor(v) for v in value.values())) return False # Optional: override to validate non-tensor attributes def check_other_attribute(self, attr, log_error=False): if attr == "label": return isinstance(self.label, str) return False instance = MyContainer(data=torch.randn(4, 3), extras={"a": torch.randn(4, 2)}) print(instance) # uses to_string() moved = instance.to(dtype=torch.float64) # all tensor attributes converted assert instance.check_sanity()
- classmethod assert_supported(attr)¶
Raises an exception if class does not support provided attribute name.
- check_other_attribute(attr, log_error=False)¶
Checks a non-tensor attribute validity; returns True if ok.
- check_sanity(log_error=True)¶
Validates that all tensor attributes are correct; implement abstract methods.
- abstract check_tensor_attribute(attr, log_error=False)¶
Checks tensor attribute validity; returns True if ok.
- abstract classmethod class_other_attributes()¶
Returns attribute names that are not PyTorch tensors or dicts thereof.
- abstract classmethod class_tensor_attributes()¶
Returns attribute names that are PyTorch tensors or dicts thereof.
- cpu(attributes: Optional[Sequence[str]] = None)¶
Calls
cpuon all or selected tensor attributes; returns a shallow copy.
- cuda(device: Optional[Union[int, device, str]] = None, attributes: Optional[Sequence[str]] = None)¶
Calls
cudaon all or selected tensor attributes; returns a shallow copy.
- describe_attribute(attr, print_stats=False, detailed=False)¶
Outputs an informative string about an attribute; the same method used for all attributes in
to_string.
- detach(attributes: Optional[Sequence[str]] = None)¶
Detaches all or selected tensor attributes; returns a shallow copy.
- get_attributes(only_tensors=False)¶
Returns names of all attributes that are currently set to non-None value in this class instance.
- Parameters
only_tensors – if true, will only include tensor attributes
- Returns
list of string names
- Return type
(list)
- to(*args: Any, attributes: Optional[Sequence[str]] = None, **kwargs: Any)¶
Moves or casts tensors like
torch.Tensor.to()/torch.nn.Module.to().- Parameters
*args – forwarded to
tensor.to(*args)attributes (list of str, optional) – if set, only these tensor attributes are updated
**kwargs – forwarded to
tensor.to(**kwargs)
- Returns
shallow copy with converted tensors
- Return type
- to_string(print_stats=False, detailed=False)¶
Returns information about tensor attributes currently contained in the object.