kaolin.rep.GaussianSplatModel

API

class kaolin.rep.GaussianSplatModel(positions, orientations, scales, opacities, sh_coeff, features=None, transform=None, sh_degree=None, strict_checks: bool = True)

Bases: PointSamples

Container for a 3D Gaussian Splat cloud of N splats. Extends PointSamples with Gaussian-specific attributes, inheriting generalized tensor and point-level utilities, summarized below.

Supported Attributes:

GaussianSplatModel supports the following attributes.

Attribute

Shape

Description

positions

(N, 3)

Splat centres

orientations

(N, 4)

Unit quaternions \((w,x,y,z)\)

scales

(N, 3)

Per-axis scale, post activation

opacities

(N,)

Opacity per splat, post activation

sh_coeff

(N, S, 3)

SH coefficients; \(S = (sh\_degree + 1)^2\)

features

(N, ...) or dict

Per-point features (optional)

transform

(4, 4) or (N, 4, 4) or (1, 4, 4) or None

Affine transform (optional), stored, not applied

(see as_transformed())

sh_degree

int

SH degree \(L\); inferred from sh_coeff if omitted

General utility methods

These are inherited, but still work on all attributes, including tensor of dict of tensors for features.

Gaussian-specific utility methods

Gaussian-specific utility methods that work on GaussianSplatModel instances.

  • as_transformed() - apply affine transform to all gaussian attributes, or per-Gaussian transforms

    • Note: ⚠️ Works for only isotropic scaling, rotation, translation, and unexpected results may occur for other transform combinations.

  • [mask](get) - return new instance with properties masked by a boolean mask

  • [mask](set) - sets masked set of properties from same length GaussianSplatModel instance

    • E.g. scene[object_mask] = scene[object_mask].as_transformed(object_transform)

Gaussian-specific utility methods (class methods).

Detailed API Docs

__getitem__(mask)

Return a new instance with mask applied to all per-point attributes.

Parameters

mask (torch.Tensor) – Boolean tensor of shape (N,).

Returns

New instance of the same class containing only selected points.

Return type

PointSamples

__init__(positions, orientations, scales, opacities, sh_coeff, features=None, transform=None, sh_degree=None, strict_checks: bool = True)

Initializes the class and optionally validates the inputs.

👉 Note: all attributes are stored post-activation in their final range. Override class to customize.

Parameters
  • positions (torch.Tensor) – Splat centres, shape (N, 3).

  • orientations (torch.Tensor) – Unit quaternions (N, 4), wxyz convention (will be normalized internally).

  • scales (torch.Tensor) – Per-axis scale (N, 3).

  • opacities (torch.Tensor) – Opacity per splat, shape (N,).

  • sh_coeff (torch.Tensor) – SH coefficients (N, S, 3) where S = (sh_degree + 1) ** 2.

  • features (torch.Tensor or dict, optional) – Arbitrary per-point features.

  • transform (torch.Tensor, optional) – Global affine transform of shapes (1, 4, 4) or (4, 4), or per-point transforms (N, 4, 4).

  • sh_degree (int, optional) – SH degree. If None, inferred from sh_coeff.shape[1].

  • strict_checks (bool) – If True, validates tensor shapes on construction and raises error if invalid.

Raises

ValueError – if strict_checks is True and the inputs are invalid.

__len__()

Returns the number of points.

__setitem__(mask, value)

Assign per-point attributes from value into this instance at indices selected by mask.

Acts as the in-place inverse of __getitem__(): per-point attributes (see class_point_attributes()) of self are updated in-place using the matching attributes of value. Non-point attributes (e.g. transform, sh_degree) are not modified; the caller is responsible for ensuring they are consistent between self and value (in particular both should be in the same coordinate frame – see as_transformed() to canonicalize beforehand).

Parameters
  • mask (torch.Tensor) – Boolean tensor of shape (N,).

  • value – Instance of the same class as self with len(value) == mask.sum().

Raises
  • TypeError – If mask is not a boolean tensor or value is not an instance of the same class as self.

  • ValueError – If mask shape, value length, or per-point attribute structures are inconsistent.

as_dict(only_tensors=False) Dict[str, Any]

Return all non-None attributes as a {name: value} dict.

as_transformed(additional_transform=None)

Uses stored transform (if set) or additional_transform, or chains additional_transform @ self.transform if both are set, and returns a new instance of the class, with the transform applied.

