GRUCell#
- class braintrace.nn.GRUCell(in_size, out_size, w_init=Orthogonal(scale=1.0), b_init=ZeroInit(unit=1), state_init=ZeroInit(unit=1), activation='tanh', name=None)#
Gated Recurrent Unit (GRU) cell.
Gated Recurrent Unit (GRU) cell, implemented as in Learning Phrase Representations using RNN Encoder-Decoder for Statistical Machine Translation.
- Parameters:
in_size (
int|Sequence[int] |integer|Sequence[integer]) – The number of input units.out_size (
int|Sequence[int] |integer|Sequence[integer]) – The number of hidden units.w_init (
Array|ndarray|bool|number|bool|int|float|complex|Quantity|Callable) – The input weight initializer. Default is Orthogonal().b_init (
Array|ndarray|bool|number|bool|int|float|complex|Quantity|Callable) – The bias weight initializer. Default is ZeroInit().state_init (
Array|ndarray|bool|number|bool|int|float|complex|Quantity|Callable) – The state initializer. Default is ZeroInit().activation (
str|Callable) – The activation function. It can be a string or a callable function. Default is ‘tanh’.name (
str) – The name of the module. Default is None.
Examples
>>> import braintrace >>> import brainstate >>> >>> # Create a GRU cell >>> gru_cell = braintrace.nn.GRUCell(in_size=128, out_size=256) >>> gru_cell.init_state(batch_size=16) >>> >>> # Process a sequence of inputs >>> x = brainstate.random.randn(16, 128) >>> h = gru_cell(x) >>> print(h.shape) (16, 256)