Source code for braintools.cogtask.tasks.reasoning

# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Reasoning cognitive task classes."""

from typing import Sequence, Tuple

import brainunit as u
import jax.numpy as jnp

from ..context import Context
from ..encoder import one_hot
from ..feature import Feature
from ..phase import concat, Phase, Repeat, Fixation, Stimulus, Delay, Response, Cue, DeclarativePhase
from ..task import Task
from .._typing import Duration

__all__ = [
    'HierarchicalReasoning',
    'ProbabilisticReasoning',
]


[docs] class HierarchicalReasoning(Task): """ Hierarchical Reasoning task. Agent must apply conditional rules: - If delay < threshold: Rule A (go toward flash2) - If delay >= threshold: Rule B (go away from flash2) Rules change in blocks without explicit cues. Structure: Fixation >> Flash1 >> Delay >> Flash2 >> Response Parameters ---------- t_fixation : Duration Fixation duration (default: 500ms). t_flash1 : Duration First flash duration (default: 100ms). t_delay : tuple (min, max) for delay duration (default: (200ms, 800ms)). t_flash2 : Duration Second flash duration (default: 100ms). t_response : Duration Response duration (default: 500ms). delay_threshold : float Threshold for rule switching in ms (default: 500.0). seed : int, optional Random seed. Examples -------- >>> task = HierarchicalReasoning() >>> task = HierarchicalReasoning(delay_threshold=400.0) >>> X, Y, info = task.sample_trial(0) """ def __init__( self, t_fixation: Duration = 500.0 * u.ms, t_flash1: Duration = 100.0 * u.ms, t_delay: tuple = (200.0 * u.ms, 800.0 * u.ms), t_flash2: Duration = 100.0 * u.ms, t_response: Duration = 500.0 * u.ms, delay_threshold: float = 500.0, show_rule_cue: bool = True, **kwargs ): self.t_fixation = t_fixation self.t_flash1 = t_flash1 self.t_delay = t_delay self.t_flash2 = t_flash2 self.t_response = t_response self.delay_threshold = delay_threshold # If False, the rule is implicit (changes every 100 trials by index) and # the agent must infer it. If True, the rule is provided as an input # cue during fixation — required for learnability without feedback. self.show_rule_cue = show_rule_cue super().__init__(**kwargs)
[docs] def define_features(self) -> Tuple: fix_feat = Feature(1, 'fixation') flash_feat = Feature(2, 'flash') # left/right if self.show_rule_cue: rule_feat = Feature(2, 'rule') input_features = fix_feat + flash_feat + rule_feat else: input_features = fix_feat + flash_feat resp_feat = Feature(2, 'response') output_features = fix_feat + resp_feat return input_features, output_features
[docs] def define_phases(self) -> Phase: from ..phase import VariableDuration variable_delay = VariableDuration( min_duration=self.t_delay[0], max_duration=self.t_delay[1], ctx_key='delay_duration', inputs={'fixation': 1.0}, outputs={'label': 0}, name='delay', ) fixation_inputs = {'fixation': 1.0} if self.show_rule_cue: fixation_inputs['rule'] = one_hot('rule', num_classes=2) return concat([ Fixation( duration=self.t_fixation, name='fixation', inputs=fixation_inputs, outputs={'label': 0} ), Stimulus( duration=self.t_flash1, name='flash1', inputs={ 'fixation': 1.0, 'flash': one_hot('flash1_loc') }, outputs={'label': 0} ), variable_delay, Stimulus( duration=self.t_flash2, name='flash2', inputs={ 'fixation': 1.0, 'flash': one_hot('flash2_loc') }, outputs={'label': 0} ), Response( duration=self.t_response, name='response', inputs={'fixation': 0.0}, outputs={'label': lambda ctx, f: ctx['ground_truth'] + 1} ), ])
[docs] def trial_init(self, ctx: Context) -> None: # Convert Quantity tuple to float (in ms) delay_min = float(self.t_delay[0].to(u.ms).mantissa) delay_max = float(self.t_delay[1].to(u.ms).mantissa) # Trial index determines rule block (alternating every 100 trials). # Under ``batch_sample``/vmap the index is a traced int32 scalar, so # use jnp arithmetic rather than Python int() coercion. trial_idx = jnp.asarray(ctx.get('trial_index', 0), dtype=jnp.int32) ctx['rule'] = (trial_idx // 100) % 2 # 0 or 1 # Sample delay ctx['delay_duration'] = ctx.rng.uniform(delay_min, delay_max) # Flash locations (left=0, right=1) ctx['flash1_loc'] = ctx.rng.choice(2) ctx['flash2_loc'] = ctx.rng.choice(2) # Determine correct response based on rule and delay. Both rule and # delay condition can be JAX values under vmap, so use jnp.where. short_delay = ctx['delay_duration'] < self.delay_threshold flash = ctx['flash2_loc'] # Rule 0: short→toward, long→away. Rule 1: opposite. rule_a = jnp.where(short_delay, flash, 1 - flash) rule_b = jnp.where(short_delay, 1 - flash, flash) ctx['ground_truth'] = jnp.where(ctx['rule'] == 0, rule_a, rule_b)
[docs] class ProbabilisticReasoning(Task): """ Probabilistic Reasoning task. Agent accumulates log-likelihood evidence from multiple cues. Each cue provides probabilistic evidence for one of two choices. Structure: Fixation >> (Cue >> Delay) * N >> Response Parameters ---------- t_fixation : Duration Fixation duration (default: 500ms). t_cue : Duration Duration of each cue (default: 100ms). t_delay : Duration Delay between cues (default: 100ms). num_cues : int Number of evidence cues (default: 8). t_response : Duration Response duration (default: 500ms). num_choices : int Number of choices (default: 2). cue_evidence : sequence Possible log-likelihood ratios (positive = choice 1) (default: (-0.08, -0.04, -0.02, -0.01, 0.01, 0.02, 0.04, 0.08)). seed : int, optional Random seed. Examples -------- >>> task = ProbabilisticReasoning() >>> task = ProbabilisticReasoning(num_cues=12, t_cue=150*u.ms) >>> X, Y, info = task.sample_trial(0) """ def __init__( self, t_fixation: Duration = 500.0 * u.ms, t_cue: Duration = 100.0 * u.ms, t_delay: Duration = 100.0 * u.ms, num_cues: int = 8, t_response: Duration = 500.0 * u.ms, num_choices: int = 2, cue_evidence: Sequence[float] = (-0.08, -0.04, -0.02, -0.01, 0.01, 0.02, 0.04, 0.08), **kwargs ): self.t_fixation = t_fixation self.t_cue = t_cue self.t_delay = t_delay self.num_cues = num_cues self.t_response = t_response self.num_choices = num_choices self.cue_evidence = cue_evidence super().__init__(**kwargs)
[docs] def define_features(self) -> Tuple: fix_feat = Feature(1, 'fixation') cue_feat = Feature(len(self.cue_evidence), 'cue') input_features = fix_feat + cue_feat resp_feat = Feature(self.num_choices, 'response') output_features = fix_feat + resp_feat return input_features, output_features
def _cue_encoder(self, ctx: Context, feature) -> jnp.ndarray: """One-hot encoding of the i-th sampled cue. JIT/vmap safe.""" # ``repeat_index`` is set by Repeat; fall back to legacy cue_index. idx = ctx.get('repeat_index', ctx.get('cue_index', 0)) cue_indices = ctx['cue_indices'] # shape (num_cues,) cue_id = cue_indices[idx] result = jnp.zeros(feature.num) result = result.at[cue_id].set(1.0) return result
[docs] def define_phases(self) -> Phase: cue_block = concat([ Cue( duration=self.t_cue, name='cue', inputs={ 'fixation': 1.0, 'cue': self._cue_encoder }, outputs={'label': 0}, ), Delay( duration=self.t_delay, name='cue_delay', inputs={'fixation': 1.0}, outputs={'label': 0} ), ]) return concat([ Fixation( duration=self.t_fixation, name='fixation', inputs={'fixation': 1.0}, outputs={'label': 0} ), Repeat(cue_block, self.num_cues), Response( duration=self.t_response, name='response', inputs={'fixation': 0.0}, outputs={'label': lambda ctx, f: ctx['ground_truth'] + 1} ), ])
[docs] def trial_init(self, ctx: Context) -> None: cue_evidence = jnp.asarray(self.cue_evidence, dtype=jnp.float32) # Sample cues for this trial ctx['cue_indices'] = ctx.rng.choice(len(cue_evidence), size=self.num_cues) ctx['cues'] = cue_evidence[ctx['cue_indices']] # Accumulate log-likelihood total_llr = jnp.sum(ctx['cues']) ctx['total_evidence'] = total_llr # Decision based on accumulated evidence ctx['ground_truth'] = jnp.where(total_llr > 0, 1, 0)