Note: ⚠️ For Gaussians, works robustly for isotropic scaling, rotation, translation (combined in a general transform matrix), but unexpected results may occur for shear and anisotropic scaling (for example, applying inverse transform will not work correctly for these cases).

Parameters

additional_transform (optional, torch.Tensor) – if not set, will use transform set on the class; should be affine transform of shape (4,4) or (1,4,4) or (N,4,4) where N is the number of Gaussians.

Note

This method does not copy all the attributes, only transformable ones result in new tensors; if full copy is required, first call copy.deepcopy(). Only isotropic scale, rotation and translation can be applied consistently.

Returns

new instance of GaussianSplatModel with the transform applied.

Return type

GaussianSplatModel

classmethod assert_supported(attr)

Raises an exception if class does not support provided attribute name.

classmethod cat(models, skip_errors=False, **kwargs)

Concatenates a list of instances along the point dimension.

Any stored transform on each model is applied before concatenation; the result always has transform=None.

Parameters
  • models (list) – Non-empty list of instances of this class.

  • skip_errors (bool) – If True, log and skip mismatched attributes instead of raising.

Returns

New instance with all point attributes concatenated.

check_other_attribute(attr, log_error=False)

Performs custom checks for gaussian-specific non-tensor attributes.

Parameters
  • attr (str) – Attribute name.

  • log_error (bool) – If True, logs error messages.

Returns

True if attribute is valid, False otherwise.

Return type

bool

check_sanity(log_error=True)

Validates that all tensor attributes are correct; implement abstract methods.

Parameters

log_error (bool) – If True, logs each failed check via logger.error.

Returns

True if all checks pass, False otherwise.

Return type

bool

check_tensor_attribute(attr, log_error=False)

Checks tensor attribute validity; returns True if valid.

check_tensor_attribute_shape(attr)

Performs custom shape checks for gaussian-specific tensor attributes.

Parameters

attr (str) – Attribute name.

Returns

True if shape is valid, False otherwise.

Return type

bool

classmethod class_other_attributes()

Class attribute names that are not PyTorch tensors.

classmethod class_point_attributes()

Subset of class tensor attributes that contain per-point values, so sized as (N,…), where N is num points.

classmethod class_tensor_attributes()

Class attribute names that are PyTorch tensors.

classmethod compute_num_sh_coeff(sh_degree)

Computes expected number of total sh_coeff features (i.e. 2nd dim) based on sh_degree.

Returns

Number of SH coefficients.

Return type

int

classmethod compute_sh_degree(num_sh_coeff)

Computes SH degree based on total number of SH coefficients, i.e. second dim of sh_coeff.

Returns

SH degree.

Return type

int

Raises

ValueError – If num_sh_coeff is not a perfect square.

cpu(attributes: Optional[Sequence[str]] = None)

Calls cpu on all or selected tensor attributes; returns a shallow copy.

cuda(device: Optional[Union[int, device, str]] = None, attributes: Optional[Sequence[str]] = None)

Calls cuda on all or selected tensor attributes; returns a shallow copy.

describe_attribute(attr, print_stats=False, detailed=False)

Outputs an informative string about an attribute; the same method used for all attributes in to_string.

Parameters
  • attr (str) – attribute name

  • print_stats (bool) – if to print statistics about values in each tensor

  • detailed (bool) – if to include additional information about each tensor

Returns

multi-line string with attribute information

Return type

(str)

Raises

ValueError if attr is not supported

detach(attributes: Optional[Sequence[str]] = None)

Detaches all or selected tensor attributes; returns a shallow copy.

get_attributes(only_tensors=False)

Returns names of all attributes that are currently set to non-None value in this class instance.

Parameters

only_tensors – if true, will only include tensor attributes

Returns

list of string names

Return type

(list)

to(*args: Any, attributes: Optional[Sequence[str]] = None, **kwargs: Any)

Moves or casts tensors like torch.Tensor.to() / torch.nn.Module.to().

Parameters
  • *args – forwarded to tensor.to(*args)

  • attributes (list of str, optional) – if set, only these tensor attributes are updated

  • **kwargs – forwarded to tensor.to(**kwargs)

Returns

shallow copy with converted tensors

Return type

PointSamples

to_string(print_stats=False, detailed=False)

Returns information about tensor attributes currently contained in the object.

Parameters
  • print_stats (bool) – if to print statistics about values in each tensor

  • detailed (bool) – if to include additional information about each tensor

Returns

multi-line string with attribute information

Return type

(str)