ctc_loss#
- class braintools.metric.ctc_loss(logits, logit_paddings, labels, label_paddings, blank_id=0, log_epsilon=-100000.0)#
Compute Connectionist Temporal Classification (CTC) loss.
A simplified interface to CTC loss computation that returns only the loss values without forward probabilities. This is the most commonly used function for training sequence models with CTC.
- Parameters:
logits (
Array|ndarray|bool|number|bool|int|float|complex|Quantity) – Logits with shape(batch_size, max_time, num_classes)containing unnormalized log probabilities for each class including blanks.logit_paddings (
Array|ndarray|bool|number|bool|int|float|complex|Quantity) – Padding indicators with shape(batch_size, max_time). Values of 1.0 indicate padded positions, 0.0 indicate valid positions.labels (
Array|ndarray|bool|number|bool|int|float|complex|Quantity) – Reference integer labels with shape(batch_size, max_label_length). Contains target sequences without blanks.label_paddings (
Array|ndarray|bool|number|bool|int|float|complex|Quantity) – Label padding indicators with shape(batch_size, max_label_length). Must be right-padded (zeros followed by ones).blank_id (
int) – Class index for the blank symbol in the logits.log_epsilon (
float) – Numerically stable approximation of log(0) for invalid transitions.
- Returns:
CTC loss values with shape
(batch_size,)containing the negative log-likelihood for each sequence in the batch.- Return type:
Array|ndarray|bool|number|bool|int|float|complex|Quantity
Examples
>>> import jax.numpy as jnp >>> import braintools >>> # Setup for speech recognition task >>> batch_size, time_steps, vocab_size = 2, 10, 30 >>> logits = jnp.random.normal(size=(batch_size, time_steps, vocab_size)) >>> logit_pad = jnp.zeros((batch_size, time_steps)) >>> labels = jnp.array([[1, 2, 3], [4, 5, 0]]) # Different length sequences >>> label_pad = jnp.array([[0, 0, 0], [0, 0, 1]]) # Last label is padded >>> loss = braintools.metric.ctc_loss(logits, logit_pad, labels, label_pad) >>> print(f"Average CTC loss: {jnp.mean(loss):.4f}")
Notes
This function internally calls
ctc_loss_with_forward_probsand discards the forward probability arrays. For applications that need the forward probabilities, usectc_loss_with_forward_probsdirectly.See also
ctc_loss_with_forward_probsCTC loss computation with forward probabilities
References