{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Pytorch supervised learning of perceptual decision making task\n", "\n", "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/neurogym/neurogym/blob/master/examples/example_neurogym_pytorch.ipynb)\n", "\n", "Pytorch-based example code for training a RNN on a perceptual decision-making task." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Installation when used on Google Colab" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Install gym\n", "! pip install gym\n", "# Install neurogym\n", "! git clone https://github.com/gyyang/neurogym.git\n", "%cd neurogym/\n", "! pip install -e ." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Dataset" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import torch\n", "import torch.nn as nn\n", "\n", "import neurogym as ngym\n", "\n", "# Environment\n", "task = 'PerceptualDecisionMaking-v0'\n", "kwargs = {'dt': 100}\n", "seq_len = 100\n", "\n", "# Make supervised dataset\n", "dataset = ngym.Dataset(task, env_kwargs=kwargs, batch_size=16,\n", " seq_len=seq_len)\n", "env = dataset.env\n", "ob_size = env.observation_space.shape[0]\n", "act_size = env.action_space.n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Network and Training" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "200 loss: 0.11658\n", "400 loss: 0.03191\n", "600 loss: 0.01894\n", "800 loss: 0.01393\n", "1000 loss: 0.01351\n", "1200 loss: 0.01275\n", "1400 loss: 0.01280\n", "1600 loss: 0.01279\n", "1800 loss: 0.01223\n", "2000 loss: 0.01239\n", "Finished Training\n" ] } ], "source": [ "class Net(nn.Module):\n", " def __init__(self, num_h):\n", " super(Net, self).__init__()\n", " self.lstm = nn.LSTM(ob_size, num_h)\n", " self.linear = nn.Linear(num_h, act_size)\n", "\n", " def forward(self, x):\n", " out, hidden = self.lstm(x)\n", " x = self.linear(out)\n", " return x\n", "\n", "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", "net = Net(num_h=64).to(device)\n", "criterion = nn.CrossEntropyLoss()\n", "optimizer = torch.optim.Adam(net.parameters(), lr=1e-2)\n", "\n", "running_loss = 0.0\n", "for i in range(2000):\n", " inputs, labels = dataset()\n", " inputs = torch.from_numpy(inputs).type(torch.float).to(device)\n", " labels = torch.from_numpy(labels.flatten()).type(torch.long).to(device)\n", "\n", " # zero the parameter gradients\n", " optimizer.zero_grad()\n", "\n", " # forward + backward + optimize\n", " outputs = net(inputs)\n", "\n", " loss = criterion(outputs.view(-1, act_size), labels)\n", " loss.backward()\n", " optimizer.step()\n", "\n", " # print statistics\n", " running_loss += loss.item()\n", " if i % 200 == 199:\n", " print('{:d} loss: {:0.5f}'.format(i + 1, running_loss / 200))\n", " running_loss = 0.0\n", "\n", "print('Finished Training')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Analysis" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Average performance in 200 trials\n", "0.83\n" ] } ], "source": [ "perf = 0\n", "num_trial = 200\n", "for i in range(num_trial):\n", " env.new_trial()\n", " ob, gt = env.ob, env.gt\n", " ob = ob[:, np.newaxis, :] # Add batch axis\n", " inputs = torch.from_numpy(ob).type(torch.float).to(device)\n", "\n", " action_pred = net(inputs)\n", " action_pred = action_pred.detach().numpy()\n", " action_pred = np.argmax(action_pred, axis=-1)\n", " perf += gt[-1] == action_pred[-1, 0]\n", "\n", "perf /= num_trial\n", "print('Average performance in {:d} trials'.format(num_trial))\n", "print(perf)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.8" } }, "nbformat": 4, "nbformat_minor": 2 }