Simple demo notebook
Exploring NeuroGym tasks¶
NeuroGym is a comprehensive toolkit that allows training any network model on many established neuroscience tasks using Reinforcement Learning techniques. It includes working memory tasks, value-based decision tasks and context-dependent perceptual categorization tasks.
In this notebook we first show how to install the relevant toolbox.
We then show how to access the available tasks and their relevant information.
Finally we train an LSTM network on the Random Dots Motion task using the A2C algorithm Mnih et al. 2016 implemented in the stable-baselines3 toolbox, and plot the results.
You can easily change the code to train a network on any other available task or using a different algorithm (e.g. ACER, PPO2).
Installation¶
Google Colab: Uncomment and execute cell below when running this notebook on google colab.
Local: Follow these instructions when running this notebook locally.
# ! pip install neurogym[rl]
Explore tasks¶
import warnings
warnings.filterwarnings('ignore')
import gymnasium as gym
import neurogym as ngym
from neurogym.utils import info, plotting
info.all_tasks()
Visualize a single task¶
task = 'GoNogo-v0'
env = gym.make(task)
print(env)
fig = plotting.plot_env(
env,
num_steps=100,
# def_act=0,
ob_traces=['Fixation cue', 'NoGo', 'Go'],
# fig_kwargs={'figsize': (12, 12)}
)
Explore wrappers¶
info.all_wrappers()
info.info_wrapper('TrialHistoryV2-v0', show_code=True)
Train a network¶
Here, we train a simple neural network on the task at hand. We use a configuration file to load the parameters for the monitor. You can refer to the documentation for more information about how to use the configuration system.
import warnings
warnings.filterwarnings('ignore')
import numpy as np
from neurogym.wrappers import monitor, TrialHistoryV2
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3 import A2C # ACER, PPO2
# task paremters
timing = {'fixation': ('constant', 300),
'stimulus': ('constant', 700),
'decision': ('constant', 300)}
kwargs = {'dt': 100, 'timing': timing}
# wrapper parameters
n_ch = 2
p = 0.8
num_blocks = 2
probs = np.array([[p, 1-p], [1-p, p]]) # repeating block
# Build the task
env = gym.make(task, **kwargs)
# Apply the wrapper.
env = TrialHistoryV2(env, probs=probs)
env = monitor.Monitor(env, config="config.toml")
# the env is now wrapped automatically when passing it to the constructor
model = A2C("MlpPolicy", env, verbose=1, policy_kwargs={'net_arch': [64, 64]})
model.learn(total_timesteps=env.config.agent.training.value)
env.close()
Visualize the results¶
import numpy as np
import matplotlib.pyplot as plt
# Create task
env = gym.make(task, **kwargs)
# Apply the wrapper
env = TrialHistoryV2(env, probs=probs)
env = DummyVecEnv([lambda: env])
fig = plotting.plot_env(
env,
num_steps=100,
# def_act=0,
ob_traces=['Fixation cue', 'NoGo', 'Go'],
# fig_kwargs={'figsize': (12, 12)},
model=model)