HORN Models#

HORN (Harmonic Oscillator Recurrent Networks) are specialized recurrent neural networks based on coupled harmonic oscillators. They excel at learning and generating temporal sequences with complex dynamics.

Overview#

HORN models use harmonic oscillators as computational units, where the oscillatory dynamics naturally encode temporal patterns. Unlike traditional RNNs with sigmoid or tanh activations, HORN leverages the rich dynamics of coupled oscillators for sequence learning.

Key Features: - Oscillator-based recurrent dynamics - Natural handling of periodic and quasi-periodic sequences - Interpretable dynamical systems approach - Compatible with gradient-based optimization

Architecture#

HORN models consist of three main components:

  1. HORNStep: Single time step update of coupled oscillators

  2. HORNSeqLayer: Layer that processes sequential inputs

  3. HORNSeqNetwork: Full network with multiple HORN layers

API Reference#

HORNStep

Harmonic oscillator recurrent networks (HORNs) with one-step dynamics update.

HORNSeqLayer

Sequential layer wrapper for HORN dynamics with input and recurrent connections.

HORNSeqNetwork

Multi-layer HORN network for sequential processing tasks.

HORNStep#

HORNStep implements a single dynamics step for a population of coupled harmonic oscillators:

\[\ddot{x} + 2\zeta\omega\dot{x} + \omega^2 x = f(x, \text{input})\]

where \(\zeta\) is damping, \(\omega\) is natural frequency, and \(f\) is a coupling function.

Example:

import brainmass
import jax.numpy as jnp

# Create HORN step with 10 oscillators
horn_step = brainmass.HORNStep(
    in_size=5,    # input dimension
    num_osc=10,   # number of oscillators
    omega=1.0,    # natural frequency
    zeta=0.1,     # damping coefficient
)
horn_step.init_all_states()

# Single step update
x_input = jnp.randn(5)
x_out = horn_step.update(x_input)

HORNSeqLayer#

HORNSeqLayer wraps HORNStep to process sequences:

import brainstate

horn_layer = brainmass.HORNSeqLayer(
    in_size=5,
    num_osc=10,
    omega=1.0,
    zeta=0.1,
)
horn_layer.init_all_states()

# Process sequence
sequence = jnp.randn(100, 5)  # (time_steps, in_size)

outputs = brainstate.transform.for_loop(
    lambda t: horn_layer.update(sequence[t]),
    jnp.arange(100)
)

HORNSeqNetwork#

HORNSeqNetwork stacks multiple HORN layers to create a deep recurrent network:

horn_net = brainmass.HORNSeqNetwork(
    in_size=5,
    hidden_sizes=[20, 20, 10],  # 3 HORN layers
    out_size=3,                  # output dimension
    omega=1.0,
    zeta=0.1,
)
horn_net.init_all_states()

# Forward pass through network
sequence = jnp.randn(100, 5)

def forward_step(t):
    return horn_net.update(sequence[t])

predictions = brainstate.transform.for_loop(forward_step, jnp.arange(100))

Use Cases#

Sequence Generation#

HORN networks can learn to generate temporal sequences:

import brainmass
import jax
import jax.numpy as jnp
import brainstate

# Create generator network
generator = brainmass.HORNSeqNetwork(
    in_size=1,           # seed input
    hidden_sizes=[50, 50],
    out_size=10,         # sequence dimension
    omega=2.0,
    zeta=0.05,
)
generator.init_all_states()

# Generate sequence
seed = jnp.array([1.0])

generated_sequence = []
for t in range(500):
    output = generator.update(seed)
    generated_sequence.append(output)
    seed = output[:1]  # feedback

generated_sequence = jnp.stack(generated_sequence)

Time Series Prediction#

Predict future values of a time series:

# Training data
time_series = ...  # shape (T, D)

predictor = brainmass.HORNSeqNetwork(
    in_size=time_series.shape[1],
    hidden_sizes=[100, 50],
    out_size=time_series.shape[1],
    omega=1.5,
    zeta=0.1,
)
predictor.init_all_states()

# Training loop (simplified)
def loss_fn(params, inputs, targets):
    # Forward pass with params
    predictions = ...
    return jnp.mean((predictions - targets) ** 2)

# Optimize with JAX
optimizer = ...
for epoch in range(num_epochs):
    for batch_inputs, batch_targets in dataloader:
        grads = jax.grad(loss_fn)(params, batch_inputs, batch_targets)
        params = optimizer.update(grads, params)

Oscillatory Pattern Recognition#

Classify temporal patterns with oscillatory structure:

# HORN classifier
classifier = brainmass.HORNSeqNetwork(
    in_size=64,          # e.g., sensor channels
    hidden_sizes=[100],
    out_size=5,          # number of classes
    omega=3.0,           # match expected oscillation frequency
    zeta=0.2,
)
classifier.init_all_states()

# Classification
input_signal = ...  # shape (time_steps, 64)

logits = []
for t in range(input_signal.shape[0]):
    output = classifier.update(input_signal[t])
    logits.append(output)

# Final classification from last output or pooling
final_logits = logits[-1]
predicted_class = jnp.argmax(final_logits)

Parameter Selection#

Natural Frequency (ω):

  • Should match the characteristic frequency of the data

  • For data with dominant frequency \(f\), set \(\omega \approx 2\pi f\)

  • Multiple oscillators can have different frequencies to capture multi-scale dynamics

Damping (ζ):

  • Controls oscillation decay

  • \(\zeta < 1\): Underdamped (oscillatory)

  • \(\zeta = 1\): Critically damped

  • \(\zeta > 1\): Overdamped (no oscillations)

  • Typical values: 0.05 - 0.5 for learning temporal patterns

Number of Oscillators:

  • More oscillators increase capacity but also parameters

  • Start with 50-100 oscillators per layer

  • Scale based on sequence complexity

Training Considerations#

Initialization:

Proper initialization is important for oscillator-based networks:

import braintools.init

# Custom initialization of position and velocity states
horn_net = brainmass.HORNSeqNetwork(
    in_size=5,
    hidden_sizes=[20, 20, 10],
    out_size=3,
    omega=1.0,
    zeta=0.1,
    x_init=braintools.init.Normal(scale=0.1),  # position state initializer
    y_init=braintools.init.Normal(scale=0.01),  # velocity state initializer
)
horn_net.init_all_states(batch_size=32)  # for batched training

Gradient Clipping:

Oscillator dynamics can have large gradients; use gradient clipping:

grads = jax.grad(loss_fn)(params, inputs, targets)

# Clip gradients
clipped_grads = jax.tree_map(
    lambda g: jnp.clip(g, -1.0, 1.0),
    grads
)

Learning Rate:

Start with smaller learning rates (1e-4 to 1e-3) due to oscillatory dynamics.

Advantages and Limitations#

Advantages:

  • Natural temporal dynamics without gating mechanisms

  • Interpretable oscillator-based representation

  • Effective for periodic and quasi-periodic patterns

  • Differentiable and trainable with standard optimizers

Limitations:

  • More parameters than vanilla RNNs for same hidden size

  • Requires tuning of oscillator parameters (ω, ζ)

  • May not outperform LSTMs/GRUs on all sequence tasks

  • Less established than traditional RNN architectures

Comparison with Traditional RNNs#

Aspect

HORN

LSTM/GRU

Dynamics

Oscillatory (physics-based)

Gated activations

Interpretability

High (oscillator states)

Low (hidden states)

Best for

Periodic/oscillatory patterns

General sequences

Parameters

More (oscillator equations)

Fewer (compact gates)

Training

Gradient-based (may need clipping)

Gradient-based (stable)

See Also#