Skip to content

Utils

data

Utilities for data.

Dataset

Dataset(
    env: str | Env,
    env_kwargs: dict[str, Any] | None = None,
    batch_size: int = 1,
    seq_len: int | None = None,
    max_batch: int | None = None,
    batch_first: bool = False,
    cache_len: int | None = None,
)

Make an environment into an iterable dataset for supervised learning.

Create an iterator that at each call returns inputs: numpy array (sequence_length, batch_size, input_units) target: numpy array (sequence_length, batch_size, output_units)

Parameters:

Name Type Description Default
env str | Env

str for env id or gym.Env objects

required
env_kwargs dict[str, Any] | None

dict, additional kwargs for environment, if env is str

None
batch_size int

int, batch size

1
seq_len int | None

int, sequence length

None
max_batch int | None

int, maximum number of batch for iterator, default infinite

None
batch_first bool

bool, if True, return (batch, seq_len, n_units), default False

False
cache_len int | None

int, default length of caching

None
Source code in neurogym/utils/data.py
def __init__(
    self,
    env: str | gym.Env,
    env_kwargs: dict[str, Any] | None = None,
    batch_size: int = 1,
    seq_len: int | None = None,
    max_batch: int | None = None,
    batch_first: bool = False,
    cache_len: int | None = None,
) -> None:
    if max_batch is None:
        max_batch = int(1e31)
    if not isinstance(env, str | gym.Env):
        msg = f"{type(env)=} must be `gym.Env` or `str`."
        raise TypeError(msg)
    if isinstance(env, gym.Env):
        self.envs = [deepcopy(env) for _ in range(batch_size)]
    else:
        if env_kwargs is None:
            env_kwargs = {}
        self.envs = [make(env, **env_kwargs) for _ in range(batch_size)]
    for env_ in self.envs:
        env_.reset()
    self.seed()

    env = self.envs[0]
    self.env = env
    self.batch_size = batch_size
    self.batch_first = batch_first

    if seq_len is None:
        # TODO: infer sequence length from task
        seq_len = 1000

    obs_shape = env.observation_space.shape
    action_shape = env.action_space.shape
    if obs_shape is None or action_shape is None:
        msg = "The observation and action spaces must have a shape."
        raise ValueError(msg)

    self._expand_action = len(action_shape) == 0

    if cache_len is None:
        # Infer cache len
        obs_size = int(np.prod(obs_shape))
        action_size = int(np.prod(action_shape))
        cache_len = int(1e5 / (obs_size + action_size) / batch_size)
    cache_len = (1 + (cache_len // seq_len)) * seq_len

    self.seq_len = seq_len
    self._cache_len = cache_len

    if batch_first:
        shape1, shape2 = [batch_size, seq_len], [batch_size, cache_len]
    else:
        shape1, shape2 = [seq_len, batch_size], [cache_len, batch_size]

    self.inputs_shape = shape1 + list(obs_shape)
    self.target_shape = shape1 + list(action_shape)
    self._cache_inputs_shape = shape2 + list(obs_shape)
    self._cache_target_shape = shape2 + list(action_shape)

    self._inputs = np.zeros(
        self._cache_inputs_shape,
        dtype=env.observation_space.dtype,
    )
    self._target = np.zeros(self._cache_target_shape, dtype=env.action_space.dtype)

    self._cache()

    self._i_batch = 0
    self.max_batch = max_batch

info

Formatting information about envs and wrappers.

show_all_tasks

show_all_tasks(tag: str | None = None) -> None

Show all available tasks in neurogym.

Parameters:

Name Type Description Default
tag str | None

If provided, only show tasks with this tag.

None
Source code in neurogym/utils/info.py
def show_all_tasks(tag: str | None = None) -> None:
    """Show all available tasks in neurogym.

    Args:
        tag: If provided, only show tasks with this tag.
    """
    if not tag:
        logger.info("Available tasks:", color="green")
    else:
        logger.info(f"Available tasks with tag '{tag}':", color="green")

    for task in all_envs(tag=tag):
        logger.info(task)

show_all_wrappers

show_all_wrappers() -> None

Show all available wrappers in neurogym.

Source code in neurogym/utils/info.py
def show_all_wrappers() -> None:
    """Show all available wrappers in neurogym."""
    logger.info("Available wrappers:", color="green")
    for wrapper in all_wrappers():
        logger.info(wrapper)

show_all_tags

show_all_tags()

Show all available tags in neurogym.

Source code in neurogym/utils/info.py
def show_all_tags():
    """Show all available tags in neurogym."""
    logger.info("Available tags:", color="green")
    for tag in all_tags():
        logger.info(tag)

show_info

show_info(obj_: str | Env) -> None

Show information about an environment or a wrapper.

Using the built-in logger.

Parameters:

Name Type Description Default
obj_ str | Env

the environment or wrapper to show information about.

required
Source code in neurogym/utils/info.py
def show_info(obj_: str | gym.Env) -> None:
    """Show information about an environment or a wrapper.

    Using the built-in logger.

    Args:
        obj_: the environment or wrapper to show information about.
    """
    if isinstance(obj_, str):
        if obj_ in ALL_ENVS:
            _env_info(env=make(obj_))
        elif obj_ in ALL_WRAPPERS:
            _wrap_info(obj_)
        else:
            msg = f"Unknown environment or wrapper: {obj_}"
            raise ValueError(msg)

    elif isinstance(obj_, gym.Env):
        _env_info(obj_)

    else:
        msg = f"Expected a str or gym.Env, got {type(obj_)}"
        raise TypeError(msg)

plotting

Plotting functions.

plot_env

plot_env(
    env: TrialEnv | Env,
    num_steps: int = 200,
    num_trials: int | None = None,
    def_act: int | None = None,
    model=None,
    name: str | None = None,
    legend: bool = True,
    ob_traces: list | None = None,
    fig_kwargs: dict | None = None,
    fname: str | None = None,
    plot_performance: bool = True,
    plot_config: PlotConfig | None = None,
)

Plot environment with agent.

Parameters:

Name Type Description Default
env TrialEnv | Env

Already built neurogym task or its name.

required
num_steps int

Number of steps to run the task for.

200
num_trials int | None

If not None, the number of trials to run.

None
def_act int | None

If not None (and model=None), the task will be run with the specified action.

None
model

If not None, the task will be run with the actions predicted by model, which so far is assumed to be created and trained with the stable-baselines3 toolbox: (https://stable-baselines3.readthedocs.io/en/master/)

None
name str | None

Title to show on the rewards panel.

None
legend bool

Whether to show the legend for actions panel or not.

True
ob_traces list | None

If != [] observations will be plot as traces, with the labels specified by ob_traces.

None
fig_kwargs dict | None

Figure properties admitted by matplotlib.pyplot.subplots() function

None
fname str | None

If not None, save fig or movie to fname.

None
plot_performance bool

Whether to show the performance subplot (default: True).

True
plot_config PlotConfig | None

Plot configuration (experimental). If set to None, the global configuration is used.

None
Source code in neurogym/utils/plotting.py
@suppress_during_pytest(
    ValueError,
    message="This may be due to a small sample size; please increase to get reasonable results.",
)
def plot_env(
    env: TrialEnv | Env,
    num_steps: int = 200,
    num_trials: int | None = None,
    def_act: int | None = None,
    model=None,
    name: str | None = None,
    legend: bool = True,
    ob_traces: list | None = None,
    fig_kwargs: dict | None = None,
    fname: str | None = None,
    plot_performance: bool = True,
    plot_config: PlotConfig | None = None,
):
    """Plot environment with agent.

    Args:
        env: Already built neurogym task or its name.
        num_steps: Number of steps to run the task for.
        num_trials: If not None, the number of trials to run.
        def_act: If not None (and model=None), the task will be run with the
            specified action.
        model: If not None, the task will be run with the actions predicted by
            model, which so far is assumed to be created and trained with the
            stable-baselines3 toolbox:
            (https://stable-baselines3.readthedocs.io/en/master/)
        name: Title to show on the rewards panel.
        legend: Whether to show the legend for actions panel or not.
        ob_traces: If != [] observations will be plot as traces, with the labels
            specified by ob_traces.
        fig_kwargs: Figure properties admitted by matplotlib.pyplot.subplots() function
        fname: If not None, save fig or movie to fname.
        plot_performance: Whether to show the performance subplot (default: True).
        plot_config: Plot configuration (experimental). If set to None, the global configuration is used.
    """
    # We don't use monitor here because:
    # 1) env could be already prewrapped with monitor
    # 2) monitor will save data and so the function will need a folder
    if fig_kwargs is None:
        fig_kwargs = {}
    if ob_traces is None:
        ob_traces = []
    if isinstance(env, str):
        env = make(env, disable_env_checker=True)
    if name is None:
        name = type(env).__name__
    data = run_env(
        env=env,
        num_steps=num_steps,
        num_trials=num_trials,
        def_act=def_act,
        model=model,
    )
    # Find trial start steps (0-based)
    trial_starts_step_indices = np.where(np.array(data["actions_end_of_trial"]) != -1)[0] + 1
    # Shift again for plotting (since steps are 1-based)
    trial_starts_axis = trial_starts_step_indices + 1

    if plot_config is None:
        plot_config = config.plot

    return visualize_run(
        data["ob"],
        data["actions"],
        gt=data["gt"],
        rewards=data["rewards"],
        legend=legend,
        performance=data["perf"] if plot_performance else None,
        states=data["states"],
        name=name,
        ob_traces=ob_traces,
        fig_kwargs=fig_kwargs,
        env=env,
        initial_ob=data["initial_ob"],
        fname=fname,
        trial_starts=trial_starts_axis,
        plot_config=plot_config,
    )

run_env

run_env(
    env: TrialEnv | Env,
    num_steps: int = 200,
    num_trials: int | None = None,
    def_act: int | None = None,
    model=None,
) -> dict

Run the given environment with the.

Parameters:

Name Type Description Default
env TrialEnv | Env

A NeuroGym environment.

required
num_steps int

Number of steps to run the task for.

200
num_trials int | None

Number of trials to run.

None
def_act int | None

Preset action to pass to the environment.

None
model

Model (agent) learning from the environment.

None

Returns:

Type Description
dict

A dictionary containing the results of the trial.

Source code in neurogym/utils/plotting.py
def run_env(
    env: TrialEnv | Env,
    num_steps: int = 200,
    num_trials: int | None = None,
    def_act: int | None = None,
    model=None,
) -> dict:
    """Run the given environment with the.

    Args:
        env: A NeuroGym environment.
        num_steps: Number of steps to run the task for.
        num_trials: Number of trials to run.
        def_act: Preset action to pass to the environment.
        model: Model (agent) learning from the environment.

    Returns:
        A dictionary containing the results of the trial.
    """
    observations = []
    ob_cum = []
    state_mat = []
    rewards = []
    actions = []
    actions_end_of_trial = []
    gt = []
    perf = []
    if _SB3_INSTALLED and isinstance(env, DummyVecEnv):
        ob = env.reset()
    else:
        ob, _ = env.reset()

    initial_ob = ob.copy()

    ob_cum_temp = ob.copy()

    # Initialize hidden states
    states = None
    episode_starts = np.array([True])

    if num_trials is not None:
        num_steps = int(1e5)  # Overwrite num_steps value

    trial_count = 0
    for _ in range(int(num_steps)):
        if model is not None:
            if _SB3_INSTALLED and isinstance(model.policy, RecurrentActorCriticPolicy):
                action, states = model.predict(ob, state=states, episode_start=episode_starts, deterministic=True)
            else:
                action, _ = model.predict(ob, deterministic=True)
            if isinstance(action, float | int):
                action = [action]
            if (states is not None) and (len(states) > 0):
                state_mat.append(states)
        elif def_act is not None:
            action = def_act
        else:
            action = env.action_space.sample()
        if _SB3_INSTALLED and isinstance(env, DummyVecEnv):
            ob, rew, terminated, info = env.step(action)
        else:
            ob, rew, terminated, _truncated, info = env.step(action)
        # Update episode_starts after each step
        episode_starts = np.array([False])
        ob_cum_temp += ob
        ob_cum.append(ob_cum_temp.copy())
        if isinstance(info, list):
            info = info[0]
            ob_aux = ob[0]
            # TODO: Fix these and remove the ignore directives
            rew = rew[0]  # type: ignore[index]
            terminated = terminated[0]  # type: ignore[index]
            action = action[0]
        else:
            ob_aux = ob

        if terminated:
            env.reset()
        observations.append(ob_aux)
        rewards.append(rew)
        actions.append(action)
        if "gt" in info:
            gt.append(info["gt"])
        else:
            gt.append(0)

        if info["new_trial"]:
            actions_end_of_trial.append(action)
            perf.append(info["performance"])
            ob_cum_temp = np.zeros_like(ob_cum_temp)
            trial_count += 1
            # Reset states at the end of each trial
            states = None
            episode_starts = np.array([True])
            if num_trials is not None and trial_count >= num_trials:
                break
        else:
            actions_end_of_trial.append(-1)
            perf.append(-1)

    if model is not None and len(state_mat) > 0:  # noqa: SIM108
        # states = np.array(state_mat)  # noqa: ERA001
        # states = states[:, 0, :]  # noqa: ERA001
        states = None  # TODO: Fix this
    else:
        states = None

    return {
        "ob": np.array(observations).astype(float),
        "ob_cum": np.array(ob_cum).astype(float),
        "rewards": rewards,
        "actions": actions,
        "perf": perf,
        "actions_end_of_trial": actions_end_of_trial,
        "gt": gt,
        "states": states,
        "initial_ob": initial_ob,
    }

visualize_run

visualize_run(
    ob: ndarray,
    actions: ndarray,
    gt: ndarray | None = None,
    rewards: ndarray | None = None,
    performance: ndarray | None = None,
    states: ndarray | None = None,
    legend: bool = True,
    ob_traces: list | None = None,
    name: str = "",
    fname: str | None = None,
    fig_kwargs: dict | None = None,
    env: TrialEnv | Env | None = None,
    initial_ob: ndarray | None = None,
    trial_starts: list | None = None,
    plot_config: PlotConfig | None = None,
) -> None

Visualize a run in a simple environment.

Parameters:

Name Type Description Default
ob ndarray

NumPy array of observation (n_step, n_unit).

required
actions ndarray

NumPy array of action (n_step, n_unit)

required
gt ndarray | None

NumPy array of groud truth.

None
rewards ndarray | None

NumPy array of rewards.

None
performance ndarray | None

NumPy array of performance (if set to None performance plotting will be skipped).

None
states ndarray | None

NumPy array of network states

None
name str

Title to show on the rewards panel and name to save figure.

''
fname str | None

Optional name for the file where the figure should be saved.

None
legend bool

Whether to show the legend for actions panel.

True
ob_traces list | None

If a non-empty listis provided, observations will be plot as traces, with the labels specified by ob_traces

None
fig_kwargs dict | None

Figure properties admitted by matplotlib.pyplot.subplots() function

None
env TrialEnv | Env | None

Environment class for extra information

None
initial_ob ndarray | None

Initial observation to be used to align with actions

None
trial_starts list | None

List of trial start indices, 1-based

None
plot_config PlotConfig | None

Optional plot configuration.

None
Source code in neurogym/utils/plotting.py
@suppress_during_pytest(
    ValueError,
    message="This may be due to a small sample size; please increase to get reasonable results.",
)
def visualize_run(
    ob: np.ndarray,
    actions: np.ndarray,
    gt: np.ndarray | None = None,
    rewards: np.ndarray | None = None,
    performance: np.ndarray | None = None,
    states: np.ndarray | None = None,
    legend: bool = True,
    ob_traces: list | None = None,
    name: str = "",
    fname: str | None = None,
    fig_kwargs: dict | None = None,
    env: TrialEnv | Env | None = None,
    initial_ob: np.ndarray | None = None,
    trial_starts: list | None = None,
    plot_config: PlotConfig | None = None,
) -> None:
    """Visualize a run in a simple environment.

    Args:
        ob: NumPy array of observation (n_step, n_unit).
        actions: NumPy array of action (n_step, n_unit)
        gt: NumPy array of groud truth.
        rewards: NumPy array of rewards.
        performance: NumPy array of performance (if set to `None` performance plotting will be skipped).
        states: NumPy array of network states
        name: Title to show on the rewards panel and name to save figure.
        fname: Optional name for the file where the figure should be saved.
        legend: Whether to show the legend for actions panel.
        ob_traces: If a non-empty listis provided, observations will be plot as traces,
            with the labels specified by ob_traces
        fig_kwargs: Figure properties admitted by matplotlib.pyplot.subplots() function
        env: Environment class for extra information
        initial_ob: Initial observation to be used to align with actions
        trial_starts: List of trial start indices, 1-based
        plot_config: Optional plot configuration.
    """
    if fig_kwargs is None:
        fig_kwargs = {}
    ob = np.array(ob)
    actions = np.array(actions)

    if initial_ob is None:
        initial_ob = ob[0].copy()

    # Align observation with actions by inserting an initial obs from env
    ob = np.insert(ob, 0, initial_ob, axis=0)
    # Trim last obs to match actions
    ob = ob[:-1]

    if plot_config is None:
        plot_config = config.plot

    if len(ob.shape) == 2:
        return plot_env_1dbox(  # type: ignore[no-any-return]
            ob,
            actions,
            gt=gt,
            rewards=rewards,
            performance=performance,
            states=states,
            legend=legend,
            ob_traces=ob_traces,
            name=name,
            fname=fname,
            fig_kwargs=fig_kwargs,
            env=env,
            trial_starts=trial_starts,
            plot_config=plot_config,
        )
    if len(ob.shape) == 4:
        return plot_env_3dbox(ob, fname=fname, env=env)  # type: ignore[no-any-return]

    msg = f"{ob.shape=} not supported."
    raise ValueError(msg)

plot_env_1dbox

plot_env_1dbox(
    ob: ndarray,
    actions: ndarray,
    gt: ndarray | None = None,
    rewards: ndarray | None = None,
    performance: ndarray | None = None,
    states: ndarray | None = None,
    legend: bool = True,
    ob_traces: list | None = None,
    name: str = "",
    fname: str | None = None,
    fig_kwargs: dict | None = None,
    env: TrialEnv | Env | None = None,
    trial_starts: ndarray | None = None,
    plot_config: PlotConfig | None = None,
) -> Figure | None

Plot environment with 1-D Box observation space.

Parameters:

Name Type Description Default
ob ndarray

Array of observed values.

required
actions ndarray

Array of actions.

required
gt ndarray | None

Array of ground truth values.

None
rewards ndarray | None

Array of reward values.

None
performance ndarray | None

Array of performance values.

None
states ndarray | None

Array of state values.

None
legend bool

Legend toggle.

True
ob_traces list | None

List of observation traces.

None
name str

Name of the environment.

''
fname str | None

Name of the file to save the plot to.

None
fig_kwargs dict | None

Figure configuration.

None
env TrialEnv | Env | None

A NeuroGym environment.

None
trial_starts ndarray | None

List of trial start times.

None
plot_config PlotConfig | None

Plot configuration.

None

Raises:

Type Description
ValueError

Raised if the array of observed values is not 2D.

Returns:

Type Description
Figure | None

A Matplotlib figure.

Source code in neurogym/utils/plotting.py
@suppress_during_pytest(
    ValueError,
    message="This may be due to a small sample size; please increase to get reasonable results.",
)
def plot_env_1dbox(
    ob: np.ndarray,
    actions: np.ndarray,
    gt: np.ndarray | None = None,
    rewards: np.ndarray | None = None,
    performance: np.ndarray | None = None,
    states: np.ndarray | None = None,
    legend: bool = True,
    ob_traces: list | None = None,
    name: str = "",
    fname: str | None = None,
    fig_kwargs: dict | None = None,
    env: TrialEnv | Env | None = None,
    trial_starts: np.ndarray | None = None,
    plot_config: PlotConfig | None = None,
) -> plt.Figure | None:
    """Plot environment with 1-D Box observation space.

    Args:
        ob: Array of observed values.
        actions: Array of actions.
        gt: Array of ground truth values.
        rewards: Array of reward values.
        performance: Array of performance values.
        states: Array of state values.
        legend: Legend toggle.
        ob_traces: List of observation traces.
        name: Name of the environment.
        fname: Name of the file to save the plot to.
        fig_kwargs: Figure configuration.
        env: A NeuroGym environment.
        trial_starts: List of trial start times.
        plot_config: Plot configuration.

    Raises:
        ValueError: Raised if the array of observed values is not 2D.

    Returns:
        A Matplotlib figure.
    """
    if plot_config is None:
        plot_config = config.plot

    if len(ob.shape) != 2:
        msg = "ob must be 2-dimensional."
        raise ValueError(msg)
    steps = np.arange(1, ob.shape[0] + 1)

    n_row = 2  # observation and action
    for extra in [rewards, performance, states]:
        n_row += int(extra is not None)

    gt_colors = "gkmcry"
    if not fig_kwargs:
        fig_kwargs = {
            "sharex": True,
            "figsize": (12, n_row * 1.4),
        }
    f, axes = plt.subplots(n_row, 1, **fig_kwargs)
    i_ax = 0

    legend_props = {
        "loc": "upper right",
        "bbox_to_anchor": (1.01, 1.0, 0.15, 0),
        "mode": "expand",
        "fontsize": 8,
    }

    xticks = np.arange(0, len(steps) + 1, 10)
    xtick_labels = [f"{x}" for x in xticks]

    # Observations
    ax = axes[i_ax]
    i_ax += 1
    _plot_observations(
        ax,
        steps,
        ob,
        ob_traces,
        env=env,
        trial_starts=trial_starts,
        xticks=xticks,
        xtick_labels=xtick_labels,
        title="Observation traces",
        name=name,
        legend=legend,
        legend_props=legend_props,
        plot_config=plot_config,
    )

    # Plot actions
    ax = axes[i_ax]
    i_ax += 1
    _plot_actions(
        ax,
        steps,
        actions,
        gt,
        env=env,
        trial_starts=trial_starts,
        xticks=xticks,
        xtick_labels=xtick_labels,
        gt_colors=gt_colors,
        legend=legend,
        legend_props=legend_props,
        plot_config=plot_config,
    )

    if rewards is not None:
        ax = axes[i_ax]
        i_ax += 1
        _plot_rewards(
            ax,
            steps,
            rewards,
            env=env,
            trial_starts=trial_starts,
            xticks=xticks,
            xtick_labels=xtick_labels,
            legend=legend,
            legend_props=legend_props,
            plot_config=plot_config,
        )

    # Plot performance if provided
    if performance is not None:
        ax = axes[i_ax]
        i_ax += 1
        _plot_performance(
            ax,
            steps,
            performance,
            trial_starts=trial_starts,
            xticks=xticks,
            xtick_labels=xtick_labels,
            legend=legend,
            legend_props=legend_props,
            plot_config=plot_config,
        )

    # Plot states if provided
    if states is not None:
        ax = axes[i_ax]
        i_ax += 1
        _plot_states(
            ax,
            states,
            performance,
            rewards,
            xticks=xticks,
            xtick_labels=xtick_labels,
            plot_config=plot_config,
        )

    ax.set_xlabel("Steps", fontproperties=plot_config.font.label)
    f.align_ylabels(axes)
    f.tight_layout()

    if fname:
        fname = str(fname)
        if not (fname.endswith((".png", ".svg"))):
            fname += ".png"
        f.savefig(fname, dpi=300)
        plt.close(f)

    return f  # type: ignore[no-any-return]

plot_env_3dbox

plot_env_3dbox(
    ob: ndarray,
    fname: str = "",
    env: TrialEnv | Env | None = None,
) -> None

Plot environment with 3-D Box observation space.

Parameters:

Name Type Description Default
ob ndarray

Array of observation values.

required
fname str

File name to save the figure to.

''
env TrialEnv | Env | None

A NeuroGym environment.

None
Source code in neurogym/utils/plotting.py
@suppress_during_pytest(
    ValueError,
    message="This may be due to a small sample size; please increase to get reasonable results.",
)
def plot_env_3dbox(
    ob: np.ndarray,
    fname: str = "",
    env: TrialEnv | Env | None = None,
) -> None:
    """Plot environment with 3-D Box observation space.

    Args:
        ob: Array of observation values.
        fname: File name to save the figure to.
        env: A NeuroGym environment.
    """
    ob = ob.astype(np.uint8)  # TODO: Temporary
    fig = plt.figure()
    ax = fig.add_axes((0.1, 0.1, 0.8, 0.8))
    ax.axis("off")
    im = ax.imshow(ob[0], animated=True)

    def animate(i, *args, **kwargs):
        im.set_array(ob[i])
        return (im,)

    interval = env.dt if isinstance(env, TrialEnv) else 50
    ani = animation.FuncAnimation(fig, animate, frames=ob.shape[0], interval=interval)
    if fname:
        writer = animation.writers["ffmpeg"](fps=int(1000 / interval))
        fname = str(fname)
        if not fname.endswith(".mp4"):
            fname += ".mp4"
        ani.save(fname, writer=writer, dpi=300)

ngym_random

TruncExp

TruncExp(vmean, vmin=0, vmax=inf)
Source code in neurogym/utils/ngym_random.py
def __init__(self, vmean, vmin=0, vmax=np.inf) -> None:
    self.vmean = vmean
    self.vmin = vmin
    self.vmax = vmax
    self.rng = np.random.RandomState()

seed

seed(seed=None) -> None

Seed the PRNG of this space.

Source code in neurogym/utils/ngym_random.py
def seed(self, seed=None) -> None:
    """Seed the PRNG of this space."""
    self.rng = np.random.RandomState(seed)

trunc_exp

trunc_exp(rng, vmean, vmin=0, vmax=inf)

Function for generating period durations.

Source code in neurogym/utils/ngym_random.py
def trunc_exp(rng, vmean, vmin=0, vmax=np.inf):
    """Function for generating period durations."""
    if vmin >= vmax:  # the > is to avoid issues when making vmin as big as dt
        return vmax
    while True:
        x = rng.exponential(vmean)
        if vmin <= x < vmax:
            return x

random_number_fn

random_number_fn(dist, args, rng)

Return a random number generating function from a distribution.

Source code in neurogym/utils/ngym_random.py
def random_number_fn(dist, args, rng):
    """Return a random number generating function from a distribution."""
    if dist == "uniform":
        return lambda: rng.uniform(*args)
    if dist == "choice":
        return lambda: rng.choice(args)
    if dist == "truncated_exponential":
        return lambda: trunc_exp(rng, *args)
    if dist == "constant":
        return lambda: args
    msg = f"Unknown distribution: {dist}."
    raise ValueError(msg)

random_number_name

random_number_name(dist, args)

Return a string explaining the dist and args.

Source code in neurogym/utils/ngym_random.py
def random_number_name(dist, args):
    """Return a string explaining the dist and args."""
    if dist == "uniform":
        return f"{dist} between {args[0]} and {args[1]}"
    if dist == "choice":
        return f"{dist} within {args}"
    if dist == "truncated_exponential":
        string = f"truncated exponential with mean {args[0]}"
        if len(args) > 1:
            string += f", min {args[1]}"
        if len(args) > 2:
            string += f", max {args[2]}"
        return string
    if dist == "constant":
        return f"dist{args}"
    msg = f"Unknown distribution: {dist}."
    raise ValueError(msg)

scheduler

Trial scheduler class.

BaseSchedule

BaseSchedule(n)

Base schedule.

Parameters:

Name Type Description Default
n

int, number of conditions to schedule

required
Source code in neurogym/utils/scheduler.py
def __init__(self, n) -> None:
    self.n = n
    self.total_count = 0  # total count
    self.count = 0  # count within a condition
    self.i = 0  # initialize at 0
    self.rng = np.random.RandomState()

SequentialSchedule

SequentialSchedule(n)

Bases: BaseSchedule

Sequential schedules.

Source code in neurogym/utils/scheduler.py
def __init__(self, n) -> None:
    super().__init__(n)

RandomSchedule

RandomSchedule(n)

Bases: BaseSchedule

Random schedules.

Source code in neurogym/utils/scheduler.py
def __init__(self, n) -> None:
    super().__init__(n)

SequentialBlockSchedule

SequentialBlockSchedule(n, block_lens)

Bases: BaseSchedule

Sequential block schedules.

Source code in neurogym/utils/scheduler.py
def __init__(self, n, block_lens) -> None:
    super().__init__(n)
    self.block_lens = block_lens
    if len(block_lens) != n:
        msg = f"{len(block_lens)=} must be equal to {n=}."
        raise ValueError(msg)

RandomBlockSchedule

RandomBlockSchedule(n, block_lens)

Bases: BaseSchedule

Random block schedules.

Source code in neurogym/utils/scheduler.py
def __init__(self, n, block_lens) -> None:
    super().__init__(n)
    self.block_lens = block_lens
    if len(block_lens) != n:
        msg = f"{len(block_lens)=} must be equal to {n=}."
        raise ValueError(msg)

spaces

Box

Box(low, high, name=None, **kwargs)

Bases: Box

Thin wrapper of gymnasium.spaces.Box.

Allow the user to give names to each dimension of the Box.

Parameters:

Name Type Description Default
low, high, kwargs

see gymnasium.spaces.Box

required
name

dict describing the name of different dimensions

None
Example usage

observation_space = Box(low=0, high=1, name={'fixation': 0, 'stimulus': [1, 2]})

Source code in neurogym/utils/spaces.py
def __init__(self, low, high, name=None, **kwargs) -> None:
    super().__init__(low, high, **kwargs)
    if isinstance(name, dict):
        self.name = name
    elif name is not None:
        msg = f"{type(name)=} must be `dict` or `NoneType`."
        raise TypeError(msg)

Discrete

Discrete(n, name=None, **kwargs)

Bases: Discrete

Thin wrapper of gymnasium.spaces.Discrete.

Allow the user to give names to each dimension of the Discrete space.

Parameters:

Name Type Description Default
low, high, kwargs

see gymnasium.spaces.Box

required
name

dict describing the name of different dimensions

None
Example usage

observation_space = Discrete(n=3, name={'fixation': 0, 'stimulus': [1, 2]})

Source code in neurogym/utils/spaces.py
def __init__(self, n, name=None, **kwargs) -> None:
    super().__init__(n)
    if isinstance(name, dict):
        self.name = name
    elif name is not None:
        msg = f"{type(name)=} must be `dict` or `NoneType`."
        raise TypeError(msg)

tasktools

to_map

to_map(*args)

Produces ordered dict from given inputs.

Source code in neurogym/utils/tasktools.py
def to_map(*args):
    """Produces ordered dict from given inputs."""
    var_list = args[0] if isinstance(args[0], list) else args
    od = OrderedDict()
    for i, v in enumerate(var_list):
        od[v] = i

    return od

get_idx

get_idx(t, start_end)

Auxiliary function for defining task periods.

Source code in neurogym/utils/tasktools.py
def get_idx(t, start_end):
    """Auxiliary function for defining task periods."""
    start, end = start_end
    return list(np.where((start <= t) & (t < end))[0])

get_periods_idx

get_periods_idx(dt, periods)

Function for defining task periods.

Source code in neurogym/utils/tasktools.py
def get_periods_idx(dt, periods):
    """Function for defining task periods."""
    t = np.linspace(0, periods["tmax"], int(periods["tmax"] / dt) + 1)

    return t, {k: get_idx(t, v) for k, v in periods.items() if k != "tmax"}

minmax_number

minmax_number(dist, args)

Given input to the random_number_fn function, return min and max.

Source code in neurogym/utils/tasktools.py
def minmax_number(dist, args):
    """Given input to the random_number_fn function, return min and max."""
    if dist == "uniform":
        return args[0], args[1]
    if dist == "choice":
        return np.min(args), np.max(args)
    if dist == "truncated_exponential":
        return args[1], args[2]
    if dist == "constant":
        return args, args
    msg = f"Unknown distribution: {dist}."
    raise ValueError(msg)

circular_dist

circular_dist(original_dist)

Get the distance in periodic boundary conditions.

Source code in neurogym/utils/tasktools.py
def circular_dist(original_dist):
    """Get the distance in periodic boundary conditions."""
    return np.minimum(abs(original_dist), 2 * np.pi - abs(original_dist))

correct_2AFC

correct_2AFC(perf)

Computes performance.

Source code in neurogym/utils/tasktools.py
def correct_2AFC(perf):  # noqa: N802
    """Computes performance."""
    p_decision = perf.n_decision / perf.n_trials
    p_correct = divide(perf.n_correct, perf.n_decision)

    return p_decision, p_correct