NeuroGym with RL
Reinforcement learning example with stable-baselines3¶
NeuroGym is a toolkit that allows training any network model on many established neuroscience tasks techniques such as standard Supervised Learning or Reinforcement Learning (RL). In this notebook we will use RL to train an LSTM network on the classical Random Dots Motion (RDM) task (Britten et al. 1992).
We first show how to install the relevant toolboxes. We then show how build the task of interest (in the example the RDM task), wrapp it with the pass-reward wrapper in one line and visualize the structure of the final task. Finally we train an LSTM network on the task using the A2C algorithm Mnih et al. 2016 implemented in the stable-baselines3 toolbox, and plot the results.
It is straightforward to change the code to train a network on any other available task or using a different RL algorithm (e.g. ACER, PPO2).
Installation¶
Google Colab: Uncomment and execute cell below when running this notebook on google colab.
Local: Follow these instructions when running this notebook locally.
# ! pip install neurogym[rl]
Task¶
here we build the Random Dots Motion task, specifying the duration of each trial period (fixation, stimulus, decision) and wrapp it with the pass-reward wrapper which appends the previous reward to the observation. We then plot the structure of the task in a figure that shows:
- The observations received by the agent (top panel).
- The actions taken by a random agent and the correct action at each timestep (second panel).
- The rewards provided by the environment at each timestep (third panel).
- The performance of the agent at each trial (bottom panel).
import gymnasium as gym
import neurogym as ngym
from neurogym.wrappers import pass_reward
import warnings
warnings.filterwarnings('ignore')
# Task name
name = 'PerceptualDecisionMaking-v0'
# task specification (here we only specify the duration of the different trial periods)
timing = {'fixation': ('constant', 300),
'stimulus': ('constant', 500),
'decision': ('constant', 300)}
kwargs = {'dt': 100, 'timing': timing}
# build task
env = gym.make(name, **kwargs)
# print task properties
print(env)
# wrapp task with pass-reward wrapper
env = pass_reward.PassReward(env)
# plot example trials with random agent
data = ngym.utils.plot_env(env, fig_kwargs={'figsize': (12, 12)}, num_steps=100, ob_traces=['Fixation cue', 'Stim 1', 'Stim 2', 'Previous reward'])
Train a network¶
import warnings
from stable_baselines3.common.policies import LstmPolicy # TODO: this no longer exists in stable_baselines3
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3 import A2C # ACER, PPO2
warnings.filterwarnings('default')
# Optional: PPO2 requires a vectorized environment to run
# the env is now wrapped automatically when passing it to the constructor
env = DummyVecEnv([lambda: env])
model = A2C(LstmPolicy, env, verbose=1, policy_kwargs={'feature_extraction':"mlp"})
model.learn(total_timesteps=100000, log_interval=1000)
env.close()
Visualize results¶
env = gym.make(name, **kwargs)
# print task properties
print(env)
# wrapp task with pass-reward wrapper
env = pass_reward.PassReward(env)
env = DummyVecEnv([lambda: env])
# plot example trials with random agent
data = ngym.utils.plot_env(env, fig_kwargs={'figsize': (12, 12)}, num_steps=100, ob_traces=['Fixation cue', 'Stim 1', 'Stim 2', 'Previous reward'], model=model)