diff --git a/notebooks/mdp_policy_gradient.ipynb b/notebooks/mdp_policy_gradient.ipynb index 22dce58..b842bbb 100644 --- a/notebooks/mdp_policy_gradient.ipynb +++ b/notebooks/mdp_policy_gradient.ipynb @@ -8,7 +8,7 @@ "source": [ "import os\n", "\n", - "from behavior_generation_lecture_python.mdp.policy import CategorialPolicy\n", + "from behavior_generation_lecture_python.mdp.policy import CategoricalPolicy\n", "from behavior_generation_lecture_python.utils.grid_plotting import (\n", " make_plot_policy_step_function,\n", ")\n", @@ -47,7 +47,7 @@ "metadata": {}, "outputs": [], "source": [ - "policy = CategorialPolicy(\n", + "policy = CategoricalPolicy(\n", " sizes=[len(grid_mdp.initial_state), 32, len(grid_mdp.actions)],\n", " actions=list(grid_mdp.actions),\n", ")" @@ -146,7 +146,7 @@ "metadata": {}, "outputs": [], "source": [ - "policy = CategorialPolicy(\n", + "policy = CategoricalPolicy(\n", " sizes=[len(highway_mdp.initial_state), 32, len(highway_mdp.actions)],\n", " actions=list(highway_mdp.actions),\n", ")" diff --git a/src/behavior_generation_lecture_python/mdp/mdp.py b/src/behavior_generation_lecture_python/mdp/mdp.py index a79a9d6..26fe0cc 100644 --- a/src/behavior_generation_lecture_python/mdp/mdp.py +++ b/src/behavior_generation_lecture_python/mdp/mdp.py @@ -8,7 +8,7 @@ import numpy as np import torch -from behavior_generation_lecture_python.mdp.policy import CategorialPolicy +from behavior_generation_lecture_python.mdp.policy import CategoricalPolicy SIMPLE_MDP_DICT = { "states": [1, 2], @@ -147,9 +147,9 @@ def get_transitions_with_probabilities( def sample_next_state(self, state, action) -> Any: """Randomly sample the next state given the current state and taken action.""" if self.is_terminal(state): - return ValueError("No next state for terminal states.") + raise ValueError("No next state for terminal states.") if action is None: - return ValueError("Action must not be None.") + raise ValueError("Action must not be None.") prob_per_transition = self.get_transitions_with_probabilities(state, action) num_actions = len(prob_per_transition) choice = np.random.choice( @@ -431,6 +431,7 @@ def q_learning( alpha: float, epsilon: float, iterations: int, + seed: Optional[int] = None, return_history: Optional[bool] = False, ) -> Union[QTable, List[QTable]]: """Derive a value estimate for state-action pairs by means of Q learning. @@ -441,6 +442,7 @@ def q_learning( epsilon: Exploration-exploitation threshold. A random action is taken with probability epsilon, the best action otherwise. iterations: Number of iterations. + seed: Random seed for reproducibility (default: None). return_history: Whether to return the whole history of value estimates instead of just the final estimate. @@ -448,6 +450,9 @@ def q_learning( The final value estimate, if return_history is false. The history of value estimates as list, if return_history is true. """ + if seed is not None: + np.random.seed(seed) + q_table = {} for state in mdp.get_states(): for action in mdp.get_actions(state): @@ -455,8 +460,6 @@ def q_learning( q_table_history = [q_table.copy()] state = mdp.initial_state - np.random.seed(1337) - for _ in range(iterations): # available actions: avail_actions = mdp.get_actions(state) @@ -528,14 +531,15 @@ def mean_episode_length(self) -> float: def policy_gradient( *, mdp: MDP, - policy: CategorialPolicy, + policy: CategoricalPolicy, lr: float = 1e-2, iterations: int = 50, batch_size: int = 5000, + seed: Optional[int] = None, return_history: bool = False, use_random_init_state: bool = False, verbose: bool = True, -) -> Union[List[CategorialPolicy], CategorialPolicy]: +) -> Union[List[CategoricalPolicy], CategoricalPolicy]: """Train a paramterized policy using vanilla policy gradient. Adapted from: https://github.com/openai/spinningup/blob/master/spinup/examples/pytorch/pg_math/1_simple_pg.py @@ -556,6 +560,7 @@ def policy_gradient( lr: Learning rate. iterations: Number of iterations. batch_size: Number of samples generated for each policy update. + seed: Random seed for reproducibility (default: None). return_history: Whether to return the whole history of value estimates instead of just the final estimate. use_random_init_state: bool, if the agent should be initialized randomly. @@ -565,8 +570,9 @@ def policy_gradient( The final policy, if return_history is false. The history of policies as list, if return_history is true. """ - np.random.seed(1337) - torch.manual_seed(1337) + if seed is not None: + np.random.seed(seed) + torch.manual_seed(seed) # add untrained model to model_checkpoints model_checkpoints = [deepcopy(policy)] @@ -650,7 +656,7 @@ def policy_gradient( return policy -def derive_deterministic_policy(mdp: MDP, policy: CategorialPolicy) -> Dict[Any, Any]: +def derive_deterministic_policy(mdp: MDP, policy: CategoricalPolicy) -> Dict[Any, Any]: """Compute the best policy for an MDP given the stochastic policy. Args: diff --git a/src/behavior_generation_lecture_python/mdp/policy.py b/src/behavior_generation_lecture_python/mdp/policy.py index 28c7618..5855f6d 100644 --- a/src/behavior_generation_lecture_python/mdp/policy.py +++ b/src/behavior_generation_lecture_python/mdp/policy.py @@ -1,6 +1,6 @@ """This module contains the CategoricalPolicy implementation.""" -from typing import List, Type +from typing import Any, List, Optional, Type import torch from torch import nn @@ -23,36 +23,78 @@ def multi_layer_perceptron( return mlp -class CategorialPolicy: - def __init__(self, sizes: List[int], actions: List): +class CategoricalPolicy: + """A categorical policy parameterized by a neural network.""" + + def __init__( + self, sizes: List[int], actions: List[Any], seed: Optional[int] = None + ) -> None: + """Initialize the categorical policy. + + Args: + sizes: List of layer sizes for the MLP. + actions: List of available actions. + seed: Random seed for reproducibility (default: None). + """ assert sizes[-1] == len(actions) - torch.manual_seed(1337) + if seed is not None: + torch.manual_seed(seed) self.net = multi_layer_perceptron(sizes=sizes) self.actions = actions self._actions_tensor = torch.tensor(actions, dtype=torch.long).view( len(actions), -1 ) - def _get_distribution(self, state: torch.Tensor): - """Calls the model and returns a categorial distribution over the actions.""" + def _get_distribution(self, state: torch.Tensor) -> Categorical: + """Calls the model and returns a categorical distribution over the actions. + + Args: + state: The current state tensor. + + Returns: + A categorical distribution over actions. + """ logits = self.net(state) return Categorical(logits=logits) - def get_action(self, state: torch.Tensor, deterministic: bool = False): - """Returns an action sample for the given state""" + def get_action(self, state: torch.Tensor, deterministic: bool = False) -> Any: + """Returns an action sample for the given state. + + Args: + state: The current state tensor. + deterministic: If True, return the most likely action. + + Returns: + The selected action. + """ policy = self._get_distribution(state) if deterministic: return self.actions[policy.mode.item()] return self.actions[policy.sample().item()] - def get_log_prob(self, states: torch.Tensor, actions: torch.Tensor): - """Returns the log-probability for taking the action, when being the given state""" + def get_log_prob(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor: + """Returns the log-probability for taking the action, when being in the given state. + + Args: + states: Batch of state tensors. + actions: Batch of action tensors. + + Returns: + Log-probabilities of the actions. + """ return self._get_distribution(states).log_prob( self._get_action_id_from_action(actions) ) - def _get_action_id_from_action(self, actions: torch.Tensor): - """Returns the indices of the passed actions in self.actions""" + def _get_action_id_from_action(self, actions: torch.Tensor) -> torch.Tensor: + """Returns the indices of the passed actions in self.actions. + + Args: + actions: Batch of action tensors. + + Returns: + Tensor of action indices. + """ reshaped_actions = actions.unsqueeze(1).expand( -1, self._actions_tensor.size(0), -1 ) diff --git a/tests/test_mdp.py b/tests/test_mdp.py index dde34d7..d391a84 100644 --- a/tests/test_mdp.py +++ b/tests/test_mdp.py @@ -14,7 +14,7 @@ random_action, value_iteration, ) -from behavior_generation_lecture_python.mdp.policy import CategorialPolicy +from behavior_generation_lecture_python.mdp.policy import CategoricalPolicy def test_init_mdp(): @@ -151,6 +151,7 @@ def test_q_learning(return_history): alpha=0.1, epsilon=0.1, iterations=10000, + seed=1337, return_history=return_history, ) @@ -158,8 +159,10 @@ def test_q_learning(return_history): @pytest.mark.parametrize("return_history", (True, False)) def test_policy_gradient(return_history): mdp = GridMDP(**GRID_MDP_DICT) - pol = CategorialPolicy( - sizes=[len(mdp.initial_state), 32, len(mdp.actions)], actions=list(mdp.actions) + pol = CategoricalPolicy( + sizes=[len(mdp.initial_state), 32, len(mdp.actions)], + actions=list(mdp.actions), + seed=1337, ) assert policy_gradient( mdp=mdp, @@ -167,6 +170,7 @@ def test_policy_gradient(return_history): lr=1e2, iterations=5, batch_size=5000, + seed=1337, return_history=return_history, verbose=False, )