laplax.util.tree
Relevant tree operations.
get_size ¶
Compute the total number of elements in a PyTree.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
tree
|
PyTree
|
A PyTree whose total size is to be calculated. |
required |
Returns:
Type | Description |
---|---|
int
|
The total number of elements across all leaves in the PyTree. |
Source code in laplax/util/tree.py
add ¶
Add corresponding elements of two PyTrees.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
tree1
|
PyTree
|
The first PyTree. |
required |
tree2
|
PyTree
|
The second PyTree. |
required |
Returns:
Type | Description |
---|---|
PyTree
|
A PyTree where each leaf is the element-wise sum of the leaves in
|
Source code in laplax/util/tree.py
neg ¶
Negate all elements of a PyTree.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
tree
|
PyTree
|
A PyTree to negate. |
required |
Returns:
Type | Description |
---|---|
PyTree
|
A PyTree with negated elements. |
sub ¶
Subtract corresponding elements of two PyTrees.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
tree1
|
PyTree
|
The first PyTree. |
required |
tree2
|
PyTree
|
The second PyTree. |
required |
Returns:
Type | Description |
---|---|
PyTree
|
A PyTree where each leaf is the element-wise difference of the leaves in
|
Source code in laplax/util/tree.py
mul ¶
Multiply all elements of a PyTree by a scalar.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
scalar
|
Float
|
The scalar value to multiply by. |
required |
tree
|
PyTree
|
A PyTree to multiply. |
required |
Returns:
Type | Description |
---|---|
PyTree
|
A PyTree where each leaf is the element-wise product of the leaves in
|
Source code in laplax/util/tree.py
sqrt ¶
Compute the square root of each element in a PyTree.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
tree
|
PyTree
|
A PyTree whose elements are to be square-rooted. |
required |
Returns:
Type | Description |
---|---|
PyTree
|
A PyTree with square-rooted elements. |
Source code in laplax/util/tree.py
invert ¶
Invert all elements of a PyTree.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
tree
|
PyTree
|
A PyTree to invert. |
required |
Returns:
Type | Description |
---|---|
PyTree
|
A PyTree with inverted elements. |
mean ¶
Compute the mean of each element in a PyTree.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
tree
|
PyTree
|
A PyTree whose elements are to be averaged. |
required |
**kwargs
|
Kwargs
|
Additional keyword arguments for |
{}
|
Returns:
Type | Description |
---|---|
PyTree
|
A PyTree with averaged elements. |
Source code in laplax/util/tree.py
std ¶
Compute the standard deviation of each element in a PyTree.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
tree
|
PyTree
|
A PyTree whose elements are to be standard-deviated. |
required |
**kwargs
|
Kwargs
|
Additional keyword arguments for |
{}
|
Returns:
Type | Description |
---|---|
PyTree
|
A PyTree with standard-deviated elements. |
Source code in laplax/util/tree.py
var ¶
Compute the variance of each element in a PyTree.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
tree
|
PyTree
|
A PyTree whose elements are to be variance-ed. |
required |
**kwargs
|
Kwargs
|
Additional keyword arguments for |
{}
|
Returns:
Type | Description |
---|---|
PyTree
|
A PyTree with variance-ed elements. |
Source code in laplax/util/tree.py
cov ¶
Compute the covariance of each element in a PyTree.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
tree
|
PyTree
|
A PyTree whose elements are to be covariance-ed. |
required |
**kwargs
|
Kwargs
|
Additional keyword arguments for |
{}
|
Returns:
Type | Description |
---|---|
PyTree
|
A PyTree with covariance-ed elements. |
Source code in laplax/util/tree.py
tree_matvec ¶
Multiply a PyTree by a vector.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
tree
|
PyTree
|
A PyTree to multiply. |
required |
vector
|
Array
|
A vector to multiply by. |
required |
Returns:
Type | Description |
---|---|
PyTree
|
A PyTree with multiplied elements. |
Source code in laplax/util/tree.py
tree_partialmatvec ¶
Multiply a PyTree by a vector.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
tree
|
PyTree
|
A PyTree to multiply. |
required |
vector
|
Array
|
A vector to multiply by. |
required |
Returns:
Type | Description |
---|---|
PyTree
|
A PyTree with multiplied elements. |
Source code in laplax/util/tree.py
ones_like ¶
Create a PyTree of ones with the same structure as the input tree.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
tree
|
PyTree
|
A PyTree whose structure and shape will be used. |
required |
Returns:
Type | Description |
---|---|
PyTree
|
A PyTree of ones with the same structure and shape as |
Source code in laplax/util/tree.py
zeros_like ¶
Create a PyTree of zeros with the same structure as the input tree.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
tree
|
PyTree
|
A PyTree whose structure and shape will be used. |
required |
Returns:
Type | Description |
---|---|
PyTree
|
A PyTree of zeros with the same structure and shape as |
Source code in laplax/util/tree.py
randn_like ¶
Generate a PyTree of random normal values with the same structure as the input.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
key
|
KeyType
|
A JAX PRNG key. |
required |
tree
|
PyTree
|
A PyTree whose structure will be replicated. |
required |
Returns:
Type | Description |
---|---|
PyTree
|
A PyTree of random normal values. |
Source code in laplax/util/tree.py
normal_like ¶
Generate a PyTree of random normal values scaled and shifted by mean
.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
key
|
KeyType
|
A JAX PRNG key. |
required |
mean
|
PyTree
|
A PyTree representing the mean of the distribution. |
required |
scale_mv
|
Callable[[PyTree], PyTree]
|
A callable that scales a PyTree. |
required |
Returns:
Type | Description |
---|---|
PyTree
|
A PyTree of random normal values shifted by |
Source code in laplax/util/tree.py
basis_vector_from_index ¶
Create a basis vector from an index in a PyTree.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
idx
|
int
|
The index of the basis vector. |
required |
tree
|
PyTree
|
A PyTree whose structure will be used. |
required |
Returns:
Type | Description |
---|---|
PyTree
|
A PyTree with a basis vector at the specified index. |
Source code in laplax/util/tree.py
eye_like_with_basis_vector ¶
Create a PyTree where each element is a basis vector.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
tree
|
PyTree
|
A PyTree defining the structure. |
required |
Returns:
Type | Description |
---|---|
PyTree
|
A PyTree of basis vectors. |
Source code in laplax/util/tree.py
eye_like ¶
Create a PyTree equivalent of an identity matrix.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
tree
|
PyTree
|
A PyTree defining the structure. |
required |
Returns:
Type | Description |
---|---|
PyTree
|
A PyTree equivalent to an identity matrix. |
Source code in laplax/util/tree.py
tree_slice ¶
Slice each leaf of a PyTree along the first dimension.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
tree
|
PyTree
|
A PyTree to slice. |
required |
a
|
int
|
The start index. |
required |
b
|
int
|
The end index. |
required |
Returns:
Type | Description |
---|---|
PyTree
|
A PyTree with sliced leaves. |
Source code in laplax/util/tree.py
tree_vec_get ¶
Retrieve the element at the specified index from a flattened PyTree.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
tree
|
PyTree
|
A PyTree to retrieve the element from. |
required |
idx
|
int
|
The index of the element. |
required |
Returns:
Type | Description |
---|---|
Any
|
The element at the specified index. |
Source code in laplax/util/tree.py
allclose ¶
Check whether all elements in two PyTrees are approximately equal.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
tree1
|
PyTree
|
The first PyTree. |
required |
tree2
|
PyTree
|
The second PyTree. |
required |
Returns:
Type | Description |
---|---|
bool
|
True if all elements are approximately equal, otherwise False. |