LoRA#

class brainstate.nn.LoRA(in_features, lora_rank, out_features, *, base_module=None, kernel_init=LecunNormal(   scale=1.0, mode='fan_in', in_axis=-2, out_axis=-1, distribution='truncated_normal', rng=RandomState([ 900 9244]), unit=Unit("1") ), param_type=<class 'brainstate.ParamState'>, in_size=None)#

Low-Rank Adaptation (LoRA) layer.

Implements parameter-efficient fine-tuning using low-rank decomposition [1]. Can be used standalone or as a wrapper around an existing module.

Parameters:
  • in_features (int) – The number of input features.

  • lora_rank (int) – The rank of the low-rank decomposition. Lower rank means fewer parameters.

  • out_features (int) – The number of output features.

  • base_module (Module | None) – A base module to wrap. If provided, the LoRA output will be added to the base module’s output. Default is None.

  • kernel_init (Callable | Array | ndarray | bool | number | bool | int | float | complex | Quantity) – Initializer for the LoRA weight matrices. Default is LecunNormal().

  • param_type (type) – Type of parameter state. Default is ParamState.

in_size#

Input feature size.

Type:

int

out_size#

Output feature size.

Type:

int

in_features#

Number of input features.

Type:

int

out_features#

Number of output features.

Type:

int

base_module#

The wrapped base module if provided.

Type:

Module or None

weight#

Parameter state containing ‘lora_a’ and ‘lora_b’ matrices.

Type:

ParamState

References

Examples

>>> import brainstate as brainstate
>>> import jax.numpy as jnp
>>>
>>> # Standalone LoRA layer
>>> layer = brainstate.nn.LoRA(in_features=10, lora_rank=2, out_features=5)
>>> x = jnp.ones((32, 10))
>>> y = layer(x)
>>> y.shape
(32, 5)
>>>
>>> # Wrap around existing linear layer
>>> base = brainstate.nn.Linear((10,), (5,))
>>> lora_layer = brainstate.nn.LoRA(in_features=10, lora_rank=2,
...                           out_features=5, base_module=base)
>>> y = lora_layer(x)
>>> y.shape
(32, 5)
>>>
>>> # Check parameter count - LoRA has fewer parameters
>>> # Base layer: 10 * 5 = 50 parameters
>>> # LoRA: 10 * 2 + 2 * 5 = 30 parameters