Recurrent Neural Networks#

Recurrent Neural Networks (RNNs) are designed to process sequential data by maintaining hidden states that capture temporal dependencies.

In this tutorial, you will learn:

  • 🔄 RNN Basics - Understanding recurrence and hidden states

  • 🧠 LSTM - Long Short-Term Memory for long-term dependencies

  • GRU - Gated Recurrent Unit as efficient alternative

  • 📊 Sequence Processing - Handling variable-length sequences

  • 💡 Practical Examples - Sentiment analysis, time series

Why RNNs?#

RNNs excel at:

  • 📝 Natural language processing

  • 🎵 Speech and audio

  • 📈 Time series forecasting

  • 🎬 Video analysis

import brainstate
import numpy as np
import jax.numpy as jnp
import matplotlib.pyplot as plt

1. RNN Basics#

An RNN processes sequences one element at a time, updating its hidden state:

h_t = tanh(W_xh @ x_t + W_hh @ h_{t-1} + b)

Simple RNN Cell#

# RNNCell: Basic recurrent unit
brainstate.random.seed(42)
rnn_cell = brainstate.nn.ValinaRNNCell(num_in=10, num_out=20)
brainstate.nn.init_all_states(rnn_cell)

print("RNNCell:")
print(rnn_cell)
print(f"Input size: {rnn_cell.in_size}")
print(f"Hidden size: {rnn_cell.out_size}")

# Process a single timestep
x_t = brainstate.random.randn(10)
hidden_before = rnn_cell.h.value
hidden_after = rnn_cell(x_t)

print(f"Input shape: {x_t.shape}")
print(f"Previous hidden shape: {hidden_before.shape}")
print(f"New hidden shape: {hidden_after.shape}")
RNNCell:
ValinaRNNCell(
  in_size=(10,),
  out_size=(20,),
  num_out=20,
  num_in=10,
  state_initializer=ZeroInit(
    unit=Unit(10.0^0)
  ),
  activation=<function relu at 0x00000135A6BE8360>,
  W=Linear(
    in_size=(30,),
    out_size=(20,),
    w_mask=None,
    weight=ParamState(
      value={
        'bias': ShapedArray(float32[20]),
        'weight': ShapedArray(float32[30,20])
      }
    )
  ),
  h=HiddenState(
    value=ShapedArray(float32[20])
  )
)
Input size: (10,)
Hidden size: (20,)
Input shape: (10,)
Previous hidden shape: (20,)
New hidden shape: (20,)

Processing Sequences#

Let’s process a complete sequence:

class SimpleRNN(brainstate.nn.Module):
    """Simple RNN for sequence processing."""

    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        self.rnn_cell = brainstate.nn.ValinaRNNCell(num_in=input_size, num_out=hidden_size)
        self.output_layer = brainstate.nn.Linear((hidden_size,), (output_size,))
        self.hidden_size = hidden_size

    def update(self, sequence):
        # Reset recurrent state for each new sequence
        brainstate.nn.init_all_states(self)
        rnn_out = brainstate.transform.for_loop(self.rnn_cell, sequence)
        return self.output_layer(rnn_out)

# Create RNN
brainstate.random.seed(0)
rnn = SimpleRNN(input_size=5, hidden_size=10, output_size=3)

# Test sequence
sequence = brainstate.random.randn(7, 5)  # 7 timesteps, 5 features
outputs = rnn(sequence)

print(f"Input sequence shape: {sequence.shape}")
print(f"Output sequence shape: {outputs.shape}")
print(f"Outputs: {outputs}")
Input sequence shape: (7, 5)
Output sequence shape: (7, 3)
Outputs: [[ 0.02491026 -1.2344698  -0.7660187 ]
 [ 1.020083   -0.94504666  0.5592617 ]
 [ 0.5109156  -0.7276021  -0.29766443]
 [ 0.8082595   0.6000599   0.46565723]
 [ 1.1170439   0.4770828  -0.37073386]
 [ 1.4667886  -1.4823256   0.30516532]
 [ 0.7935723   0.4356203   0.21530463]]

Visualizing RNN Processing#

# Generate sine wave sequence
t = jnp.linspace(0, 4 * jnp.pi, 50)
sine_wave = jnp.sin(t)

# Prepare as sequence (add feature dimension)
sequence = sine_wave[:, None]

# Create RNN
brainstate.random.seed(42)
rnn = SimpleRNN(input_size=1, hidden_size=16, output_size=1)

# Process sequence
outputs = rnn(sequence).flatten()

