src.decision_transformer package#
Submodules#
src.decision_transformer.calibration module#
- src.decision_transformer.calibration.calibration_statistics(dt, env_id, env_func, initial_rtg_range=array([-1., -0.9, -0.8, -0.7, -0.6, -0.5, -0.4, -0.3, -0.2, -0.1, 0., 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.]), trajectories=100, num_envs=8)#
- src.decision_transformer.calibration.plot_calibration_statistics(statistics, show_spread=False, CI=0.95)#
src.decision_transformer.model module#
src.decision_transformer.offline_dataset module#
- class src.decision_transformer.offline_dataset.TrajectoryDataset(trajectory_path, max_len=1, prob_go_from_end=0, pct_traj=1.0, rtg_scale=1, normalize_state=False, preprocess_observations: Optional[Callable] = None, device='cpu')#
Bases:
Dataset
- add_padding(tokens, padding_token, padding_required)#
- discount_cumsum(x, gamma)#
- get_batch(batch_size=256, max_len=100, prob_go_from_end=None)#
- get_indices_of_top_p_trajectories(pct_traj)#
- get_sampling_probabilities()#
- get_state_mean_std()#
- get_traj(traj_index, max_len=100, prob_go_from_end=None)#
- load_trajectories() None #
- return_tensors(s, a, r, rtg, d, timesteps, mask)#
- class src.decision_transformer.offline_dataset.TrajectoryReader(path)#
Bases:
object
The trajectory reader is responsible for reading trajectories from a file.
- read()#
- class src.decision_transformer.offline_dataset.TrajectoryVisualizer(trajectory_dataset: TrajectoryDataset)#
Bases:
object
- plot_base_action_frequencies()#
- plot_reward_over_time()#
- src.decision_transformer.offline_dataset.one_hot_encode_observation(img: Tensor) Tensor #
Converts a batch of observations into one-hot encoded numpy arrays.
src.decision_transformer.runner module#
- src.decision_transformer.runner.run_decision_transformer(run_config: RunConfig, transformer_config: TransformerModelConfig, offline_config: OfflineTrainConfig, make_env: Callable)#
- src.decision_transformer.runner.set_device(run_config)#
- src.decision_transformer.runner.store_transformer_model(path, model, offline_config)#
src.decision_transformer.train module#
- src.decision_transformer.train.get_dataloaders(trajectory_data_set, offline_config)#
- src.decision_transformer.train.test(model: TrajectoryTransformer, dataloader: DataLoader, env, epochs=10, track=False, batch_number=0)#
- src.decision_transformer.train.train(model: TrajectoryTransformer, trajectory_data_set: TrajectoryDataset, env, make_env, offline_config: OfflineTrainConfig, device='cpu')#
src.decision_transformer.trainer module#
src.decision_transformer.utils module#
- src.decision_transformer.utils.configure_optimizers(model, offline_config)#
-
This long function is unfortunately doing something very simple and is being very defensive: We are separating out all parameters of the model into two buckets: those that will experience weight decay for regularization and those that won’t (biases, and layernorm/embedding weights). We are then returning the PyTorch optimizer object.
- src.decision_transformer.utils.get_max_len_from_model_type(model_type: str, n_ctx: int)#
Ihe max len in timesteps is 3 for decision transformers and 2 for clone transformers since decision transformers have 3 tokens per timestep and clone transformers have 2.
This is a map between timestep and tokens. We start with one for the most recent state/action and then add another timestep for every 3 tokens for decision transformers and every 2 tokens for clone transformers.
- src.decision_transformer.utils.get_optim_groups(model, offline_config)#
- src.decision_transformer.utils.get_optimizer(optimizer_name: str, optim_groups: list[dict[str, Any]], lr: float, **kwargs)#
- src.decision_transformer.utils.get_scheduler(scheduler_name: Optional[str], optimizer: Optimizer, **kwargs)#
Loosely based on this, seemed simpler write this than import transformers: https://huggingface.co/docs/transformers/main_classes/optimizer_schedules
- Parameters:
scheduler_name (Optional[str]) – Name of the scheduler to use. If None, returns a constant scheduler
optimizer (optim.Optimizer) – Optimizer to use
**kwargs – Additional arguments to pass to the scheduler including warm_up_steps, training_steps, num_cycles, lr_end.
- src.decision_transformer.utils.initialize_padding_inputs(max_len: int, initial_obs: dict, initial_rtg: float, action_pad_token: int, batch_size=1, device='cpu')#
Initializes input tensors for a decision transformer based on the given maximum length of the sequence, initial observation, initial return-to-go (rtg) value, and padding token for actions.
Padding token for rtg is assumed to be the initial RTG at all values. This is important. Padding token for initial obs is 0. But it could be -1 and we might parameterize in the future. Mask is initialized to 0.0 and then set to 1.0 for all values that are not padding (one value currently)
Args: - max_len (int): maximum length of the sequence - initial_obs (Dict[str, Union[torch.Tensor, np.ndarray]]): initial observation dictionary, containing an “image” tensor with shape (batch_size, channels, height, width) - initial_rtg (float): initial return-to-go value used to initialize the reward-to-go tensor - action_pad_token (int): padding token used to initialize the actions tensor - batch_size (int): batch size of the sequences (default: 1)
Returns: - obs (torch.Tensor): tensor of shape (batch_size, max_len, channels, height, width), initialized with zeros and the initial observation in the last dimension - actions (torch.Tensor): tensor of shape (batch_size, max_len - 1, 1), initialized with the padding token - reward (torch.Tensor): tensor of shape (batch_size, max_len, 1), initialized with zeros - rtg (torch.Tensor): tensor of shape (1, max_len, 1), initialized with the initial rtg value and broadcasted to the batch size dimension - timesteps (torch.Tensor): tensor of shape (batch_size, max_len, 1), initialized with zeros - mask (torch.Tensor): tensor of shape (batch_size, max_len), initialized with zeros and ones at the last position to mark the end of the sequence
- src.decision_transformer.utils.load_decision_transformer(model_path, env=None, tlens_weight_processing=False) DecisionTransformer #
- src.decision_transformer.utils.parse_args()#