Skip to content

Core

TrialEnv

TrialEnv(dt=100, num_trials_before_reset=10000000, r_tmax=0)

Bases: BaseEnv

The main Neurogym class for trial-based envs.

Source code in neurogym/core.py
def __init__(self, dt=100, num_trials_before_reset=10000000, r_tmax=0) -> None:
    super().__init__(dt=dt)
    self.r_tmax = r_tmax
    self.num_tr = 0
    self.num_tr_exp = num_trials_before_reset
    self.trial: dict | None = None
    self._ob_built = False
    self._gt_built = False
    self._has_gt = False  # check if the task ever defined gt

    self._default_ob_value = None  # default to 0

    # For optional periods
    self.timing: dict = {}
    self.start_t: dict = {}
    self.end_t: dict = {}
    self.start_ind: dict = {}
    self.end_ind: dict = {}
    self._tmax = 0  # Length of each trial

    self._top = self
    self._duration: dict = {}

seed

seed(seed=None)

Set random seed.

Source code in neurogym/core.py
def seed(self, seed=None):
    """Set random seed."""
    self.rng = np.random.RandomState(seed)
    if hasattr(self, "action_space") and self.action_space is not None:
        self.action_space.seed(seed)
    for val in self.timing.values():
        with contextlib.suppress(AttributeError):
            val.seed(seed)
    return [seed]

post_step

post_step(ob, reward, terminated, truncated, info)

Optional task-specific wrapper applied at the end of step.

It allows to modify ob online (e.g. provide a specific observation for different actions made by the agent)

Source code in neurogym/core.py
def post_step(self, ob, reward, terminated, truncated, info):
    """Optional task-specific wrapper applied at the end of step.

    It allows to modify ob online (e.g. provide a specific observation for different actions made by the agent)
    """
    return ob, reward, terminated, truncated, info

new_trial

new_trial(**kwargs)

Public interface for starting a new trial.

Returns:

Name Type Description
trial

dict of trial information. Available to step function as self.trial

Source code in neurogym/core.py
def new_trial(self, **kwargs):
    """Public interface for starting a new trial.

    Returns:
        trial: dict of trial information. Available to step function as
            self.trial
    """
    # Reset for next trial
    self._tmax = 0  # reset, self.tmax not reset so it can be used in step
    self._ob_built = False
    self._gt_built = False
    trial = self._new_trial(**kwargs)
    self.trial = trial
    self.num_tr += 1  # Increment trial count
    self._has_gt = self._gt_built
    return trial

step

step(action)

Public interface for the environment.

Source code in neurogym/core.py
def step(self, action):
    """Public interface for the environment."""
    ob, reward, terminated, truncated, info = self._step(action)

    if "new_trial" not in info:
        info["new_trial"] = False

    if self._has_gt and "gt" not in info:
        # If gt is built, default gt to gt_now
        # must run before incrementing t
        info["gt"] = self.gt_now

    self.t += self.dt  # increment within trial time count
    self.t_ind += 1

    if self.t + self.dt > self.tmax and not info["new_trial"]:
        info["new_trial"] = True
        reward += self.r_tmax

    # TODO: new_trial happens after step, so trial indx precedes obs change
    if info["new_trial"]:
        info["performance"] = self.performance
        self.t = self.t_ind = 0  # Reset within trial time count
        trial = self._top.new_trial()
        self.performance = 0
        info["trial"] = trial
    if ob is OBNOW:
        ob = self.ob[self.t_ind]
    return self.post_step(ob, reward, terminated, truncated, info)

reset

reset(seed=None, options=None)

Reset the environment.

Parameters:

Name Type Description Default
seed

random seed, overwrites self.seed if not None

None
options

additional options used to reset the env. Can include 'step_fn' and 'no_step'. step_fn can be a function or None. If function, overwrite original self.step method. no_step is a bool. If True, no step is taken and observation randomly sampled. It defaults to False.

None
Source code in neurogym/core.py
def reset(self, seed=None, options=None):
    """Reset the environment.

    Args:
        seed: random seed, overwrites self.seed if not None
        options: additional options used to reset the env.
            Can include 'step_fn' and 'no_step'.
            `step_fn` can be a function or None. If function, overwrite original
            `self.step` method.
            `no_step` is a bool. If True, no step is taken and observation randomly
            sampled. It defaults to False.
    """
    super().reset(seed=seed)

    self.num_tr = 0
    self.t = self.t_ind = 0

    step_fn = options.get("step_fn") if options else None
    no_step = options.get("no_step", False) if options else False

    self._top.new_trial()

    # have to also call step() to get the initial ob since some wrappers modify step() but not new_trial()
    self.action_space.seed(0)
    if no_step:
        return self.observation_space.sample(), {}
    if step_fn is None:
        ob, _, _, _, _ = self._top.step(self.action_space.sample())
    else:
        ob, _, _, _, _ = step_fn(self.action_space.sample())
    return ob, {}

render

render(mode='human') -> None

Plots relevant variables/parameters.

Source code in neurogym/core.py
def render(self, mode="human") -> None:
    """Plots relevant variables/parameters."""

set_top

set_top(wrapper) -> None