# Plot
plt.figure(figsize=(12, 4))
plt.plot(t, sine_wave, linewidth=2, label='Input (Sine)', alpha=0.7)
plt.plot(t, outputs, linewidth=2, label='RNN Output', alpha=0.7)
plt.xlabel('Time')
plt.ylabel('Value')
plt.title('RNN Processing Temporal Sequence', fontweight='bold')
plt.legend()
plt.grid(alpha=0.3)
plt.show()

print("RNN learns to process temporal patterns")
../../_images/58771ac71118adbbd8ec34f629a8b6daf7a5b88ab21ae3fb6bd164c70e6066ad.png
RNN learns to process temporal patterns

2. LSTM - Long Short-Term Memory#

LSTM addresses the vanishing gradient problem with gating mechanisms:

  • Forget gate: What to forget from cell state

  • Input gate: What new information to store

  • Output gate: What to output from cell state

LSTM Cell#

# LSTMCell: Advanced recurrent unit
brainstate.random.seed(42)
lstm_cell = brainstate.nn.LSTMCell(num_in=10, num_out=20)
brainstate.nn.init_all_states(lstm_cell)

print("LSTMCell:")
print(lstm_cell)

# LSTM has two states: hidden (h) and cell (c)
hidden_before = lstm_cell.h.value
cell_before = lstm_cell.c.value

# Process one timestep
x_t = brainstate.random.randn(10)
hidden_after = lstm_cell(x_t)
cell_after = lstm_cell.c.value

print(f"Input: {x_t.shape}")
print(f"Hidden state shape: {hidden_after.shape}")
print(f"Cell state shape: {cell_after.shape}")
print("✅ LSTM maintains both hidden and cell states")
print("✅ Cell state provides long-term memory")
LSTMCell:
LSTMCell(
  in_size=(10,),
  out_size=(20,),
  num_out=20,
  num_in=10,
  state_initializer=ZeroInit(
    unit=Unit(10.0^0)
  ),
  activation=<function tanh at 0x00000135A6BDFD80>,
  W=Linear(
    in_size=(30,),
    out_size=(80,),
    w_mask=None,
    weight=ParamState(
      value={
        'bias': ShapedArray(float32[80]),
        'weight': ShapedArray(float32[30,80])
      }
    )
  ),
  c=HiddenState(
    value=ShapedArray(float32[20])
  ),
  h=HiddenState(
    value=ShapedArray(float32[20])
  )
)
Input: (10,)
Hidden state shape: (20,)
Cell state shape: (20,)
✅ LSTM maintains both hidden and cell states
✅ Cell state provides long-term memory

LSTM Network#

class LSTMNet(brainstate.nn.Module):
    """LSTM network for sequence processing."""

    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        self.lstm_cell = brainstate.nn.LSTMCell(num_in=input_size, num_out=hidden_size)
        self.output_layer = brainstate.nn.Linear((hidden_size,), (output_size,))
        self.hidden_size = hidden_size

    def update(self, sequence):
        brainstate.nn.init_all_states(self)
        rnn_out = brainstate.transform.for_loop(self.lstm_cell, sequence)
        return self.output_layer(rnn_out)

# Create LSTM
brainstate.random.seed(0)
lstm = LSTMNet(input_size=5, hidden_size=20, output_size=3)

# Test
sequence = brainstate.random.randn(10, 5)
outputs = lstm(sequence)

print(f"LSTM Network:")
print(lstm)
print(f"Sequence: {sequence.shape} → Output: {outputs.shape}")
LSTM Network:
LSTMNet(
  lstm_cell=LSTMCell(
    in_size=(5,),
    out_size=(20,),
    num_out=20,
    num_in=5,
    state_initializer=ZeroInit(
      unit=Unit(10.0^0)
    ),
    activation=<function tanh at 0x00000135A6BDFD80>,
    W=Linear(
      in_size=(25,),
      out_size=(80,),
      w_mask=None,
      weight=ParamState(
        value={
          'bias': ShapedArray(float32[80]),
          'weight': ShapedArray(float32[25,80])
        }
      )
    ),
    c=HiddenState(
      value=ShapedArray(float32[20])
    ),
    h=HiddenState(
      value=ShapedArray(float32[20])
    )
  ),
  output_layer=Linear(
    in_size=(20,),
    out_size=(3,),
    w_mask=None,
    weight=ParamState(
      value={
        'bias': ShapedArray(float32[3]),
        'weight': ShapedArray(float32[20,3])
      }
    )
  ),
  hidden_size=20
)
Sequence: (10, 5) → Output: (10, 3)

