Skip to content

Wrappers

block

ScheduleAttr

ScheduleAttr(env, schedule, attr_list)

Bases: TrialWrapper

Schedule attributes.

Parameters:

Name Type Description Default
env

TrialEnv object

required
schedule
required
Source code in neurogym/wrappers/block.py
def __init__(self, env, schedule, attr_list) -> None:
    super().__init__(env)
    self.schedule = schedule
    self.attr_list = attr_list

MultiEnvs

MultiEnvs(envs, env_input=False)

Bases: TrialWrapper

Wrap multiple environments.

Parameters:

Name Type Description Default
envs

list of env object

required
env_input

bool, if True, add scalar inputs indicating current envinronment. default False.

False
Source code in neurogym/wrappers/block.py
def __init__(self, envs, env_input=False) -> None:
    super().__init__(envs[0])
    for env in envs:
        env.unwrapped.set_top(self)
    self.envs = envs
    self.i_env = 0

    self.env_input = env_input
    if env_input:
        env_shape = envs[0].observation_space.shape
        if len(env_shape) > 1:
            msg = f"Env must have 1-D Box shape but got {env_shape}."
            raise ValueError(msg)
        _have_equal_shape(envs)
        self.observation_space: spaces.Box = spaces.Box(
            -np.inf,
            np.inf,
            shape=(env_shape[0] + len(self.envs),),
            dtype=self.envs[0].observation_space.dtype,
        )

set_i

set_i(i) -> None

Set the i-th environment.

Source code in neurogym/wrappers/block.py
def set_i(self, i) -> None:
    """Set the i-th environment."""
    self.i_env = i
    self.env = self.envs[self.i_env]

ScheduleEnvs

ScheduleEnvs(envs, schedule, env_input=False)

Bases: TrialWrapper

Schedule environments.

Parameters:

Name Type Description Default
envs

list of env object

required
schedule

utils.scheduler.BaseSchedule object

required
env_input

bool, if True, add scalar inputs indicating current environment. default False.

False
Source code in neurogym/wrappers/block.py
def __init__(self, envs, schedule, env_input=False) -> None:
    super().__init__(envs[0])
    for env in envs:
        env.unwrapped.set_top(self)
    self.envs = envs
    self.schedule = schedule
    self.i_env = self.next_i_env = 0

    self.env_input = env_input
    if env_input:
        env_shape = envs[0].observation_space.shape
        if len(env_shape) > 1:
            msg = f"Env must have 1-D Box shape but got {env_shape}."
            raise ValueError(msg)
        _have_equal_shape(envs)
        self.observation_space: spaces.Box = spaces.Box(
            -np.inf,
            np.inf,
            shape=(env_shape[0] + len(self.envs),),
            dtype=np.float32,
        )

reset

reset(**kwargs)

Resets environments.

Reset each environment in self.envs and use the scheduler to select the environment returning the initial observation. This environment is also used to set the current environment self.env.

Source code in neurogym/wrappers/block.py
def reset(self, **kwargs):
    # TODO: kwargs to specify the condition for new_trial
    """Resets environments.

    Reset each environment in self.envs and use the scheduler to select the environment returning
    the initial observation. This environment is also used to set the current environment self.env.
    """
    self.schedule.reset()
    return_i_env = self.schedule()

    # first reset all the env excepted return_i_env
    for i, env in enumerate(self.envs):
        if i == return_i_env:
            continue

        # change the current env so that calling _top.new_trial() in env.reset() will generate a trial for the env
        # being currently reset (and not an env that is not yet reset)
        self.set_i(i)
        # same env used here and in the first call to new_trial()
        self.next_i_env = self.i_env

        env.reset(**kwargs)

    # then reset return_i_env and return the result
    self.set_i(return_i_env)
    self.next_i_env = self.i_env
    return self.env.reset()

set_i

set_i(i) -> None

Set the current environment to the i-th environment in the list envs.

Source code in neurogym/wrappers/block.py
def set_i(self, i) -> None:
    """Set the current environment to the i-th environment in the list envs."""
    self.i_env = i
    self.env = self.envs[self.i_env]
    self.schedule.i = i

TrialHistoryV2

TrialHistoryV2(env, probs=None)

Bases: TrialWrapper

Change ground truth probability based on previous outcome.

Parameters:

Name Type Description Default
probs

matrix of probabilities of the current choice conditioned on the previous. Shape, num-choices x num-choices