Set top to be wrapper.

Source code in neurogym/core.py
def set_top(self, wrapper) -> None:
    """Set top to be wrapper."""
    self._top = wrapper

add_period

add_period(period, duration=None, before=None, after=None, last_period=False) -> None

Add an period.

Parameters:

Name Type Description Default
period

string or list of strings, name of the period

required
duration

float or None, duration of the period if None, inferred from timing_fn

None
before

(optional) str, name of period that this period is before

None
after

(optional) str, name of period that this period is after or float, time of period start

None
last_period

bool, default False. If True, then this is last period will generate self.tmax, self.tind, and self.ob

False
Source code in neurogym/core.py
def add_period(
    self,
    period,
    duration=None,
    before=None,
    after=None,
    last_period=False,
) -> None:
    """Add an period.

    Args:
        period: string or list of strings, name of the period
        duration: float or None, duration of the period
            if None, inferred from timing_fn
        before: (optional) str, name of period that this period is before
        after: (optional) str, name of period that this period is after
            or float, time of period start
        last_period: bool, default False. If True, then this is last period
            will generate self.tmax, self.tind, and self.ob
    """
    if self._ob_built:
        msg = "Cannot add period after ob is built, i.e. after running add_ob."
        raise InvalidOperationError(msg)
    if isinstance(period, str):
        pass
    else:
        if duration is None:
            duration = [None] * len(period)
        elif len(duration) != len(period):
            msg = f"{len(duration)=} and {len(period)=} must be the same."
            raise ValueError(msg)

        # Recursively calling itself
        self.add_period(period[0], duration=duration[0], after=after)
        for i in range(1, len(period)):
            is_last = (i == len(period) - 1) and last_period
            self.add_period(
                period[i],
                duration=duration[i],
                after=period[i - 1],
                last_period=is_last,
            )
        return

    if duration is None:
        duration = self.sample_time(period)
    self._duration[period] = duration

    if after is not None:
        start = self.end_t[after] if isinstance(after, str) else after
    elif before is not None:
        start = self.start_t[before] - duration
    else:
        start = 0  # default start with 0

    self.start_t[period] = start
    self.end_t[period] = start + duration
    self.start_ind[period] = int(start / self.dt)
    self.end_ind[period] = int((start + duration) / self.dt)

    self._tmax = max(self._tmax, start + duration)
    self.tmax = int(self._tmax / self.dt) * self.dt

view_ob

view_ob(period=None)

View observation of an period.

Source code in neurogym/core.py
def view_ob(self, period=None):
    """View observation of an period."""
    if not self._ob_built:
        self._init_ob()

    if period is None:
        return self.ob
    return self.ob[self.start_ind[period] : self.end_ind[period]]

add_ob

add_ob(value, period=None, where=None) -> None

Add value to observation.

Parameters:

Name Type Description Default
value

array-like (ob_space.shape, ...)

required
period

string, must be name of an added period

None
where

string or np array, location of stimulus to be added

None
Source code in neurogym/core.py
def add_ob(self, value, period=None, where=None) -> None:
    """Add value to observation.

    Args:
        value: array-like (ob_space.shape, ...)
        period: string, must be name of an added period
        where: string or np array, location of stimulus to be added
    """
    self._add_ob(value, period, where, reset=False)

set_groundtruth

set_groundtruth(value, period=None, where=None) -> None

Set groundtruth value.

Source code in neurogym/core.py
def set_groundtruth(self, value, period=None, where=None) -> None:
    """Set groundtruth value."""
    if not self._gt_built:
        self._init_gt()

    if where is not None:
        # TODO: Only works for Discrete action_space, make it work for Box
        value = self.action_space.name[where][value]  # type: ignore[attr-defined]
    if isinstance(period, str):
        self.gt[self.start_ind[period] : self.end_ind[period]] = value
    elif period is None:
        self.gt[:] = value
    else:
        for p in period:
            self.set_groundtruth(value, p)

view_groundtruth

view_groundtruth(period)

View observation of an period.

Source code in neurogym/core.py
def view_groundtruth(self, period):
    """View observation of an period."""
    if not self._gt_built:
        self._init_gt()
    return self.gt[self.start_ind[period] : self.end_ind[period]]

in_period

in_period(period, t=None)

Check if current time or time t is in period.

Source code in neurogym/core.py
def in_period(self, period, t=None):
    """Check if current time or time t is in period."""
    if t is None:
        t = self.t  # Default
    return self.start_t[period] <= t < self.end_t[period]

BaseEnv

BaseEnv(dt=100)

Bases: Env

The base Neurogym class to include dt.

Source code in neurogym/core.py
def __init__(self, dt=100) -> None:
    super().__init__()
    self.dt = dt
    self.t = self.t_ind = 0
    self.tmax = 10000  # maximum time steps
    self.performance = 0
    self.rewards: dict = {}
    self.rng = np.random.RandomState()

seed

seed(seed=None)

Set random seed.

Source code in neurogym/core.py
def seed(self, seed=None):
    """Set random seed."""
    self.rng = np.random.RandomState(seed)
    if self.action_space is not None:
        self.action_space.seed(seed)
    return [seed]