Skip to content

Utils

data

Utilities for data.

Dataset

Dataset(
    env,
    env_kwargs=None,
    batch_size=1,
    seq_len=None,
    max_batch=inf,
    batch_first=False,
    cache_len=None,
)

Make an environment into an iterable dataset for supervised learning.

Create an iterator that at each call returns inputs: numpy array (sequence_length, batch_size, input_units) target: numpy array (sequence_length, batch_size, output_units)

Parameters:

Name Type Description Default
env

str for env id or gym.Env objects

required
env_kwargs

dict, additional kwargs for environment, if env is str

None
batch_size

int, batch size

1
seq_len

int, sequence length

None
max_batch

int, maximum number of batch for iterator, default infinite

inf
batch_first

bool, if True, return (batch, seq_len, n_units), default False

False
cache_len

int, default length of caching

None
Source code in neurogym/utils/data.py
def __init__(
    self,
    env,
    env_kwargs=None,
    batch_size=1,
    seq_len=None,
    max_batch=np.inf,
    batch_first=False,
    cache_len=None,
) -> None:
    if not isinstance(env, str | gym.Env):
        msg = f"{type(env)=} must be `gym.Env` or `str`."
        raise TypeError(msg)
    if isinstance(env, gym.Env):
        self.envs = [copy.deepcopy(env) for _ in range(batch_size)]
    else:
        if env_kwargs is None:
            env_kwargs = {}
        self.envs = [gym.make(env, **env_kwargs) for _ in range(batch_size)]
    for env_ in self.envs:
        env_.reset()
    self.seed()

    env = self.envs[0]
    self.env = env
    self.batch_size = batch_size
    self.batch_first = batch_first

    if seq_len is None:
        # TODO: infer sequence length from task
        seq_len = 1000

    obs_shape = env.observation_space.shape
    action_shape = env.action_space.shape
    if len(action_shape) == 0:
        self._expand_action = True
    else:
        self._expand_action = False
    if cache_len is None:
        # Infer cache len
        cache_len = 1e5  # Probably too low
        cache_len /= np.prod(obs_shape) + np.prod(action_shape)
        cache_len /= batch_size
    cache_len = int((1 + (cache_len // seq_len)) * seq_len)

    self.seq_len = seq_len
    self._cache_len = cache_len

    if batch_first:
        shape1, shape2 = [batch_size, seq_len], [batch_size, cache_len]
    else:
        shape1, shape2 = [seq_len, batch_size], [cache_len, batch_size]

    self.inputs_shape = shape1 + list(obs_shape)
    self.target_shape = shape1 + list(action_shape)
    self._cache_inputs_shape = shape2 + list(obs_shape)
    self._cache_target_shape = shape2 + list(action_shape)

    self._inputs = np.zeros(
        self._cache_inputs_shape,
        dtype=env.observation_space.dtype,
    )
    self._target = np.zeros(self._cache_target_shape, dtype=env.action_space.dtype)

    self._cache()

    self._i_batch = 0
    self.max_batch = max_batch

info

Formatting information about envs and wrappers.

info

info(env=None, show_code=False)

Script to get envs info.

Source code in neurogym/utils/info.py
def info(env=None, show_code=False):
    """Script to get envs info."""
    string = ""
    env_name = env
    env = ngym.make(env)
    # remove extra wrappers (make can add a OrderEnforcer wrapper)
    env = env.unwrapped
    string = env_string(env)
    # show source code
    if show_code:
        string += """\n#### Source code #### \n\n"""
        env_ref = ALL_ENVS[env_name]
        from_, class_ = env_ref.split(":")
        imported = getattr(__import__(from_, fromlist=[class_]), class_)
        lines = inspect.getsource(imported)
        string += lines + "\n\n"
    return string

info_wrapper

info_wrapper(wrapper=None, show_code=False)

Script to get wrappers info.

Source code in neurogym/utils/info.py
def info_wrapper(wrapper=None, show_code=False):
    """Script to get wrappers info."""
    string = ""

    wrapp_ref = ALL_WRAPPERS[wrapper]
    from_, class_ = wrapp_ref.split(":")
    imported = getattr(__import__(from_, fromlist=[class_]), class_)
    metadata = imported.metadata

    if not isinstance(metadata, dict):
        metadata = {}

    string += f"### {wrapper}\n\n"
    paper_name = metadata.get("paper_name", None)
    paper_link = metadata.get("paper_link", None)
    wrapper_description = metadata.get("description", None) or "Missing description"
    string += f"Logic: {wrapper_description}\n\n"
    if paper_name is not None:
        string += "Reference paper \n\n"
        if paper_link is None:
            string += f"{paper_name}\n\n"
        else:
            string += f"[{paper_name}]({paper_link})\n\n"
    # add extra info
    other_info = list(set(metadata.keys()) - set(METADATA_DEF_KEYS))
    if len(other_info) > 0:
        string += "Input parameters: \n\n"
        for key in other_info:
            string += f"{key} : {metadata[key]}\n\n"

    # show source code
    if show_code:
        string += """\n#### Source code #### \n\n"""
        lines = inspect.getsource(imported)
        string += lines + "\n\n"

    return string

all_tags

all_tags(verbose=0)

Script to get all tags.

Source code in neurogym/utils/info.py
def all_tags(verbose=0):
    """Script to get all tags."""
    envs = all_envs()
    tags = []
    for env_name in sorted(envs):
        try:
            env = ngym.make(env_name)
            metadata = env.metadata
            tags += metadata.get("tags", [])
        except BaseException as e:  # noqa: BLE001, PERF203 # FIXME: unclear which error is expected here.
            print("Failure in ", env_name)
            print(e)
    tags = set(tags)
    if verbose:
        print("\nTAGS:\n")
        for tag in tags:
            print(tag)
    return tags

plotting

Plotting functions.

plot_env

plot_env(
    env,
    num_steps=200,
    num_trials=None,
    def_act=None,
    model=None,
    name=None,
    legend=True,
    ob_traces=None,
    fig_kwargs=None,
    fname=None,
)

Plot environment with agent.

Parameters:

Name Type Description Default
env

already built neurogym task or name of it

required
num_steps

number of steps to run the task

200
num_trials

if not None, the number of trials to run

None
def_act

if not None (and model=None), the task will be run with the specified action

None
model

if not None, the task will be run with the actions predicted by model, which so far is assumed to be created and trained with the stable-baselines3 toolbox: (https://stable-baselines3.readthedocs.io/en/master/)

None
name

title to show on the rewards panel

None
legend

whether to show the legend for actions panel or not.

True
ob_traces

if != [] observations will be plot as traces, with the labels specified by ob_traces

None
fig_kwargs

figure properties admited by matplotlib.pyplot.subplots() fun.

None
fname

if not None, save fig or movie to fname

None
Source code in neurogym/utils/plotting.py
def plot_env(
    env,
    num_steps=200,
    num_trials=None,
    def_act=None,
    model=None,
    name=None,
    legend=True,
    ob_traces=None,
    fig_kwargs=None,
    fname=None,
):
    """Plot environment with agent.

    Args:
        env: already built neurogym task or name of it
        num_steps: number of steps to run the task
        num_trials: if not None, the number of trials to run
        def_act: if not None (and model=None), the task will be run with the
                 specified action
        model: if not None, the task will be run with the actions predicted by
               model, which so far is assumed to be created and trained with the
               stable-baselines3 toolbox:
                   (https://stable-baselines3.readthedocs.io/en/master/)
        name: title to show on the rewards panel
        legend: whether to show the legend for actions panel or not.
        ob_traces: if != [] observations will be plot as traces, with the labels
                    specified by ob_traces
        fig_kwargs: figure properties admited by matplotlib.pyplot.subplots() fun.
        fname: if not None, save fig or movie to fname
    """
    # We don't use monitor here because:
    # 1) env could be already prewrapped with monitor
    # 2) monitor will save data and so the function will need a folder

    if fig_kwargs is None:
        fig_kwargs = {}
    if ob_traces is None:
        ob_traces = []
    if isinstance(env, str):
        env = gym.make(env)
    if name is None:
        name = type(env).__name__
    data = run_env(
        env=env,
        num_steps=num_steps,
        num_trials=num_trials,
        def_act=def_act,
        model=model,
    )

    return fig_(
        data["ob"],
        data["actions"],
        gt=data["gt"],
        rewards=data["rewards"],
        legend=legend,
        performance=data["perf"],
        states=data["states"],
        name=name,
        ob_traces=ob_traces,
        fig_kwargs=fig_kwargs,
        env=env,
        fname=fname,
    )

fig_

fig_(
    ob,
    actions,
    gt=None,
    rewards=None,
    performance=None,
    states=None,
    legend=True,
    ob_traces=None,
    name="",
    fname=None,
    fig_kwargs=None,
    env=None,
)

Visualize a run in a simple environment.

Parameters:

Name Type Description Default
ob

np array of observation (n_step, n_unit)

required
actions

np array of action (n_step, n_unit)

required
gt

np array of groud truth

None
rewards

np array of rewards

None
performance

np array of performance

None
states

np array of network states

None
name

title to show on the rewards panel and name to save figure

''
fname

if != '', where to save the figure

None
legend

whether to show the legend for actions panel or not.

True
ob_traces

None or list. If list, observations will be plot as traces, with the labels specified by ob_traces

None
fig_kwargs

figure properties admited by matplotlib.pyplot.subplots() fun.

None
env

environment class for extra information

None
Source code in neurogym/utils/plotting.py
def fig_(
    ob,
    actions,
    gt=None,
    rewards=None,
    performance=None,
    states=None,
    legend=True,
    ob_traces=None,
    name="",
    fname=None,
    fig_kwargs=None,
    env=None,
):
    """Visualize a run in a simple environment.

    Args:
        ob: np array of observation (n_step, n_unit)
        actions: np array of action (n_step, n_unit)
        gt: np array of groud truth
        rewards: np array of rewards
        performance: np array of performance
        states: np array of network states
        name: title to show on the rewards panel and name to save figure
        fname: if != '', where to save the figure
        legend: whether to show the legend for actions panel or not.
        ob_traces: None or list.
            If list, observations will be plot as traces, with the labels
            specified by ob_traces
        fig_kwargs: figure properties admited by matplotlib.pyplot.subplots() fun.
        env: environment class for extra information
    """
    if fig_kwargs is None:
        fig_kwargs = {}
    ob = np.array(ob)
    actions = np.array(actions)

    if len(ob.shape) == 2:
        return plot_env_1dbox(
            ob,
            actions,
            gt=gt,
            rewards=rewards,
            performance=performance,
            states=states,
            legend=legend,
            ob_traces=ob_traces,
            name=name,
            fname=fname,
            fig_kwargs=fig_kwargs,
            env=env,
        )
    if len(ob.shape) == 4:
        return plot_env_3dbox(ob, fname=fname, env=env)

    msg = f"{ob.shape=} not supported."
    raise ValueError(msg)

plot_env_1dbox

plot_env_1dbox(
    ob,
    actions,
    gt=None,
    rewards=None,
    performance=None,
    states=None,
    legend=True,
    ob_traces=None,
    name="",
    fname=None,
    fig_kwargs=None,
    env=None,
)

Plot environment with 1-D Box observation space.

Source code in neurogym/utils/plotting.py
def plot_env_1dbox(
    ob,
    actions,
    gt=None,
    rewards=None,
    performance=None,
    states=None,
    legend=True,
    ob_traces=None,
    name="",
    fname=None,
    fig_kwargs=None,
    env=None,
):
    """Plot environment with 1-D Box observation space."""
    if fig_kwargs is None:
        fig_kwargs = {}
    if len(ob.shape) != 2:
        msg = "ob has to be 2-dimensional."
        raise ValueError(msg)
    steps = np.arange(ob.shape[0])  # XXX: +1? 1st ob doesn't have action/gt

    n_row = 2  # observation and action
    n_row += rewards is not None
    n_row += performance is not None
    n_row += states is not None

    gt_colors = "gkmcry"
    if not fig_kwargs:
        fig_kwargs = {"sharex": True, "figsize": (5, n_row * 1.2)}

    f, axes = plt.subplots(n_row, 1, **fig_kwargs)
    i_ax = 0
    # ob
    ax = axes[i_ax]
    i_ax += 1
    if ob_traces:
        if len(ob_traces) != ob.shape[1]:
            msg = f"Please provide label for each of the {ob.shape[1]} traces in the observations."
            raise ValueError(msg)
        yticks = []
        for ind_tr, tr in enumerate(ob_traces):
            ax.plot(ob[:, ind_tr], label=tr)
            yticks.append(np.mean(ob[:, ind_tr]))
        if legend:
            ax.legend()
        ax.set_xlim([-0.5, len(steps) - 0.5])
        ax.set_yticks(yticks)
        ax.set_yticklabels(ob_traces)
    else:
        ax.imshow(ob.T, aspect="auto", origin="lower")
        if env and hasattr(env.observation_space, "name"):
            # Plot environment annotation
            yticks = []
            yticklabels = []
            for key, val in env.observation_space.name.items():
                yticks.append((np.min(val) + np.max(val)) / 2)
                yticklabels.append(key)
            ax.set_yticks(yticks)
            ax.set_yticklabels(yticklabels)
        else:
            ax.set_yticks([])
        ax.spines["top"].set_visible(False)
        ax.spines["bottom"].set_visible(False)
        ax.spines["right"].set_visible(False)

    if name:
        ax.set_title(f"{name} env")
    ax.set_ylabel("Obs.")
    ax.set_xticks([])
    # actions
    ax = axes[i_ax]
    i_ax += 1
    if len(actions.shape) > 1:
        # Changes not implemented yet
        ax.plot(steps, actions, marker="+", label="Actions")
    else:
        ax.plot(steps, actions, marker="+", label="Actions")
    if gt is not None:
        gt = np.array(gt)
        if len(gt.shape) > 1:
            for ind_gt in range(gt.shape[1]):
                ax.plot(
                    steps,
                    gt[:, ind_gt],
                    f"--{gt_colors[ind_gt]}",
                    label=f"Ground truth {ind_gt}",
                )
        else:
            ax.plot(steps, gt, f"--{gt_colors[0]}", label="Ground truth")
    ax.set_xlim([-0.5, len(steps) - 0.5])
    ax.set_ylabel("Act.")
    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)
    if legend:
        ax.legend()
    if env and hasattr(env.action_space, "name"):
        # Plot environment annotation
        yticks = []
        yticklabels = []
        for key, val in env.action_space.name.items():
            yticks.append((np.min(val) + np.max(val)) / 2)
            yticklabels.append(key)
        ax.set_yticks(yticks)
        ax.set_yticklabels(yticklabels)
    if n_row > 2:
        ax.set_xticks([])
    # rewards
    if rewards is not None:
        ax = axes[i_ax]
        i_ax += 1
        ax.plot(steps, rewards, "r", label="Rewards")
        ax.set_ylabel("Rew.")
        ax.spines["top"].set_visible(False)
        ax.spines["right"].set_visible(False)
        if legend:
            ax.legend()
        ax.set_xlim([-0.5, len(steps) - 0.5])

        if env and hasattr(env, "rewards") and env.rewards is not None:
            # Plot environment annotation
            yticks = []
            yticklabels = []

            if isinstance(env.rewards, dict):
                for key, val in env.rewards.items():
                    yticks.append(val)
                    yticklabels.append(f"{key[:4]} {val:0.2f}")
            else:
                for val in env.rewards:
                    yticks.append(val)
                    yticklabels.append(f"{val:0.2f}")

            ax.set_yticks(yticks)
            ax.set_yticklabels(yticklabels)
    if n_row > 3:
        ax.set_xticks([])
    # performance
    if performance is not None:
        ax = axes[i_ax]
        i_ax += 1
        ax.plot(steps, performance, "k", label="Performance")
        ax.set_ylabel("Performance")
        performance = np.array(performance)
        mean_perf = np.mean(performance[performance != -1])
        ax.set_title(f"Mean performance: {np.round(mean_perf, 2)}")
        ax.spines["top"].set_visible(False)
        ax.spines["right"].set_visible(False)
        if legend:
            ax.legend()
        ax.set_xlim([-0.5, len(steps) - 0.5])

    # states
    if states is not None:
        ax.set_xticks([])
        ax = axes[i_ax]
        i_ax += 1
        plt.imshow(states[:, int(states.shape[1] / 2) :].T, aspect="auto")
        ax.set_title("Activity")
        ax.set_ylabel("Neurons")
        ax.spines["top"].set_visible(False)
        ax.spines["right"].set_visible(False)

    ax.set_xlabel("Steps")
    plt.tight_layout()
    if fname:
        fname = str(fname)
        if not (fname.endswith((".png", ".svg"))):
            fname += ".png"
        f.savefig(fname, dpi=300)
        plt.close(f)
    return f

plot_env_3dbox

plot_env_3dbox(ob, fname='', env=None) -> None

Plot environment with 3-D Box observation space.

Source code in neurogym/utils/plotting.py
def plot_env_3dbox(ob, fname="", env=None) -> None:
    """Plot environment with 3-D Box observation space."""
    ob = ob.astype(np.uint8)  # TODO: Temporary
    fig = plt.figure()
    ax = fig.add_axes((0.1, 0.1, 0.8, 0.8))
    ax.axis("off")
    im = ax.imshow(ob[0], animated=True)

    def animate(i, *args, **kwargs):
        im.set_array(ob[i])
        return (im,)

    interval = env.dt if env is not None else 50
    ani = animation.FuncAnimation(fig, animate, frames=ob.shape[0], interval=interval)
    if fname:
        writer = animation.writers["ffmpeg"](fps=int(1000 / interval))
        fname = str(fname)
        if not fname.endswith(".mp4"):
            fname += ".mp4"
        ani.save(fname, writer=writer, dpi=300)

ngym_random

TruncExp

TruncExp(vmean, vmin=0, vmax=inf)
Source code in neurogym/utils/ngym_random.py
def __init__(self, vmean, vmin=0, vmax=np.inf) -> None:
    self.vmean = vmean
    self.vmin = vmin
    self.vmax = vmax
    self.rng = np.random.RandomState()

seed

seed(seed=None) -> None

Seed the PRNG of this space.

Source code in neurogym/utils/ngym_random.py
def seed(self, seed=None) -> None:
    """Seed the PRNG of this space."""
    self.rng = np.random.RandomState(seed)

trunc_exp

trunc_exp(rng, vmean, vmin=0, vmax=inf)

Function for generating period durations.

Source code in neurogym/utils/ngym_random.py
def trunc_exp(rng, vmean, vmin=0, vmax=np.inf):
    """Function for generating period durations."""
    if vmin >= vmax:  # the > is to avoid issues when making vmin as big as dt
        return vmax
    while True:
        x = rng.exponential(vmean)
        if vmin <= x < vmax:
            return x

random_number_fn

random_number_fn(dist, args, rng)

Return a random number generating function from a distribution.

Source code in neurogym/utils/ngym_random.py
def random_number_fn(dist, args, rng):
    """Return a random number generating function from a distribution."""
    if dist == "uniform":
        return lambda: rng.uniform(*args)
    if dist == "choice":
        return lambda: rng.choice(args)
    if dist == "truncated_exponential":
        return lambda: trunc_exp(rng, *args)
    if dist == "constant":
        return lambda: args
    msg = f"Unknown distribution: {dist}."
    raise ValueError(msg)

random_number_name

random_number_name(dist, args)

Return a string explaining the dist and args.

Source code in neurogym/utils/ngym_random.py
def random_number_name(dist, args):
    """Return a string explaining the dist and args."""
    if dist == "uniform":
        return f"{dist} between {args[0]} and {args[1]}"
    if dist == "choice":
        return f"{dist} within {args}"
    if dist == "truncated_exponential":
        string = f"truncated exponential with mean {args[0]}"
        if len(args) > 1:
            string += f", min {args[1]}"
        if len(args) > 2:
            string += f", max {args[2]}"
        return string
    if dist == "constant":
        return f"dist{args}"
    msg = f"Unknown distribution: {dist}."
    raise ValueError(msg)

scheduler

Trial scheduler class.

BaseSchedule

BaseSchedule(n)

Base schedule.

Parameters:

Name Type Description Default
n

int, number of conditions to schedule

required
Source code in neurogym/utils/scheduler.py
def __init__(self, n) -> None:
    self.n = n
    self.total_count = 0  # total count
    self.count = 0  # count within a condition
    self.i = 0  # initialize at 0
    self.rng = np.random.RandomState()

SequentialSchedule

SequentialSchedule(n)

Bases: BaseSchedule

Sequential schedules.

Source code in neurogym/utils/scheduler.py
def __init__(self, n) -> None:
    super().__init__(n)

RandomSchedule

RandomSchedule(n)

Bases: BaseSchedule

Random schedules.

Source code in neurogym/utils/scheduler.py
def __init__(self, n) -> None:
    super().__init__(n)

SequentialBlockSchedule

SequentialBlockSchedule(n, block_lens)

Bases: BaseSchedule

Sequential block schedules.

Source code in neurogym/utils/scheduler.py
def __init__(self, n, block_lens) -> None:
    super().__init__(n)
    self.block_lens = block_lens
    if len(block_lens) != n:
        msg = f"{len(block_lens)=} must be equal to {n=}."
        raise ValueError(msg)

RandomBlockSchedule

RandomBlockSchedule(n, block_lens)

Bases: BaseSchedule

Random block schedules.

Source code in neurogym/utils/scheduler.py
def __init__(self, n, block_lens) -> None:
    super().__init__(n)
    self.block_lens = block_lens
    if len(block_lens) != n:
        msg = f"{len(block_lens)=} must be equal to {n=}."
        raise ValueError(msg)

spaces

Box

Box(low, high, name=None, **kwargs)

Bases: Box

Thin wrapper of gymnasium.spaces.Box.

Allow the user to give names to each dimension of the Box.

Parameters:

Name Type Description Default
low, (high, kwargs)

see gymnasium.spaces.Box

required
name

dict describing the name of different dimensions

None
Example usage

observation_space = Box(low=0, high=1, name={'fixation': 0, 'stimulus': [1, 2]})

Source code in neurogym/utils/spaces.py
def __init__(self, low, high, name=None, **kwargs) -> None:
    super().__init__(low, high, **kwargs)
    if isinstance(name, dict):
        self.name = name
    elif name is not None:
        msg = f"{type(name)=} must be `dict` or `NoneType`."
        raise TypeError(msg)

Discrete

Discrete(n, name=None, **kwargs)

Bases: Discrete

Thin wrapper of gymnasium.spaces.Discrete.

Allow the user to give names to each dimension of the Discrete space.

Parameters:

Name Type Description Default
low, (high, kwargs)

see gymnasium.spaces.Box

required
name

dict describing the name of different dimensions

None
Example usage

observation_space = Discrete(n=3, name={'fixation': 0, 'stimulus': [1, 2]})

Source code in neurogym/utils/spaces.py
def __init__(self, n, name=None, **kwargs) -> None:
    super().__init__(n)
    if isinstance(name, dict):
        self.name = name
    elif name is not None:
        msg = f"{type(name)=} must be `dict` or `NoneType`."
        raise TypeError(msg)

tasktools

to_map

to_map(*args)

Produces ordered dict from given inputs.

Source code in neurogym/utils/tasktools.py
def to_map(*args):
    """Produces ordered dict from given inputs."""
    var_list = args[0] if isinstance(args[0], list) else args
    od = OrderedDict()
    for i, v in enumerate(var_list):
        od[v] = i

    return od

get_idx

get_idx(t, start_end)

Auxiliary function for defining task periods.

Source code in neurogym/utils/tasktools.py
def get_idx(t, start_end):
    """Auxiliary function for defining task periods."""
    start, end = start_end
    return list(np.where((start <= t) & (t < end))[0])

get_periods_idx

get_periods_idx(dt, periods)

Function for defining task periods.

Source code in neurogym/utils/tasktools.py
def get_periods_idx(dt, periods):
    """Function for defining task periods."""
    t = np.linspace(0, periods["tmax"], int(periods["tmax"] / dt) + 1)

    return t, {k: get_idx(t, v) for k, v in periods.items() if k != "tmax"}

minmax_number

minmax_number(dist, args)

Given input to the random_number_fn function, return min and max.

Source code in neurogym/utils/tasktools.py
def minmax_number(dist, args):
    """Given input to the random_number_fn function, return min and max."""
    if dist == "uniform":
        return args[0], args[1]
    if dist == "choice":
        return np.min(args), np.max(args)
    if dist == "truncated_exponential":
        return args[1], args[2]
    if dist == "constant":
        return args, args
    msg = f"Unknown distribution: {dist}."
    raise ValueError(msg)

circular_dist

circular_dist(original_dist)

Get the distance in periodic boundary conditions.

Source code in neurogym/utils/tasktools.py
def circular_dist(original_dist):
    """Get the distance in periodic boundary conditions."""
    return np.minimum(abs(original_dist), 2 * np.pi - abs(original_dist))

correct_2AFC

correct_2AFC(perf)

Computes performance.

Source code in neurogym/utils/tasktools.py
def correct_2AFC(perf):  # noqa: N802
    """Computes performance."""
    p_decision = perf.n_decision / perf.n_trials
    p_correct = divide(perf.n_correct, perf.n_decision)

    return p_decision, p_correct