LoRA#
- class braintrace.nn.LoRA(in_features, lora_rank, out_features, *, base_module=None, kernel_init=LecunNormal( scale=1.0, mode='fan_in', in_axis=-2, out_axis=-1, distribution='truncated_normal', rng=RandomState([4928 4982]), unit=Unit("1") ), param_type=<class 'brainstate.ParamState'>, in_size=None)#
A standalone LoRA layer.
LoRA (Low-Rank Adaptation) injects two low-rank factors into a layer so a large pre-trained model can be fine-tuned with far fewer parameters. This subclass preserves the upstream
brainstate.nn.LoRAconstructor and replaces only the forward pass so that the multiplication is routed throughbraintrace.lora_matmul()and therefore participates in eligibility-trace computation.The layer adds a low-rank component \(\frac{1}{r} B A\) to the base weight, where \(B\) and \(A\) are learnable factors of rank \(r\):
\[W_{\mathrm{LoRA}} = W_{\text{orig}} + \frac{1}{r} B A\]The scaling factor is fixed to
1 / lora_rank.- Parameters:
in_features (
int) – Number of input features.lora_rank (
int) – Rank of the LoRA decomposition.out_features (
int) – Number of output features.base_module (
Module|None) – Optional base layer that is called onxand added to the LoRA branch. DefaultNone.kernel_init (
Callable|Array|ndarray|bool|number|bool|int|float|complex|Quantity) – Initializer used for bothlora_a(in×rank) andlora_b(rank×out). Default isLecunNormal(). To get the classic “LoRA-zero” initialisation useinit.ZeroInit().param_type (
type) –ParamStatesubclass used to wrap the weights. Default isbrainstate.ParamState.in_size (
int|Sequence[int] |integer|Sequence[integer]) – Optional explicit input size override. DefaultNone.
- base_module#
The optional base layer added to the LoRA branch.
- Type:
brainstate.nn.Module or None
- weight#
ParamStatewhose value is a dict with two keys:'lora_a'of shape(in_features, lora_rank)and'lora_b'of shape(lora_rank, out_features).- Type:
ParamState
Examples
>>> import brainstate >>> import braintrace >>> >>> # Create a standalone LoRA layer >>> brainstate.environ.set(precision=64) >>> layer = braintrace.nn.LoRA(in_features=3, lora_rank=2, out_features=4) >>> x = brainstate.random.randn(16, 3) >>> y = layer(x) >>> print(y.shape) (16, 4) >>> >>> # Wrap around an existing linear layer >>> linear = brainstate.nn.Linear(3, 4) >>> wrapper = braintrace.nn.LoRA(3, 2, 4, base_module=linear) >>> assert wrapper.base_module is linear >>> y = wrapper(x) >>> print(y.shape) (16, 4)