Source code for neurogym.wrappers.monitor

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

from gym import Wrapper
import os
import numpy as np
from neurogym.utils.plotting import fig_


[docs] class Monitor(Wrapper): """Monitor task. Saves relevant behavioral information: rewards,actions, observations, new trial, ground truth. Args: folder: Folder where the data will be saved. (def: None, str) sv_per and sv_stp: Data will be saved every sv_per sv_stp's. (def: 100000, int) verbose: Whether to print information about average reward and number of trials. (def: False, bool) sv_fig: Whether to save a figure of the experiment structure. If True, a figure will be updated every sv_per. (def: False, bool) num_stps_sv_fig: Number of trial steps to include in the figure. (def: 100, int) """ metadata = { 'description': 'Saves relevant behavioral information: rewards,' + ' actions, observations, new trial, ground truth.', 'paper_link': None, 'paper_name': None, } # TODO: use names similar to Tensorboard def __init__(self, env, folder=None, sv_per=100000, sv_stp='trial', verbose=False, sv_fig=False, num_stps_sv_fig=100, name='', fig_type='png'): super().__init__(env) self.env = env self.num_tr = 0 # data to save self.data = {'action': [], 'reward': []} self.sv_per = sv_per self.sv_stp = sv_stp self.fig_type = fig_type if self.sv_stp == 'timestep': self.t = 0 self.verbose = verbose if folder is not None: self.folder = folder + '/' else: self.folder = "/tmp/" if not os.path.exists(self.folder): os.makedirs(self.folder) # seeding self.sv_name = self.folder +\ self.env.__class__.__name__+'_bhvr_data_'+name+'_' # figure self.sv_fig = sv_fig if self.sv_fig: self.num_stps_sv_fig = num_stps_sv_fig self.stp_counter = 0 self.ob_mat = [] self.act_mat = [] self.rew_mat = [] self.gt_mat = [] self.perf_mat = []
[docs] def reset(self, step_fn=None): if step_fn is None: step_fn = self.step return self.env.reset(step_fn=step_fn)
[docs] def step(self, action): obs, rew, done, info = self.env.step(action) if self.sv_fig: self.store_data(obs, action, rew, info) if self.sv_stp == 'timestep': self.t += 1 if info['new_trial']: self.num_tr += 1 self.data['action'].append(action) self.data['reward'].append(rew) for key in info: if key not in self.data.keys(): self.data[key] = [info[key]] else: self.data[key].append(info[key]) # save data save = False if self.sv_stp == 'timestep': save = self.t >= self.sv_per else: save = self.num_tr % self.sv_per == 0 if save: np.savez(self.sv_name + str(self.num_tr) + '.npz', **self.data) if self.verbose: print('--------------------') print('Number of steps: ', np.mean(self.num_tr)) print('Average reward: ', np.mean(self.data['reward'])) print('--------------------') self.reset_data() if self.sv_fig: self.stp_counter = 0 if self.sv_stp == 'timestep': self.t = 0 return obs, rew, done, info
def reset_data(self): for key in self.data.keys(): self.data[key] = [] def store_data(self, obs, action, rew, info): if self.stp_counter <= self.num_stps_sv_fig: self.ob_mat.append(obs) self.act_mat.append(action) self.rew_mat.append(rew) if 'gt' in info.keys(): self.gt_mat.append(info['gt']) else: self.gt_mat.append(-1) if 'performance' in info.keys(): self.perf_mat.append(info['performance']) else: self.perf_mat.append(-1) self.stp_counter += 1 elif len(self.rew_mat) > 0: fname = self.sv_name + 'task_{0:06d}.'.format(self.num_tr)+self.fig_type obs_mat = np.array(self.ob_mat) act_mat = np.array(self.act_mat) fig_(ob=obs_mat, actions=act_mat, gt=self.gt_mat, rewards=self.rew_mat, performance=self.perf_mat, fname=fname) self.ob_mat = [] self.act_mat = [] self.rew_mat = [] self.gt_mat = [] self.perf_mat = []