ctc_loss_with_forward_probs#
- class braintools.metric.ctc_loss_with_forward_probs(logits, logit_paddings, labels, label_paddings, blank_id=0, log_epsilon=-100000.0)#
Compute CTC loss and forward probabilities for sequence alignment.
Connectionist Temporal Classification (CTC) loss enables training of sequence models without requiring frame-level alignment between input and output sequences. It uses dynamic programming to compute the probability of all valid alignments.
The CTC loss uses a special blank symbol \(\phi\) to represent variable-length output sequences and computes log-likelihoods over all possible alignments.
Forward probabilities are computed for:
\[ \begin{align}\begin{aligned}\alpha_{\mathrm{BLANK}}(t, n) = \sum_{\pi_{1:t-1}} p(\pi_t = \phi | \pi_{1:t-1}, y_{1:n-1})\\\alpha_{\mathrm{LABEL}}(t, n) = \sum_{\pi_{1:t-1}} p(\pi_t = y_n | \pi_{1:t-1}, y_{1:n-1})\end{aligned}\end{align} \]where \(\pi\) denotes alignment sequences with blank insertions.
- Parameters:
logits (
Array|ndarray|bool|number|bool|int|float|complex|Quantity) – Logits with shape(batch_size, max_time, num_classes)containing unnormalized log probabilities for each class including blanks.logit_paddings (
Array|ndarray|bool|number|bool|int|float|complex|Quantity) – Padding indicators with shape(batch_size, max_time). Values of 1.0 indicate padded positions, 0.0 indicate valid positions.labels (
Array|ndarray|bool|number|bool|int|float|complex|Quantity) – Reference integer labels with shape(batch_size, max_label_length). Contains target sequences without blanks.label_paddings (
Array|ndarray|bool|number|bool|int|float|complex|Quantity) – Label padding indicators with shape(batch_size, max_label_length). Must be right-padded (zeros followed by ones).blank_id (
int) – Class index for the blank symbol in the logits.log_epsilon (
float) – Numerically stable approximation of log(0) for invalid transitions.
- Returns:
A tuple containing:
loss_values :
(batch_size,)- CTC loss for each sequencelogalpha_blank :
(max_time, batch_size, max_label_length+1)- Log forward probabilities for blank stateslogalpha_nonblank :
(max_time, batch_size, max_label_length)- Log forward probabilities for non-blank states
- Return type:
tuple[Array|ndarray|bool|number|bool|int|float|complex|Quantity,Array|ndarray|bool|number|bool|int|float|complex|Quantity,Array|ndarray|bool|number|bool|int|float|complex|Quantity]
Examples
>>> import jax.numpy as jnp >>> import braintools >>> # Example with batch_size=1, time=4, classes=3, labels=2 >>> logits = jnp.random.normal(size=(1, 4, 3)) >>> logit_pad = jnp.zeros((1, 4)) >>> labels = jnp.array([[1, 2]]) >>> label_pad = jnp.zeros((1, 2)) >>> loss, alpha_blank, alpha_label = braintools.metric.ctc_loss_with_forward_probs( ... logits, logit_pad, labels, label_pad, blank_id=0 ... ) >>> print(f"CTC loss: {loss[0]:.4f}")
Notes
This function requires that labels are right-padded and logit sequences are properly aligned. The forward probabilities can be used for additional analysis or for implementing more sophisticated training procedures.
References