Source code for hisss.game.rewards

from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from enum import Enum

import numpy as np


class BattleSnakeRewardType(Enum):
    """Enumeration of available BattleSnake reward function types."""

    STANDARD = "STANDARD"
    KILL = "KILL"
    COOP = "COOP"


[docs] @dataclass class BattleSnakeRewardConfig: """Base configuration class for BattleSnake reward functions.""" #: The reward given to a snake for surviving a turn. living_reward: float = 0.0 #: The base reward given/taken when a snake wins/dies. terminal_reward: float = 1.0
class BattleSnakeRewardFunction(ABC): """Abstract base class for BattleSnake reward functions. Attributes: cfg (BattleSnakeRewardConfig): The configuration governing this reward function. """ def __init__(self, cfg: BattleSnakeRewardConfig): self.cfg = cfg @abstractmethod def __call__( self, is_terminal: bool, num_players: int, players_at_turn: list[int], players_at_turn_last: list[int], ) -> np.ndarray: raise NotImplementedError()
[docs] @dataclass class StandardBattleSnakeRewardConfig(BattleSnakeRewardConfig): """Configuration for the Standard BattleSnake reward function. Assigns a negative terminal reward to players who died this turn. The last remaining player receives a positive terminal reward. Surviving players receive a constant living reward (if configured) while the game is ongoing. """ pass
class BattleSnakeRewardFunctionStandard(BattleSnakeRewardFunction): """Standard reward function for BattleSnake. Assigns a negative terminal reward to players who died this turn. The last remaining player receives a positive terminal reward. Surviving players receive a constant living reward (if configured) while the game is ongoing. """ def __init__(self, cfg: StandardBattleSnakeRewardConfig): super().__init__(cfg) self.cfg = cfg def __call__( self, is_terminal: bool, num_players: int, players_at_turn: list[int], players_at_turn_last: list[int], ) -> np.ndarray: rewards = np.zeros(shape=(num_players,), dtype=float) num_at_turn = len(players_at_turn) # if everyone died, then nobody gets any reward if num_at_turn == 0: return rewards player_died = set(players_at_turn_last) - set(players_at_turn) # all players that died this round get a negative terminal reward for player in player_died: rewards[player] = -self.cfg.terminal_reward # last player alive gets positive reward if num_at_turn == 1: rewards[players_at_turn[0]] = self.cfg.terminal_reward return rewards # if game has not ended, all player alive get the living reward for ( player ) in players_at_turn: # all players still alive get a positive living reward rewards[player] = self.cfg.living_reward return rewards
[docs] @dataclass class KillBattleSnakeRewardConfig(BattleSnakeRewardConfig): """Configuration for the Kill-based BattleSnake reward function. This reward function acts as a zero-sum or monotone game where cumulative un-discounted rewards balance out: - 2 players alive: zero-sum game around cum_reward of 2/3 - 3 players alive: monotone game around cum_reward of 1/3 (kill reward is 1/3) - 4 players alive: monotone game around starting point of zero """ pass
class BattleSnakeRewardFunctionKill(BattleSnakeRewardFunction): """Reward function that assigns fractional rewards when enemies are killed. This reward function acts as a zero-sum or monotone game where cumulative un-discounted rewards balance out: - 2 players alive: zero-sum game around cum_reward of 2/3 - 3 players alive: monotone game around cum_reward of 1/3 (kill reward is 1/3) - 4 players alive: monotone game around starting point of zero """ def __init__(self, cfg: KillBattleSnakeRewardConfig): super().__init__(cfg) self.cfg = cfg if self.cfg.living_reward != 0: raise ValueError( "Kill reward function does not work with living reward due to fixed maximum reward" ) if self.cfg.terminal_reward != 1: raise ValueError("Need terminal reward on one for kill reward function") def __call__( self, is_terminal: bool, num_players: int, players_at_turn: list[int], players_at_turn_last: list[int], ) -> np.ndarray: # rewards still need to be scaled between -1 and 1. Technically max reward is 3 if all enemies kill themselves # in the same move, but we clip this rare situation rewards = np.zeros(shape=(num_players,), dtype=float) num_at_turn = len(players_at_turn) num_at_turn_last = len(players_at_turn_last) player_died = set(players_at_turn_last) - set(players_at_turn) # if everyone died, then nobody gets any reward if num_at_turn == 0: return rewards # if no one died, then no one gets any reward if len(player_died) == 0: return rewards # reward of dead players depends on number of dead players in last round for player in player_died: if num_at_turn_last == 4 and num_at_turn == 3: rewards[player] = -1 elif num_at_turn_last == 4 and num_at_turn == 2: rewards[player] = -2 / 3 elif num_at_turn_last == 4 and num_at_turn == 1: rewards[player] = -1 / 3 elif num_at_turn_last == 3 and num_at_turn == 2: rewards[player] = -2 / 3 elif num_at_turn_last == 3 and num_at_turn == 1: rewards[player] = -1 / 3 elif num_at_turn_last == 2 and num_at_turn == 1: rewards[player] = -1 / 3 # reward of players that are still alive for player in players_at_turn: if num_at_turn_last == 4 and num_at_turn == 3: rewards[player] = 1 / 3 elif num_at_turn_last == 4 and num_at_turn == 2: rewards[player] = 2 / 3 elif num_at_turn_last == 4 and num_at_turn == 1: rewards[player] = 1 elif num_at_turn_last == 3 and num_at_turn == 2: rewards[player] = 1 / 3 elif num_at_turn_last == 3 and num_at_turn == 1: rewards[player] = 2 / 3 elif num_at_turn_last == 2 and num_at_turn == 1: rewards[player] = 1 / 3 return rewards @dataclass class CooperationBattleSnakeRewardConfig(BattleSnakeRewardConfig): """Configuration for the Cooperation BattleSnake reward function. Defaults to providing a small living reward and a negative terminal reward when a snake dies. """ living_reward: float = field(default=0.02) terminal_reward: float = -0.25 class BattleSnakeRewardFunctionCooperation(BattleSnakeRewardFunction): def __init__(self, cfg: CooperationBattleSnakeRewardConfig): super().__init__(cfg) self.cfg = cfg def __call__( self, is_terminal: bool, num_players: int, players_at_turn: list[int], players_at_turn_last: list[int], ) -> np.ndarray: num_at_turn = len(players_at_turn) num_at_turn_last = len(players_at_turn_last) num_dead = num_at_turn_last - num_at_turn # all players get negative terminal reward if a snake died rewards = np.zeros(shape=(num_players,), dtype=float) player_died = [p for p in players_at_turn_last if p not in players_at_turn] for p in player_died: rewards[p] += self.cfg.terminal_reward * num_at_turn_last for p in players_at_turn: rewards[p] += self.cfg.terminal_reward * num_dead rewards[p] += self.cfg.living_reward return rewards def get_battlesnake_reward_func_from_cfg( cfg: BattleSnakeRewardConfig, ) -> BattleSnakeRewardFunction: if isinstance(cfg, StandardBattleSnakeRewardConfig): return BattleSnakeRewardFunctionStandard(cfg) elif isinstance(cfg, KillBattleSnakeRewardConfig): return BattleSnakeRewardFunctionKill(cfg) elif isinstance(cfg, CooperationBattleSnakeRewardConfig): return BattleSnakeRewardFunctionCooperation(cfg) else: raise ValueError(f"Unknown reward function type: {cfg}")