3. GRU - Gated Recurrent Unit#

GRU simplifies LSTM with fewer gates:

  • Reset gate: How much past to forget

  • Update gate: How much to update

GRU Cell#

# GRUCell: Efficient alternative to LSTM
brainstate.random.seed(42)
gru_cell = brainstate.nn.GRUCell(num_in=10, num_out=20)
brainstate.nn.init_all_states(gru_cell)

print("GRUCell:")
print(gru_cell)

# GRU only has hidden state (no separate cell state)
hidden_before = gru_cell.h.value
x_t = brainstate.random.randn(10)
hidden_after = gru_cell(x_t)

print(f"Input: {x_t.shape}")
print(f"Hidden state shape: {hidden_after.shape}")
print("✅ Simpler than LSTM (no cell state)")
print("✅ Faster training, fewer parameters")
print("✅ Often performs similarly to LSTM")
GRUCell:
GRUCell(
  in_size=(10,),
  out_size=(20,),
  state_initializer=ZeroInit(
    unit=Unit(10.0^0)
  ),
  num_out=20,
  num_in=10,
  activation=<function tanh at 0x00000135A6BDFD80>,
  Wrz=Linear(
    in_size=(30,),
    out_size=(40,),
    w_mask=None,
    weight=ParamState(
      value={
        'bias': ShapedArray(float32[40]),
        'weight': ShapedArray(float32[30,40])
      }
    )
  ),
  Wh=Linear(
    in_size=(30,),
    out_size=(20,),
    w_mask=None,
    weight=ParamState(
      value={
        'bias': ShapedArray(float32[20]),
        'weight': ShapedArray(float32[30,20])
      }
    )
  ),
  h=HiddenState(
    value=ShapedArray(float32[20])
  )
)
Input: (10,)
Hidden state shape: (20,)
✅ Simpler than LSTM (no cell state)
✅ Faster training, fewer parameters
✅ Often performs similarly to LSTM

Comparing RNN, LSTM, and GRU#

# Create all three types
brainstate.random.seed(0)
input_size, hidden_size = 5, 10

rnn_cell = brainstate.nn.ValinaRNNCell(num_in=input_size, num_out=hidden_size)
lstm_cell = brainstate.nn.LSTMCell(num_in=input_size, num_out=hidden_size)
gru_cell = brainstate.nn.GRUCell(num_in=input_size, num_out=hidden_size)

# Test on same sequence
sequence = brainstate.random.randn(20, input_size)

# Process with RNN
brainstate.nn.init_all_states(rnn_cell)
rnn_states = []
for x_t in sequence:
    rnn_states.append(rnn_cell(x_t))
rnn_states = jnp.stack(rnn_states)

# Process with LSTM
brainstate.nn.init_all_states(lstm_cell)
lstm_states = []
for x_t in sequence:
    lstm_states.append(lstm_cell(x_t))
lstm_states = jnp.stack(lstm_states)

# Process with GRU
brainstate.nn.init_all_states(gru_cell)
gru_states = []
for x_t in sequence:
    gru_states.append(gru_cell(x_t))
gru_states = jnp.stack(gru_states)

# Visualize hidden states
plt.figure(figsize=(15, 4))

plt.subplot(1, 3, 1)
plt.imshow(np.array(rnn_states.T), aspect='auto', cmap='viridis')
plt.colorbar()
plt.title('RNN Hidden States', fontweight='bold')
plt.xlabel('Time')
plt.ylabel('Hidden Unit')

plt.subplot(1, 3, 2)
plt.imshow(np.array(lstm_states.T), aspect='auto', cmap='viridis')
plt.colorbar()
plt.title('LSTM Hidden States', fontweight='bold')
plt.xlabel('Time')
plt.ylabel('Hidden Unit')

plt.subplot(1, 3, 3)
plt.imshow(np.array(gru_states.T), aspect='auto', cmap='viridis')
plt.colorbar()
plt.title('GRU Hidden States', fontweight='bold')
plt.xlabel('Time')
plt.ylabel('Hidden Unit')

plt.tight_layout()
plt.show()

print("Different activation patterns across time")
../../_images/9e67611ae03e92e2b88f440c95b4ecee92773267710cdac5c93460a289b3658c.png
Different activation patterns across time

4. Practical Example: Sequence Classification#

Let’s build a complete sequence classifier:

