Source code for neurogym.envs.reaching

"""Reaching to target."""

import numpy as np

import neurogym as ngym
from neurogym import spaces

from neurogym.utils import tasktools


# TODO: Ground truth and action have different space,
# making it difficult for SL and RL to work together
# TODO: Need to clean up this task
[docs] class Reaching1D(ngym.TrialEnv): r"""Reaching to the stimulus. The agent is shown a stimulus during the fixation period. The stimulus encodes a one-dimensional variable such as a movement direction. At the end of the fixation period, the agent needs to respond by reaching towards the stimulus direction. """ metadata = { 'paper_link': 'https://science.sciencemag.org/content/233/4771/1416', 'paper_name': 'Neuronal population coding of movement direction', 'tags': ['motor', 'steps action space'] } def __init__(self, dt=100, rewards=None, timing=None, dim_ring=16): super().__init__(dt=dt) # Rewards self.rewards = {'correct': +1., 'fail': -0.1} if rewards: self.rewards.update(rewards) self.timing = { 'fixation': 500, 'reach': 500} if timing: self.timing.update(timing) # action and observation spaces name = {'self': range(dim_ring, 2*dim_ring), 'target': range(dim_ring)} self.observation_space = spaces.Box( -np.inf, np.inf, shape=(2*dim_ring,), dtype=np.float32, name=name) name = {'fixation': 0, 'left': 1, 'right': 2} self.action_space = spaces.Discrete(3, name=name) self.theta = np.arange(0, 2*np.pi, 2*np.pi/dim_ring) self.state = np.pi self.dim_ring = dim_ring def _new_trial(self, **kwargs): # Trial self.state = np.pi trial = { 'ground_truth': self.rng.uniform(0, np.pi*2) } trial.update(kwargs) # Periods self.add_period(['fixation', 'reach']) target = np.cos(self.theta - trial['ground_truth']) self.add_ob(target, 'reach', where='target') self.set_groundtruth(np.pi, 'fixation') self.set_groundtruth(trial['ground_truth'], 'reach') self.dec_per_dur = (self.end_ind['reach'] - self.start_ind['reach']) return trial def _step(self, action): if action == 1: self.state += 0.05 elif action == 2: self.state -= 0.05 self.state = np.mod(self.state, 2*np.pi) gt = self.gt_now if self.in_period('fixation'): reward = 0 else: reward =\ np.max((self.rewards['correct']-tasktools.circular_dist(self.state-gt), self.rewards['fail'])) norm_rew = (reward-self.rewards['fail'])/(self.rewards['correct']-self.rewards['fail']) self.performance += norm_rew/self.dec_per_dur return self.ob_now, reward, False, {'new_trial': False}
[docs] def post_step(self, ob, reward, done, info): """Modify observation""" ob[self.dim_ring:] = np.cos(self.theta - self.state) return ob, reward, done, info
[docs] class Reaching1DWithSelfDistraction(ngym.TrialEnv): r"""Reaching with self distraction. In this task, the reaching state itself generates strong inputs that overshadows the actual target input. This task is inspired by behavior in electric fish where the electric sensing organ is distracted by discharges from its own electric organ for active sensing. Similar phenomena in bats. """ metadata = { 'description': '''The agent has to reproduce the angle indicated by the observation. Furthermore, the reaching state itself generates strong inputs that overshadows the actual target input.''', 'paper_link': None, 'paper_name': None, 'tags': ['motor', 'steps action space'] } def __init__(self, dt=100, rewards=None, timing=None): super().__init__(dt=dt) # Rewards self.rewards = {'correct': +1., 'fail': -0.1} if rewards: self.rewards.update(rewards) self.timing = { 'fixation': 500, 'reach': 500} if timing: self.timing.update(timing) # action and observation spaces self.action_space = spaces.Discrete(3) self.observation_space = spaces.Box(-np.inf, np.inf, shape=(32,), dtype=np.float32) self.theta = np.arange(0, 2*np.pi, 2*np.pi/32) self.state = np.pi def _new_trial(self, **kwargs): # --------------------------------------------------------------------- # Trial # --------------------------------------------------------------------- self.state = np.pi trial = { 'ground_truth': self.rng.uniform(0, np.pi*2) } trial.update(kwargs) # --------------------------------------------------------------------- # Periods # --------------------------------------------------------------------- self.add_period('fixation') self.add_period('reach', after='fixation') ob = self.view_ob('reach') # Signal is weaker than the self-distraction ob += np.cos(self.theta - trial['ground_truth']) * 0.3 self.set_groundtruth(np.pi, 'fixation') self.set_groundtruth(trial['ground_truth'], 'reach') self.dec_per_dur = (self.end_ind['reach'] - self.start_ind['reach']) return trial def _step(self, action): if action == 1: self.state += 0.05 elif action == 2: self.state -= 0.05 self.state = np.mod(self.state, 2*np.pi) gt = self.gt_now if self.in_period('fixation'): reward = 0 else: reward =\ np.max((self.rewards['correct']-tasktools.circular_dist(self.state-gt), self.rewards['fail'])) norm_rew = (reward-self.rewards['fail'])/(self.rewards['correct']-self.rewards['fail']) self.performance += norm_rew/self.dec_per_dur return self.ob_now, reward, False, {'new_trial': False}
[docs] def post_step(self, ob, reward, done, info): """Modify observation.""" ob += np.cos(self.theta - self.state) return ob, reward, done, info