Checking Function Units#

Colab Open in Kaggle

In scientific computing, it is crucial to ensure that function parameters and return values have the correct units. To streamline this process, we can use the brainunit.check_units decorator to validate the units of function parameters and return values.

First, we need to import the necessary libraries and modules.

import brainunit
from brainunit import volt, mV, meter, second, check_dims, check_units, assign_units, DimensionMismatchError, UnitMismatchError

Checking Units#

check_dims Decorator#

The check_dims decorator is used to validate the dimensions of input arguments or return values of a function. It ensures that the dimensions match the expected dimensions, helping to avoid errors caused by unit mismatches.

We will demonstrate the usage of check_dims through several examples.

Basic Usage#

We can use the check_dims decorator to validate whether the input arguments of a function have the expected units.

@check_dims(v=volt.dim)
def a_function(v, x):
    """
    v must have units of volt, and x can have any (or no) unit.
    """
    pass
Correct Dimensions#

The following calls are correct because the v argument has units of volt or are strings or None:

a_function(3 * mV, 5 * second)
a_function(5 * volt, "something")
a_function([1, 2, 3] * volt, None)
a_function([1 * volt, 2 * volt, 3 * volt], None)
a_function("a string", None)
a_function(None, None)
Incorrect Units#

The following calls will raise a DimensionMismatchError because the v argument does not have the expected units.

try:
    a_function(5 * second, None)
except DimensionMismatchError as e:
    print(e)
    
try:
    a_function(5, None)
except DimensionMismatchError as e:
    print(e)
    
try:
    a_function(object(), None)
except DimensionMismatchError as e:
    print(e)
Function 'a_function' expected a array with dimension metre ** 2 * kilogram * second ** -3 * amp ** -1 for argument 'v' but got '5 s' (unit is s).
Function 'a_function' expected a array with dimension metre ** 2 * kilogram * second ** -3 * amp ** -1 for argument 'v' but got '5' (unit is 1).
Function 'a_function' expected a array with dimension metre ** 2 * kilogram * second ** -3 * amp ** -1 for argument 'v' but got '<object object at 0x7fc824247440>' (unit is 1).

Validating Return Values#

The check_dims decorator can also be used to validate whether the return value of a function has the expected dimensions.

@check_dims(result=second.dim)
def b_function(return_second):
    """
    If return_second is True, return a value in seconds; otherwise, return a value in volts.
    """
    if return_second:
        return 5 * second
    else:
        return 3 * volt
Correct Return Value#

The following call is correct because the return value has dimensions of seconds.

b_function(True)
Quantity(5, "s")
Incorrect Return Value#

The following call will raise a DimensionMismatchError because the return value has dimensions of volts instead of seconds.

try:
    b_function(False)
except DimensionMismatchError as e:
    print(e)
The return value of function 'b_function' was expected to have dimension s but was '3 V' (unit is m^2 kg s^-3 A^-1).

Validating Multiple Return Values#

The check_dims decorator can also validate multiple return values to ensure they have the expected dimensions.

@check_dims(result=(second.dim, volt.dim))
def d_function(true_result):
    """
    If true_result is True, return values in seconds and volts; otherwise, return values in volts and seconds.
    """
    if true_result:
        return 5 * second, 3 * volt
    else:
        return 3 * volt, 5 * second
Correct Return Values#

The following call is correct because the return values have dimensions of seconds and volts, respectively.

d_function(True)
(Quantity(5, "s"), Quantity(3, "V"))
Incorrect Return Values#

The following call will raise a DimensionMismatchError because the return values are in volts and seconds, which do not match the expected order.

try:
    d_function(False)
except DimensionMismatchError as e:
    print(e)
The return value of function 'd_function' was expected to have dimension s but was '3 V' (unit is m^2 kg s^-3 A^-1).

Validating Dictionary Return Values#

The check_dims decorator can also validate dictionary return values to ensure they have the expected dimensions.

@check_dims(result={'u': second.dim, 'v': (volt.dim, meter.dim)})
def d_function2(true_result):
    """
    Return different dictionary results based on the value of true_result.
    """
    if true_result == 0:
        return {'u': 5 * second, 'v': (3 * volt, 2 * meter)}
    elif true_result == 1:
        return 3 * volt, 5 * second
    else:
        return {'u': 5 * second, 'v': (3 * volt, 2 * volt)}
