Source code for gym_loop.gym_loop
import logging
import gym
import importlib
import os
import sys
import re
"""Main module."""
[docs]def train_agent(run_params):
"""Run training loop for run parameters
Args:
run_params(dict) Dictionary from parsed configuration file
"""
env = build_env(run_params)
policy = build_policy(run_params, env)
agent = build_agent(run_params, policy, env)
loop = build_loop(run_params, agent, policy, env)
loop.train()
[docs]def eval_agent(run_params):
"""Run loop without updating agent
Args:
run_params(dict) Dictionary from parsed configuration file
"""
env = build_env(run_params)
policy = build_policy(run_params, env)
agent = build_agent(run_params, policy, env)
loop = build_loop(run_params, agent, policy, env)
loop.evaluate()
[docs]def build_env(params):
"""Create gym env from run parameters"""
env_params = params["env"]["parameters"]
env_imports = params["env"]["imports"]
env_string = params["env"]["name"]
seed = env_params.pop("seed", None)
for import_module in env_imports:
importlib.import_module(import_module)
env = gym.make(env_string, **env_params)
if seed is not None:
env.seed(seed)
return env
[docs]def build_policy(params, env):
if "policy" not in params:
logging.warn(
"No policy found in parameters, using default policy for the agent"
)
return None
policy_class_string = params["policy"]["class"]
policy_params = params["policy"]["parameters"]
policy_params["observation_space"] = env.observation_space
policy_params["action_space"] = env.action_space
Policy = module_str_to_class(policy_class_string)
return Policy(**policy_params)
[docs]def build_agent(params, policy, env):
"""Create agent from run parameters"""
agent_params = params["agent"]["parameters"]
agent_class_string = params["agent"]["class"]
Agent = module_str_to_class(agent_class_string)
agent_params["observation_space"] = env.observation_space
agent_params["action_space"] = env.action_space
if policy is None:
p = Agent.get_default_policy()
policy_class_string = p["class"]
policy_params = p["parameters"]
Policy = module_str_to_class(policy_class_string)
policy_params["observation_space"] = env.observation_space
policy_params["action_space"] = env.action_space
policy = Policy(**policy_params)
agent_params["policy"] = policy
return Agent(**agent_params)
[docs]def build_loop(params, agent, policy, env):
"""Create loop from run parameters"""
loop_params = params["loop"]["parameters"]
loop_class_string = params["loop"]["class"]
Loop = module_str_to_class(loop_class_string)
loop_params["env"] = env
loop_params["agent"] = agent
loop_params["policy"] = policy
return Loop(**loop_params)
[docs]def get_default_params(agent_str, loop_str):
"""Build a dict with default run spec for the agent
Args:
loop_str (str): loop module class string of format 'package.module:class' or 'loop_filepath:class'
agent_str (str): agent module class string of format 'package.module:class' or 'agent_filepath:class'
Returns:
dict: default run spec dict
"""
Agent = module_str_to_class(agent_str)
Loop = module_str_to_class(loop_str)
return {
"env": {"name": "Pendulum-v0", "parameters": {}, "imports": []},
"agent": {
"class": agent_str,
"policy": Agent.get_default_policy(),
"parameters": Agent.get_default_parameters(),
},
"loop": {"class": loop_str, "parameters": Loop.get_default_parameters()},
}
[docs]def module_str_to_class(module_str):
"""Parse module class string to a class
Args:
module_str(str) Dictionary from parsed configuration file
Returns:
type: class
"""
if not validate_module_str(module_str):
raise ValueError("Module string is in wrong format")
module_path, class_name = module_str.split(":")
if os.path.isfile(module_path):
module_name = os.path.basename(module_path).replace(".py", "")
spec = importlib.util.spec_from_file_location(module_name, module_path)
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
spec.loader.exec_module(module)
else:
module = importlib.import_module(module_path)
return getattr(module, class_name)
[docs]def validate_module_str(module_str):
"""Check if string is module class string
Args:
module_str(str) A string to test
Returns:
bool: Is module class string valid
"""
module_path, class_name = module_str.split(":")
identifier = re.compile(r"^[^\d\W]\w*\Z", re.UNICODE)
classname_correct = re.match(identifier, class_name)
module_is_path = os.path.isfile(module_path)
module_is_import_str = all(
[re.match(identifier, name) for name in module_path.split(".")]
)
return classname_correct and (module_is_path or module_is_import_str)