LoRA#

class braintrace.nn.LoRA(in_features, lora_rank, out_features, *, base_module=None, kernel_init=LecunNormal(   scale=1.0, mode='fan_in', in_axis=-2, out_axis=-1, distribution='truncated_normal', rng=RandomState([4928 4982]), unit=Unit("1") ), param_type=<class 'brainstate.ParamState'>, in_size=None)#

A standalone LoRA layer.

LoRA (Low-Rank Adaptation) injects two low-rank factors into a layer so a large pre-trained model can be fine-tuned with far fewer parameters. This subclass preserves the upstream brainstate.nn.LoRA constructor and replaces only the forward pass so that the multiplication is routed through braintrace.lora_matmul() and therefore participates in eligibility-trace computation.

The layer adds a low-rank component \(\frac{1}{r} B A\) to the base weight, where \(B\) and \(A\) are learnable factors of rank \(r\):

\[W_{\mathrm{LoRA}} = W_{\text{orig}} + \frac{1}{r} B A\]

The scaling factor is fixed to 1 / lora_rank.

Parameters:
  • in_features (int) – Number of input features.

  • lora_rank (int) – Rank of the LoRA decomposition.

  • out_features (int) – Number of output features.

  • base_module (Module | None) – Optional base layer that is called on x and added to the LoRA branch. Default None.

  • kernel_init (Callable | Array | ndarray | bool | number | bool | int | float | complex | Quantity) – Initializer used for both lora_a (in×rank) and lora_b (rank×out). Default is LecunNormal(). To get the classic “LoRA-zero” initialisation use init.ZeroInit().

  • param_type (type) – ParamState subclass used to wrap the weights. Default is brainstate.ParamState.

  • in_size (int | Sequence[int] | integer | Sequence[integer]) – Optional explicit input size override. Default None.

in_features#

Number of input features.

Type:

int

out_features#

Number of output features.

Type:

int

base_module#

The optional base layer added to the LoRA branch.

Type:

brainstate.nn.Module or None

weight#

ParamState whose value is a dict with two keys: 'lora_a' of shape (in_features, lora_rank) and 'lora_b' of shape (lora_rank, out_features).

Type:

ParamState

Examples

>>> import brainstate
>>> import braintrace
>>>
>>> # Create a standalone LoRA layer
>>> brainstate.environ.set(precision=64)
>>> layer = braintrace.nn.LoRA(in_features=3, lora_rank=2, out_features=4)
>>> x = brainstate.random.randn(16, 3)
>>> y = layer(x)
>>> print(y.shape)
(16, 4)
>>>
>>> # Wrap around an existing linear layer
>>> linear = brainstate.nn.Linear(3, 4)
>>> wrapper = braintrace.nn.LoRA(3, 2, 4, base_module=linear)
>>> assert wrapper.base_module is linear
>>> y = wrapper(x)
>>> print(y.shape)
(16, 4)