None
Source code in neurogym/wrappers/block.py
def __init__(self, env, probs=None) -> None:
    super().__init__(env)
    try:
        self.n_ch = len(self.choices)  # max num of choices
    except AttributeError as e:
        msg = "TrialHistory requires task to have attribute choices."
        raise AttributeError(msg) from e
    if probs is None:
        probs = np.ones((self.n_ch, self.n_ch)) / self.n_ch  # uniform
    self.probs = probs
    if self.probs.shape != (self.n_ch, self.n_ch):
        msg = f"{self.probs.shape=} should be {self.n_ch, self.n_ch=}."
        raise ValueError(msg)
    self.prev_trial = self.rng.choice(self.n_ch)  # random initialization

monitor

Monitor

Monitor(env: TrialEnv, config: Config | str | Path | None = None, name: str | None = None, trigger: str = 'trial', interval: int = 1000, plot_create: bool = False, plot_steps: int = 1000, ext: str = 'png', step_fn: Callable | None = None, verbose: bool = True, level: str = 'INFO', log_trigger: str = 'trial', log_interval: int = 1000)

Bases: Wrapper

Monitor class to log, visualize, and evaluate NeuroGym environment behavior.

Wraps a NeuroGym TrialEnv to track actions, rewards, and performance metrics, save them to disk, and optionally generate trial visualizations. Supports logging at trial or step level, with configurable frequency and verbosity.

Parameters:

Name Type Description Default
env TrialEnv

The NeuroGym environment to wrap.

required
config Config | str | Path | None

Optional configuration source (Config object, TOML file path, or dictionary).

None
name str | None

Optional monitor name; defaults to the environment class name.

None
trigger str

When to save data ("trial" or "step").

'trial'
interval int

How often to save data, in number of trials or steps.

1000
plot_create bool

Whether to generate and save visualizations of environment behavior.

False
plot_steps int

Number of steps to visualize in each plot.

1000
ext str

Image file extension for saved plots (e.g., "png").

'png'
step_fn Callable | None

Optional custom step function to override the environment's.

None
verbose bool

Whether to print information when logging or saving data.

True
level str

Logging verbosity level (e.g., "INFO", "DEBUG").

'INFO'
log_trigger str

When to log progress ("trial" or "step").

'trial'
log_interval int

How often to log, in trials or steps.

1000

Attributes:

Name Type Description
config Config

Final validated configuration object.

data dict[str, list]

Collected behavioral data for each completed trial.

cum_reward

Cumulative reward for the current trial.

num_tr

Number of completed trials.

t

Step counter (used when trigger is "step").

save_dir

Directory where data and plots are saved.

Source code in neurogym/wrappers/monitor.py
def __init__(
    self,
    env: ngym.TrialEnv,
    config: ngym.Config | str | Path | None = None,
    name: str | None = None,
    trigger: str = "trial",
    interval: int = 1000,
    plot_create: bool = False,
    plot_steps: int = 1000,
    ext: str = "png",
    step_fn: Callable | None = None,
    verbose: bool = True,
    level: str = "INFO",
    log_trigger: str = "trial",
    log_interval: int = 1000,
) -> None:
    super().__init__(env)
    self.env = env
    self.step_fn = step_fn

    log_format = "<magenta>Neurogym</magenta> | <cyan>{time:YYYY-MM-DD@HH:mm:ss}</cyan> | <level>{message}</level>"

    cfg: ngym.Config
    if config is None:
        config_dict = {
            "env": {"name": env.unwrapped.__class__.__name__},
            "monitor": {
                "name": name or "Monitor",
                "trigger": trigger,
                "interval": interval,
                "plot": {
                    "create": plot_create,
                    "step": plot_steps,
                    "title": env.unwrapped.__class__.__name__,
                    "ext": ext,
                },
                "log": {
                    "verbose": verbose,
                    "format": log_format,
                    "level": level,
                    "trigger": log_trigger,
                    "interval": log_interval,
                },
            },
            "local_dir": LOCAL_DIR,
        }
        cfg = ngym.Config.model_validate(config_dict)
    elif isinstance(config, (str, Path)):
        cfg = ngym.Config(config_file=config)
    else:
        cfg = config  # type: ignore[arg-type]

    self.config: ngym.Config = cfg

    # Assign names for the environment and/or the monitor if they are empty
    if len(self.config.env.name) == 0:
        self.config.env.name = self.env.unwrapped.__class__.__name__
    if len(self.config.monitor.name) == 0:
        self.config.monitor.name = self.__class__.__name__

    self._configure_logger()

    # data to save
    self.data: dict[str, list] = {"action": [], "reward": [], "cum_reward": [], "performance": []}
    self.cum_reward = 0.0
    if self.config.monitor.trigger == "step":
        self.t = 0
    self.num_tr = 0

    # Directory for saving plots
    save_dir_name = f"{self.config.env.name}/{ngym.utils.iso_timestamp()}"
    self.save_dir = ngym.utils.ensure_dir(self.config.local_dir / save_dir_name)

    # Figures
    if self.config.monitor.plot.create:
        self.stp_counter = 0
        self.ob_mat: list = []
        self.act_mat: list = []
        self.rew_mat: list = []
        self.gt_mat: list = []
        self.perf_mat: list = []

