Source code for neurogym.wrappers.pass_reward

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import numpy as np
from gym import Wrapper
from gym import spaces


[docs] class PassReward(Wrapper): metadata = { 'description': 'Modifies observation by adding the previous reward.', 'paper_link': None, 'paper_name': None, } def __init__(self, env): """ Modifies observation by adding the previous reward. """ super().__init__(env) env_oss = env.observation_space.shape[0] self.observation_space = spaces.Box(-np.inf, np.inf, shape=(env_oss+1,), dtype=np.float32)
[docs] def reset(self, step_fn=None): if step_fn is None: step_fn = self.step return self.env.reset(step_fn=step_fn)
[docs] def step(self, action): obs, reward, done, info = self.env.step(action) obs = np.concatenate((obs, np.array([reward]))) return obs, reward, done, info