class SequenceClassifier(brainstate.nn.Module):
    """Classify sequences using LSTM."""

    def __init__(self, input_size, hidden_size, num_classes):
        super().__init__()

        # LSTM layer
        self.lstm = brainstate.nn.LSTMCell(num_in=input_size, num_out=hidden_size)

        # Classifier
        self.fc = brainstate.nn.Linear((hidden_size,), (num_classes,))
        self.hidden_size = hidden_size

    def update(self, sequence):
        """Classify a sequence and return logits."""
        brainstate.nn.init_all_states(self)
        rnn_out = brainstate.transform.for_loop(self.lstm, sequence)
        return self.fc(rnn_out)

# Create classifier
brainstate.random.seed(42)
classifier = SequenceClassifier(
    input_size=8,
    hidden_size=32,
    num_classes=3
)

print("Sequence Classifier:")
print(classifier)

# Test with batch of sequences
sequences = [
    brainstate.random.randn(10, 8),  # Short sequence
    brainstate.random.randn(15, 8),  # Medium sequence
    brainstate.random.randn(20, 8),  # Long sequence
]

print("Classifying sequences of different lengths:")
for i, seq in enumerate(sequences):
    logits = classifier(seq)
    pred = jnp.argmax(logits)
    print(f"  Sequence {i + 1} (length={seq.shape[0]:2d}): logits={logits}, predicted class={pred}")
Sequence Classifier:
SequenceClassifier(
  lstm=LSTMCell(
    in_size=(8,),
    out_size=(32,),
    num_out=32,
    num_in=8,
    state_initializer=ZeroInit(
      unit=Unit(10.0^0)
    ),
    activation=<function tanh at 0x00000135A6BDFD80>,
    W=Linear(
      in_size=(40,),
      out_size=(128,),
      w_mask=None,
      weight=ParamState(
        value={
          'bias': ShapedArray(float32[128]),
          'weight': ShapedArray(float32[40,128])
        }
      )
    )
  ),
  fc=Linear(
    in_size=(32,),
    out_size=(3,),
    w_mask=None,
    weight=ParamState(
      value={
        'bias': ShapedArray(float32[3]),
        'weight': ShapedArray(float32[32,3])
      }
    )
  ),
  hidden_size=32
)
Classifying sequences of different lengths:
  Sequence 1 (length=10): logits=[[-0.01547857  0.13513905  0.00754708]
 [ 0.05315275  0.04406237 -0.03946488]
 [-0.01024249  0.01766446 -0.3073437 ]
 [-0.08865614  0.21953711 -0.22142717]
 [-0.01125905  0.23754263 -0.07070658]
 [-0.18919653  0.26399848  0.10286664]
 [-0.18127857  0.3011483   0.14584416]
 [-0.15617499  0.07344149  0.16522595]
 [-0.01809211  0.13386661  0.14542104]
 [-0.02841321  0.22150865  0.21882784]], predicted class=19
  Sequence 2 (length=15): logits=[[ 0.05904964 -0.10120873 -0.06329243]
 [ 0.17392379 -0.09240627  0.04093027]
 [ 0.18226701 -0.07839968  0.01378617]
 [ 0.17103598  0.0156443   0.07495441]
 [ 0.15527852 -0.00152822 -0.01727968]
 [-0.03237936 -0.12036745 -0.09828403]
 [-0.06232247 -0.05956251  0.05448275]
 [-0.08484429  0.02086652  0.02134096]
 [-0.04571725  0.0271459  -0.03105914]
 [ 0.03056999  0.03487205 -0.06918844]
 [ 0.10113777  0.07460307 -0.22886036]
 [-0.03290576  0.01573478 -0.23962773]
 [ 0.01381631  0.01904286 -0.24691448]
 [ 0.00996852  0.1708695  -0.21775015]
 [-0.07599682  0.10057165 -0.13780987]], predicted class=6
  Sequence 3 (length=20): logits=[[-0.03186802 -0.02496959 -0.09452604]
 [ 0.06142969 -0.00672111 -0.16893423]
 [ 0.08336904  0.19765691 -0.03138714]
 [ 0.20414184  0.17294948 -0.01703165]
 [ 0.12318151  0.10354187 -0.13263148]
 [-0.06375938 -0.08475278 -0.17380428]
 [-0.01210234 -0.14016798 -0.17115496]
 [ 0.04803856 -0.05086609 -0.28233108]
 [ 0.07683823 -0.08949272 -0.2292438 ]
 [ 0.06163431 -0.19596225 -0.25488782]
 [-0.03754998 -0.4479535  -0.2219074 ]
 [ 0.01833916 -0.40765774 -0.1654255 ]
 [-0.01680591 -0.23481718  0.03301343]
 [ 0.00481436 -0.1791741  -0.0043909 ]
 [-0.02281204 -0.07078494  0.00409675]
 [-0.10403152 -0.07854924 -0.0093255 ]
 [-0.04130641 -0.01892788  0.07624298]
 [-0.09435351  0.02813931  0.19073027]
 [ 0.02372409  0.02199761  0.17838839]
 [-0.00277933  0.20082155  0.1791797 ]], predicted class=9

