Source code for gym_loop.agents.base_agent
from typing import Dict
import numpy as np
[docs]class BaseAgent:
def __init__(self, **params: Dict):
super().__init__()
self.__dict__.update(self.get_default_parameters())
self.__dict__.update(params)
[docs] def act(self, state: np.ndarray, episode_num: int):
"""Retrieves agent's action upon state"""
raise NotImplementedError()
[docs] def memorize(
self,
last_ob: np.ndarray,
action: np.ndarray,
reward: np.ndarray,
done: np.ndarray,
ob: np.ndarray,
):
"""Called after environment steps on action, arguments are classic SARSA tuple"""
raise NotImplementedError()
[docs] def update(self, episode_num: int):
"""Called immediately after memorize"""
raise NotImplementedError()
[docs] def metrics(self, episode_num: int) -> Dict:
"""Returns dict with metrics to log in tensorboard"""
raise NotImplementedError()
[docs] @staticmethod
def get_default_parameters() -> Dict:
"""Specifies tweakable parameters for agents
Returns:
dict: default parameters for the agent
"""
raise NotImplementedError()
[docs] @staticmethod
def get_default_policy() -> Dict:
"""Specifies default policy to use with agent
Returns:
dict: class string and parameters for the policy
"""
return {
"class": "gym_loop.policies.base_policy:BasePolicy",
"parameters": {},
}