Correct Return Values#

The following call is correct because the return values match the expected dimensions.

d_function2(0)
{'u': Quantity(5, "s"), 'v': (Quantity(3, "V"), Quantity(2, "m"))}
Incorrect Return Values#

The following calls will raise a TypeError or DimensionMismatchError because the return values do not match the expected dimensions.

try:
    d_function2(1)
except TypeError as e:
    print(e)
try:
    d_function2(2)
except DimensionMismatchError as e:
    print(e)
Expected a return value of type {'u': second, 'v': (metre ** 2 * kilogram * second ** -3 * amp ** -1, metre)} but got (Quantity(3, "V"), Quantity(5, "s"))
The return value of function 'd_function2' was expected to have dimension m but was '2 V' (unit is m^2 kg s^-3 A^-1).

check_units Decorator#

The check_units decorator is used to validate the dimensions of input arguments or return values of a function. It ensures that the dimensions match the expected dimensions, helping to avoid errors caused by unit mismatches.

We will demonstrate the usage of check_units through several examples.

Basic Usage#

We can use the check_units decorator to validate whether the input arguments of a function have the expected units.

@check_units(v=volt)
def a_function(v, x):
    """
    v must have units of volt, and x can have any (or no) unit.
    """
    pass
Correct Dimensions#

The following calls are correct because the v argument has units of volt or are strings or None:

a_function(3 * volt, 5 * second)
a_function(5 * volt, "something")
a_function([1, 2, 3] * volt, None)
# lists that can be converted should also work
a_function([1 * volt, 2 * volt, 3 * volt], None)
# Strings and None are also allowed to pass
a_function("a string", None)
a_function(None, None)
Incorrect Units#

The following calls will raise a DimensionMismatchError because the v argument does not have the expected units.

try:
    a_function(5 * second, None)
except UnitMismatchError as e:
    print(e)
    
try:
    a_function(5, None)
except UnitMismatchError as e:
    print(e)
    
try:
    a_function(object(), None)
except UnitMismatchError as e:
    print(e)
Function 'a_function' expected a array with unit Unit("V") for argument 'v' but got '5 s' (unit is s).
Function 'a_function' expected a array with unit Unit("V") for argument 'v' but got '5' (unit is 1).
Function 'a_function' expected a array with unit Unit("V") for argument 'v' but got '<object object at 0x7fc824247390>' (unit is 1).

Validating Return Values#

The check_units decorator can also be used to validate whether the return value of a function has the expected units.

@check_units(result=second)
def b_function(return_second):
    """
    Return a value in seconds if return_second is True, otherwise return
    a value in volt.
    """
    if return_second:
        return 5 * second
    else:
        return 3 * volt
Correct Return Value#

The following call is correct because the return value has units of seconds.

b_function(True)
Quantity(5, "s")
Incorrect Return Value#

The following call will raise a UnitMismatchError because the return value has units of volts instead of seconds.

try:
    b_function(False)
except UnitMismatchError as e:
    print(e)
The return value of function 'b_function' was expected to have unit s but got unit V (value: Quantity(3, "V")) (units are s and V).

Validating Multiple Return Values#

The check_units decorator can also validate multiple return values to ensure they have the expected units.

@check_units(result=(second, volt))
def d_function(true_result):
    """
    Return a value in seconds if return_second is True, otherwise return
    a value in volt.
    """
    if true_result:
        return 5 * second, 3 * volt
    else:
        return 3 * volt, 5 * second
Correct Return Values#

The following call is correct because the return values have units of seconds and volts, respectively.

d_function(True)
(Quantity(5, "s"), Quantity(3, "V"))
Incorrect Return Values#

The following call will raise a UnitMismatchError because the return values are in volts and seconds, which do not match the expected order.

try:
    d_function(False)
except UnitMismatchError as e:
    print(e)
The return value of function 'd_function' was expected to have unit s but got unit V (value: Quantity(3, "V")) (units are s and V).

Validating Dictionary Return Values#

The check_units decorator can also validate dictionary return values to ensure they have the expected units.

@check_units(result={'u': second, 'v': (volt, meter)})
def d_function2(true_result):
    """
    Return a value in seconds if return_second is True, otherwise return
    a value in volt.
    """
    if true_result == 0:
        return {'u': 5 * second, 'v': (3 * volt, 2 * meter)}
    elif true_result == 1:
        return 3 * volt, 5 * second
    else:
        return {'u': 5 * second, 'v': (3 * volt, 2 * volt)}
