kaolin.rep.PointSamples

API

class kaolin.rep.PointSamples(positions, features=None, transform=None, strict_checks: bool = True)

Bases: TensorContainerBase

Base container for point-based 3D representations built over PyTorch tensors.

Stores positions of shape (N, 3) and optional features as a tensor of any dimensionality with shape (N, ...) or a dict of tensors (N, ...), where different tensors can have any different dimensionalities, so long as the first dimension is N. This class is implemented in a general way, such that subclasses inherit useful point-based utilities even when adding extra attributes.

General utility methods

These are inherited from TensorContainerBase, but still work on new attributes.

  • to() - move tensor attributes to device or dtype

  • cuda(), cpu() - move tensor attributes to cuda/CPU devices.

  • detach() - detach all tensor attributes

  • to_string(print_stats=True)() - easy inspection, also allows print(obj) to work

  • as_dict() - saves all attributes to dict, compatible with constructor PointSamples(**dict_output)

  • get_attributes() - return all non-None attribute names

  • assert_supported() - returns True if attribute name is supported

  • check_sanity() - checks all tensors for sanity

Point-specific utility methods

These are point-specific utilities, and will be inherited by subclasses.

  • cat() - concatenate all tensors along the point dimension, including features (override _custom_attr_cat() to customize)

  • as_transformed() - apply affinetransform to positions (override to customize)

  • [mask](get) - return new instance with point properties masked by boolean mask

  • [mask](set) - sets masked set of properties from same length class instance

  • len - return number of points

Inheriting from PointSamples

Inherit from PointSamples whenever you want to manage any number of attributes associated to points and want to define custom behavior for some of the attributes. To inherit, simply define (e.g. see GaussianSplatModel):

class MyAugmentedPoints(PointSamples):
    @classmethod
    def class_tensor_attributes(cls):
        return ["positions", "transform", "features"] + ... custom attributes

    @classmethod
    def class_other_attributes(cls):
        return [] + ... custom non-tensor attributes

    @classmethod
    def class_point_attributes(cls):
        return ["positions", "features"] + ... custom per-point attributes (subset of tensor attributes)

    def check_tensor_attribute_shape(self, attr):
        # check if geattr(self, attr) has expected shape and return True if it does
        pass

    # Optional ----------------------------------
    def _custom_attr_cat(cls, models, attr, skip_errors=False, **kwargs):
        # Define custom concatenation behavior for any attr; return None for default

    def as_transformed(self, transform=None):
        # Define custom transform behavior
__getitem__(mask)

Return a new instance with mask applied to all per-point attributes.

Parameters

mask (torch.Tensor) – Boolean tensor of shape (N,).

Returns

New instance of the same class containing only selected points.

Return type

PointSamples

__init__(positions, features=None, transform=None, strict_checks: bool = True)

Initializes the class and optionally validates the inputs.

Parameters
  • positions (torch.Tensor) – Point positions, shape (N, 3).

  • features (torch.Tensor or dict, optional) – Per-point features as a tensor of any dimensionality with first dimension equal to N, i.e. (N, ...), or a dict of feature tensors of any varying dimensionalities, as long as the first dimension is N, e.g. {"a": (N, F_0), "b": (N, F_1, F_2)}.

  • transform (torch.Tensor, optional) – Global affine transform of shapes (1, 4, 4) or (4, 4), or per-point transforms (N, 4, 4).

  • strict_checks (bool) – If True, validate tensor shapes on construction and raise error.

Raises

ValueError – if strict_checks is True and the inputs are invalid.

__len__()

Returns the number of points.

__setitem__(mask, value)

Assign per-point attributes from value into this instance at indices selected by mask.

Acts as the in-place inverse of __getitem__(): per-point attributes (see class_point_attributes()) of self are updated in-place using the matching attributes of value. Non-point attributes (e.g. transform, sh_degree) are not modified; the caller is responsible for ensuring they are consistent between self and value (in particular both should be in the same coordinate frame – see as_transformed() to canonicalize beforehand).

Parameters
  • mask (torch.Tensor) – Boolean tensor of shape (N,).

  • value – Instance of the same class as self with len(value) == mask.sum().