reset

reset(seed=None)

Reset the environment.

Parameters:

Name Type Description Default
seed

Random seed for the environment

None

Returns:

Type Description

The initial observation from the environment reset

Source code in neurogym/wrappers/monitor.py
def reset(self, seed=None):
    """Reset the environment.

    Args:
        seed: Random seed for the environment

    Returns:
        The initial observation from the environment reset
    """
    self.cum_reward = 0
    return super().reset(seed=seed)

step

step(action: Any, collect_data: bool = True) -> tuple[Any, float, bool, bool, dict[str, Any]]

Execute one environment step.

This method: 1. Takes a step in the environment 2. Collects data if sv_fig is enabled 3. Saves data when a trial completes and saving conditions are met

Parameters:

Name Type Description Default
action Any

The action to take in the environment

required
collect_data bool

If True, collect and save data

True

Returns:

Type Description
tuple[Any, float, bool, bool, dict[str, Any]]

Tuple of (observation, reward, terminated, truncated, info)

Source code in neurogym/wrappers/monitor.py
def step(self, action: Any, collect_data: bool = True) -> tuple[Any, float, bool, bool, dict[str, Any]]:
    """Execute one environment step.

    This method:
    1. Takes a step in the environment
    2. Collects data if sv_fig is enabled
    3. Saves data when a trial completes and saving conditions are met

    Args:
        action: The action to take in the environment
        collect_data: If True, collect and save data

    Returns:
        Tuple of (observation, reward, terminated, truncated, info)
    """
    if self.step_fn is not None:
        obs, rew, terminated, truncated, info = self.step_fn(action)
    else:
        obs, rew, terminated, truncated, info = self.env.step(action)
    self.cum_reward += rew
    if self.config.monitor.plot.create:
        self.store_data(obs, action, rew, info)
    if self.config.monitor.trigger == "step":
        self.t += 1
    if info.get("new_trial", False):
        self.num_tr += 1
        self.data["action"].append(action)
        self.data["reward"].append(rew)
        self.data["cum_reward"].append(self.cum_reward)
        self.cum_reward = 0
        for key in info:
            if key not in self.data:
                self.data[key] = [info[key]]
            else:
                self.data[key].append(info[key])

        # save data
        save = (
            self.t >= self.config.monitor.interval
            if self.config.monitor.trigger == "step"
            else self.num_tr % self.config.monitor.interval == 0
        )
        if save and collect_data:
            # Create save path with pathlib for cross-platform compatibility
            save_path = self.save_dir / f"trial_{self.num_tr}.npz"
            np.savez(save_path, **self.data)

            if self.config.monitor.log.verbose:
                print("--------------------")
                print(f"Data saved to: {save_path}")
                print(f"Number of trials: {self.num_tr}")
                print(f"Average reward: {np.mean(self.data['reward'])}")
                print(f"Average performance: {np.mean(self.data['performance'])}")
                print("--------------------")
            self.reset_data()
            if self.config.monitor.plot.create:
                self.stp_counter = 0
            if self.config.monitor.trigger == "step":
                self.t = 0
    return obs, rew, terminated, truncated, info

reset_data

reset_data() -> None

Reset all data containers to empty lists.

Source code in neurogym/wrappers/monitor.py
def reset_data(self) -> None:
    """Reset all data containers to empty lists."""
    for key in self.data:
        self.data[key] = []

store_data

store_data(obs: Any, action: Any, rew: float, info: dict[str, Any]) -> None

Store data for visualization figures.

Parameters:

Name Type Description Default
obs Any

Current observation

required
action Any

Current action