Correct Return Values#

The following call is correct because the return values match the expected units.

d_function2(0)
{'u': Quantity(5, "s"), 'v': (Quantity(3, "V"), Quantity(2, "m"))}
Incorrect Return Values#

The following calls will raise a TypeError or UnitMismatchError because the return values do not match the expected units.

try:
    d_function2(1)
except TypeError as e:
    print(e)
try:
    d_function2(2)
except UnitMismatchError as e:
    print(e)
Expected a return value of type {'u': Unit("s"), 'v': (Unit("V"), Unit("m"))} but got (Quantity(3, "V"), Quantity(5, "s"))
The return value of function 'd_function2' was expected to have unit m but got unit V (value: Quantity(2, "V")) (units are m and V).

Assigning Units#

assign_units Decorator#

The assign_units decorator is used to automatically assign units to the input arguments or return values of a function. It ensures that the values are converted to the specified units, simplifying unit handling in scientific computations.

Basic Usage#

We can use the assign_units decorator to automatically assign units to the input arguments of a function.

@assign_units(v=volt)
def a_function(v, x):
    """
    v will be assigned units of volt, and x can have any (or no) unit.
    """
    return v
Correct Units#

The following calls are correct because the v argument is automatically converted to volts.

assert a_function(3 * mV, 5 * second) == (3 * mV).to_decimal(volt)
assert a_function(3 * volt, 5 * second) == (3 * volt).to_decimal(volt)
assert a_function(5 * volt, "something") == (5 * volt).to_decimal(volt)
Incorrect Units#

The following calls will raise a UnitMismatchError or TypeError because the v argument cannot be converted to volts.

try:
    a_function(5 * second, None)
except UnitMismatchError as e:
    print(e)

try:
    a_function(5, None)
except TypeError as e:
    print(e)

try:
    a_function(object(), None)
except TypeError as e:
    print(e)
Cannot convert to the decimal number using a unit with different dimensions. (units are s and V).
Function 'a_function' expected a Quantity object for argument 'v' but got '5'
Function 'a_function' expected a Quantity object for argument 'v' but got '<object object at 0x7fc824247520>'

Assigning Units to Return Values#

The assign_units decorator can also be used to automatically assign units to the return value of a function.

@assign_units(result=second)
def b_function():
    """
    The return value will be assigned units of seconds.
    """
    return 5
Correct Return Value#

The following call is correct because the return value is automatically converted to seconds.

assert b_function() == 5 * second

Assigning Units to Multiple Return Values#

The assign_units decorator can also assign units to multiple return values.

@assign_units(result=(second, volt))
def d_function():
    """
    The return values will be assigned units of seconds and volts, respectively.
    """
    return 5, 3
Correct Return Values#

The following call is correct because the return values are automatically converted to seconds and volts.

assert d_function()[0] == 5 * second
assert d_function()[1] == 3 * volt

Assigning Units to Dictionary Return Values#

The assign_units decorator can also assign units to dictionary return values.

@assign_units(result={'u': second, 'v': (volt, meter)})
def d_function2(true_result):
    """
    The return values will be assigned units based on the dictionary specification.
    """
    if true_result == 0:
        return {'u': 5, 'v': (3, 2)}
    elif true_result == 1:
        return 3, 5
    else:
        return 3, 5
Correct Return Values#

The following call is correct because the return values are automatically converted to the specified units.

d_function2(0)
{'u': Quantity(5, "s"), 'v': (Quantity(3, "V"), Quantity(2, "m"))}
Incorrect Return Values#

The following call will raise a TypeError because the return values do not match the expected structure.

try:
    d_function2(1)
except TypeError as e:
    print(e)
Expected a return value of pytree PyTreeDef({'u': *, 'v': (*, *)}) with type {'u': Unit("s"), 'v': (Unit("V"), Unit("m"))}, but got the pytree PyTreeDef((*, *)) and the value (3, 5)

Through the examples above, we can see the utility of the assign_units decorator in automatically assigning units to input arguments and return values. It simplifies unit handling in scientific computations, ensuring consistency and reducing the likelihood of errors.