Raises
  • TypeError – If mask is not a boolean tensor or value is not an instance of the same class as self.

  • ValueError – If mask shape, value length, or per-point attribute structures are inconsistent.

as_dict(only_tensors=False) Dict[str, Any]

Return all non-None attributes as a {name: value} dict.

as_transformed(additional_transform=None)

Uses stored transform (if set) or additional_transform, or chains additional_transform @ self.transform if both are set, and returns a new instance of the class, with the transform applied. Works for any affine transform.

Parameters

additional_transform (optional, torch.Tensor) – if not set, will use transform set on the class; should be affine transform of shape (4,4) or (1,4,4) or (N,4,4) where N is the number of points.

Note

This method does not copy all the attributes, only transformable ones result in new tensors; if full copy is required, first call copy.deepcopy(). Only isotropic scale, rotation and translation can be applied consistently.

Returns

new instance of PointSamples with the transform applied.

Return type

PointSamples

classmethod assert_supported(attr)

Raises an exception if class does not support provided attribute name.

classmethod cat(models, skip_errors=False, **kwargs)

Concatenates a list of instances along the point dimension.

Any stored transform on each model is applied before concatenation; the result always has transform=None.

Parameters
  • models (list) – Non-empty list of instances of this class.

  • skip_errors (bool) – If True, log and skip mismatched attributes instead of raising.

Returns

New instance with all point attributes concatenated.

check_other_attribute(attr, log_error=False)

Checks a non-tensor attribute validity; returns True if ok.

check_sanity(log_error=True)

Validates that all tensor attributes are correct; implement abstract methods.

Parameters

log_error (bool) – If True, logs each failed check via logger.error.

Returns

True if all checks pass, False otherwise.

Return type

bool

check_tensor_attribute(attr, log_error=False)

Checks tensor attribute validity; returns True if valid.

check_tensor_attribute_shape(attr)

Checks that the tensor stored in attr has an expected shape.

Per-attribute shape validation hook used by _check_tensor_attribute(). Override in subclasses to add support for custom tensor attributes.

Parameters

attr (str) – attribute name (must be in class_tensor_attributes()).

Returns

True if the tensor at attr has the expected shape, False otherwise.

Return type

bool

Raises

ValueError – If attr is not a tensor attribute supported by this class.

classmethod class_other_attributes()

Class attribute names that are not PyTorch tensors.

classmethod class_point_attributes()

Subset of class tensor attributes that contain per-point values, so sized as (N,…), where N is num points.

classmethod class_tensor_attributes()

Class attribute names that are PyTorch tensors.

cpu(attributes: Optional[Sequence[str]] = None)

Calls cpu on all or selected tensor attributes; returns a shallow copy.

cuda(device: Optional[Union[int, device, str]] = None, attributes: Optional[Sequence[str]] = None)

Calls cuda on all or selected tensor attributes; returns a shallow copy.

describe_attribute(attr, print_stats=False, detailed=False)

Outputs an informative string about an attribute; the same method used for all attributes in to_string.

Parameters
  • attr (str) – attribute name

  • print_stats (bool) – if to print statistics about values in each tensor

  • detailed (bool) – if to include additional information about each tensor

Returns

multi-line string with attribute information

Return type

(str)

Raises

ValueError if attr is not supported

detach(attributes: Optional[Sequence[str]] = None)

Detaches all or selected tensor attributes; returns a shallow copy.

get_attributes(only_tensors=False)

Returns names of all attributes that are currently set to non-None value in this class instance.

Parameters

only_tensors – if true, will only include tensor attributes

Returns

list of string names

Return type

(list)

to(*args: Any, attributes: Optional[Sequence[str]] = None, **kwargs: Any)

Moves or casts tensors like torch.Tensor.to() / torch.nn.Module.to().

Parameters
  • *args – forwarded to tensor.to(*args)

  • attributes (list of str, optional) – if set, only these tensor attributes are updated

  • **kwargs – forwarded to tensor.to(**kwargs)

Returns

shallow copy with converted tensors

Return type

PointSamples

to_string(print_stats=False, detailed=False)

Returns information about tensor attributes currently contained in the object.

Parameters
  • print_stats (bool) – if to print statistics about values in each tensor

  • detailed (bool) – if to include additional information about each tensor

Returns

multi-line string with attribute information

Return type

(str)