gradient#
- class brainunit.math.gradient(f, *varargs, axis=None, edge_order=None)#
Computes the gradient of a scalar field.
Return the gradient of an N-dimensional array.
The gradient is computed using second order accurate central differences in the interior points and either first or second order accurate one-sides (forward or backwards) differences at the boundaries. The returned gradient hence has the same shape as the input array.
- Parameters:
f (
Array|ndarray|bool|number|bool|int|float|complex| saiunit.Quantity) – An N-dimensional array containing samples of a scalar function.varargs (
Array|ndarray|bool|number|bool|int|float|complex| saiunit.Quantity) –Spacing between f values. Default unitary spacing for all dimensions. Spacing can be specified using:
single scalar to specify a sample distance for all dimensions.
N scalars to specify a constant sample distance for each dimension. i.e. dx, dy, dz, …
N arrays to specify the coordinates of the values along each dimension of F. The length of the array must match the size of the corresponding dimension
Any combination of N scalars/arrays with the meaning of 2. and 3.
If axis is given, the number of varargs must equal the number of axes. Default: 1.
edge_order (
int|None) – Gradient is calculated using N-th order accurate differences at the boundaries. Default: 1.axis (
int|Sequence[int] |None) – Gradient is calculated only along the given axis or axes The default (axis = None) is to calculate the gradient for all the axes of the input array. axis may be negative, in which case it counts from the last to the first axis.
- Returns:
gradient – A list of ndarrays (or a single ndarray if there is only one dimension) corresponding to the derivatives of f with respect to each dimension. Each derivative has the same shape as f.
- Return type:
Array|list[Array] | saiunit.Quantity |list[saiunit.Quantity]
Examples
>>> import jax.numpy as jnp >>> import saiunit.math as sumath >>> f = jnp.array([1., 2., 4., 7., 11.]) >>> sumath.gradient(f) Array([1. , 1.5, 2.5, 3.5, 4. ], dtype=float32)