Skip to content

Utils

data

Utilities for data.

Dataset

Dataset(env, env_kwargs=None, batch_size=1, seq_len=None, max_batch=inf, batch_first=False, cache_len=None)

Make an environment into an iterable dataset for supervised learning.

Create an iterator that at each call returns inputs: numpy array (sequence_length, batch_size, input_units) target: numpy array (sequence_length, batch_size, output_units)

Parameters:

Name Type Description Default
env

str for env id or gym.Env objects

required
env_kwargs

dict, additional kwargs for environment, if env is str

None
batch_size

int, batch size

1
seq_len

int, sequence length

None
max_batch

int, maximum number of batch for iterator, default infinite

inf
batch_first

bool, if True, return (batch, seq_len, n_units), default False

False
cache_len

int, default length of caching

None
Source code in neurogym/utils/data.py
def __init__(
    self,
    env,
    env_kwargs=None,
    batch_size=1,
    seq_len=None,
    max_batch=np.inf,
    batch_first=False,
    cache_len=None,
) -> None:
    if not isinstance(env, str | gym.Env):
        msg = f"{type(env)=} must be `gym.Env` or `str`."
        raise TypeError(msg)
    if isinstance(env, gym.Env):
        self.envs = [copy.deepcopy(env) for _ in range(batch_size)]
    else:
        if env_kwargs is None:
            env_kwargs = {}
        self.envs = [gym.make(env, **env_kwargs) for _ in range(batch_size)]
    for env_ in self.envs:
        env_.reset()
    self.seed()

    env = self.envs[0]
    self.env = env
    self.batch_size = batch_size
    self.batch_first = batch_first

    if seq_len is None:
        # TODO: infer sequence length from task
        seq_len = 1000

    obs_shape = env.observation_space.shape
    action_shape = env.action_space.shape
    if obs_shape is None or action_shape is None:
        msg = "The observation and action spaces must have a shape."
        raise ValueError(msg)

    self._expand_action = len(action_shape) == 0

    if cache_len is None:
        # Infer cache len
        cache_len = 1e5  # Probably too low
        cache_len /= np.prod(obs_shape) + np.prod(action_shape)
        cache_len /= batch_size
    cache_len = int((1 + (cache_len // seq_len)) * seq_len)

    self.seq_len = seq_len
    self._cache_len = cache_len

    if batch_first:
        shape1, shape2 = [batch_size, seq_len], [batch_size, cache_len]
    else:
        shape1, shape2 = [seq_len, batch_size], [cache_len, batch_size]

    self.inputs_shape = shape1 + list(obs_shape)
    self.target_shape = shape1 + list(action_shape)
    self._cache_inputs_shape = shape2 + list(obs_shape)
    self._cache_target_shape = shape2 + list(action_shape)

    self._inputs = np.zeros(
        self._cache_inputs_shape,
        dtype=env.observation_space.dtype,
    )
    self._target = np.zeros(self._cache_target_shape, dtype=env.action_space.dtype)

    self._cache()

    self._i_batch = 0
    self.max_batch = max_batch

info

Formatting information about envs and wrappers.

info

info(env=None, show_code=False)

Script to get envs info.

Source code in neurogym/utils/info.py
def info(env=None, show_code=False):
    """Script to get envs info."""
    string = ""
    env_name = env
    env = ngym.make(env)
    # remove extra wrappers (make can add a OrderEnforcer wrapper)
    env = env.unwrapped
    string = env_string(env)
    # show source code
    if show_code:
        string += """\n#### Source code #### \n\n"""
        env_ref = ALL_ENVS[env_name]
        from_, class_ = env_ref.split(":")
        imported = getattr(__import__(from_, fromlist=[class_]), class_)
        lines = inspect.getsource(imported)
        string += lines + "\n\n"
    return string

info_wrapper

info_wrapper(wrapper=None, show_code=False)

Script to get wrappers info.

Source code in neurogym/utils/info.py
def info_wrapper(wrapper=None, show_code=False):
    """Script to get wrappers info."""
    string = ""

    wrapp_ref = ALL_WRAPPERS[wrapper]
    from_, class_ = wrapp_ref.split(":")
    imported = getattr(__import__(from_, fromlist=[class_]), class_)
    metadata = imported.metadata

    if not isinstance(metadata, dict):
        metadata = {}

    string += f"### {wrapper}\n\n"
    paper_name = metadata.get("paper_name", None)
    paper_link = metadata.get("paper_link", None)
    wrapper_description = metadata.get("description", None) or "Missing description"
    string += f"Logic: {wrapper_description}\n\n"
    if paper_name is not None:
        string += "Reference paper \n\n"
        if paper_link is None:
            string += f"{paper_name}\n\n"
        else:
            string += f"[{paper_name}]({paper_link})\n\n"
    # add extra info
    other_info = list(set(metadata.keys()) - set(METADATA_DEF_KEYS))
    if len(other_info) > 0:
        string += "Input parameters: \n\n"
        for key in other_info:
            string += f"{key} : {metadata[key]}\n\n"

    # show source code
    if show_code:
        string += """\n#### Source code #### \n\n"""
        lines = inspect.getsource(imported)
        string += lines + "\n\n"

    return string

all_tags

all_tags(verbose=0)

Script to get all tags.

Source code in neurogym/utils/info.py
def all_tags(verbose=0):
    """Script to get all tags."""
    envs = all_envs()
    tags = []
    for env_name in sorted(envs):
        try:
            env = ngym.make(env_name)
            metadata = env.metadata
            tags += metadata.get("tags", [])
        except BaseException as e:  # noqa: BLE001, PERF203 # FIXME: unclear which error is expected here.
            print("Failure in ", env_name)
            print(e)
    tags = set(tags)
    if verbose:
        print("\nTAGS:\n")
        for tag in tags:
            print(tag)
    return tags

plotting

Plotting functions.

plot_env

plot_env(env, num_steps=200, num_trials=None, def_act=None, model=None, name=None, legend=True, ob_traces=None, fig_kwargs=None, fname=None, plot_performance=True)

Plot environment with agent.

Parameters:

Name Type Description Default
env

already built neurogym task or name of it

required
num_steps

number of steps to run the task

200
num_trials

if not None, the number of trials to run

None
def_act

if not None (and model=None), the task will be run with the specified action

None
model

if not None, the task will be run with the actions predicted by model, which so far is assumed to be created and trained with the stable-baselines3 toolbox: (https://stable-baselines3.readthedocs.io/en/master/)

None
name

title to show on the rewards panel

None
legend

whether to show the legend for actions panel or not

True
ob_traces

if != [] observations will be plot as traces, with the labels specified by ob_traces

None
fig_kwargs

figure properties admitted by matplotlib.pyplot.subplots() function

None
fname

if not None, save fig or movie to fname

None
plot_performance

whether to show the performance subplot (default: True)

True
Source code in neurogym/utils/plotting.py
def plot_env(
    env,
    num_steps=200,
    num_trials=None,
    def_act=None,
    model=None,
    name=None,
    legend=True,
    ob_traces=None,
    fig_kwargs=None,
    fname=None,
    plot_performance=True,
):
    """Plot environment with agent.

    Args:
        env: already built neurogym task or name of it
        num_steps: number of steps to run the task
        num_trials: if not None, the number of trials to run
        def_act: if not None (and model=None), the task will be run with the
                 specified action
        model: if not None, the task will be run with the actions predicted by
               model, which so far is assumed to be created and trained with the
               stable-baselines3 toolbox:
                   (https://stable-baselines3.readthedocs.io/en/master/)
        name: title to show on the rewards panel
        legend: whether to show the legend for actions panel or not
        ob_traces: if != [] observations will be plot as traces, with the labels
                    specified by ob_traces
        fig_kwargs: figure properties admitted by matplotlib.pyplot.subplots() function
        fname: if not None, save fig or movie to fname
        plot_performance: whether to show the performance subplot (default: True)
    """
    # We don't use monitor here because:
    # 1) env could be already prewrapped with monitor
    # 2) monitor will save data and so the function will need a folder
    if fig_kwargs is None:
        fig_kwargs = {}
    if ob_traces is None:
        ob_traces = []
    if isinstance(env, str):
        env = gym.make(env)
    if name is None:
        name = type(env).__name__
    data = run_env(
        env=env,
        num_steps=num_steps,
        num_trials=num_trials,
        def_act=def_act,
        model=model,
    )
    # Find trial start steps (0-based)
    trial_starts_step_indices = np.where(np.array(data["actions_end_of_trial"]) != -1)[0] + 1
    # Shift again for plotting (since steps are 1-based)
    trial_starts_axis = trial_starts_step_indices + 1

    return fig_(
        data["ob"],
        data["actions"],
        gt=data["gt"],
        rewards=data["rewards"],
        legend=legend,
        performance=data["perf"] if plot_performance else None,
        states=data["states"],
        name=name,
        ob_traces=ob_traces,
        fig_kwargs=fig_kwargs,
        env=env,
        fname=fname,
        trial_starts=trial_starts_axis,
    )

fig_

fig_(ob, actions, gt=None, rewards=None, performance=None, states=None, legend=True, ob_traces=None, name='', fname=None, fig_kwargs=None, env=None, trial_starts=None)

Visualize a run in a simple environment.

Parameters:

Name Type Description Default
ob

np array of observation (n_step, n_unit)

required
actions

np array of action (n_step, n_unit)

required
gt

np array of groud truth

None
rewards

np array of rewards

None
performance

np array of performance (if set to None performance plotting will be skipped)

None
states

np array of network states

None
name

title to show on the rewards panel and name to save figure

''
fname

if != '', where to save the figure

None
legend

whether to show the legend for actions panel or not

True
ob_traces

None or list. If list, observations will be plot as traces, with the labels specified by ob_traces

None
fig_kwargs

figure properties admitted by matplotlib.pyplot.subplots() function

None
env

environment class for extra information

None
trial_starts

list of trial start indices, 1-based

None
Source code in neurogym/utils/plotting.py
def fig_(
    ob,
    actions,
    gt=None,
    rewards=None,
    performance=None,
    states=None,
    legend=True,
    ob_traces=None,
    name="",
    fname=None,
    fig_kwargs=None,
    env=None,
    trial_starts=None,
):
    """Visualize a run in a simple environment.

    Args:
        ob: np array of observation (n_step, n_unit)
        actions: np array of action (n_step, n_unit)
        gt: np array of groud truth
        rewards: np array of rewards
        performance: np array of performance (if set to `None` performance plotting will be skipped)
        states: np array of network states
        name: title to show on the rewards panel and name to save figure
        fname: if != '', where to save the figure
        legend: whether to show the legend for actions panel or not
        ob_traces: None or list.
            If list, observations will be plot as traces, with the labels
            specified by ob_traces
        fig_kwargs: figure properties admitted by matplotlib.pyplot.subplots() function
        env: environment class for extra information
        trial_starts: list of trial start indices, 1-based
    """
    if fig_kwargs is None:
        fig_kwargs = {}
    ob = np.array(ob)
    actions = np.array(actions)

    if len(ob.shape) == 2:
        return plot_env_1dbox(
            ob,
            actions,
            gt=gt,
            rewards=rewards,
            performance=performance,
            states=states,
            legend=legend,
            ob_traces=ob_traces,
            name=name,
            fname=fname,
            fig_kwargs=fig_kwargs,
            env=env,
            trial_starts=trial_starts,
        )
    if len(ob.shape) == 4:
        return plot_env_3dbox(ob, fname=fname, env=env)

    msg = f"{ob.shape=} not supported."
    raise ValueError(msg)

plot_env_1dbox

plot_env_1dbox(ob, actions, gt=None, rewards=None, performance=None, states=None, legend=True, ob_traces=None, name='', fname=None, fig_kwargs=None, env=None, trial_starts=None)

Plot environment with 1-D Box observation space.

Source code in neurogym/utils/plotting.py
def plot_env_1dbox(
    ob,
    actions,
    gt=None,
    rewards=None,
    performance=None,
    states=None,
    legend=True,
    ob_traces=None,
    name="",
    fname=None,
    fig_kwargs=None,
    env=None,
    trial_starts=None,
):
    """Plot environment with 1-D Box observation space."""
    if fig_kwargs is None:
        fig_kwargs = {}
    if len(ob.shape) != 2:
        msg = "ob has to be 2-dimensional."
        raise ValueError(msg)
    steps = np.arange(1, ob.shape[0] + 1)

    n_row = 2  # observation and action
    n_row += rewards is not None
    n_row += performance is not None
    n_row += states is not None

    gt_colors = "gkmcry"
    if not fig_kwargs:
        fig_kwargs = {"sharex": True, "figsize": (6, n_row * 1.2)}
    f, axes = plt.subplots(n_row, 1, **fig_kwargs)
    i_ax = 0

    # Plot observation
    ax = axes[i_ax]
    i_ax += 1
    if ob_traces:
        if len(ob_traces) != ob.shape[1]:
            msg = f"Please provide label for each of the {ob.shape[1]} traces in the observations."
            raise ValueError(msg)

        # Plot all traces first
        for ind_tr, tr in enumerate(ob_traces):
            ax.plot(ob[:, ind_tr], label=tr)

        ax.spines["top"].set_visible(False)
        ax.spines["right"].set_visible(False)

        # Compute ticks and labels
        yticks = []
        yticklabels = []

        # Find fixation index (if exists)
        fix_idx = next((i for i, tr in enumerate(ob_traces) if "fix" in tr.lower()), None)

        if fix_idx is not None:
            yticks.append(np.mean(ob[:, fix_idx]))
            yticklabels.append("Fix. Cue")

            # All other indices are stimuli
            stim_means = [np.mean(ob[:, i]) for i in range(len(ob_traces)) if i != fix_idx]
            if stim_means:
                yticks.append(np.mean(stim_means))
                yticklabels.append("Stimuli")
        else:
            # No fixation, all are stimuli
            yticks.append(np.mean([np.mean(ob[:, i]) for i in range(len(ob_traces))]))
            yticklabels.append("Stimuli")

        if legend:
            ax.legend(loc="upper right")
        if trial_starts is not None:
            for t_start in trial_starts:
                ax.axvline(t_start, linestyle="--", color="grey", alpha=0.7)
        ax.set_xlim([0.5, len(steps) + 1])
        ax.set_yticks(yticks)
        ax.set_yticklabels(yticklabels)
    else:
        ax.imshow(ob.T, aspect="auto", origin="lower")
        if env and hasattr(env.observation_space, "name"):
            # Plot environment annotation
            yticks = []
            yticklabels = []
            for key, val in env.observation_space.name.items():
                yticks.append((np.min(val) + np.max(val)) / 2)
                yticklabels.append(key)
            ax.set_yticks(yticks)
            ax.set_yticklabels(yticklabels)
        else:
            ax.set_yticks([])
        ax.spines["top"].set_visible(False)
        ax.spines["bottom"].set_visible(False)
        ax.spines["right"].set_visible(False)

    if name:
        ax.set_title(f"{name} env")
    ax.set_ylabel("Obs.")
    # Show step numbers on x-axis
    ax.set_xticks(np.arange(0, len(steps), 5))
    ax.set_xticklabels(np.arange(0, len(steps), 5))
    # Add gray background grid with white lines
    _set_grid_style(ax)

    # Plot actions
    ax = axes[i_ax]
    i_ax += 1
    if len(actions.shape) > 1:
        # Changes not implemented yet
        ax.plot(steps, actions, marker="+", label="Actions")
    else:
        ax.plot(steps, actions, marker="+", label="Actions")
    if gt is not None:
        gt = np.array(gt)
        if len(gt.shape) > 1:
            for ind_gt in range(gt.shape[1]):
                ax.plot(
                    steps,
                    gt[:, ind_gt],
                    f"--{gt_colors[ind_gt]}",
                    label=f"Ground truth {ind_gt}",
                )
        else:
            ax.plot(steps, gt, f"--{gt_colors[0]}", label="Ground truth")
    if trial_starts is not None:
        for t_start in trial_starts:
            ax.axvline(t_start, linestyle="--", color="grey", alpha=0.7)
    ax.set_xlim([0.5, len(steps) + 1])
    ax.set_ylabel("Act.")
    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)
    if legend:
        ax.legend(loc="upper right")
    if env and hasattr(env.action_space, "name"):
        yticks = []
        yticklabels = []
        for key, val in env.action_space.name.items():
            if isinstance(val, list | tuple | np.ndarray):
                for v in val:
                    yticks.append(v)
                    yticklabels.append(f"{key}_{v}")
            else:  # single int
                yticks.append(val)
                yticklabels.append(key)
        ax.set_yticks(yticks)
        ax.set_yticklabels(yticklabels)
    # Show step numbers on x-axis
    ax.set_xticks(np.arange(0, len(steps), 5))
    ax.set_xticklabels(np.arange(0, len(steps), 5))
    # Add gray background grid with white lines
    _set_grid_style(ax)

    # Plot rewards if provided
    if rewards is not None:
        ax = axes[i_ax]
        i_ax += 1
        ax.plot(steps, rewards, "r", label="Rewards")
        ax.set_ylabel("Rew.")
        ax.spines["top"].set_visible(False)
        ax.spines["right"].set_visible(False)
        if legend:
            ax.legend(loc="upper right")
        if trial_starts is not None:
            for t_start in trial_starts:
                ax.axvline(t_start, linestyle="--", color="grey", alpha=0.7)
        ax.set_xlim([0.5, len(steps) + 1])

        if env and hasattr(env, "rewards") and env.rewards is not None:
            yticks = []
            yticklabels = []

            if isinstance(env.rewards, dict):
                for key, val in env.rewards.items():
                    yticks.append(val)
                    yticklabels.append(f"{key[:5].title()} {val:0.2f}")
            else:
                for val in env.rewards:
                    yticks.append(val)
                    yticklabels.append(f"{val:0.2f}")

            ax.set_yticks(yticks)
            ax.set_yticklabels(yticklabels)
        # Show step numbers on x-axis
        ax.set_xticks(np.arange(0, len(steps), 5))
        ax.set_xticklabels(np.arange(0, len(steps), 5))
        # Add gray background grid with white lines
        _set_grid_style(ax)

    # Plot performance if provided
    if performance is not None:
        ax = axes[i_ax]
        i_ax += 1
        ax.plot(steps, performance, "k", label="Performance")
        ax.set_ylabel("Performance")
        performance = np.array(performance)
        mean_perf = np.mean(performance[performance != -1])
        ax.set_title(f"Mean performance: {np.round(mean_perf, 2)}")
        ax.spines["top"].set_visible(False)
        ax.spines["right"].set_visible(False)
        if legend:
            ax.legend(loc="upper right")
        if trial_starts is not None:
            for t_start in trial_starts:
                ax.axvline(t_start, linestyle="--", color="grey", alpha=0.7)
        ax.set_xlim([0.5, len(steps) + 1])
        # Add gray background grid with white lines
        _set_grid_style(ax)

    # Plot states if provided
    if states is not None:
        if performance is not None or rewards is not None:
            # Show step numbers on x-axis
            ax.set_xticks(np.arange(0, len(steps), 5))
            ax.set_xticklabels(np.arange(0, len(steps), 5))
        ax = axes[i_ax]
        i_ax += 1
        plt.imshow(states[:, int(states.shape[1] / 2) :].T, aspect="auto")
        ax.set_title("Activity")
        ax.set_ylabel("Neurons")
        ax.spines["top"].set_visible(False)
        ax.spines["right"].set_visible(False)

    ax.set_xlabel("Steps")
    plt.tight_layout()
    if fname:
        fname = str(fname)
        if not (fname.endswith((".png", ".svg"))):
            fname += ".png"
        f.savefig(fname, dpi=300)
        plt.close(f)
    return f

plot_env_3dbox

plot_env_3dbox(ob, fname='', env=None) -> None

Plot environment with 3-D Box observation space.

Source code in neurogym/utils/plotting.py
def plot_env_3dbox(ob, fname="", env=None) -> None:
    """Plot environment with 3-D Box observation space."""
    ob = ob.astype(np.uint8)  # TODO: Temporary
    fig = plt.figure()
    ax = fig.add_axes((0.1, 0.1, 0.8, 0.8))
    ax.axis("off")
    im = ax.imshow(ob[0], animated=True)

    def animate(i, *args, **kwargs):
        im.set_array(ob[i])
        return (im,)

    interval = env.dt if env is not None else 50
    ani = animation.FuncAnimation(fig, animate, frames=ob.shape[0], interval=interval)
    if fname:
        writer = animation.writers["ffmpeg"](fps=int(1000 / interval))
        fname = str(fname)
        if not fname.endswith(".mp4"):
            fname += ".mp4"
        ani.save(fname, writer=writer, dpi=300)

ngym_random

TruncExp

TruncExp(vmean, vmin=0, vmax=inf)
Source code in neurogym/utils/ngym_random.py
def __init__(self, vmean, vmin=0, vmax=np.inf) -> None:
    self.vmean = vmean
    self.vmin = vmin
    self.vmax = vmax
    self.rng = np.random.RandomState()

seed

seed(seed=None) -> None

Seed the PRNG of this space.

Source code in neurogym/utils/ngym_random.py
def seed(self, seed=None) -> None:
    """Seed the PRNG of this space."""
    self.rng = np.random.RandomState(seed)

trunc_exp

trunc_exp(rng, vmean, vmin=0, vmax=inf)

Function for generating period durations.

Source code in neurogym/utils/ngym_random.py
def trunc_exp(rng, vmean, vmin=0, vmax=np.inf):
    """Function for generating period durations."""
    if vmin >= vmax:  # the > is to avoid issues when making vmin as big as dt
        return vmax
    while True:
        x = rng.exponential(vmean)
        if vmin <= x < vmax:
            return x

random_number_fn

random_number_fn(dist, args, rng)

Return a random number generating function from a distribution.

Source code in neurogym/utils/ngym_random.py
def random_number_fn(dist, args, rng):
    """Return a random number generating function from a distribution."""
    if dist == "uniform":
        return lambda: rng.uniform(*args)
    if dist == "choice":
        return lambda: rng.choice(args)
    if dist == "truncated_exponential":
        return lambda: trunc_exp(rng, *args)
    if dist == "constant":
        return lambda: args
    msg = f"Unknown distribution: {dist}."
    raise ValueError(msg)

random_number_name

random_number_name(dist, args)

Return a string explaining the dist and args.

Source code in neurogym/utils/ngym_random.py
def random_number_name(dist, args):
    """Return a string explaining the dist and args."""
    if dist == "uniform":
        return f"{dist} between {args[0]} and {args[1]}"
    if dist == "choice":
        return f"{dist} within {args}"
    if dist == "truncated_exponential":
        string = f"truncated exponential with mean {args[0]}"
        if len(args) > 1:
            string += f", min {args[1]}"
        if len(args) > 2:
            string += f", max {args[2]}"
        return string
    if dist == "constant":
        return f"dist{args}"
    msg = f"Unknown distribution: {dist}."
    raise ValueError(msg)

scheduler

Trial scheduler class.

BaseSchedule

BaseSchedule(n)

Base schedule.

Parameters:

Name Type Description Default
n

int, number of conditions to schedule

required
Source code in neurogym/utils/scheduler.py
def __init__(self, n) -> None:
    self.n = n
    self.total_count = 0  # total count
    self.count = 0  # count within a condition
    self.i = 0  # initialize at 0
    self.rng = np.random.RandomState()

SequentialSchedule

SequentialSchedule(n)

Bases: BaseSchedule

Sequential schedules.

Source code in neurogym/utils/scheduler.py
def __init__(self, n) -> None:
    super().__init__(n)

RandomSchedule

RandomSchedule(n)

Bases: BaseSchedule

Random schedules.

Source code in neurogym/utils/scheduler.py
def __init__(self, n) -> None:
    super().__init__(n)

SequentialBlockSchedule

SequentialBlockSchedule(n, block_lens)

Bases: BaseSchedule

Sequential block schedules.

Source code in neurogym/utils/scheduler.py
def __init__(self, n, block_lens) -> None:
    super().__init__(n)
    self.block_lens = block_lens
    if len(block_lens) != n:
        msg = f"{len(block_lens)=} must be equal to {n=}."
        raise ValueError(msg)

RandomBlockSchedule

RandomBlockSchedule(n, block_lens)

Bases: BaseSchedule

Random block schedules.

Source code in neurogym/utils/scheduler.py
def __init__(self, n, block_lens) -> None:
    super().__init__(n)
    self.block_lens = block_lens
    if len(block_lens) != n:
        msg = f"{len(block_lens)=} must be equal to {n=}."
        raise ValueError(msg)

spaces

Box

Box(low, high, name=None, **kwargs)

Bases: Box

Thin wrapper of gymnasium.spaces.Box.

Allow the user to give names to each dimension of the Box.

Parameters:

Name Type Description Default
low, (high, kwargs)

see gymnasium.spaces.Box

required
name

dict describing the name of different dimensions

None
Example usage

observation_space = Box(low=0, high=1, name={'fixation': 0, 'stimulus': [1, 2]})

Source code in neurogym/utils/spaces.py
def __init__(self, low, high, name=None, **kwargs) -> None:
    super().__init__(low, high, **kwargs)
    if isinstance(name, dict):
        self.name = name
    elif name is not None:
        msg = f"{type(name)=} must be `dict` or `NoneType`."
        raise TypeError(msg)

Discrete

Discrete(n, name=None, **kwargs)

Bases: Discrete

Thin wrapper of gymnasium.spaces.Discrete.

Allow the user to give names to each dimension of the Discrete space.

Parameters:

Name Type Description Default
low, (high, kwargs)

see gymnasium.spaces.Box

required
name

dict describing the name of different dimensions

None
Example usage

observation_space = Discrete(n=3, name={'fixation': 0, 'stimulus': [1, 2]})

Source code in neurogym/utils/spaces.py
def __init__(self, n, name=None, **kwargs) -> None:
    super().__init__(n)
    if isinstance(name, dict):
        self.name = name
    elif name is not None:
        msg = f"{type(name)=} must be `dict` or `NoneType`."
        raise TypeError(msg)

tasktools

to_map

to_map(*args)

Produces ordered dict from given inputs.

Source code in neurogym/utils/tasktools.py
def to_map(*args):
    """Produces ordered dict from given inputs."""
    var_list = args[0] if isinstance(args[0], list) else args
    od = OrderedDict()
    for i, v in enumerate(var_list):
        od[v] = i

    return od

get_idx

get_idx(t, start_end)

Auxiliary function for defining task periods.

Source code in neurogym/utils/tasktools.py
def get_idx(t, start_end):
    """Auxiliary function for defining task periods."""
    start, end = start_end
    return list(np.where((start <= t) & (t < end))[0])

get_periods_idx

get_periods_idx(dt, periods)

Function for defining task periods.

Source code in neurogym/utils/tasktools.py
def get_periods_idx(dt, periods):
    """Function for defining task periods."""
    t = np.linspace(0, periods["tmax"], int(periods["tmax"] / dt) + 1)

    return t, {k: get_idx(t, v) for k, v in periods.items() if k != "tmax"}

minmax_number

minmax_number(dist, args)

Given input to the random_number_fn function, return min and max.

Source code in neurogym/utils/tasktools.py
def minmax_number(dist, args):
    """Given input to the random_number_fn function, return min and max."""
    if dist == "uniform":
        return args[0], args[1]
    if dist == "choice":
        return np.min(args), np.max(args)
    if dist == "truncated_exponential":
        return args[1], args[2]
    if dist == "constant":
        return args, args
    msg = f"Unknown distribution: {dist}."
    raise ValueError(msg)

circular_dist

circular_dist(original_dist)

Get the distance in periodic boundary conditions.

Source code in neurogym/utils/tasktools.py
def circular_dist(original_dist):
    """Get the distance in periodic boundary conditions."""
    return np.minimum(abs(original_dist), 2 * np.pi - abs(original_dist))

correct_2AFC

correct_2AFC(perf)

Computes performance.

Source code in neurogym/utils/tasktools.py
def correct_2AFC(perf):  # noqa: N802
    """Computes performance."""
    p_decision = perf.n_decision / perf.n_trials
    p_correct = divide(perf.n_correct, perf.n_decision)

    return p_decision, p_correct