"""Custom model class for Cambrian. This class is a subclass of the PPO model fromStable Baselines 3. It overrides the save and load methods to only save the policyweights. It also adds a method to load rollout data from a previous training run. Thepredict method is then overwritten to return the next action in the rollout if therollout data is loaded. This is useful for testing the evolutionary loop withouthaving to train the agent each time."""importpicklefrompathlibimportPathfromtypingimportAny,Dict,Listimporttorchfromstable_baselines3importPPOfromcambrian.utils.loggerimportget_logger
[docs]defsave_policy(self,path:Path|str):"""Overwrite the save method. Instead of saving the entire state, we'll just save the policy weights."""path=Path(path)path.mkdir(parents=True,exist_ok=True)torch.save(self.policy.state_dict(),path/"policy.pt")
[docs]defload_policy(self,path:Path|str):"""Overwrite the load method. Instead of loading the entire state, we'll just load the policy weights. There are four cases to consider: - A layer in the saved policy is identical in shape to the current policy - Do nothing for this layer - A layer is both present in the saved policy and the current policy, but the shapes are different - Delete the layer from the saved policy - A layer is present in the saved policy but not the current policy - Delete the layer from the saved policy - A layer is present in the current policy but not the saved policy - Do nothing for this layer. By setting `strict=False` in the call to `load_state_dict`, we can ignore this layer. """policy_path=Path(path)/"policy.pt"ifnotpolicy_path.exists():raiseFileNotFoundError(f"Could not find policy.pt file at {policy_path}.")# Loop through the loaded state_dict and remove any layers that don't match in# shape with the current policysaved_state_dict=torch.load(policy_path)policy_state_dict=self.policy.state_dict()forsaved_state_dict_keyinlist(saved_state_dict.keys()):ifsaved_state_dict_keynotinpolicy_state_dict:get_logger().warning(f"Key '{saved_state_dict_key}' not found in policy ""state_dict. Deleting from saved state dict.")delsaved_state_dict[saved_state_dict_key]continuesaved_state_dict_var=saved_state_dict[saved_state_dict_key]policy_state_dict_var=policy_state_dict[saved_state_dict_key]ifsaved_state_dict_var.shape!=policy_state_dict_var.shape:get_logger().warning(f"Shape mismatch for key '{saved_state_dict_key}'")delsaved_state_dict[saved_state_dict_key]self.policy.load_state_dict(saved_state_dict,strict=False)
[docs]defload_rollout(self,path:Path|str):"""Load the rollout data from a previous training run. The rollout is a list of actions based on a current step. The model.predict call will then be overwritten to return the next action. This loader is "dumb" in the sense that it doesn't actually process the observations when it's using rollout, it will simply keep track of the current step and return the next action in the rollout. """withopen(path,"rb")asf:self._rollout=pickle.load(f)["actions"]
@classmethod
[docs]defload_weights(cls,weights:Dict[str,List[float]],**kwargs):"""Load the weights for the policy. This is useful for testing the evolutionary loop without having to train the agent each time."""model=cls(**kwargs)# Iteratively load the weights into the modelstate_dict=model.policy.state_dict()forname,weightinweights.items():name=name.replace("__",".")weight=torch.tensor(weight)assertnameinstate_dict,f"Layer {name} not found in model"assertstate_dict[name].shape==weight.shape,(f"Shape mismatch for layer {name}: {state_dict[name].shape} != "f"{weight.shape}")state_dict[name]=weightreturnmodel