required
rew float

Current reward

required
info dict[str, Any]

Info dictionary from environment

required
Source code in neurogym/wrappers/monitor.py
def store_data(self, obs: Any, action: Any, rew: float, info: dict[str, Any]) -> None:
    """Store data for visualization figures.

    Args:
        obs: Current observation
        action: Current action
        rew: Current reward
        info: Info dictionary from environment
    """
    if self.stp_counter <= self.config.monitor.plot.step:
        self.ob_mat.append(obs)
        self.act_mat.append(action)
        self.rew_mat.append(rew)
        if "gt" in info:
            self.gt_mat.append(info["gt"])
        else:
            self.gt_mat.append(-1)
        if "performance" in info:
            self.perf_mat.append(info["performance"])
        else:
            self.perf_mat.append(-1)
        self.stp_counter += 1
    elif len(self.rew_mat) > 0:
        fname = self.save_dir / f"task_{self.num_tr:06d}.{self.config.monitor.plot.ext}"
        obs_mat = np.array(self.ob_mat)
        act_mat = np.array(self.act_mat)
        fig_(
            ob=obs_mat,
            actions=act_mat,
            gt=self.gt_mat,
            rewards=self.rew_mat,
            performance=self.perf_mat,
            fname=fname,
            name=self.config.monitor.plot.title,
        )
        self.ob_mat = []
        self.act_mat = []
        self.rew_mat = []
        self.gt_mat = []
        self.perf_mat = []

evaluate_policy

evaluate_policy(num_trials: int = 100, model: Any | None = None, verbose: bool = True) -> dict[str, float | list[float]]

Evaluates the average performance of the RL agent in the environment.

This method runs the given model (or random policy if None) on the environment for a specified number of trials and collects performance metrics.

Parameters:

Name Type Description Default
num_trials int

Number of trials to run for evaluation

100
model Any | None

The policy model to evaluate (if None, uses random actions)

None
verbose bool

If True, prints progress information

True

Returns: dict: Dictionary containing performance metrics: - mean_performance: Average performance (if reported by environment) - mean_reward: Proportion of positive rewards - performances: List of performance values for each trial - rewards: List of rewards for each trial.

Source code in neurogym/wrappers/monitor.py
def evaluate_policy(
    self,
    num_trials: int = 100,
    model: Any | None = None,
    verbose: bool = True,
) -> dict[str, float | list[float]]:
    """Evaluates the average performance of the RL agent in the environment.

    This method runs the given model (or random policy if None) on the
    environment for a specified number of trials and collects performance
    metrics.

    Args:
        num_trials: Number of trials to run for evaluation
        model: The policy model to evaluate (if None, uses random actions)
        verbose: If True, prints progress information
    Returns:
        dict: Dictionary containing performance metrics:
            - mean_performance: Average performance (if reported by environment)
            - mean_reward: Proportion of positive rewards
            - performances: List of performance values for each trial
            - rewards: List of rewards for each trial.
    """
    # Reset environment
    obs, _ = self.env.reset()

    # Initialize hidden states
    states = None
    episode_starts = np.array([True])

    # Tracking variables
    rewards = []
    cum_reward = 0.0
    cum_rewards = []
    performances = []
    # Initialize trial count
    trial_count = 0

    # Run trials
    while trial_count < num_trials:
        if model is not None:
            action, states = model.predict(obs, state=states, episode_start=episode_starts, deterministic=True)
        else:
            action = self.env.action_space.sample()

        # Use collect_data=False to avoid saving evaluation data
        obs, reward, _, _, info = self.step(action, collect_data=False)
        # Update episode_starts after each step
        episode_starts = np.array([False])
        cum_reward += reward

        if info.get("new_trial", False):
            trial_count += 1
            rewards.append(reward)
            cum_rewards.append(cum_reward)
            cum_reward = 0.0
            if "performance" in info:
                performances.append(info["performance"])

            if verbose and trial_count % 1000 == 0:
                print(f"Completed {trial_count}/{num_trials} trials")

            # Reset states at the end of each trial
            states = None
            episode_starts = np.array([True])

    # Calculate metrics
    performance_array = np.array([p for p in performances if p != -1])
    reward_array = np.array(rewards)
    cum_reward_array = np.array(cum_rewards)

    return {
        "rewards": rewards,
        "mean_reward": float(np.mean(reward_array > 0)) if len(reward_array) > 0 else 0,
        "cum_rewards": cum_rewards,
        "mean_cum_reward": float(np.mean(cum_reward_array)) if len(cum_reward_array) > 0 else 0,
        "performances": performances,
        "mean_performance": float(np.mean(performance_array)) if len(performance_array) > 0 else 0,
    }

