Model#

class safepo.common.model.Actor(obs_dim: int, act_dim: int, hidden_sizes: list = [64, 64])#

Bases: Module

Actor network for policy-based reinforcement learning.

This class represents an actor network that outputs a distribution over actions given observations.

Parameters:
  • obs_dim (int) – Dimensionality of the observation space.

  • act_dim (int) – Dimensionality of the action space.

Variables:
  • mean (nn.Sequential) – MLP network representing the mean of the action distribution.

  • log_std (nn.Parameter) – Learnable parameter representing the log standard deviation of the action distribution.

Example

obs_dim = 10 act_dim = 2 actor = Actor(obs_dim, act_dim) observation = torch.randn(1, obs_dim) action_distribution = actor(observation)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(obs: Tensor)#

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool#
class safepo.common.model.VCritic(obs_dim, hidden_sizes: list = [64, 64])#

Bases: Module

Critic network for value-based reinforcement learning.

This class represents a critic network that estimates the value function for input observations.

Parameters:

obs_dim (int) – Dimensionality of the observation space.

Variables:

critic (nn.Sequential) – MLP network representing the critic function.

Example

obs_dim = 10 critic = VCritic(obs_dim) observation = torch.randn(1, obs_dim) value_estimate = critic(observation)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(obs)#

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool#
class safepo.common.model.ActorVCritic(obs_dim, act_dim, hidden_sizes: list = [64, 64])#

Bases: Module

Actor-critic policy for reinforcement learning.

This class represents an actor-critic policy that includes an actor network, two critic networks for reward and cost estimation, and provides methods for taking policy steps and estimating values.

Parameters:
  • obs_dim (int) – Dimensionality of the observation space.

  • act_dim (int) – Dimensionality of the action space.

Example

obs_dim = 10 act_dim = 2 actor_critic = ActorVCritic(obs_dim, act_dim) observation = torch.randn(1, obs_dim) action, log_prob, reward_value, cost_value = actor_critic.step(observation) value_estimate = actor_critic.get_value(observation)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

get_value(obs)#

Estimate the value of observations using the critic network.

Parameters:

obs (torch.Tensor) – Input observation tensor.

Returns:

torch.Tensor – Estimated value for the input observation.

step(obs, deterministic=False)#

Take a policy step based on observations.

Parameters:
  • obs (torch.Tensor) – Input observation tensor.

  • deterministic (bool) – Flag indicating whether to take a deterministic action.

Returns:

tuple – Tuple containing action tensor, log probabilities of the action, reward value estimate, and cost value estimate.

training: bool#
class safepo.common.model.MultiAgentActor(config, obs_space, action_space, device=torch.device('cpu'))#

Bases: Module

Multi-agent actor network for reinforcement learning.

This class represents a multi-agent actor network that takes observations as input and produces actions and action probabilities as outputs. It includes options for using recurrent layers and policy active masks.

Parameters:
  • config (dict) – Configuration parameters for the actor network.

  • obs_space – Observation space of the environment.

  • action_space – Action space of the environment.

  • device (torch.device) – Device to run the network on (default is “cpu”).

Variables:
  • hidden_size (int) – Size of the hidden layers.

  • config (dict) – Configuration parameters for the actor network.

  • _gain (float) – Gain factor for action scaling.

  • _use_orthogonal (bool) – Flag indicating whether to use orthogonal initialization.

  • _use_policy_active_masks (bool) – Flag indicating whether to use policy active masks.

  • _use_naive_recurrent_policy (bool) – Flag indicating whether to use naive recurrent policy.

  • _use_recurrent_policy (bool) – Flag indicating whether to use recurrent policy.

  • _recurrent_N (int) – Number of recurrent layers.

  • tpdv (dict) – Dictionary with data type and device for tensor conversion.

Example

config = {“hidden_size”: 256, “gain”: 0.1, …} obs_space = gym.spaces.Box(low=0, high=1, shape=(4,)) action_space = gym.spaces.Discrete(2) actor = MultiAgentActor(config, obs_space, action_space) observation = torch.randn(1, 4) rnn_states = torch.zeros(1, 256) masks = torch.ones(1, 1) actions, action_log_probs, new_rnn_states = actor(observation, rnn_states, masks) action = torch.tensor([0]) action_log_probs, dist_entropy = actor.evaluate_actions(observation, rnn_states, action, masks)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

evaluate_actions(obs, rnn_states, action, masks, available_actions=None, active_masks=None)#

Evaluate the actions based on the network’s policy.

Parameters:
  • obs (torch.Tensor) – Input observation tensor.

  • rnn_states (torch.Tensor) – Recurrent states tensor.

  • action (torch.Tensor) – Action tensor.

  • masks (torch.Tensor) – Mask tensor.

  • available_actions (torch.Tensor, optional) – Available actions tensor (default: None).

  • active_masks (torch.Tensor, optional) – Active masks tensor (default: None).

Returns:

tuple – Tuple containing action log probabilities tensor, distribution entropy tensor, action mean tensor, action standard deviation tensor, and other optional tensors.

forward(obs, rnn_states, masks, available_actions=None, deterministic=False)#

Perform a forward pass through the network to generate actions and log probabilities.

Parameters:
  • obs (torch.Tensor) – Input observation tensor.

  • rnn_states (torch.Tensor) – Recurrent states tensor.

  • masks (torch.Tensor) – Mask tensor.

  • available_actions (torch.Tensor, optional) – Available actions tensor (default: None).

  • deterministic (bool, optional) – Flag indicating whether to take deterministic actions (default: False).

Returns:

tuple – Tuple containing action tensor, log probability tensor, and new recurrent states tensor.

training: bool#
class safepo.common.model.MultiAgentCritic(config, cent_obs_space, device=torch.device('cuda:0'))#

Bases: Module

Multi-agent critic network.

This class represents a multi-agent critic network used in reinforcement learning algorithms. It consists of a base network (CNN or MLP), recurrent layers (if applicable), and a value output layer.

Parameters:
  • config (dict) – Configuration dictionary.

  • cent_obs_space (gym.spaces.Space) – Centralized observation space.

  • device (torch.device) – Device to use for computations (default: cuda:0).

Variables:
  • hidden_size (int) – Size of the hidden layer.

  • _use_orthogonal (bool) – Flag indicating whether to use orthogonal initialization.

  • _use_naive_recurrent_policy (bool) – Flag indicating whether to use naive recurrent policy.

  • _use_recurrent_policy (bool) – Flag indicating whether to use recurrent policy.

  • _recurrent_N (int) – Number of recurrent layers.

  • tpdv (dict) – Dictionary for tensor properties.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(cent_obs, rnn_states, masks)#

Perform a forward pass through the network to compute value estimates.

Parameters:
  • cent_obs (torch.Tensor) – Centralized observation tensor.

  • rnn_states (torch.Tensor) – Recurrent states tensor.

  • masks (torch.Tensor) – Mask tensor.

Returns:

tuple – Tuple containing value estimates tensor and new recurrent states tensor.

training: bool#