{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Understanding Neurogym Task\n", "\n", "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/neurogym/neurogym/blob/master/examples/understanding_neurogym_task.ipynb)\n", "\n", "This is a tutorial for understanding Neurogym task structure. Here we will go through\n", "1. Defining a basic OpenAI gym task\n", "2. Defining a basic trial-based neurogym task\n", "3. Adding observation and ground truth in neurogym tasks" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Installation\n", "\n", "Only needed if running in Google colab. Uncomment to run." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "scrolled": true }, "outputs": [], "source": [ "# # Install gym\n", "# ! pip install gym\n", "\n", "# # Install neurogym\n", "# ! git clone https://github.com/gyyang/neurogym.git\n", "# %cd neurogym/\n", "# ! pip install -e ." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### OpenAI gym tasks" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Neurogym tasks follow basic [OpenAI gym](https://gym.openai.com/) tasks format. Each task is defined as a Python class, inheriting from the ```gym.Env``` class.\n", "\n", "In this section we describe basic structure for an OpenAI gym task.\n", "\n", "In the ```__init__``` method, it is necessary to define two attributes, ```self.observation_space``` and ```self.action_space``` which describe the kind of spaces used by observations (network inputs) and actions (network outputs)." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Sample random observation value\n", "[0.28708524 0.2543813 ]\n", "Sample random action value\n", "1\n" ] } ], "source": [ "import warnings\n", "warnings.filterwarnings(\"ignore\") # to suppress warnings\n", "\n", "import numpy as np\n", "import gym\n", "\n", "class MyEnv(gym.Env):\n", " def __init__(self):\n", " super().__init__() # Python boilerplate to initialize base class\n", "\n", " # A two-dimensional box with minimum and maximum value set by low and high\n", " self.observation_space = gym.spaces.Box(low=0., high=1., shape=(2,))\n", " \n", " # A discrete space with 3 possible values (0, 1, 2)\n", " self.action_space = gym.spaces.Discrete(3)\n", " \n", "# Instantiate an environment\n", "env = MyEnv()\n", "print('Sample random observation value')\n", "print(env.observation_space.sample())\n", "print('Sample random action value')\n", "print(env.action_space.sample())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Another key method that needs to be defined is the ```step``` method, which updates the environment and outputs observations and rewards after receiving the agent's action.\n", "\n", "The ```step``` method takes ```action``` as inputs, and outputs \n", " the agent's next observation ```observation```,\n", " a scalar reward received by the agent ```reward```,\n", " a boolean describing whether the environment needs to be reset ```done```, and\n", " a dictionary holding any additional information ```info```.\n", " \n", "If the environment is described by internal states, the ```reset``` method would reset these internal states. This method returns an initial observation ```observation```." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "class MyEnv(gym.Env):\n", " def __init__(self):\n", " super().__init__() # Python boilerplate to initialize base class\n", " self.observation_space = gym.spaces.Box(low=-10., high=10., shape=(1,))\n", " self.action_space = gym.spaces.Discrete(3)\n", " \n", " def step(self, action):\n", " ob = self.observation_space.sample() # random sampling\n", " reward = 1. # reward\n", " done = False # never ending\n", " info = {} # empty dictionary\n", " return ob, reward, done, info\n", " \n", " def reset(self):\n", " ob = self.observation_space.sample()\n", " return ob" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Below we define a simple task where actions move an agent along a one-dimensional line. The reward is determined by the agent's location on this line." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Text(0, 0.5, 'Reward')" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "import matplotlib.pyplot as plt\n", "\n", "def get_reward(x):\n", " return np.sin(x) * np.exp(-np.abs(x)/3)\n", "\n", "xs = np.linspace(-10, 10, 100)\n", "plt.plot(xs, get_reward(xs))\n", "plt.xlabel('State value (observation)')\n", "plt.ylabel('Reward')" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "class MyEnv(gym.Env):\n", " def __init__(self):\n", " # A one-dimensional box with minimum and maximum value set by low and high\n", " self.observation_space = gym.spaces.Box(low=-10., high=10., shape=(1,))\n", " \n", " # A discrete space with 3 possible values (0, 1, 2)\n", " self.action_space = gym.spaces.Discrete(3)\n", " \n", " self.state = 0.\n", " \n", " def step(self, action):\n", " # Actions 0, 1, 2 correspond to state change of -0.1, 0, +0.1 \n", " self.state += (action - 1.) * 0.1\n", " self.state = np.clip(self.state, -10, 10)\n", " \n", " ob = self.state # observation\n", " reward = get_reward(self.state) # reward\n", " done = False # never ending\n", " info = {} # empty dictionary\n", " return ob, reward, done, info\n", " \n", " def reset(self):\n", " # Re-initialize state\n", " self.state = self.observation_space.sample()\n", " return self.state" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "An agent can interact with the environment iteratively." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[]" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "env = MyEnv()\n", "ob = env.reset()\n", "ob_log = list()\n", "reward_log = list()\n", "for i in range(1000):\n", " action = env.action_space.sample() # A random agent\n", " ob, reward, done, info = env.step(action)\n", " ob_log.append(ob)\n", " reward_log.append(reward)\n", " \n", "plt.plot(ob_log, reward_log)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Trial-based Neurogym Tasks\n", "\n", "Many neuroscience and cognitive science tasks have trial structure. ```neurogym.TrialEnv``` provides a class for common trial-based tasks. Its main difference from ```gym.Env``` is the ```_new_trial()``` method that generates abstract information about a new trial, and optionally, the observation and ground-truth output. Additionally, users provide a ```_step()``` method instead of ```step()```.\n", "\n", "The ```_new_trial()``` method takes any key-word arguments (```**kwargs```), and outputs a dictionary ```trial``` containing relevant information about this trial. This dictionary is accesible during ```_step``` as ```self.trial```.\n", "\n", "Here we define a simple task where the agent needs to make a binary decision on every trial based on its observation. Each trial is only one time step." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "import neurogym as ngym\n", "from neurogym import TrialEnv\n", "\n", "class MyTrialEnv(TrialEnv):\n", " def __init__(self):\n", " super().__init__()\n", " self.observation_space = gym.spaces.Box(low=-1., high=1., shape=(1,))\n", " self.action_space = gym.spaces.Discrete(2)\n", "\n", " self.next_ob = np.random.uniform(-1, 1, size=(1,))\n", "\n", " def _new_trial(self):\n", " ob = self.next_ob # observation previously computed\n", " # Sample observation for the next trial\n", " self.next_ob = np.random.uniform(-1, 1, size=(1,))\n", " \n", " trial = dict()\n", " # Ground-truth is 1 if ob > 0, else 0\n", " trial['ground_truth'] = (ob > 0) * 1.0\n", " \n", " return trial\n", " \n", " def _step(self, action):\n", " ob = self.next_ob\n", " # If action equals to ground_truth, reward=1, otherwise 0\n", " reward = (action == self.trial['ground_truth']) * 1.0\n", " done = False\n", " info = {'new_trial': True}\n", " return ob, reward, done, info" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Trial 0\n", "Received observation [0.45315209]\n", "Selected action 0\n", "Received reward [0.]\n", "Trial 1\n", "Received observation [0.418608]\n", "Selected action 0\n", "Received reward [0.]\n", "Trial 2\n", "Received observation [-0.30473682]\n", "Selected action 1\n", "Received reward [0.]\n", "Trial 3\n", "Received observation [0.94499442]\n", "Selected action 1\n", "Received reward [1.]\n", "Trial 4\n", "Received observation [-0.90813549]\n", "Selected action 1\n", "Received reward [0.]\n", "Trial 5\n", "Received observation [0.51512945]\n" ] } ], "source": [ "env = MyTrialEnv()\n", "ob = env.reset()\n", "\n", "print('Trial', 0)\n", "print('Received observation', ob)\n", "\n", "for i in range(5):\n", " action = env.action_space.sample() # A random agent\n", " print('Selected action', action)\n", " ob, reward, done, info = env.step(action)\n", " print('Received reward', reward)\n", " print('Trial', i+1)\n", " print('Received observation', ob)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Including time, period, and observation in trial-based tasks\n", "\n", "Most neuroscience and cognitive science tasks follow additional temporal structures that are incorporated into ```neurogym.TrialEnv```. These tasks typically\n", "1. Are described in real time instead of discrete time steps. For example, the task can last 3 seconds.\n", "2. Contain multiple time periods in each trial, such as a stimulus period and a response period.\n", "\n", "To include these features, neurogym tasks typically support setting the time length of each step in ```dt``` (in ms), and the time length of each time period in ```timing```.\n", "\n", "For example, consider the following binary decision-making task with a 500ms stimulus period, followed by a 500ms decision period. The periods are added to each trial through ```self.add_period()``` in ```self._new_trial()```. During ```_step()```, you can check which period the task is currently in with ```self.in_period(period_name)```." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "class MyDecisionEnv(TrialEnv):\n", " def __init__(self, dt=100, timing=None):\n", " super().__init__(dt=dt) # dt is passed to base task\n", " \n", " # Setting default task timing\n", " self.timing = {'stimulus': 500, 'decision': 500}\n", " # Update timing if provided externally\n", " if timing:\n", " self.timing.update(timing)\n", " \n", " self.observation_space = gym.spaces.Box(low=-1., high=1., shape=(1,))\n", " self.action_space = gym.spaces.Discrete(2)\n", "\n", " def _new_trial(self):\n", " # Setting time periods for this trial\n", " periods = ['stimulus', 'decision']\n", " # Will add stimulus and decision periods sequentially using self.timing info\n", " self.add_period(periods)\n", "\n", " # Sample observation for the next trial\n", " stimulus = np.random.uniform(-1, 1, size=(1,))\n", " \n", " trial = dict()\n", " trial['stimulus'] = stimulus\n", " # Ground-truth is 1 if stimulus > 0, else 0\n", " trial['ground_truth'] = (stimulus > 0) * 1.0\n", " \n", " return trial\n", " \n", " def _step(self, action):\n", " # Check if the current time step is in stimulus period\n", " if self.in_period('stimulus'):\n", " ob = np.array([self.trial['stimulus']])\n", " reward = 0. # no reward\n", " else:\n", " ob = np.array([0.]) # no observation\n", " # If action equals to ground_truth, reward=1, otherwise 0\n", " reward = (action == self.trial['ground_truth']) * 1.0\n", " \n", " done = False\n", " # By default, the trial is not ended\n", " info = {'new_trial': False}\n", " return ob, reward, done, info" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Running the environment with a random agent and plotting the agent's observation, action, and rewards" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# Logging\n", "log = {'ob': [], 'action': [], 'reward': []}\n", "\n", "env = MyDecisionEnv(dt=100)\n", "ob = env.reset()\n", "log['ob'].append(ob)\n", "for i in range(30):\n", " action = env.action_space.sample() # A random agent\n", " ob, reward, done, info = env.step(action)\n", " \n", " log['action'].append(action)\n", " log['ob'].append(ob)\n", " log['reward'].append(reward)\n", " \n", "log['ob'] = log['ob'][:-1] # exclude last observation\n", "# Visualize\n", "f, axes = plt.subplots(3, 1, sharex=True)\n", "for ax, key in zip(axes, ['ob', 'action', 'reward']):\n", " ax.plot(log[key], 'o-')\n", " ax.set_ylabel(key)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Setting observation and ground-truth at the beginning of each trial\n", "\n", "In many tasks, the observation and ground-truth are pre-determined for each trial, and can be set in ```self._new_trial()```. The generated observation and ground-truth can then be used as inputs and targets for supervised learning.\n", "\n", "Observation and ground_truth can be set in ```self._new_trial()``` with the ```self.add_ob()``` and ```self.set_groundtruth``` methods. Users can specify the period and location of the observation using their names. For example, ```self.add_ob(1, period='stimulus', where='fixation')```.\n", "\n", "This allows the users to access the observation and groundtruth of the entire trial with ```self.ob``` and ```self.gt```, and access their values with ```self.ob_now``` and ```self.gt_now```.\n" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "class MyDecisionEnv(TrialEnv):\n", " def __init__(self, dt=100, timing=None):\n", " super().__init__(dt=dt) # dt is passed to base task\n", " \n", " # Setting default task timing\n", " self.timing = {'stimulus': 500, 'decision': 500}\n", " # Update timing if provided externally\n", " if timing:\n", " self.timing.update(timing)\n", " \n", " # Here we use ngym.spaces, which allows setting name of each dimension\n", " name = {'fixation': 0, 'stimulus': 1}\n", " self.observation_space = ngym.spaces.Box(\n", " low=-1., high=1., shape=(2,), name=name)\n", " name = {'fixation': 0, 'choice': [1, 2]}\n", " self.action_space = ngym.spaces.Discrete(3, name=name)\n", "\n", " def _new_trial(self):\n", " # Setting time periods for this trial\n", " periods = ['stimulus', 'decision']\n", " # Will add stimulus and decision periods sequentially using self.timing info\n", " self.add_period(periods)\n", "\n", " # Sample observation for the next trial\n", " stimulus = np.random.uniform(-1, 1, size=(1,))\n", " \n", " # Add value 1 to stimulus period at fixation location\n", " self.add_ob(1, period='stimulus', where='fixation')\n", " # Add value stimulus to stimulus period at stimulus location\n", " self.add_ob(stimulus, period='stimulus', where='stimulus')\n", " \n", " # Set ground_truth\n", " groundtruth = int(stimulus > 0)\n", " self.set_groundtruth(groundtruth, period='decision', where='choice')\n", " \n", " trial = dict()\n", " trial['stimulus'] = stimulus\n", " trial['ground_truth'] = groundtruth\n", " \n", " return trial\n", " \n", " def _step(self, action):\n", " # self.ob_now and self.gt_now correspond to\n", " # current step observation and groundtruth\n", "\n", " # If action equals to ground_truth, reward=1, otherwise 0\n", " reward = (action == self.gt_now) * 1.0\n", " \n", " done = False\n", " # By default, the trial is not ended\n", " info = {'new_trial': False}\n", " return self.ob_now, reward, done, info" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Sampling one trial. The trial observation and ground-truth can be used for supervised learning." ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Trial information {'stimulus': array([-0.77779679]), 'ground_truth': 0}\n", "Observation shape is (N_time, N_unit) = (10, 2)\n", "Groundtruth shape is (N_time,) = (10,)\n" ] } ], "source": [ "env = MyDecisionEnv()\n", "_ = env.reset()\n", "\n", "trial = env.new_trial()\n", "ob, gt = env.ob, env.gt\n", "\n", "print('Trial information', trial)\n", "print('Observation shape is (N_time, N_unit) =', ob.shape)\n", "print('Groundtruth shape is (N_time,) =', gt.shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Visualizing the environment with a helper function." ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# Run the environment for 2 trials using a random agent.\n", "fig = ngym.utils.plot_env(env, num_trials=2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### An example perceptual decision-making task\n", "\n", "Using the above style, we can define a simple perceptual decision-making task (the PerceptualDecisionMaking task from neurogym)." ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "class PerceptualDecisionMaking(ngym.TrialEnv):\n", " \"\"\"Two-alternative forced choice task in which the subject has to\n", " integrate two stimuli to decide which one is higher on average.\n", "\n", " Args:\n", " stim_scale: Controls the difficulty of the experiment. (def: 1., float)\n", " sigma: float, input noise level\n", " dim_ring: int, dimension of ring input and output\n", " \"\"\"\n", " metadata = {\n", " 'paper_link': 'https://www.jneurosci.org/content/12/12/4745',\n", " 'paper_name': '''The analysis of visual motion: a comparison of\n", " neuronal and psychophysical performance''',\n", " 'tags': ['perceptual', 'two-alternative', 'supervised']\n", " }\n", "\n", " def __init__(self, dt=100, rewards=None, timing=None, stim_scale=1.,\n", " sigma=1.0, dim_ring=2):\n", " super().__init__(dt=dt)\n", " # The strength of evidence, modulated by stim_scale\n", " self.cohs = np.array([0, 6.4, 12.8, 25.6, 51.2]) * stim_scale\n", " self.sigma = sigma / np.sqrt(self.dt) # Input noise\n", "\n", " # Rewards\n", " self.rewards = {'abort': -0.1, 'correct': +1., 'fail': 0.}\n", " if rewards:\n", " self.rewards.update(rewards)\n", "\n", " self.timing = {\n", " 'fixation': 100,\n", " 'stimulus': 2000,\n", " 'delay': 0,\n", " 'decision': 100}\n", " if timing:\n", " self.timing.update(timing)\n", "\n", " self.abort = False\n", "\n", " self.theta = np.linspace(0, 2*np.pi, dim_ring+1)[:-1]\n", " self.choices = np.arange(dim_ring)\n", "\n", " name = {'fixation': 0, 'stimulus': range(1, dim_ring+1)}\n", " self.observation_space = ngym.spaces.Box(\n", " -np.inf, np.inf, shape=(1+dim_ring,), dtype=np.float32, name=name)\n", " name = {'fixation': 0, 'choice': range(1, dim_ring+1)}\n", " self.action_space = ngym.spaces.Discrete(1+dim_ring, name=name)\n", "\n", " def _new_trial(self, **kwargs):\n", " # Trial info\n", " trial = {\n", " 'ground_truth': self.rng.choice(self.choices),\n", " 'coh': self.rng.choice(self.cohs),\n", " }\n", " trial.update(kwargs)\n", "\n", " coh = trial['coh']\n", " ground_truth = trial['ground_truth']\n", " stim_theta = self.theta[ground_truth]\n", "\n", " # Periods\n", " self.add_period(['fixation', 'stimulus', 'delay', 'decision'])\n", "\n", " # Observations\n", " self.add_ob(1, period=['fixation', 'stimulus', 'delay'], where='fixation')\n", " stim = np.cos(self.theta - stim_theta) * (coh/200) + 0.5\n", " self.add_ob(stim, 'stimulus', where='stimulus')\n", " self.add_randn(0, self.sigma, 'stimulus', where='stimulus')\n", "\n", " # Ground truth\n", " self.set_groundtruth(ground_truth, period='decision', where='choice')\n", "\n", " return trial\n", "\n", " def _step(self, action):\n", " new_trial = False\n", " # rewards\n", " reward = 0\n", " gt = self.gt_now\n", " # observations\n", " if self.in_period('fixation'):\n", " if action != 0: # action = 0 means fixating\n", " new_trial = self.abort\n", " reward += self.rewards['abort']\n", " elif self.in_period('decision'):\n", " if action != 0:\n", " new_trial = True\n", " if action == gt:\n", " reward += self.rewards['correct']\n", " self.performance = 1\n", " else:\n", " reward += self.rewards['fail']\n", "\n", " return self.ob_now, reward, False, {'new_trial': new_trial, 'gt': gt}" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "env = PerceptualDecisionMaking(dt=20)\n", "fig = ngym.utils.plot_env(env, num_trials=2)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.7" } }, "nbformat": 4, "nbformat_minor": 4 }