Source code for gym_loop.gym_loop

import logging
import gym
import importlib
import os
import sys
import re

"""Main module."""


[docs]def train_agent(run_params): """Run training loop for run parameters Args: run_params(dict) Dictionary from parsed configuration file """ env = build_env(run_params) policy = build_policy(run_params, env) agent = build_agent(run_params, policy, env) loop = build_loop(run_params, agent, policy, env) loop.train()
[docs]def eval_agent(run_params): """Run loop without updating agent Args: run_params(dict) Dictionary from parsed configuration file """ env = build_env(run_params) policy = build_policy(run_params, env) agent = build_agent(run_params, policy, env) loop = build_loop(run_params, agent, policy, env) loop.evaluate()
[docs]def build_env(params): """Create gym env from run parameters""" env_params = params["env"]["parameters"] env_imports = params["env"]["imports"] env_string = params["env"]["name"] seed = env_params.pop("seed", None) for import_module in env_imports: importlib.import_module(import_module) env = gym.make(env_string, **env_params) if seed is not None: env.seed(seed) return env
[docs]def build_policy(params, env): if "policy" not in params: logging.warn( "No policy found in parameters, using default policy for the agent" ) return None policy_class_string = params["policy"]["class"] policy_params = params["policy"]["parameters"] policy_params["observation_space"] = env.observation_space policy_params["action_space"] = env.action_space Policy = module_str_to_class(policy_class_string) return Policy(**policy_params)
[docs]def build_agent(params, policy, env): """Create agent from run parameters""" agent_params = params["agent"]["parameters"] agent_class_string = params["agent"]["class"] Agent = module_str_to_class(agent_class_string) agent_params["observation_space"] = env.observation_space agent_params["action_space"] = env.action_space if policy is None: p = Agent.get_default_policy() policy_class_string = p["class"] policy_params = p["parameters"] Policy = module_str_to_class(policy_class_string) policy_params["observation_space"] = env.observation_space policy_params["action_space"] = env.action_space policy = Policy(**policy_params) agent_params["policy"] = policy return Agent(**agent_params)
[docs]def build_loop(params, agent, policy, env): """Create loop from run parameters""" loop_params = params["loop"]["parameters"] loop_class_string = params["loop"]["class"] Loop = module_str_to_class(loop_class_string) loop_params["env"] = env loop_params["agent"] = agent loop_params["policy"] = policy return Loop(**loop_params)
[docs]def get_default_params(agent_str, loop_str): """Build a dict with default run spec for the agent Args: loop_str (str): loop module class string of format 'package.module:class' or 'loop_filepath:class' agent_str (str): agent module class string of format 'package.module:class' or 'agent_filepath:class' Returns: dict: default run spec dict """ Agent = module_str_to_class(agent_str) Loop = module_str_to_class(loop_str) return { "env": {"name": "Pendulum-v0", "parameters": {}, "imports": []}, "agent": { "class": agent_str, "policy": Agent.get_default_policy(), "parameters": Agent.get_default_parameters(), }, "loop": {"class": loop_str, "parameters": Loop.get_default_parameters()}, }
[docs]def module_str_to_class(module_str): """Parse module class string to a class Args: module_str(str) Dictionary from parsed configuration file Returns: type: class """ if not validate_module_str(module_str): raise ValueError("Module string is in wrong format") module_path, class_name = module_str.split(":") if os.path.isfile(module_path): module_name = os.path.basename(module_path).replace(".py", "") spec = importlib.util.spec_from_file_location(module_name, module_path) module = importlib.util.module_from_spec(spec) sys.modules[module_name] = module spec.loader.exec_module(module) else: module = importlib.import_module(module_path) return getattr(module, class_name)
[docs]def validate_module_str(module_str): """Check if string is module class string Args: module_str(str) A string to test Returns: bool: Is module class string valid """ module_path, class_name = module_str.split(":") identifier = re.compile(r"^[^\d\W]\w*\Z", re.UNICODE) classname_correct = re.match(identifier, class_name) module_is_path = os.path.isfile(module_path) module_is_import_str = all( [re.match(identifier, name) for name in module_path.split(".")] ) return classname_correct and (module_is_path or module_is_import_str)