"""Random dot motion task."""
import numpy as np
import neurogym as ngym
from neurogym import spaces
[docs]
class ProbabilisticReasoning(ngym.TrialEnv):
"""Probabilistic reasoning.
The agent is shown a sequence of stimuli. Each stimulus is associated
with a certain log-likelihood of the correct response being one choice
versus the other. The final log-likelihood of the target response being,
for example, option 1, is the sum of all log-likelihood associated with
the presented stimuli. A delay period separates each stimulus, so the
agent is encouraged to lean the log-likelihood association and integrate
these values over time within a trial.
Args:
shape_weight: array-like, evidence weight of each shape
n_loc: int, number of location of show shapes
"""
metadata = {
'paper_link': 'https://www.nature.com/articles/nature05852',
'paper_name': 'Probabilistic reasoning by neurons',
'tags': ['perceptual', 'two-alternative', 'supervised']
}
def __init__(self, dt=100, rewards=None, timing=None, shape_weight=None,
n_loc=4):
super().__init__(dt=dt)
# The evidence weight of each stimulus
if shape_weight is not None:
self.shape_weight = shape_weight
else:
self.shape_weight = [-10, -0.9, -0.7, -0.5, -0.3,
0.3, 0.5, 0.7, 0.9, 10]
self.n_shape = len(self.shape_weight)
dim_shape = self.n_shape
# Shape representation needs to be fixed cross-platform
self.shapes = np.eye(self.n_shape, dim_shape)
self.n_loc = n_loc
# Rewards
self.rewards = {'abort': -0.1, 'correct': +1., 'fail': 0.}
if rewards:
self.rewards.update(rewards)
self.timing = {'fixation': 500,
'delay': lambda: self.rng.uniform(450, 550),
'decision': 500
}
for i_loc in range(n_loc):
self.timing['stimulus'+str(i_loc)] = 500
if timing:
self.timing.update(timing)
self.abort = False
name = {'fixation': 0}
start = 1
for i_loc in range(n_loc):
name['loc' + str(i_loc)] = range(start, start + dim_shape)
start += dim_shape
self.observation_space = spaces.Box(
-np.inf, np.inf, shape=(1 + dim_shape*n_loc,),
dtype=np.float32, name=name)
name = {'fixation': 0, 'choice': [1, 2]}
self.action_space = spaces.Discrete(3, name=name)
def _new_trial(self, **kwargs):
# Trial info
trial = {
'locs': self.rng.choice(range(self.n_loc),
size=self.n_loc, replace=False),
'shapes': self.rng.choice(range(self.n_shape),
size=self.n_loc, replace=True),
}
trial.update(kwargs)
locs = trial['locs']
shapes = trial['shapes']
log_odd = sum([self.shape_weight[shape] for shape in shapes])
p = 1. / (10**(-log_odd) + 1.)
ground_truth = int(self.rng.rand() < p)
trial['log_odd'] = log_odd
trial['ground_truth'] = ground_truth
# Periods
periods = ['fixation']
periods += ['stimulus'+str(i) for i in range(self.n_loc)]
periods += ['delay', 'decision']
self.add_period(periods)
# Observations
self.add_ob(1, where='fixation')
self.set_ob(0, 'decision', where='fixation')
for i_loc in range(self.n_loc):
loc = locs[i_loc]
shape = shapes[i_loc]
periods = ['stimulus'+str(j) for j in range(i_loc, self.n_loc)]
self.add_ob(self.shapes[shape], periods, where='loc'+str(loc))
# Ground truth
self.set_groundtruth(ground_truth, period='decision', where='choice')
return trial
def _step(self, action):
new_trial = False
# rewards
reward = 0
gt = self.gt_now
# observations
if self.in_period('decision'):
if action != 0:
new_trial = True
if action == gt:
reward += self.rewards['correct']
self.performance = 1
else:
reward += self.rewards['fail']
else:
if action != 0: # action = 0 means fixating
new_trial = self.abort
reward += self.rewards['abort']
return self.ob_now, reward, False, {'new_trial': new_trial, 'gt': gt}