src package#
Subpackages#
- src.decision_transformer package
- Submodules
- src.decision_transformer.calibration module
- src.decision_transformer.model module
- src.decision_transformer.offline_dataset module
TrajectoryDataset
TrajectoryDataset.add_padding()
TrajectoryDataset.discount_cumsum()
TrajectoryDataset.get_batch()
TrajectoryDataset.get_indices_of_top_p_trajectories()
TrajectoryDataset.get_sampling_probabilities()
TrajectoryDataset.get_state_mean_std()
TrajectoryDataset.get_traj()
TrajectoryDataset.load_trajectories()
TrajectoryDataset.return_tensors()
TrajectoryReader
TrajectoryVisualizer
one_hot_encode_observation()
- src.decision_transformer.runner module
- src.decision_transformer.train module
- src.decision_transformer.trainer module
- src.decision_transformer.utils module
- Module contents
- src.environments package
- src.ppo package
- Submodules
- src.ppo.agent module
- src.ppo.compute_adv_vectorized module
- src.ppo.memory module
- src.ppo.my_probe_envs module
- src.ppo.runner module
- src.ppo.train module
- src.ppo.utils module
- Module contents
Submodules#
src.config module#
This module contains the configuration classes for the project.
- class src.config.ConfigJsonEncoder(*, skipkeys=False, ensure_ascii=True, check_circular=True, allow_nan=True, sort_keys=False, indent=None, separators=None, default=None)#
Bases:
JSONEncoder
- default(config: dataclass)#
Implement this method in a subclass such that it returns a serializable object for
o
, or calls the base implementation (to raise aTypeError
).For example, to support arbitrary iterators, you could implement default like this:
def default(self, o): try: iterable = iter(o) except TypeError: pass else: return list(iterable) # Let the base class default method raise the TypeError return JSONEncoder.default(self, o)
- class src.config.EnvironmentConfig(env_id: str = 'MiniGrid-Dynamic-Obstacles-8x8-v0', one_hot_obs: bool = False, img_obs: bool = False, fully_observed: bool = False, max_steps: int = 1000, seed: int = 1, view_size: int = 7, capture_video: bool = False, video_dir: str = 'videos', video_frequency: int = 50, render_mode: str = 'rgb_array', action_space: None = None, observation_space: None = None, device: str = 'cpu')#
Bases:
object
Configuration class for the environment.
- action_space: None = None#
- capture_video: bool = False#
- device: str = 'cpu'#
- env_id: str = 'MiniGrid-Dynamic-Obstacles-8x8-v0'#
- fully_observed: bool = False#
- img_obs: bool = False#
- max_steps: int = 1000#
- observation_space: None = None#
- one_hot_obs: bool = False#
- render_mode: str = 'rgb_array'#
- seed: int = 1#
- video_dir: str = 'videos'#
- video_frequency: int = 50#
- view_size: int = 7#
- class src.config.LSTMModelConfig(environment_config: EnvironmentConfig, image_dim: int = 128, memory_dim: int = 128, instr_dim: int = 128, use_instr: bool = False, lang_model: str = 'gru', use_memory: bool = False, recurrence: int = 4, arch: str = 'bow_endpool_res', aux_info: bool = False, device: str = 'cpu')#
Bases:
object
Configuration class for the LSTM model.
- arch: str = 'bow_endpool_res'#
- aux_info: bool = False#
- device: str = 'cpu'#
- environment_config: EnvironmentConfig#
- image_dim: int = 128#
- instr_dim: int = 128#
- lang_model: str = 'gru'#
- memory_dim: int = 128#
- recurrence: int = 4#
- use_instr: bool = False#
- use_memory: bool = False#
- class src.config.OfflineTrainConfig(trajectory_path: str, batch_size: int = 128, convert_to_one_hot: bool = False, optimizer: str = 'AdamW', scheduler: str = 'ConstantWithWarmUp', lr: float = 0.0001, lr_end: float = 1e-07, weight_decay: float = 0.0, warm_up_steps: int = 1000, num_cycles: int = 3, pct_traj: float = 1.0, prob_go_from_end: float = 0.0, train_epochs: int = 100, test_epochs: int = 10, test_frequency: int = 10, eval_frequency: int = 10, eval_episodes: int = 10, eval_max_time_steps: int = 100, eval_num_envs: int = 8, initial_rtg: list[float] = (0.0, 1.0), model_type: str = 'decision_transformer', track: bool = False, device: str = 'cpu')#
Bases:
object
Configuration class for offline training.
- batch_size: int = 128#
- convert_to_one_hot: bool = False#
- device: str = 'cpu'#
- eval_episodes: int = 10#
- eval_frequency: int = 10#
- eval_max_time_steps: int = 100#
- eval_num_envs: int = 8#
- initial_rtg: list[float] = (0.0, 1.0)#
- lr: float = 0.0001#
- lr_end: float = 1e-07#
- model_type: str = 'decision_transformer'#
- num_cycles: int = 3#
- optimizer: str = 'AdamW'#
- pct_traj: float = 1.0#
- prob_go_from_end: float = 0.0#
- scheduler: str = 'ConstantWithWarmUp'#
- test_epochs: int = 10#
- test_frequency: int = 10#
- track: bool = False#
- train_epochs: int = 100#
- trajectory_path: str#
- warm_up_steps: int = 1000#
- weight_decay: float = 0.0#
- class src.config.OnlineTrainConfig(use_trajectory_model: bool = False, hidden_size: int = 64, total_timesteps: int = 180000, learning_rate: float = 0.00025, decay_lr: bool = (False,), num_envs: int = 4, num_steps: int = 128, gamma: float = 0.99, gae_lambda: float = 0.95, num_minibatches: int = 4, update_epochs: int = 4, clip_coef: float = 0.4, ent_coef: float = 0.2, vf_coef: float = 0.5, max_grad_norm: float = 2, trajectory_path: Optional[str] = None, fully_observed: bool = False, prob_go_from_end: float = 0.0, num_checkpoints: int = 10, device: str = 'cpu')#
Bases:
object
Configuration class for online training.
- clip_coef: float = 0.4#
- decay_lr: bool = (False,)#
- device: str = 'cpu'#
- ent_coef: float = 0.2#
- fully_observed: bool = False#
- gae_lambda: float = 0.95#
- gamma: float = 0.99#
- learning_rate: float = 0.00025#
- max_grad_norm: float = 2#
- num_checkpoints: int = 10#
- num_envs: int = 4#
- num_minibatches: int = 4#
- num_steps: int = 128#
- prob_go_from_end: float = 0.0#
- total_timesteps: int = 180000#
- trajectory_path: str = None#
- update_epochs: int = 4#
- use_trajectory_model: bool = False#
- vf_coef: float = 0.5#
- class src.config.RunConfig(exp_name: str = 'MiniGrid-Dynamic-Obstacles-8x8-v0', seed: int = 1, device: str = 'cpu', track: bool = True, wandb_project_name: str = 'PPO-MiniGrid', wandb_entity: Optional[str] = None)#
Bases:
object
Configuration class for running the model.
- device: str = 'cpu'#
- exp_name: str = 'MiniGrid-Dynamic-Obstacles-8x8-v0'#
- seed: int = 1#
- track: bool = True#
- wandb_entity: str = None#
- wandb_project_name: str = 'PPO-MiniGrid'#
- class src.config.TransformerModelConfig(d_model: int = 128, n_heads: int = 4, d_mlp: int = 256, n_layers: int = 2, n_ctx: int = 2, layer_norm: Optional[str] = None, gated_mlp: bool = False, activation_fn: str = 'relu', state_embedding_type: str = 'grid', time_embedding_type: str = 'embedding', seed: int = 1, device: str = 'cpu')#
Bases:
object
Configuration class for the transformer model.
- activation_fn: str = 'relu'#
- d_mlp: int = 256#
- d_model: int = 128#
- device: str = 'cpu'#
- gated_mlp: bool = False#
- layer_norm: Optional[str] = None#
- n_ctx: int = 2#
- n_heads: int = 4#
- n_layers: int = 2#
- seed: int = 1#
- state_embedding_type: str = 'grid'#
- time_embedding_type: str = 'embedding'#
- src.config.parse_metadata_to_environment_config(metadata: dict)#
Parses the metadata dictionary from a loaded trajectory to an EnvironmentConfig object.
src.dummy_agent module#
src.run_calibration module#
- src.run_calibration.runner(args)#
src.run_decision_transformer module#
This file is the entry point for running the decision transformer.
src.run_ppo module#
src.utils module#
src.visualization module#
- src.visualization.find_agent(observation)#
- src.visualization.get_cosine_sim_df(tensor, column_labels=None, row_labels=None)#
- src.visualization.get_param_stats(model)#
- src.visualization.get_rendered_obs(env: Env, obs: Tensor)#
- src.visualization.get_rendered_obss(env: Env, obs: Tensor)#
- src.visualization.plot_param_stats(df)#
use get_param stats then this to look at properties of weights.
- src.visualization.render_minigrid_observation(env, observation)#
- src.visualization.render_minigrid_observations(env, observations)#
- src.visualization.tensor_2d_embedding_similarity(tensor, x, y, mode='heatmap')#
- src.visualization.tensor_cosine_similarity_heatmap(tensor, labels=None, index_labels=None)#