NeuroGym with RL
Neurogym with Reinforcement Learning (stable-baselines3)¶
NeuroGym is a toolkit that allows training any network model on many established neuroscience tasks techniques such as standard Supervised Learning or Reinforcement Learning (RL). In this notebook we will use RL to train an LSTM network on the classical Random Dots Motion (RDM) task (Britten et al. 1992).
We first show how to install the relevant toolboxes. We then show how build the task of interest (in the example the RDM task), wrapp it with the pass-reward wrapper in one line and visualize the structure of the final task. Finally we train an LSTM network on the task using the A2C algorithm Mnih et al. 2016 implemented in the stable-baselines3 toolbox, and plot the results.
It is straightforward to change the code to train a network on any other available task or using a different RL algorithm (e.g. ACER, PPO2).
Installation¶
Google Colab: Uncomment and execute cell below when running this notebook on google colab.
Local: Follow these instructions when running this notebook locally.
# ! pip install neurogym[rl]
Import libraries¶
import gymnasium as gym
import neurogym as ngym
from neurogym.wrappers import pass_reward
import warnings
from IPython.display import clear_output
clear_output()
warnings.filterwarnings('ignore')
Task¶
here we build the Random Dots Motion task, specifying the duration of each trial period (fixation, stimulus, decision) and wrapp it with the pass-reward wrapper which appends the previous reward to the observation. We then plot the structure of the task in a figure that shows:
- The observations received by the agent (top panel).
- The actions taken by a random agent and the correct action at each timestep (second panel).
- The rewards provided by the environment at each timestep (third panel).
- The performance of the agent at each trial (bottom panel).
# Task name
name = 'PerceptualDecisionMaking-v0'
# task specification (here we only specify the duration of the different trial periods)
timing = {
'fixation': ('constant', 300),
'stimulus': ('constant', 500),
'decision': ('constant', 300),
}
kwargs = {'dt': 100, 'timing': timing}
# build task
env = gym.make(name, **kwargs)
# print task properties
print(env)
# wrapp task with pass-reward wrapper
env = pass_reward.PassReward(env)
# plot example trials with random agent
_ = ngym.utils.plot_env(
env,
fig_kwargs={'figsize': (12, 12)},
num_steps=100,
ob_traces=['Fixation cue', 'Stim 1', 'Stim 2', 'Previous reward'],
)
Train a network¶
# these values are set low for testing purposes. To get a better sense of the package, we recommend setting
# `total_timesteps = 100_000`
total_timesteps = 500
log_interval = 500
from sb3_contrib import RecurrentPPO
from stable_baselines3.common.vec_env import DummyVecEnv
env = DummyVecEnv([lambda: env])
model = RecurrentPPO(
policy="MlpLstmPolicy",
env=env,
verbose=1,
)
model.learn(total_timesteps=total_timesteps, log_interval=log_interval)
env.close()
Visualize results¶
env = gym.make(name, **kwargs)
# print task properties
print(env)
# wrapp task with pass-reward wrapper
env = pass_reward.PassReward(env)
env = DummyVecEnv([lambda: env])
# plot example trials with random agent
_ = ngym.utils.plot_env(
env,
fig_kwargs={'figsize': (12, 12)},
num_steps=100,
ob_traces=['Fixation cue', 'Stim 1', 'Stim 2', 'Previous reward'],
model=model,
)