#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""Ready-set-go task."""
import numpy as np
import neurogym as ngym
from neurogym import spaces
[docs]
class ReadySetGo(ngym.TrialEnv):
r"""Agents have to measure and produce different time intervals.
A stimulus is briefly shown during a ready period, then again during a
set period. The ready and set periods are separated by a measure period,
the duration of which is randomly sampled on each trial. The agent is
required to produce a response after the set cue such that the interval
between the response and the set cue is as close as possible to the
duration of the measure period.
Args:
gain: Controls the measure that the agent has to produce. (def: 1, int)
prod_margin: controls the interval around the ground truth production
time within which the agent receives proportional reward
"""
metadata = {
'paper_link': 'https://www.sciencedirect.com/science/article/pii/' +
'S0896627318304185',
'paper_name': '''Flexible Sensorimotor Computations through Rapid
Reconfiguration of Cortical Dynamics''',
'tags': ['timing', 'go-no-go', 'supervised']
}
def __init__(self, dt=80, rewards=None, timing=None, gain=1,
prod_margin=0.2):
super().__init__(dt=dt)
self.prod_margin = prod_margin
self.gain = gain
# Rewards
self.rewards = {'abort': -0.1, 'correct': +1., 'fail': 0.}
if rewards:
self.rewards.update(rewards)
self.timing = {
'fixation': 100,
'ready': 83,
'measure': lambda: self.rng.uniform(800, 1500),
'set': 83}
if timing:
self.timing.update(timing)
self.abort = False
# set action and observation space
name = {'fixation': 0, 'ready': 1, 'set': 2}
self.observation_space = spaces.Box(-np.inf, np.inf, shape=(3,),
dtype=np.float32, name=name)
name = {'fixation': 0, 'go': 1}
self.action_space = spaces.Discrete(2, name=name) # (fixate, go)
def _new_trial(self, **kwargs):
measure = self.sample_time('measure')
trial = {
'measure': measure,
'gain': self.gain
}
trial.update(kwargs)
trial['production'] = measure * trial['gain']
self.add_period(['fixation', 'ready'])
self.add_period('measure', duration=measure, after='fixation')
self.add_period('set', after='measure')
self.add_period('production', duration=2*trial['production'],
after='set')
self.add_ob(1, where='fixation')
self.set_ob(0, 'production', where='fixation')
self.add_ob(1, 'ready', where='ready')
self.add_ob(1, 'set', where='set')
# set ground truth
gt = np.zeros((int(2*trial['production']/self.dt),))
gt[int(trial['production']/self.dt)] = 1
self.set_groundtruth(gt, 'production')
return trial
def _step(self, action):
trial = self.trial
reward = 0
ob = self.ob_now
gt = self.gt_now
new_trial = False
if self.in_period('fixation'):
if action != 0:
new_trial = self.abort
reward = self.rewards['abort']
if self.in_period('production'):
if action == 1:
new_trial = True # terminate
# time from end of measure:
t_prod = self.t - self.end_t['measure']
eps = abs(t_prod - trial['production'])
# actual production time
eps_threshold = self.prod_margin*trial['production']+25
if eps > eps_threshold:
reward = self.rewards['fail']
else:
reward = (1. - eps/eps_threshold)**1.5
reward = max(reward, 0.1)
reward *= self.rewards['correct']
self.performance = 1
return ob, reward, False, {'new_trial': new_trial, 'gt': gt}
[docs]
class MotorTiming(ngym.TrialEnv):
"""Agents have to produce different time intervals
using different effectors (actions).
Args:
prod_margin: controls the interval around the ground truth production
time within which the agent receives proportional reward
"""
# TODO: different actions not implemented
metadata = {
'paper_link': 'https://www.nature.com/articles/s41593-017-0028-6',
'paper_name': '''Flexible timing by temporal scaling of
cortical responses''',
'tags': ['timing', 'go-no-go', 'supervised']
}
def __init__(self, dt=80, rewards=None, timing=None, prod_margin=0.2):
super().__init__(dt=dt)
self.prod_margin = prod_margin
self.production_ind = [0, 1]
self.intervals = [800, 1500]
# Rewards
self.rewards = {'abort': -0.1, 'correct': +1., 'fail': 0.}
if rewards:
self.rewards.update(rewards)
self.timing = {
'fixation': 500, # XXX: not specified
'cue': lambda: self.rng.uniform(1000, 3000),
'set': 50}
if timing:
self.timing.update(timing)
self.abort = False
# set action and observation space
self.action_space = spaces.Discrete(2) # (fixate, go)
# Fixation, Interval indicator x2, Set
self.observation_space = spaces.Box(-np.inf, np.inf, shape=(4,),
dtype=np.float32)
def _new_trial(self, **kwargs):
trial = {
'production_ind': self.rng.choice(self.production_ind)
}
trial.update(kwargs)
trial['production'] = self.intervals[trial['production_ind']]
self.add_period(['fixation', 'cue', 'set'])
self.add_period('production', duration=2*trial['production'],
after='set')
self.set_ob([1, 0, 0, 0], 'fixation')
ob = self.view_ob('cue')
ob[:, 0] = 1
ob[:, trial['production_ind']+1] = 1
ob = self.view_ob('set')
ob[:, 0] = 1
ob[:, trial['production_ind'] + 1] = 1
ob[:, 3] = 1
# set ground truth
gt = np.zeros((int(2*trial['production']/self.dt),))
gt[int(trial['production']/self.dt)] = 1
self.set_groundtruth(gt, 'production')
return trial
def _step(self, action):
# ---------------------------------------------------------------------
# Reward and inputs
# ---------------------------------------------------------------------
trial = self.trial
reward = 0
ob = self.ob_now
gt = self.gt_now
new_trial = False
if self.in_period('fixation'):
if action != 0:
new_trial = self.abort
reward = self.rewards['abort']
if self.in_period('production'):
if action == 1:
new_trial = True # terminate
t_prod = self.t - self.end_t['set'] # time from end of measure
eps = abs(t_prod - trial['production'])
# actual production time
eps_threshold = self.prod_margin*trial['production']+25
if eps > eps_threshold:
reward = self.rewards['fail']
else:
reward = (1. - eps/eps_threshold)**1.5
reward = max(reward, 0.1)
reward *= self.rewards['correct']
self.performance = 1
return ob, reward, False, {'new_trial': new_trial, 'gt': gt}
[docs]
class OneTwoThreeGo(ngym.TrialEnv):
r"""Agents reproduce time intervals based on two samples.
Args:
prod_margin: controls the interval around the ground truth production
time within which the agent receives proportional reward
"""
metadata = {
'paper_link': 'https://www.nature.com/articles/s41593-019-0500-6',
'paper_name': "Internal models of sensorimotor integration "
"regulate cortical dynamics",
'tags': ['timing', 'go-no-go', 'supervised']
}
def __init__(self, dt=80, rewards=None, timing=None, prod_margin=0.2):
super().__init__(dt=dt)
self.prod_margin = prod_margin
# Rewards
self.rewards = {'abort': -0.1, 'correct': +1., 'fail': 0.}
if rewards:
self.rewards.update(rewards)
self.timing = {
'fixation': ngym.random.TruncExp(400, 100, 800),
'target': ngym.random.TruncExp(1000, 500, 1500),
's1': 100,
'interval1': (600, 700, 800, 900, 1000),
's2': 100,
'interval2': 0,
's3': 100,
'interval3': 0,
'response': 1000}
if timing:
self.timing.update(timing)
self.abort = False
# set action and observation space
name = {'fixation': 0, 'stimulus': 1, 'target': 2}
self.observation_space = spaces.Box(-np.inf, np.inf, shape=(3,),
dtype=np.float32, name=name)
name = {'fixation': 0, 'go': 1}
self.action_space = spaces.Discrete(2, name=name)
def _new_trial(self, **kwargs):
interval = self.sample_time('interval1')
trial = {
'interval': interval,
}
trial.update(kwargs)
self.add_period(['fixation', 'target', 's1'])
self.add_period('interval1', duration=interval, after='s1')
self.add_period('s2', after='interval1')
self.add_period('interval2', duration=interval, after='s2')
self.add_period('s3', after='interval2')
self.add_period('interval3', duration=interval, after='s3')
self.add_period('response', after='interval3')
self.add_ob(1, where='fixation')
self.add_ob(1, ['s1', 's2', 's3'], where='stimulus')
self.add_ob(1, where='target')
self.set_ob(0, 'fixation', where='target')
# set ground truth
self.set_groundtruth(1, period='response')
return trial
def _step(self, action):
# ---------------------------------------------------------------------
# Reward and inputs
# ---------------------------------------------------------------------
trial = self.trial
reward = 0
ob = self.ob_now
gt = self.gt_now
new_trial = False
if self.in_period('interval3') or self.in_period('response'):
if action == 1:
new_trial = True # terminate
# time from end of measure:
t_prod = self.t - self.end_t['s3']
eps = abs(t_prod - trial['interval']) # actual production time
eps_threshold = self.prod_margin*trial['interval']+25
if eps > eps_threshold:
reward = self.rewards['fail']
else:
reward = (1. - eps/eps_threshold)**1.5
reward = max(reward, 0.1)
reward *= self.rewards['correct']
self.performance = 1
else:
if action != 0:
new_trial = self.abort
reward = self.rewards['abort']
return ob, reward, False, {'new_trial': new_trial, 'gt': gt}