laplax.util.flatten
Operations for flattening PyTrees into arrays.
cumsum ¶
Compute the cumulative sum of a sequence.
This function takes a sequence of integers and returns a list of cumulative sums.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
seq
|
Generator
|
A generator or sequence of integers. |
required |
Returns:
Type | Description |
---|---|
list[int]
|
A list where each element is the cumulative sum up to that point in the input sequence. |
Source code in laplax/util/flatten.py
full_flatten ¶
Flatten a PyTree into a single 1D array.
This function takes a PyTree and concatenates all its leaves into a single array.
Returns:
Type | Description |
---|---|
Array
|
The flattened PyTree. |
Source code in laplax/util/flatten.py
flatten_function ¶
Wrap a function to flatten its input and output.
This function takes a function and a layout, and returns a new function that accepts flattened input and returns also a flattened output.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
fn
|
Callable
|
The function to wrap. |
required |
layout
|
PyTree
|
The layout of the PyTree. |
required |
Returns:
Type | Description |
---|---|
Callable
|
The wrapped function. |
Source code in laplax/util/flatten.py
create_pytree_flattener ¶
create_pytree_flattener(tree: PyTree) -> tuple[Callable[[PyTree], Array], Callable[[Array], PyTree]]
Create functions to flatten and unflatten a PyTree into and from a 1D array.
The flatten
function concatenates all leaves of the PyTree into a single
vector. The unflatten
function reconstructs the original PyTree from the
flattened vector.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
tree
|
PyTree
|
A PyTree to derive the structure for flattening and unflattening. |
required |
Returns:
Type | Description |
---|---|
tuple[Callable[[PyTree], Array], Callable[[Array], PyTree]]
|
Tuple containing:
|
Source code in laplax/util/flatten.py
create_partial_pytree_flattener ¶
create_partial_pytree_flattener(tree: PyTree) -> tuple[Callable[[PyTree], Array], Callable[[Array], PyTree]]
Create functions to flatten and unflatten partial PyTrees into and from arrays.
This function assumes that each leaf in the PyTree is a multi-dimensional
array, where the last dimension represents column indices. The flatten
function combines all rows across leaves into a single 2D array. The
unflatten
function reconstructs the PyTree from this 2D array.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
tree
|
PyTree
|
A PyTree to derive the structure for flattening and unflattening. |
required |
Returns:
Type | Description |
---|---|
tuple[Callable[[PyTree], Array], Callable[[Array], PyTree]]
|
Tuple containing:
|
Source code in laplax/util/flatten.py
unravel_array_into_pytree ¶
Unravel an array into a PyTree with a specified structure.
This function splits and reshapes an array to match the structure of a given
PyTree, with options to control the resulting shapes using the axis
parameter.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pytree
|
PyTree
|
The PyTree defining the desired structure. |
required |
axis
|
int
|
The axis along which to split the array. |
required |
arr
|
Array
|
The array to be unraveled into the PyTree structure. |
required |
Returns:
Type | Description |
---|---|
PyTree
|
A PyTree with the specified structure, populated with parts of the input array. |
This function follows the implementation in jax._src.api._unravel_array_into_pytree.
Source code in laplax/util/flatten.py
wrap_function ¶
wrap_function(fn: Callable, input_fn: Callable | None = None, output_fn: Callable | None = None, argnums: int = 0) -> Callable
Wrap a function with input and output transformations.
This utility wraps a function fn
, applying an optional transformation to its
inputs before execution and another transformation to its outputs after
execution.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
fn
|
Callable
|
The function to be wrapped. |
required |
input_fn
|
Callable | None
|
A callable to transform the input arguments (default: identity). |
None
|
output_fn
|
Callable | None
|
A callable to transform the output of the function (default: identity). |
None
|
argnums
|
int
|
The index of the argument to be transformed by |
0
|
Returns:
Type | Description |
---|---|
Callable
|
The wrapped function with input and output transformations applied. |
Source code in laplax/util/flatten.py
wrap_factory ¶
wrap_factory(factory: Callable, input_fn: Callable | None = None, output_fn: Callable | None = None) -> Callable
Wrap a factory function to apply input and output transformations.
This function wraps a factory, ensuring that any callable it produces is
transformed with wrap_function
to apply input and output transformations.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
factory
|
Callable
|
The factory function that returns a callable. |
required |
input_fn
|
Callable | None
|
A callable to transform the input arguments (default: identity). |
None
|
output_fn
|
Callable | None
|
A callable to transform the output of the function (default: identity). |
None
|
Returns:
Type | Description |
---|---|
Callable
|
The wrapped factory that produces transformed callables. |