plot_training_history

plot_training_history(figsize: tuple[int, int] = (12, 6), save_fig: bool = True, plot_performance: bool = True) -> Figure | None

Plot rewards and performance training history from saved data files with one data point per trial.

Parameters:

Name Type Description Default
figsize tuple[int, int]

Figure size as (width, height) tuple

(12, 6)
save_fig bool

Whether to save the figure to disk

True
plot_performance bool

Whether to plot performance in a separate plot

True

Returns: matplotlib figure object

Source code in neurogym/wrappers/monitor.py
def plot_training_history(
    self,
    figsize: tuple[int, int] = (12, 6),
    save_fig: bool = True,
    plot_performance: bool = True,
) -> plt.Figure | None:
    """Plot rewards and performance training history from saved data files with one data point per trial.

    Args:
        figsize: Figure size as (width, height) tuple
        save_fig: Whether to save the figure to disk
        plot_performance: Whether to plot performance in a separate plot
    Returns:
        matplotlib figure object
    """
    files = sorted(self.save_dir.glob("*.npz"))

    if not files:
        print("No data files found matching pattern: *.npz")
        return None

    print(f"Found {len(files)} data files")

    # Arrays to hold average values
    avg_rewards_per_file = []
    avg_cum_rewards_per_file = []
    avg_performances_per_file = []
    file_indices = []
    total_trials = 0

    for file in files:
        data = np.load(file, allow_pickle=True)

        if "reward" in data:
            rewards = data["reward"]
            if len(rewards) > 0:
                avg_rewards_per_file.append(np.mean(rewards))
                total_trials += len(rewards)
                file_indices.append(total_trials)

        if "cum_reward" in data:
            cum_rewards = data["cum_reward"]
            if len(cum_rewards) > 0:
                avg_cum_rewards_per_file.append(np.mean(cum_rewards))

        if "performance" in data:
            perfs = data["performance"]
            if len(perfs) > 0:
                avg_performances_per_file.append(np.mean(perfs))

    fig, axes = plt.subplots(1, 2 if plot_performance else 1, figsize=figsize)
    if not isinstance(axes, np.ndarray):
        axes = [axes]

    # 1. Rewards and Cumulative Rewards plot
    ax1 = axes[0]

    if len(avg_rewards_per_file) == len(file_indices):
        ax1.plot(file_indices, avg_rewards_per_file, "o-", color="blue", label="Avg Reward", linewidth=2)
    if len(avg_cum_rewards_per_file) == len(file_indices):
        ax1.plot(file_indices, avg_cum_rewards_per_file, "s--", color="red", label="Avg Cum Reward", linewidth=2)

    ax1.set_xlabel("Cumulative Trials")
    ax1.set_ylabel("Reward / Cumulative Reward")
    common_ylim = (-0.05, 1.05)
    ax1.set_ylim(common_ylim)
    ax1.set_title("Reward and Cumulative Reward per File")

    overall_avg_reward = np.mean(avg_rewards_per_file)
    ax1.text(
        0.05,
        0.95,
        f"Overall Avg Reward: {overall_avg_reward:.4f}",
        transform=ax1.transAxes,
        verticalalignment="top",
        bbox={"boxstyle": "round", "facecolor": "white", "alpha": 0.8},
    )

    ax1.grid(True, which="both", axis="y", linestyle="--", alpha=0.7)
    ax1.legend(loc="lower center", bbox_to_anchor=(0.5, -0.3), ncol=2)

    # 2. Optional: Performances plot
    if plot_performance and len(axes) > 1:
        ax2 = axes[1]
        if len(avg_performances_per_file) == len(file_indices):
            ax2.plot(file_indices, avg_performances_per_file, "o-", color="green", linewidth=2)
        ax2.set_xlabel("Cumulative Trials")
        ax2.set_ylabel("Average Performance (0-1)")
        ax2.set_ylim(common_ylim)
        ax2.set_title("Average Performance per File")

        overall_avg_perf = np.mean(avg_performances_per_file)
        ax2.text(
            0.05,
            0.95,
            f"Overall Avg Perf: {overall_avg_perf:.4f}",
            transform=ax2.transAxes,
            verticalalignment="top",
            bbox={"boxstyle": "round", "facecolor": "white", "alpha": 0.8},
        )

        ax2.grid(True, which="both", axis="y", linestyle="--", alpha=0.7)
    plt.tight_layout()
    fig.subplots_adjust(top=0.8)
    plt.suptitle(
        f"Training History for {self.config.env.name}\n({len(files)} data files, {total_trials} total trials)",
        fontsize=14,
    )

    if save_fig:
        save_path = self.config.local_dir / f"{self.config.env.name}_training_history.png"
        plt.savefig(save_path, dpi=300, bbox_inches="tight")
        print(f"Figure saved to {save_path}")

    return fig

