LoRA#
- class brainstate.nn.LoRA(in_features, lora_rank, out_features, *, base_module=None, kernel_init=LecunNormal( scale=1.0, mode='fan_in', in_axis=-2, out_axis=-1, distribution='truncated_normal', rng=RandomState([ 900 9244]), unit=Unit("1") ), param_type=<class 'brainstate.ParamState'>, in_size=None)#
Low-Rank Adaptation (LoRA) layer.
Implements parameter-efficient fine-tuning using low-rank decomposition [1]. Can be used standalone or as a wrapper around an existing module.
- Parameters:
in_features (
int) – The number of input features.lora_rank (
int) – The rank of the low-rank decomposition. Lower rank means fewer parameters.out_features (
int) – The number of output features.base_module (
Module|None) – A base module to wrap. If provided, the LoRA output will be added to the base module’s output. Default isNone.kernel_init (
Callable|Array|ndarray|bool|number|bool|int|float|complex|Quantity) – Initializer for the LoRA weight matrices. Default isLecunNormal().param_type (
type) – Type of parameter state. Default isParamState.
- weight#
Parameter state containing ‘lora_a’ and ‘lora_b’ matrices.
- Type:
References
Examples
>>> import brainstate as brainstate >>> import jax.numpy as jnp >>> >>> # Standalone LoRA layer >>> layer = brainstate.nn.LoRA(in_features=10, lora_rank=2, out_features=5) >>> x = jnp.ones((32, 10)) >>> y = layer(x) >>> y.shape (32, 5) >>> >>> # Wrap around existing linear layer >>> base = brainstate.nn.Linear((10,), (5,)) >>> lora_layer = brainstate.nn.LoRA(in_features=10, lora_rank=2, ... out_features=5, base_module=base) >>> y = lora_layer(x) >>> y.shape (32, 5) >>> >>> # Check parameter count - LoRA has fewer parameters >>> # Base layer: 10 * 5 = 50 parameters >>> # LoRA: 10 * 2 + 2 * 5 = 30 parameters