Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions notebooks/mdp_policy_gradient.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"source": [
"import os\n",
"\n",
"from behavior_generation_lecture_python.mdp.policy import CategorialPolicy\n",
"from behavior_generation_lecture_python.mdp.policy import CategoricalPolicy\n",
"from behavior_generation_lecture_python.utils.grid_plotting import (\n",
" make_plot_policy_step_function,\n",
")\n",
Expand Down Expand Up @@ -47,7 +47,7 @@
"metadata": {},
"outputs": [],
"source": [
"policy = CategorialPolicy(\n",
"policy = CategoricalPolicy(\n",
" sizes=[len(grid_mdp.initial_state), 32, len(grid_mdp.actions)],\n",
" actions=list(grid_mdp.actions),\n",
")"
Expand Down Expand Up @@ -146,7 +146,7 @@
"metadata": {},
"outputs": [],
"source": [
"policy = CategorialPolicy(\n",
"policy = CategoricalPolicy(\n",
" sizes=[len(highway_mdp.initial_state), 32, len(highway_mdp.actions)],\n",
" actions=list(highway_mdp.actions),\n",
")"
Expand Down
26 changes: 16 additions & 10 deletions src/behavior_generation_lecture_python/mdp/mdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import numpy as np
import torch

from behavior_generation_lecture_python.mdp.policy import CategorialPolicy
from behavior_generation_lecture_python.mdp.policy import CategoricalPolicy

SIMPLE_MDP_DICT = {
"states": [1, 2],
Expand Down Expand Up @@ -147,9 +147,9 @@ def get_transitions_with_probabilities(
def sample_next_state(self, state, action) -> Any:
"""Randomly sample the next state given the current state and taken action."""
if self.is_terminal(state):
return ValueError("No next state for terminal states.")
raise ValueError("No next state for terminal states.")
if action is None:
return ValueError("Action must not be None.")
raise ValueError("Action must not be None.")
prob_per_transition = self.get_transitions_with_probabilities(state, action)
num_actions = len(prob_per_transition)
choice = np.random.choice(
Expand Down Expand Up @@ -431,6 +431,7 @@ def q_learning(
alpha: float,
epsilon: float,
iterations: int,
seed: Optional[int] = None,
return_history: Optional[bool] = False,
) -> Union[QTable, List[QTable]]:
"""Derive a value estimate for state-action pairs by means of Q learning.
Expand All @@ -441,22 +442,24 @@ def q_learning(
epsilon: Exploration-exploitation threshold. A random action is taken with
probability epsilon, the best action otherwise.
iterations: Number of iterations.
seed: Random seed for reproducibility (default: None).
return_history: Whether to return the whole history of value estimates
instead of just the final estimate.

Returns:
The final value estimate, if return_history is false. The
history of value estimates as list, if return_history is true.
"""
if seed is not None:
np.random.seed(seed)

q_table = {}
for state in mdp.get_states():
for action in mdp.get_actions(state):
q_table[(state, action)] = 0.0
q_table_history = [q_table.copy()]
state = mdp.initial_state

np.random.seed(1337)

for _ in range(iterations):
# available actions:
avail_actions = mdp.get_actions(state)
Expand Down Expand Up @@ -528,14 +531,15 @@ def mean_episode_length(self) -> float:
def policy_gradient(
*,
mdp: MDP,
policy: CategorialPolicy,
policy: CategoricalPolicy,
lr: float = 1e-2,
iterations: int = 50,
batch_size: int = 5000,
seed: Optional[int] = None,
return_history: bool = False,
use_random_init_state: bool = False,
verbose: bool = True,
) -> Union[List[CategorialPolicy], CategorialPolicy]:
) -> Union[List[CategoricalPolicy], CategoricalPolicy]:
"""Train a paramterized policy using vanilla policy gradient.

Adapted from: https://github.com/openai/spinningup/blob/master/spinup/examples/pytorch/pg_math/1_simple_pg.py
Expand All @@ -556,6 +560,7 @@ def policy_gradient(
lr: Learning rate.
iterations: Number of iterations.
batch_size: Number of samples generated for each policy update.
seed: Random seed for reproducibility (default: None).
return_history: Whether to return the whole history of value estimates
instead of just the final estimate.
use_random_init_state: bool, if the agent should be initialized randomly.
Expand All @@ -565,8 +570,9 @@ def policy_gradient(
The final policy, if return_history is false. The
history of policies as list, if return_history is true.
"""
np.random.seed(1337)
torch.manual_seed(1337)
if seed is not None:
np.random.seed(seed)
torch.manual_seed(seed)

# add untrained model to model_checkpoints
model_checkpoints = [deepcopy(policy)]
Expand Down Expand Up @@ -650,7 +656,7 @@ def policy_gradient(
return policy


def derive_deterministic_policy(mdp: MDP, policy: CategorialPolicy) -> Dict[Any, Any]:
def derive_deterministic_policy(mdp: MDP, policy: CategoricalPolicy) -> Dict[Any, Any]:
"""Compute the best policy for an MDP given the stochastic policy.

Args:
Expand Down
66 changes: 54 additions & 12 deletions src/behavior_generation_lecture_python/mdp/policy.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""This module contains the CategoricalPolicy implementation."""

from typing import List, Type
from typing import Any, List, Optional, Type

import torch
from torch import nn
Expand All @@ -23,36 +23,78 @@ def multi_layer_perceptron(
return mlp


class CategorialPolicy:
def __init__(self, sizes: List[int], actions: List):
class CategoricalPolicy:
"""A categorical policy parameterized by a neural network."""

def __init__(
self, sizes: List[int], actions: List[Any], seed: Optional[int] = None
) -> None:
"""Initialize the categorical policy.

Args:
sizes: List of layer sizes for the MLP.
actions: List of available actions.
seed: Random seed for reproducibility (default: None).
"""
assert sizes[-1] == len(actions)
torch.manual_seed(1337)
if seed is not None:
torch.manual_seed(seed)
self.net = multi_layer_perceptron(sizes=sizes)
self.actions = actions
self._actions_tensor = torch.tensor(actions, dtype=torch.long).view(
len(actions), -1
)

def _get_distribution(self, state: torch.Tensor):
"""Calls the model and returns a categorial distribution over the actions."""
def _get_distribution(self, state: torch.Tensor) -> Categorical:
"""Calls the model and returns a categorical distribution over the actions.

Args:
state: The current state tensor.

Returns:
A categorical distribution over actions.
"""
logits = self.net(state)
return Categorical(logits=logits)

def get_action(self, state: torch.Tensor, deterministic: bool = False):
"""Returns an action sample for the given state"""
def get_action(self, state: torch.Tensor, deterministic: bool = False) -> Any:
"""Returns an action sample for the given state.

Args:
state: The current state tensor.
deterministic: If True, return the most likely action.

Returns:
The selected action.
"""
policy = self._get_distribution(state)
if deterministic:
return self.actions[policy.mode.item()]
return self.actions[policy.sample().item()]

def get_log_prob(self, states: torch.Tensor, actions: torch.Tensor):
"""Returns the log-probability for taking the action, when being the given state"""
def get_log_prob(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
"""Returns the log-probability for taking the action, when being in the given state.

Args:
states: Batch of state tensors.
actions: Batch of action tensors.

Returns:
Log-probabilities of the actions.
"""
return self._get_distribution(states).log_prob(
self._get_action_id_from_action(actions)
)

def _get_action_id_from_action(self, actions: torch.Tensor):
"""Returns the indices of the passed actions in self.actions"""
def _get_action_id_from_action(self, actions: torch.Tensor) -> torch.Tensor:
"""Returns the indices of the passed actions in self.actions.

Args:
actions: Batch of action tensors.

Returns:
Tensor of action indices.
"""
reshaped_actions = actions.unsqueeze(1).expand(
-1, self._actions_tensor.size(0), -1
)
Expand Down
10 changes: 7 additions & 3 deletions tests/test_mdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
random_action,
value_iteration,
)
from behavior_generation_lecture_python.mdp.policy import CategorialPolicy
from behavior_generation_lecture_python.mdp.policy import CategoricalPolicy


def test_init_mdp():
Expand Down Expand Up @@ -151,22 +151,26 @@ def test_q_learning(return_history):
alpha=0.1,
epsilon=0.1,
iterations=10000,
seed=1337,
return_history=return_history,
)


@pytest.mark.parametrize("return_history", (True, False))
def test_policy_gradient(return_history):
mdp = GridMDP(**GRID_MDP_DICT)
pol = CategorialPolicy(
sizes=[len(mdp.initial_state), 32, len(mdp.actions)], actions=list(mdp.actions)
pol = CategoricalPolicy(
sizes=[len(mdp.initial_state), 32, len(mdp.actions)],
actions=list(mdp.actions),
seed=1337,
)
assert policy_gradient(
mdp=mdp,
policy=pol,
lr=1e2,
iterations=5,
batch_size=5000,
seed=1337,
return_history=return_history,
verbose=False,
)