5. Time Series Prediction#

# Generate synthetic time series
def generate_time_series(n_steps=100):
    t = jnp.linspace(0, 10, n_steps)
    # Combination of sine waves
    series = jnp.sin(t) + 0.5 * jnp.sin(3 * t) + 0.1 * brainstate.random.randn(n_steps)
    return series

# Create sequences for prediction
def create_sequences(data, seq_length=10):
    X, y = [], []
    for i in range(len(data) - seq_length):
        X.append(data[i:i + seq_length])
        y.append(data[i + seq_length])
    return jnp.array(X), jnp.array(y)

# Generate data
brainstate.random.seed(0)
time_series = generate_time_series(200)
X, y = create_sequences(time_series, seq_length=15)

# Add feature dimension
X = X[:, :, None]

print(f"Time series data: {time_series.shape}")
print(f"Sequences (X): {X.shape}")
print(f"Targets (y): {y.shape}")

# Create predictor
class TimeSeriesPredictor(brainstate.nn.Module):
    def __init__(self, hidden_size=32):
        super().__init__()
        self.gru = brainstate.nn.GRUCell(num_in=1, num_out=hidden_size)
        self.fc = brainstate.nn.Linear((hidden_size,), (1,))
        self.hidden_size = hidden_size

    def update(self, sequence):
        brainstate.nn.init_all_states(self)
        brainstate.transform.for_loop(self.gru, sequence)
        prediction = self.fc(self.gru.h.value)
        return prediction.squeeze()

# Create and test predictor
brainstate.random.seed(42)
predictor = TimeSeriesPredictor(hidden_size=64)

# Make predictions on test data
predictions = []
for i in range(min(50, len(X))):
    pred = predictor(X[i])
    predictions.append(pred)

predictions = jnp.array(predictions)
targets = y[:len(predictions)]

# Plot
plt.figure(figsize=(12, 5))
plt.plot(targets, label='True Values', linewidth=2, alpha=0.7)
plt.plot(predictions, label='Predictions (Untrained)', linewidth=2, alpha=0.7)
plt.xlabel('Time Step')
plt.ylabel('Value')
plt.title('Time Series Prediction with GRU', fontweight='bold')
plt.legend()
plt.grid(alpha=0.3)
plt.show()

mse = jnp.mean((predictions - targets) ** 2)
print(f"MSE (untrained): {mse:.4f}")
print("💡 With training, GRU can learn to predict future values")
Time series data: (200,)
Sequences (X): (185, 15, 1)
Targets (y): (185,)
../../_images/ad7e4d52e41baef91b0966596276e708ee060525c095742dcf6321a7706d01ea.png
MSE (untrained): 0.1451
💡 With training, GRU can learn to predict future values

Summary#

In this tutorial, you learned:

RNN Basics

  • Recurrence and hidden states

  • Processing sequences step-by-step

  • Building simple RNN networks

LSTM

  • Gating mechanisms (forget, input, output)

  • Cell state for long-term memory

  • Handling long-term dependencies

GRU

  • Simplified gating (reset, update)

  • Fewer parameters than LSTM

  • Efficient alternative

Practical Applications

  • Sequence classification

  • Time series prediction

  • Variable-length sequences

Quick Comparison#

Model

States

Gates

Best For

RNN

1 (h)

0

Short sequences, simple patterns

LSTM

2 (h, c)

3

Long sequences, complex dependencies

GRU

1 (h)

2

Balance of complexity and performance

When to Use Each#

  • 🎯 Start with GRU - Good default choice

  • 📚 Use LSTM - When you need maximum capacity for long-term memory

  • Use RNN - For simple patterns or as baseline

Best Practices#

  1. 🔄 Initialize hidden states to zero

  2. 📊 Normalize input sequences

  3. 🎯 Use gradient clipping to prevent exploding gradients

  4. 💾 Save hidden states for inference on long sequences

  5. 🔍 Try bidirectional RNNs for offline sequence processing

Next Steps#

Continue with:

  • Dynamics Systems - Brain-inspired temporal models

  • Attention Mechanisms - Beyond RNNs (Transformers)

  • Training - Optimize RNNs effectively