ctc_loss

Contents

ctc_loss#

class braintools.metric.ctc_loss(logits, logit_paddings, labels, label_paddings, blank_id=0, log_epsilon=-100000.0)#

Compute Connectionist Temporal Classification (CTC) loss.

A simplified interface to CTC loss computation that returns only the loss values without forward probabilities. This is the most commonly used function for training sequence models with CTC.

Parameters:
  • logits (Array | ndarray | bool | number | bool | int | float | complex | Quantity) – Logits with shape (batch_size, max_time, num_classes) containing unnormalized log probabilities for each class including blanks.

  • logit_paddings (Array | ndarray | bool | number | bool | int | float | complex | Quantity) – Padding indicators with shape (batch_size, max_time). Values of 1.0 indicate padded positions, 0.0 indicate valid positions.

  • labels (Array | ndarray | bool | number | bool | int | float | complex | Quantity) – Reference integer labels with shape (batch_size, max_label_length). Contains target sequences without blanks.

  • label_paddings (Array | ndarray | bool | number | bool | int | float | complex | Quantity) – Label padding indicators with shape (batch_size, max_label_length). Must be right-padded (zeros followed by ones).

  • blank_id (int) – Class index for the blank symbol in the logits.

  • log_epsilon (float) – Numerically stable approximation of log(0) for invalid transitions.

Returns:

CTC loss values with shape (batch_size,) containing the negative log-likelihood for each sequence in the batch.

Return type:

Array | ndarray | bool | number | bool | int | float | complex | Quantity

Examples

>>> import jax.numpy as jnp
>>> import braintools
>>> # Setup for speech recognition task
>>> batch_size, time_steps, vocab_size = 2, 10, 30
>>> logits = jnp.random.normal(size=(batch_size, time_steps, vocab_size))
>>> logit_pad = jnp.zeros((batch_size, time_steps))
>>> labels = jnp.array([[1, 2, 3], [4, 5, 0]])  # Different length sequences
>>> label_pad = jnp.array([[0, 0, 0], [0, 0, 1]])  # Last label is padded
>>> loss = braintools.metric.ctc_loss(logits, logit_pad, labels, label_pad)
>>> print(f"Average CTC loss: {jnp.mean(loss):.4f}")

Notes

This function internally calls ctc_loss_with_forward_probs and discards the forward probability arrays. For applications that need the forward probabilities, use ctc_loss_with_forward_probs directly.

See also

ctc_loss_with_forward_probs

CTC loss computation with forward probabilities

References