ctc_loss_with_forward_probs

ctc_loss_with_forward_probs#

class braintools.metric.ctc_loss_with_forward_probs(logits, logit_paddings, labels, label_paddings, blank_id=0, log_epsilon=-100000.0)#

Compute CTC loss and forward probabilities for sequence alignment.

Connectionist Temporal Classification (CTC) loss enables training of sequence models without requiring frame-level alignment between input and output sequences. It uses dynamic programming to compute the probability of all valid alignments.

The CTC loss uses a special blank symbol \(\phi\) to represent variable-length output sequences and computes log-likelihoods over all possible alignments.

Forward probabilities are computed for:

\[ \begin{align}\begin{aligned}\alpha_{\mathrm{BLANK}}(t, n) = \sum_{\pi_{1:t-1}} p(\pi_t = \phi | \pi_{1:t-1}, y_{1:n-1})\\\alpha_{\mathrm{LABEL}}(t, n) = \sum_{\pi_{1:t-1}} p(\pi_t = y_n | \pi_{1:t-1}, y_{1:n-1})\end{aligned}\end{align} \]

where \(\pi\) denotes alignment sequences with blank insertions.

Parameters:
  • logits (Array | ndarray | bool | number | bool | int | float | complex | Quantity) – Logits with shape (batch_size, max_time, num_classes) containing unnormalized log probabilities for each class including blanks.

  • logit_paddings (Array | ndarray | bool | number | bool | int | float | complex | Quantity) – Padding indicators with shape (batch_size, max_time). Values of 1.0 indicate padded positions, 0.0 indicate valid positions.

  • labels (Array | ndarray | bool | number | bool | int | float | complex | Quantity) – Reference integer labels with shape (batch_size, max_label_length). Contains target sequences without blanks.

  • label_paddings (Array | ndarray | bool | number | bool | int | float | complex | Quantity) – Label padding indicators with shape (batch_size, max_label_length). Must be right-padded (zeros followed by ones).

  • blank_id (int) – Class index for the blank symbol in the logits.

  • log_epsilon (float) – Numerically stable approximation of log(0) for invalid transitions.

Returns:

A tuple containing:

  • loss_values : (batch_size,) - CTC loss for each sequence

  • logalpha_blank : (max_time, batch_size, max_label_length+1) - Log forward probabilities for blank states

  • logalpha_nonblank : (max_time, batch_size, max_label_length) - Log forward probabilities for non-blank states

Return type:

tuple[Array | ndarray | bool | number | bool | int | float | complex | Quantity, Array | ndarray | bool | number | bool | int | float | complex | Quantity, Array | ndarray | bool | number | bool | int | float | complex | Quantity]

Examples

>>> import jax.numpy as jnp
>>> import braintools
>>> # Example with batch_size=1, time=4, classes=3, labels=2
>>> logits = jnp.random.normal(size=(1, 4, 3))
>>> logit_pad = jnp.zeros((1, 4))
>>> labels = jnp.array([[1, 2]])
>>> label_pad = jnp.zeros((1, 2))
>>> loss, alpha_blank, alpha_label = braintools.metric.ctc_loss_with_forward_probs(
...     logits, logit_pad, labels, label_pad, blank_id=0
... )
>>> print(f"CTC loss: {loss[0]:.4f}")

Notes

This function requires that labels are right-padded and logit sequences are properly aligned. The forward probabilities can be used for additional analysis or for implementing more sophisticated training procedures.

References