{ "cells": [ { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "4zW6CU8F69Zp" }, "source": [ "## Keras example of supervised learning a NeuroGym task" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "xIE7qx_a7D0e" }, "source": [ "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/neurogym/neurogym/blob/master/examples/example_neurogym_keras.ipynb)\n", "\n", "\n", "NeuroGym is a comprehensive toolkit that allows training any network model on many established neuroscience tasks using Reinforcement Learning techniques. It includes working memory tasks, value-based decision tasks and context-dependent perceptual categorization tasks.\n", "\n", "In this notebook we first show how to install the relevant toolbox. \n", "\n", "We then show how to access the available tasks and their relevant information.\n", "\n", "Finally we train an LSTM network on the Random Dots Motion task using standard supervised learning techniques (with Keras), and plot the results." ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "DS1FhYPG38WO" }, "source": [ "### Installation when used on Google Colab" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 561 }, "colab_type": "code", "id": "Q-4MfhnP4AgL", "outputId": "609a4b19-e5c5-4ff9-ca31-52b1b40f230d" }, "outputs": [], "source": [ "%tensorflow_version 1.x\n", "# Install gym\n", "! pip install gym\n", "# Install neurogym\n", "! git clone https://github.com/gyyang/neurogym.git\n", "%cd neurogym/\n", "! pip install -e ." ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "OCFMPbzX38Wj" }, "source": [ "### Task, network, and training" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "code_folding": [], "colab": { "base_uri": "https://localhost:8080/", "height": 496 }, "colab_type": "code", "id": "jAxTPbzL38Wl", "outputId": "a7a65c14-5e6b-40ab-9f85-2ef8db4355e1" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/Users/gryang/tf1/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:516: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", " _np_qint8 = np.dtype([(\"qint8\", np.int8, 1)])\n", "/Users/gryang/tf1/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:517: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", " _np_quint8 = np.dtype([(\"quint8\", np.uint8, 1)])\n", "/Users/gryang/tf1/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:518: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", " _np_qint16 = np.dtype([(\"qint16\", np.int16, 1)])\n", "/Users/gryang/tf1/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:519: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", " _np_quint16 = np.dtype([(\"quint16\", np.uint16, 1)])\n", "/Users/gryang/tf1/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:520: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", " _np_qint32 = np.dtype([(\"qint32\", np.int32, 1)])\n", "/Users/gryang/tf1/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:525: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", " np_resource = np.dtype([(\"resource\", np.ubyte, 1)])\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /Users/gryang/tf1/lib/python3.7/site-packages/tensorflow/python/ops/init_ops.py:1251: calling VarianceScaling.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Call initializer instance with the dtype argument instead of passing it to the constructor\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/Users/gryang/tf1/lib/python3.7/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:541: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", " _np_qint8 = np.dtype([(\"qint8\", np.int8, 1)])\n", "/Users/gryang/tf1/lib/python3.7/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:542: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", " _np_quint8 = np.dtype([(\"quint8\", np.uint8, 1)])\n", "/Users/gryang/tf1/lib/python3.7/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:543: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", " _np_qint16 = np.dtype([(\"qint16\", np.int16, 1)])\n", "/Users/gryang/tf1/lib/python3.7/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:544: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", " _np_quint16 = np.dtype([(\"quint16\", np.uint16, 1)])\n", "/Users/gryang/tf1/lib/python3.7/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:545: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", " _np_qint32 = np.dtype([(\"qint32\", np.int32, 1)])\n", "/Users/gryang/tf1/lib/python3.7/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:550: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", " np_resource = np.dtype([(\"resource\", np.ubyte, 1)])\n", "/Users/gryang/Dropbox/Code/MyPython/neurogym/neurogym/core.py:253: UserWarning: Warning: Time for period fixation 100.000000 lasts only one timestep. Agents will not have time to respond (e.g. make a choice) on time.\n", " ' time to respond (e.g. make a choice) on time.')\n", "/Users/gryang/Dropbox/Code/MyPython/neurogym/neurogym/core.py:253: UserWarning: Warning: Time for period decision 100.000000 lasts only one timestep. Agents will not have time to respond (e.g. make a choice) on time.\n", " ' time to respond (e.g. make a choice) on time.')\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Model: \"model\"\n", "_________________________________________________________________\n", "Layer (type) Output Shape Param # \n", "=================================================================\n", "input_1 (InputLayer) [(None, None, 3)] 0 \n", "_________________________________________________________________\n", "lstm (LSTM) (None, None, 64) 17408 \n", "_________________________________________________________________\n", "time_distributed (TimeDistri (None, None, 3) 195 \n", "=================================================================\n", "Total params: 17,603\n", "Trainable params: 17,603\n", "Non-trainable params: 0\n", "_________________________________________________________________\n", "WARNING:tensorflow:From /Users/gryang/tf1/lib/python3.7/site-packages/tensorflow/python/ops/math_grad.py:1250: add_dispatch_support..wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.where in 2.0, which has the same broadcast rule as np.where\n", " 7/2000 [..............................] - ETA: 5:35 - loss: 1.0581 - acc: 0.4196" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/Users/gryang/Dropbox/Code/MyPython/neurogym/neurogym/core.py:253: UserWarning: Warning: Time for period fixation 100.000000 lasts only one timestep. Agents will not have time to respond (e.g. make a choice) on time.\n", " ' time to respond (e.g. make a choice) on time.')\n", "/Users/gryang/Dropbox/Code/MyPython/neurogym/neurogym/core.py:253: UserWarning: Warning: Time for period decision 100.000000 lasts only one timestep. Agents will not have time to respond (e.g. make a choice) on time.\n", " ' time to respond (e.g. make a choice) on time.')\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "2000/2000 [==============================] - 106s 53ms/step - loss: 0.0569 - acc: 0.9835\n" ] } ], "source": [ "import warnings\n", "\n", "import numpy as np\n", "import neurogym as ngym\n", "\n", "from tensorflow.keras.models import Model\n", "from tensorflow.keras.layers import Dense, LSTM, TimeDistributed, Input\n", "\n", "warnings.filterwarnings('ignore')\n", "warnings.filterwarnings('default')\n", "\n", "# Environment\n", "task = 'PerceptualDecisionMaking-v0'\n", "kwargs = {'dt': 100}\n", "seq_len = 100\n", "\n", "# Make supervised dataset\n", "dataset = ngym.Dataset(task, env_kwargs=kwargs, batch_size=16,\n", " seq_len=seq_len)\n", "env = dataset.env\n", "obs_size = env.observation_space.shape[0]\n", "act_size = env.action_space.n\n", "\n", "# Model \n", "num_h = 64\n", "# from https://www.tensorflow.org/guide/keras/rnn\n", "xin = Input(batch_shape=(None, None, obs_size), dtype='float32')\n", "seq = LSTM(num_h, return_sequences=True, time_major=True)(xin)\n", "mlp = TimeDistributed(Dense(act_size, activation='softmax'))(seq)\n", "model = Model(inputs=xin, outputs=mlp)\n", "model.summary()\n", "model.compile(optimizer='Adam', loss='sparse_categorical_crossentropy',\n", " metrics=['accuracy'])\n", "\n", "# Train network\n", "steps_per_epoch = 2000\n", "data_generator = (dataset() for i in range(steps_per_epoch))\n", "history = model.fit(data_generator, steps_per_epoch=steps_per_epoch)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Analysis" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 122 }, "colab_type": "code", "id": "iaeFA3oklR99", "outputId": "f9f027df-f67e-41a6-d5d3-dd2d98555d6e" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.895\n" ] } ], "source": [ "perf = 0\n", "num_trial = 200\n", "for i in range(num_trial):\n", " env.new_trial()\n", " obs, gt = env.obs, env.gt\n", " obs = obs[:, np.newaxis, :]\n", "\n", " action_pred = model.predict(obs)\n", " action_pred = np.argmax(action_pred, axis=-1)\n", " perf += gt[-1] == action_pred[-1, 0]\n", "\n", "perf /= num_trial\n", "print(perf)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 225 }, "colab_type": "code", "id": "TdpkVOLhuRAK", "outputId": "e1abdc87-eaff-4e32-d395-1af3513a80ff" }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWAAAADQCAYAAAA53LuNAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3deXxU9bn48c+TnSQsYV8CBBRQtiA7iKDggkpxp1iqovXyokqv9r4ulV5bvfK79uLS2t66iwq1NqKooAVXUKgbEhBQFtkaJOxhh5CQ5fn9cU7SyUzWYWbOJDzv12temTnnfM/3mTPfeXLme875HlFVjDHGRF6M1wEYY8zZyhKwMcZ4xBKwMcZ4xBKwMcZ4xBKwMcZ4JM7rAOoqQRI1iZQ6lytuVfcycSdL6lwGoDg1NqhypcEVI+5U3c9kKY2T4CoL8l92SVKQZ9uU1j3OhOPB1RXsNpEgqtNgN39xcO+tuFFwFQbTJhOPlgZVV2FacI0rpiioYsQfDyLOk6eCqus4h/NUtZX/9HqXgJNIYYiMqXO5vBuG1blMq+xjdS4DsPfCpkGVOx1cMVp+W1znMvmtgsv2wX6Rj/ase4wAMfl1/1KmfxJcAjjVIrhtEldQ96RYFOR2TNkX3HY8kBkfVLnTzer+3jLeDS5Jbb8hKahySfuDS9zpS47XvdDX3wZV18c6f0dl060LwhhjPGIJ2BhjPGIJ2BhjPGIJ2BhjPGIJ2BhjPOJ5AhaRFBGJcZ93F5HxIhLcIVtjjKlHPE/AwHIgSUQ6AB8CtwBzPI3IGGMiIBoSsKhqPnA98LSq3gT08jgmY4wJu6hIwCIyDJgELHKnBXlNmDHG1B/RkIDvAX4NvK2q60WkK/CJxzEZY0zYeX4psqoux+kHLnu9Hfh37yIyxpjI8DwBi0h34D+BDHziUdXRXsVkjDGR4HkCBt4AngVmA8ENP2aMMfVQNCTgYlV9xusgjDEm0qLhINy7InKXiLQTkeZlD6+DMsaYcIuGPeDb3L/TfaYp0NWDWIwxJmI8T8Cq2sXrGIwxxgueJ2B33IefAyPdSZ8Cz6lqkDcaMcaY+sHzBAw8A8QDT7uvb3Gn3elZRMYYEwHRkIAHqWqmz+ulIrLWs2iMMSZCouEsiBIROafshXspsp0PbIxp8KJhD3g68ImIbAcE6Azc7m1IxhgTfp4nYFVdIiLdgB7upO9VtdDLmIwxJhI8S8AiMlpVl4rI9X6zzhURVPWtUNYXW1D3MltuaRxcZaWlQRWLOyVBlds5tu5l2i/VoOoquvpIUOUSvmsWVLnT7et+MszOK4IczTTIz63lOYfqXObggSZB1SWfJgRVLr9zcVDlUnLqniJyRycHVZcEuf2LU4Nry1snptS5zLlfB1VVlbzcAx4FLAV+VMk8BUKagI0xJtp4loBV9UH36UxV/afvPBGxizOMMQ1eNJwF8WYl0+ZHPApjjIkwL/uAz8O591tTv37gJkCSN1EZY0zkeNkH3AMYBzSjYj/wceDfPInIGGMiyMs+4IXAQhEZpqpfehWHMcZ4xfPzgIFvRORunO6I8q4HVb3Du5CMMSb8ouEg3CtAW+AKYBmQjtMNYYwxDVo0JOBzVfW3wElVnQtcDQzxOCZjjAm7aEjAZZc6HRGR3kBToLWH8RhjTEREQx/w8yKSBvwWeAdIdZ8bY0yDFg0J+GVVLcHp/7X7wBljzhrR0AXxTxF5XkTGiEhwo9EYY0w9FA0J+DzgY+BuIEdEnhSRER7HZIwxYed5AlbVfFV9XVWvB/rhXIq8zOOwjDEm7DxPwAAiMkpEngZW4VyMMcHjkIwxJuw8PwgnIjnAN8DrwHRVPeltRMYYExmeJmARiQVeUtWZXsZhjDFe8LQLwj39bJyXMRhjjFc874IAPheRJ4F5QHn3g6qu9i4kY4wJv2hIwP3cv77dEAqM9iAWY4yJGM8TsKpe4nUMxhjjBc9PQxORNiLyooi8577uKSI/8zouY4wJN88TMDAH+ABo777eDNzrWTTGGBMhnndBAC1V9XUR+TWAqhaLSElVCzfuWcpF8wrqXMmIlOfqXGb6xhvrXAbgnLS8oMo1jisMqlz3lL11LnP9lWuDquv9k+cHVe6viYODKte5yeE6l3mty9Kg6vqfvPOCKjepaXady/xu7xVB1bWhbZugyv3pnPeDKvfUD3U/FBMbUxpUXX/u+npQ5S5dek9Q5S7ssa3OZQ4EVVPVomEP+KSItMA58IaIDAWOehuSMcaEXzTsAf8HzjjA54jI50ArILhdT2OMqUc8T8CqulpERuHcpl6A71W1qIZixhhT73neBSEiNwGNVHU9cC0wT0T6exyWMcaEnecJGPitqh53xwAeA7wIPONxTMYYE3bRkIDLzni4GnhBVRcBCR7GY4wxERENCXiXiDwH/BhYLCKJREdcxhgTVtGQ6CbgXIhxhaoeAZoD070NyRhjws/zBKyq+UAOcKWI/AJop6ofehuVMcaEn+cJWEQeAOYCLYCWwMsi8htvozLGmPDz/DxgYBKQqaoFACIyC1gD/I+nURljTJh5vgcM7Ma5EWeZRGCXR7EYY0zEeLYHLCJ/xhn/4SiwXkQ+cmddCnztVVzGGBMpXnZBlA0htQFYgpOMi4FPPIvIGGMiyMsE/DfgYeAOYAfOOBCdgJeB//IwLmOMiQgv+4AfBdKALqo6QFX7A12BpsBjHsZljDER4WUCHgdMUdXjZRNU9Rjwc5zLko0xpkHzMgGrqmolE0twB2c3xpiGzMsEvEFEbvWfKCI/BTZ5EI8xxkSUlwfh7gbeEpE7gFXutIFAI+A6z6IyxpgI8SwBq+ouYIiIjAZ6uZMXq+oSr2IyxphI8vxSZFVdCgR3G1tjjKnHouFSZGOMOStZAjbGGI9IJWeCRTUROYBz5VxlWgJ5EQynKhZHoGiJxeIIFC2xNOQ4OqtqK/+J9S4BV0dEslV1oMURXXFA9MRicQSKlljOxjisC8IYYzxiCdgYYzzS0BLw814H4LI4AkVLLBZHoGiJ5ayLo0H1ARtjTH3S0PaAjTGm3rAEbIwxHrEEbIwxHrEEbIwxHrEEbIwxHrEEbIwxHrEEbIwxHrEEbIwxHrEEbIwxHgnbHTFE5CWcW8/vV9XelcwX4E/AVUA+MFlVV9e03pYtW2pGRkaIozXGmDOz71gBbZokVTpv1apVeZUNRxnOWxLNAZ4E/lLF/CuBbu5jCPCM+7daGRkZZGdnhyhEY4wJjYwZi8iedXWl80Sk0jHMw9YFoarLgUPVLHIN8Bd1fAU0E5F24YrHGGPCobS0lFF/mI9S93F1vLwpZwdgp8/rXHfaHv8FRWQKMAWgU6dOEQnOGGNq8sRHm/n90i/YlXgHLWL/nYwZAsA9Y7rxy8u611je87si14aqPo87RNzAgQNt+DZjTFT45WXdObd1KndnzQMgp4ouiKp4mYB3AR19Xqe70+qsqKiI3NxcCgoKQhKYOTNJSUmkp6cTHx/vdSjGhN23u46SFJvK6ZLSOpf1MgG/A0wTkddwDr4dVdWA7ofayM3NpXHjxmRkZOCcXGG8oqocPHiQ3NxcunTp4nU4xoTdKxseplFae36eeWedy4btIJyIZAFfAj1EJFdEfiYiU0VkqrvIYmA7sBV4Abgr2LoKCgpo0aKFJd8oICK0aNHCfo2Ys0JxSQmbT7xNXNKOWvX5+gvbHrCq3lzDfAXuDlV9lnyjh30W5mzx6fZ1lEo+g9sPCqq8XQkXQgsWLEBE2LRpU7XLzZkzh927d5e/vvPOO9mwYUO4wzPGhNjiTZ8DcEX34UGVP6sT8BMfbQ7p+rKyshgxYgRZWVnVLuefgGfPnk3Pnj1DGosxJvy+yl2JaAJjewwMqvxZnYD/tGRLyNZ14sQJPvvsM1588UVee+218umPPPIIffr0ITMzkxkzZjB//nyys7OZNGkS/fr149SpU1x88cXlV/dlZWXRp08fevfuzX333Ve+ntTUVO6//34yMzMZOnQo+/btA+CNN96gd+/eZGZmMnLkyJC9H2NMzfJOFNAqoR9J8QlBla8X5wHXxUPvrmfD7mO1Xv7Hz31Z4zI92zfhwR/1qnaZhQsXMnbsWLp3706LFi1YtWoV+/fvZ+HChaxYsYLk5GQOHTpE8+bNefLJJ3n88ccZOLDif83du3dz3333sWrVKtLS0rj88stZsGAB1157LSdPnmTo0KE8/PDD/OpXv+KFF17gN7/5DTNnzuSDDz6gQ4cOHDlypNbv2xhzZkpKlbjjt/PzgR1rXrgKZ90ecO7hfFb88xAr/ulcJV32PPdw/hmtNysri4kTJwIwceJEsrKy+Pjjj7n99ttJTk4GoHnz5tWuY+XKlVx88cW0atWKuLg4Jk2axPLlywFISEhg3LhxAAwYMICcnBwALrzwQiZPnswLL7xASUnJGb0HY0ztbT9wgvzTJfTp0DTodTS4PeCa9lR9ZcxYVOcrVypz6NAhli5dyrfffouIUFJSgohw0003nfG6y8THx5efXRAbG0txcTEAzz77LCtWrGDRokUMGDCAVatW0aJFi5DVa4yp3FMr5rA78Y+0a/5B0Os46/aAw2H+/Pnccsst7Nixg5ycHHbu3EmXLl1o2rQpL7/8Mvn5zt71oUPOXnfjxo05fvx4wHoGDx7MsmXLyMvLo6SkhKysLEaNGlVt3du2bWPIkCHMnDmTVq1asXPnzmqXN8aExpe5X1EsuxnYMSPodZzVCfieMd1Csp6srCyuu+66CtNuuOEG9uzZw/jx4xk4cCD9+vXj8ccfB2Dy5MlMnTq1/CBcmXbt2jFr1iwuueQSMjMzGTBgANdcc021dU+fPr38oN3w4cPJzMwMyXsyxlRv6+F1tEjoQUJc8B0J4lwPUX8MHDhQ/ccD3rhxI+eff75HEZnK2GdiGrKCotMkP9yYoW0m8sXP59a4vIisUtWAc9XO6j1gY4wJxvvfZ6NymqEdg7sCrowlYGOMqaOt+/NJLr6I8edVf4ymJpaAjTGmjo4ca0PXmPsZ2TXgdpd1UqcELCJpItL3jGo0xph6LvuHHfTu0JSYmDMbeKrGBCwin4pIExFpDqwGXhCRP5xRrcYYU0+dKCzgw0PXsV9eOeN11WYPuKmqHgOux7mJ5hDg0jOu2Rhj6qHFm1agUszADn3OeF21ScBx7t2KJwB/P+MaG6h9+/bxk5/8hK5duzJgwACGDRvG22+/HfE4MjIyyMvLC5j+u9/9Lqj1LViwoMJQmb4DBxlzNvpwizN+zI/Ou+iM11WbBDwT+ADYqqorRaQrELphxBoAVeXaa69l5MiRbN++nVWrVvHaa6+Rm5sbsGzZJcSRVlUCVlVKS6u+l5V/AjbmbJe9O5tYUhmecebnudeYgFX1DVXtq6p3ua+3q+oNZ1xzA7J06VISEhKYOnVq+bTOnTvzi1/8AnDG/x0/fjyjR49mzJgxqCrTp0+nd+/e9OnTh3nznDuqfvrpp+UD7gBMmzaNOXPmAM6e7YMPPkj//v3p06dP+aDvBw8e5PLLL6dXr17ceeedVHZhzYwZMzh16hT9+vVj0qRJ5OTk0KNHD2699VZ69+7Nzp07SU1NLV9+/vz5TJ48mS+++IJ33nmH6dOn069fP7Zt2wY4Q2AOHjyY7t27849//CO0G9OYKPfPY+tondiTmJgzP4msxmvoRKQV8G9Ahu/yqnrHGdceJhfPuThg2oReE7hr0F3kF+Vz1atXBcyf3G8yk/tNJi8/jxtfv7HCvE8nf1ptfevXr6d///7VLrN69WrWrVtH8+bNefPNN1mzZg1r164lLy+PQYMG1Wos35YtW7J69WqefvppHn/8cWbPns1DDz3EiBEjeOCBB1i0aBEvvvhiQLlZs2bx5JNPsmbNGgBycnLYsmULc+fOZejQoVXWN3z4cMaPH8+4ceO48cZ/bZPi4mK+/vprFi9ezEMPPcTHH39cY+zGNAQFRSUkFVzHFb0yQrK+2qTwhUBT4GNgkc/DVOHuu+8mMzOTQYP+dZXMZZddVj4c5WeffcbNN99MbGwsbdq0YdSoUaxcubLG9V5//fVAxeEoly9fzk9/+lMArr76atLS0moVY+fOnatNvnWNw5izwaa9x2lUPJIf97k2JOurzSgSyap6X82LRY/q9liT45Ornd8yuWWNe7z+evXqxZtvvln++qmnniIvL6/CgOspKSk1ricuLq5Cf6z/nYUTExOBisNRBss/Ht8badZ0R+NQxmFMffL+ppWclm307nBJSNZXmz3gv4tI4G92U2706NEUFBTwzDPPlE8rG4KyMhdddBHz5s2jpKSEAwcOsHz5cgYPHkznzp3ZsGEDhYWFHDlyhCVLltRY98iRI/nb3/4GwHvvvcfhw4crXS4+Pp6ioqIq19OmTRs2btxIaWlphbM3qho605iz0asbnuJA0gO0b5oUkvXVJgHfg5OEC0TkuPuo/T1/zgIiwoIFC1i2bBldunRh8ODB3HbbbTzyyCOVLn/dddfRt29fMjMzGT16NI8++iht27alY8eOTJgwgd69ezNhwgQuuOCCGut+8MEHWb58Ob169eKtt96iU6dOlS43ZcoU+vbty6RJkyqdP2vWLMaNG8fw4cNp165d+fSJEyfy2GOPccEFF5QfhDPmbJVz7FvaJIXmABzYcJQmTOwzMQ3NwZPHaflYMy5Ln8qHdz5Vp7JnNByliIwXkcfdx7iaS5SXGysi34vIVhGZUcn8ySJyQETWuI87a7tuY4yJpIUbPgcpZUTGkJCtszanoc0CBgGvupPuEZELVfXXNZSLBZ4CLgNygZUi8o6q+p/VP09Vp9U9dGOMiZyPtzpXwF3Tc0TI1lmbsyCuAvqpaimAiMwFvgGqTcDAYJyr57a75V4DrgHssipjTL3TTC+lR0wqme27hmydte1JbubzvLb3YO4A+N4hMted5u8GEVknIvNFpGMt1x2gvvVlN2T2WZiGaPPeUkZ0ujik66xNAv5f4BsRmePu/a4CHg5R/e8CGaraF/gIqPTmSiIyRUSyRST7wIEDAfOTkpI4ePCgffGjgKpy8OBBkpJCc5qOMdFgz7HDrDo0m7bNj4R0vTV2Qahqloh8itMPDHCfqu6txbp3Ab57tOnuNN91H/R5ORt4tIoYngeeB+csCP/56enp5ObmUllyNpGXlJREenq612EYEzIL13/Gkfi/kpJyZUjXW2UCFpHzVHWTiJQNclA2tFd7EWmvqqtrWPdKoJuIdMFJvBOBn/jV0U5V97gvxwMb6/wOcC4y6NKlSzBFjTGmRku3fwXA+BAegIPq94D/A5gC/L6SeQqMrm7FqlosItNwhrKMBV5S1fUiMhPIVtV3gH8XkfFAMXAImFz3t2CMMeG1dt9qEmhFr7aVX+gUrCoTsKpOcZ9eqaoVBgcQkVp18KnqYmCx37QHfJ7/mprPpjDGGE/tPPEd7ZN7hXy9tTkI90UtpxljTIOz++gRCkoP0Kd19UPOBqO6PuC2OKeNNRKRC4Cy4bKaAMkhj8QYY6JQzoFiOha8zt2DQn9D+Or6gK/A6ZNNx+kHLkvAx4D/CnkkxhgThdbtOooQy+DO7UO+7ur6gOcCc0XkBlV9s6rljDGmIZu99n/RxiWkpVwd8nXXpg94gIiUXwknImki8j8hj8QYY6LQd0cWE58UeIPdUKhNAr5SVcsv/1DVwzjjQxhjTIO2LW8PhbqHvmE4AAe1S8CxIpJY9kJEGgGJ1SxvjDENwoL1zl2/L+kS3P0Ta1Kb0dBeBZaIyMs4B+ImU8WYDcYY05Asy1kBwPheob0CrkxtxoJ4RETWApfiXAH3AdA5LNEYY0wU2XfsNM1i+tE5rVVY1l/b4Sj34STfm3AuQQ5qzAZjjKlPYk9cyy3nvhS29Vd3IUZ34Gb3kQfMw7mHXGjux2yMMVFs/7ECdh8toG96bYdAr7vq9oA34eztjlPVEar6Z6AkbJEYY0wUeTH7LXITf0Zq6r6w1VFdAr4e2AN8IiIviMgY/nU1nDHGNGjLc1ZQIvsZdc55YaujygSsqgtUdSJwHvAJcC/QWkSeEZHLwxaRMcZEgQ1535Ac24l2TdLCVkeNB+FU9aSq/k1Vf4QzLsQ3wH1hi8gYY6LAnlPr6ZzaO6x11PYsCMC5Ck5Vn1fVMeEKyBhjvLZ293aKOES/tgPCWk+dErAxxpwNvtt9mNTiK/nReeHd17QEbIwxfvYdSqVV8d1c2+vCsNZjCdgYY/x8uWMz3Von0yghNqz1WAI2xhgfpaWlLNh9C3tjngx7XZaAjTHGR3buVoo5Sr+2/cJelyVgY4zx8e5GZwjKy7oND3tdloCNMcbH5z98DRrLuPPDMwawL0vAxhjjY9PBNTSJ60qzRilhr6s2A7IbY8xZQVVpVHgdwzsnR6S+sO4Bi8hYEfleRLaKyIxK5ieKyDx3/goRyTiT+p74aHPEykWyrmDL1YcYgy1nMXpbrqHGmHv4FCWn+nD9+dfWuWwwwpaARSQWeAq4EugJ3CwiPf0W+xlwWFXPBZ4AHjmTOv+0ZEvEykWyrmDL1YcYgy1nMXpbrqHGuGhjNgUx6+jZPvzdDxDeLojBwFZV3Q4gIq8B1wAbfJa5Bvhv9/l84EkREVXVula2dNMuTsZ8zv3v76kwvXOT7nRs3I2C4nyy930SUK5rU+d/wvzV3/PN/n8EzO/WrC9tUjpyrPAQ6/K+LJ9+MmY797+/h/PSLqBlcnsOFxxg/cGvA8r3bDGI5kmtyTu1h02HVpeXK9O35TCaJDZnX34uWw6vDSh/QeuLSIlvQpHs5v73ZwfMH9jmEpLiktl5fAs7jlX8j38yZjvvrOlFfGwiOUc3kXtiW0D5Ye3HEiuxbDvyHXtO7qjw3gThwg7ODbA3H17L/vyKt+aOkziGtr8CgE2HVnEy5osK7y0xthGD2o4GYH3eCg4X5lUonxyXSv82owCYtXQ+x08fqTC/cUIzMls5VyKt3reM/OITFd7bY58U0bvlEABW7l1KYcmpCuVbJLXh/BYDAVix50OKSosqbP/WjTrQvblzqtEXu96jlNIK5duldOacZr0p1dJK21Z66jlkND2PotLTrNjzUcC2LZJi3vt2T7Vtr31qF04WHavQ9spirKrtlfFve/5ty7/t+Stre8Wyv9K2Vdb2dp/4J9uPbqgw72TMdt7+5rwq2x7AkLaXBrQ93xgra3tl/Nue//b3b3t5p/ZWKJ8Y2whoxnvf7qmx7a098Hl525v33WvsS1jGua1/GfB+wkGCyHW1W7HIjcBYVb3TfX0LMERVp/ks8527TK77epu7TJ7fuqYAUwA6deo0YMeOf31YT3y0mT8t2UIJx8ltdHNAHM2KbqVp8QSKZT+7ku4ImJ92egpNSsZzWnLYkzQtYH6L0/eQWnIZhTEb2Zs4PWB+y8IZpJSO4FTMavYnPhAwv3XhQzQqHUB+zBccSPxdwPw2hY+RVHo+J2KXcDDhiYD57Qr+TIJ24VjsuxxOeC5gfvuC2cRrW47Gvc6R+L8EzE8/9SqxNOVw3FyOxb8RML/TqbcR4jkU/xzH496tOFPj6FywAIC8+D9yMu7jCrNjNJWOBa8BcCDhd+THflFhfmxpK9ILXwZgX8JvKYj9psL8+NJOtC98GoC9Cf9JYeymCvMTSnvQrvD3AOxOnEZRTE6F+UklF9Dm9P8DIDfxZ5TEVBw4u1HJMFqfvh+AnUk/oVSOVZifUjyalkX/AcCOpGtBiivMb1x8Nc2Lfo5SxA+NrsNfk6IbSSuebG2vgbW9xJJetD3t/Bi/Z0w3fnlZ94DY60pEVqnqwIDp9SEB+xo4cKBmZ2cHTN+y/yijnvgrz/60f4XpLZJb06JRK06XnCbnSOBPklbJbbn5uQ0snDaIHUcD9xBbp7SjWVJz8otOknssp3z61L+u5tmf9qddajqNE5ty8vRxdh3/IaB8h8adSElozPHCo+w5kVterkx6kwyS41M4WniYfSd2B5Tv1LQrSXGNuOyPi3jqloyA+RnNupEQm8DBUwc4mL+/wrypf13Nh7+4mbiYOA7k7+PwqcDNem7z84mRGPaf3MORgkMV3hsidG/u/ELYe2IXxwor7qHGxMRybpozWPWu4z9wy0vLKry3uJh4uqY5jTf3WA75RScrlI+PTaRLs3MZ+8d/8PztHSgorrgHmxTXiE5NuwKQc3Qbp4sLKry3v9x+EelNnG2y/fBmikuLKpRPSWhMh8adANh6eBOlpSUVtn/jxKa0S00HYPOhDeD3XWialEablPaUaimX/t+rAW0rrVFLWiW3obi0mO2Hvw/Ytne/sp2P7v1RtW0vrVELCosLKrS9shirantl/Nuef9vyb3v+ytreZX98j6du6RQwv6ztHS44xIGTFff+p/51Ne9P+3GVbQ+ga1qPgLbnG2Nlba+cX9v7yYtLK7w3/7Z38vTxCsXjYuK5a+4+3r/3omrbHsAPR7dXaHvTXsnlh1k3BbyfM1FVAkZVw/IAhgEf+Lz+NfBrv2U+AIa5z+Nw7j0n1a13wIABWpXO9/29ynnVCaZcJOsKtlx9iDHYchajt+UsxroBsrWSfBbOsyBWAt1EpIuIJAATgXf8lnkHuM19fiOw1A02KPeM6RaxcpGsK9hy9SHGYMtZjN6WsxhDI2xdEAAichXwRyAWeElVHxaRmTj/Dd4RkSTgFeAC4BAwUd2DdtWs8wCwo4rZLXH2or1mcQSKllgsjkDREktDjqOzqrbynxjWBBxpIpKtlfWzWByei5ZYLI5A0RLL2RiHXYpsjDEesQRsjDEeaWgJ+HmvA3BZHIGiJRaLI1C0xHLWxdGg+oCNMaY+aWh7wMYYU29YAjbGGI/UywQc6WEuq4iho4h8IiIbRGS9iNxTyTIXi8hREVnjPgIv2A9NLDki8q1bR8B12uL4P3d7rBOR/pWtJwRx9PB5r2tE5JiI3Ou3TFi2iYi8JCL73cvby6Y1F5GPRGSL+zetirK3uctsEZHbKlvmDON4TEQ2udv+bRFpVkXZaj/HEMXy3yKyy2f7X1VF2Wq/YyGIY55PDDkisqaKsiHbJlV9Z71oJ+Uquzwumh84F3VsA7oCCcBaoKffMncBz7rPJwLzwhBHO6C/+7wxsLmSOK4HMr4AAAWMSURBVC4G/h6BbZIDtKxm/lXAe4AAQ4EVEfqc9uKcgB72bQKMBPoD3/lMexSY4T6fATxSSbnmwHb3b5r7PC3EcVwOxLnPH6ksjtp8jiGK5b+B/6zFZ1ftd+xM4/Cb/3vggXBvk6q+s160k7JHfdwDLh/mUlVPA2XDXPq6BpjrPp8PjBERCWUQqrpHVVe7z48DG4EOoawjhK4B/qKOr4BmItIuzHWOAbapalVXLYaUqi7HuZrSl287mAtUNsr2FcBHqnpIVQ8DHwFjQxmHqn6oqmVDrX0FpAe7/jONpZZq8x0LSRzu93ICkBXs+usQR1Xf2Yi3kzL1MQF3AHb6vM4lMPGVL+M2/KNAi3AF5HZxXACsqGT2MBFZKyLviUivMIWgwIciskqcoTv91WabhdpEqv5SRWKbALRR1bJhvPYCbSpZJtLb5g6cXyOVqelzDJVpbnfIS1X83I7kNrkI2KeqVY2eHpZt4ved9ayd1McEHFVEJBV4E7hXVY/5zV6N8xM8E/gzsCBMYYxQ1f44dx+5W0RGhqmeWhFn8KXxQOAgsJHbJhWo8zvS03MuReR+oBh4tYpFIvE5PgOcA/QD9uD8/PfSzVS/9xvybVLddzbS7aQ+JuBdQEef1+nutEqXEZE4oClwMNSBiEg8zgf5qqq+5T9fVY+p6gn3+WIgXkRahjoOVd3l/t0PvI3zE9JXbbZZKF0JrFbVff4zIrVNXPvKulrcv4GD1kZo24jIZGAcMMn9kgeoxed4xlR1n6qWqGop8EIVdURqm8QB1wPzqlom1Nukiu+sZ+2kPibgiA9zWRm37+pFYKOq/qGKZdqW9T2LyGCc7R3SfwQikiIijcue4xzw+c5vsXeAW8UxFDjq85MrHKrcq4nENvHh2w5uAxZWsswHwOUikub+HL/cnRYyIjIW+BUwXlXzq1imNp9jKGLx7fu/roo6avMdC4VLgU3q3pDBX6i3STXfWe/aSSiOLkb6gXNUfzPOkdr73WkzcRo4QBLOz9+twNdA1zDEMALnp8o6YI37uAqYCkx1l5kGrMc5ivwVMDwMcXR117/Wratse/jGITg3SN0GfAsMDONnk4KTUJv6TAv7NsFJ+HuAIpz+uZ/h9PsvAbYAHwPN3WUHArN9yt7htpWtwO1hiGMrTv9hWTspO0OnPbC4us8xDLG84raBdTiJp51/LFV9x0IZhzt9Tlm78Fk2bNukmu9sxNtJ2cMuRTbGGI/Uxy4IY4xpECwBG2OMRywBG2OMRywBG2OMRywBG2OMRywBmwZLRO53R71a546mNURE7hWRZK9jMwbsjhimgRKRYcAfgItVtdC92i4B+ALnPOhouP25OcvZHrBpqNoBeapaCOAm3BtxTvT/REQ+ARCRy0XkSxFZLSJvuOMElI1D+6g7Fu3XInKuO/0mEfnOHUxouTdvzTQUtgdsGiQ3kX4GJONc3TRPVZeJSA7uHrC7V/wWcKWqnhSR+4BEVZ3pLveCqj4sIrcCE1R1nIh8C4xV1V0i0kxVj3jyBk2DYHvApkFSZ8CfAcAU4AAwzx0Qx9dQnAG5Pxfnjgy3AZ195mf5/B3mPv8cmCMi/4YzcLkxQYvzOgBjwkVVS4BPgU/dPVf/28gIziDbN1e1Cv/nqjpVRIYAVwOrRGSAqoZrMCHTwNkesGmQxLk/XTefSf2AHcBxnNvRgDMY0IU+/bspItLdp8yPff5+6S5zjqquUNUHcPasfYcoNKZObA/YNFSpwJ/FuQFmMc4IVlNwhsp8X0R2q+olbrdElogkuuV+gzMKGECaiKwDCt1yAI+5iV1wRtBaG5F3YxokOwhnTCV8D9Z5HYtpuKwLwhhjPGJ7wMYY4xHbAzbGGI9YAjbGGI9YAjbGGI9YAjbGGI9YAjbGGI/8f3kkhjfgm2xdAAAAAElFTkSuQmCC", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "_ = ngym.utils.plotting.fig_(obs[0], action_pred[0], gt)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "OfGvOBO-84l-" }, "outputs": [], "source": [] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "name": "Copy of example_NeuroGym_keras.ipynb", "provenance": [] }, "hide_input": false, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.7" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": false }, "varInspector": { "cols": { "lenName": 16, "lenType": 16, "lenVar": 40 }, "kernels_config": { "python": { "delete_cmd_postfix": "", "delete_cmd_prefix": "del ", "library": "var_list.py", "varRefreshCmd": "print(var_dic_list())" }, "r": { "delete_cmd_postfix": ") ", "delete_cmd_prefix": "rm(", "library": "var_list.r", "varRefreshCmd": "cat(var_dic_list()) " } }, "types_to_exclude": [ "module", "function", "builtin_function_or_method", "instance", "_Feature" ], "window_display": false } }, "nbformat": 4, "nbformat_minor": 1 }