NeuroGym with Keras
Neurogym with Supervised Learning (keras)¶
NeuroGym is a comprehensive toolkit that allows training any network model on many established neuroscience tasks using Reinforcement Learning techniques. It includes working memory tasks, value-based decision tasks and context-dependent perceptual categorization tasks.
In this notebook we first show how to install the relevant toolbox.
We then show how to access the available tasks and their relevant information.
Finally we train an LSTM network on the Random Dots Motion task using standard supervised learning techniques (with Keras), and plot the results.
Installation¶
Google Colab: Uncomment and execute cell below when running this notebook on google colab.
Local: Follow these instructions and then run
pip install tensorflow
when running this notebook locally.
NOTE: tensorflow is pre-installed in Google Colab, but not typically part of the neurogym library.
# ! pip install neurogym
Import libraries¶
import warnings
from IPython.display import clear_output
import numpy as np
import neurogym as ngym
from neurogym.utils import plotting
# note that some system will show a warning that the lines below cannot be resolved/have missing imports.
# if you've installed the current package as instructed, these imports will work fine nonetheless.
from keras.models import Model
from keras.layers import Dense, LSTM, TimeDistributed, Input
clear_output()
warnings.filterwarnings('ignore')
Task, network, and training¶
# This settings is low to speed up testing; we recommend setting it to at least 2000
steps_per_epoch = 100
# Environment
task = 'PerceptualDecisionMaking-v0'
kwargs = {'dt': 100}
seq_len = 100
# Make supervised dataset
dataset = ngym.Dataset(task, env_kwargs=kwargs, batch_size=16,seq_len=seq_len)
env = dataset.env
obs_size = env.observation_space.shape[0]
act_size = env.action_space.n
# Model
num_h = 64
# from https://www.tensorflow.org/guide/keras/rnn
xin = Input(batch_shape=(None, None, obs_size), dtype='float32')
seq = LSTM(num_h, return_sequences=True)(xin)
mlp = TimeDistributed(Dense(act_size, activation='softmax'))(seq)
model = Model(inputs=xin, outputs=mlp)
model.summary()
model.compile(optimizer='Adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
# Train network
data_generator = (dataset() for _ in range(steps_per_epoch))
history = model.fit(data_generator, steps_per_epoch=steps_per_epoch)
Analysis¶
# n_trials is set to a low number to speed up testing; we recommend setting it to at least 200
n_trials = 50
perf = 0
for i in range(n_trials):
env.new_trial()
obs, gt = env.ob, env.gt
obs = obs[:, np.newaxis, :]
action_pred = model.predict(obs, verbose=0)
action_pred = np.argmax(action_pred, axis=-1)
perf += gt[-1] == action_pred[-1, 0]
if (i+1) % 10 == 0:
print(f"Completed trial {i+1}/{n_trials}")
perf /= n_trials
print(f"Performance: {perf} after {i+1} trials")
obs = np.squeeze(obs, axis=1) # remove the sequence dimension for plotting
action_pred = np.squeeze(action_pred, axis=1) # remove the sequence dimension for plotting
_ = ngym.utils.plotting.fig_(obs, action_pred, gt)