noise

Noise wrapper.

Created on Thu Feb 28 15:07:21 2019

@author: molano

Noise

Noise(env, std_noise=0.1)

Bases: Wrapper

Add Gaussian noise to the observations.

Parameters:

Name Type Description Default
std_noise

Standard deviation of noise. (def: 0.1)

0.1
perf_th

If != None, the wrapper will adjust the noise so the mean performance is not larger than perf_th. (def: None, float)

required
w

Window used to compute the mean performance. (def: 100, int)

required
step_noise

Step used to increment/decrease std. (def: 0.001, float)

required
Source code in neurogym/wrappers/noise.py
def __init__(self, env, std_noise=0.1) -> None:
    super().__init__(env)
    self.env = env
    self.std_noise = std_noise

pass_action

PassAction

PassAction(env)

Bases: Wrapper

Modifies observation by adding the previous action.

Source code in neurogym/wrappers/pass_action.py
def __init__(self, env) -> None:
    super().__init__(env)
    self.env = env
    # TODO: This is not adding one-hot
    env_oss = env.observation_space.shape[0]
    self.observation_space = spaces.Box(
        -np.inf,
        np.inf,
        shape=(env_oss + 1,),
        dtype=np.float32,
    )

pass_reward

PassReward

PassReward(env)

Bases: Wrapper

Modifies observation by adding the previous reward.

Source code in neurogym/wrappers/pass_reward.py
def __init__(self, env) -> None:
    """Modifies observation by adding the previous reward."""
    super().__init__(env)
    env_oss = env.observation_space.shape[0]
    self.observation_space = spaces.Box(
        -np.inf,
        np.inf,
        shape=(env_oss + 1,),
        dtype=np.float32,
    )

reaction_time

Noise wrapper.

Created on Thu Feb 28 15:07:21 2019

@author: molano

ReactionTime

ReactionTime(env, urgency=0.0)

Bases: Wrapper

Allow reaction time response.

Modifies a given environment by allowing the network to act at any time after the fixation period.

Source code in neurogym/wrappers/reaction_time.py
def __init__(self, env, urgency=0.0) -> None:
    super().__init__(env)
    self.env = env
    self.urgency = urgency
    self.tr_dur = 0

side_bias

SideBias

SideBias(env, probs=None, block_dur=200)

Bases: TrialWrapper

Changes the probability of ground truth.

Parameters:

Name Type Description Default
prob

Specifies probabilities for each choice. Within each block,the probability should sum up to 1. (def: None, numpy array (n_block, n_choices))

required
block_dur

Number of trials per block. (def: 200, int)

200
Source code in neurogym/wrappers/side_bias.py
def __init__(self, env, probs=None, block_dur=200) -> None:
    super().__init__(env)
    try:
        self.choices = self.task.choices
    except AttributeError as e:
        msg = "SideBias requires task to have attribute choices."
        raise AttributeError(msg) from e
    if not isinstance(self.task, ngym.TrialEnv):
        msg = "Task has to be TrialEnv."
        raise TypeError(msg)
    if probs is None:
        msg = "Please provide choices probabilities."
        raise ValueError(msg)
    if isinstance(probs, float | int):
        mat = np.eye(len(self.choices)) * probs
        mat[mat == 0] = 1 - probs
        self.choice_prob = mat
    else:
        self.choice_prob = np.array(probs)
    if self.choice_prob.shape[1] != len(self.choices):
        msg = (
            f"The number of choices {self.choice_prob.shape[1]} inferred from prob mismatches "
            f"{len(self.choices)} inferred from choices."
        )
        raise ValueError(msg)

    self.n_block = self.choice_prob.shape[0]
    self.curr_block = self.task.rng.choice(range(self.n_block))
    self.block_dur = block_dur