Source code for neurogym.wrappers.pass_action

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import numpy as np
from gym import Wrapper
from gym import spaces


[docs] class PassAction(Wrapper): """Modifies observation by adding the previous action.""" metadata = { 'description': 'Modifies observation by adding the previous action.', 'paper_link': None, 'paper_name': None, } def __init__(self, env): super().__init__(env) self.env = env # TODO: This is not adding one-hot env_oss = env.observation_space.shape[0] self.observation_space = spaces.Box(-np.inf, np.inf, shape=(env_oss+1,), dtype=np.float32)
[docs] def reset(self, step_fn=None): if step_fn is None: step_fn = self.step return self.env.reset(step_fn=step_fn)
[docs] def step(self, action): obs, reward, done, info = self.env.step(action) obs = np.concatenate((obs, np.array([action]))) return obs, reward, done, info