import ctypes as ct
import itertools
import math
import random
from typing import Optional, Any
import numpy as np
from hisss.cpp.lib import CPP_LIB
from hisss.game.state import (
CAUSE_INT_TO_STR,
CAUSE_STR_TO_INT,
BattleSnakeState,
EliminationEvent,
)
from hisss.game.config import (
BattleSnakeConfig,
encoding_layer_indices,
post_init_battlesnake_cfg,
validate_battlesnake_cfg,
)
from hisss.game.encoding import num_layers_general, layers_per_player, layers_per_enemy
from hisss.game.rewards import get_battlesnake_reward_func_from_cfg
from hisss.game.utils import int_to_perm
#: Constant representing the UP-direction in the Grid world
UP: int = 0
#: Constant representing the RIGHT-direction in the Grid world
RIGHT: int = 1
#: Constant representing the DOWN-direction in the Grid world
DOWN: int = 2
#: Constant representing the LEFT-direction in the Grid world
LEFT: int = 3
[docs]
class BattleSnakeGame:
"""Battlesnake game environment backed by a C++ simulation engine.
Wraps a heap-allocated C++ ``GameState`` object. Always call :meth:`close`
when the environment is no longer needed — or use it as a context manager —
to free the underlying C++ memory. :meth:`__del__` provides a fallback but
is not guaranteed to run promptly.
Attributes:
cfg: Game configuration used to create this environment.
turns_played: Number of turns elapsed since the last :meth:`reset`.
is_closed: Whether :meth:`close` has already been called.
layer_explanation: Mapping from channel name to index in the observation
tensor returned by :meth:`get_obs`.
"""
[docs]
def __init__(
self,
cfg: BattleSnakeConfig,
state_p=None, # Optional[ct.POINTER(Struct)],
):
"""Initialise a new BattleSnakeGame.
Args:
cfg: Configuration dataclass that controls board size, number of
players, food rules, game mode, encoding, and rewards.
state_p: Optional pre-existing C++ game-state pointer. When
``None`` (the default) a fresh game is created from *cfg*.
Pass an existing pointer only when cloning an environment
internally.
Note:
:func:`post_init_battlesnake_cfg` and
:func:`validate_battlesnake_cfg` are called automatically when
*state_p* is ``None``.
"""
self.cfg = cfg
self._cum_rewards = np.zeros(shape=(self.cfg.num_players,), dtype=float)
self._last_actions: Optional[tuple[int, ...]] = None
self.turns_played = 0
self.is_closed = False
self.cfg = cfg
self.turns_played = self.cfg.init_turns_played
# state pointer
self.state_p = state_p
if self.state_p is None:
# this is the first time the game was started, run post init and validate
post_init_battlesnake_cfg(self.cfg)
validate_battlesnake_cfg(self.cfg)
self._init_cpp()
# attributes for saving the current game state
self.obs_dict: dict[int, np.ndarray] = dict()
self.available_actions_save: dict[int, list[int]] = dict()
self.players_at_turn_save: Optional[list[int]] = None
self.players_at_turn_last: Optional[list[int]] = None # property of last step
self.players_alive_save: Optional[list[int]] = None
self.players_alive_last: Optional[list[int]] = None # property of last step
self.reward_func = get_battlesnake_reward_func_from_cfg(self.cfg.reward_cfg)
self.layer_explanation = encoding_layer_indices(self.cfg)
@property
def num_actions(self):
"""Number of actions available to each player (always 4: UP, RIGHT, DOWN, LEFT)."""
return self.cfg.num_actions
[docs]
def available_joint_actions(self) -> list[tuple[int, ...]]:
"""Return every legal combination of actions for all players currently at turn.
Legal actions may be restricted by actions that would lead to a certain death,
depending on the game configuration.
Returns:
List of tuples, one entry per player at turn. Each tuple element
is an action index for the corresponding player returned by
:meth:`players_at_turn`.
"""
action_lists = []
for player in range(self.num_players):
current_list = self.available_actions(player)
if current_list:
action_lists.append(current_list)
result = list(itertools.product(*action_lists))
return result
[docs]
def illegal_actions(self, player: int) -> list[int]:
"""Return the action indices that are illegal for *player*.
Args:
player: Zero-based player index.
Returns:
List of action indices not present in :meth:`available_actions`.
Raises:
ValueError: If *player* is out of range.
"""
if player < 0 or player >= self.num_players:
raise ValueError(f"Snake index out of range: {player}")
all_action_set = {i for i in range(self.num_actions)}
legal_action_set: set[int] = set(self.available_actions(player))
illegal_action_set = all_action_set - legal_action_set
illegal_actions = list(illegal_action_set)
return illegal_actions
[docs]
def illegal_joint_actions(self) -> list[tuple[int, ...]]:
"""Return every joint-action combination that is not fully legal.
Returns:
List of tuples that are absent from :meth:`available_joint_actions`.
"""
action_lists = [
[i for i in range(self.num_actions)] for _ in range(self.num_players)
]
all_action_set: set[tuple[int, ...]] = set(itertools.product(*action_lists))
available_joint_action_set: set[tuple[int, ...]] = set(
self.available_joint_actions()
)
illegal_action_set = all_action_set - available_joint_action_set
return list(illegal_action_set)
def _init_cpp(self):
# snakes
spawn_snakes_randomly = True if self.cfg.init_snake_pos is None else False
if self.cfg.init_snake_pos is not None:
# find the longest body, this determines the array shape
snake_pos = {} # we need a separate dict to not alter the config object
body_lengths = []
for s in range(self.cfg.num_players):
cur_snake_pos = []
for pos in self.cfg.init_snake_pos[s]:
cur_snake_pos.append((pos[0], pos[1]))
cur_snake_pos = list(dict.fromkeys(cur_snake_pos)) # remove duplicates
snake_pos[s] = cur_snake_pos
body_lengths.append(len(cur_snake_pos))
max_body_len = max(body_lengths)
body_len_arr = np.asarray(body_lengths, dtype=ct.c_int)
body_len_p = body_len_arr.ctypes.data_as(ct.POINTER(ct.c_int))
snake_pos_arr = (
np.zeros(shape=(self.cfg.num_players, max_body_len, 2), dtype=ct.c_int)
- 1
)
for s in range(self.cfg.num_players): # convert dictionary to numpy array
for i, pos in enumerate(snake_pos[s]):
snake_pos_arr[s, i, 0] = pos[0]
snake_pos_arr[s, i, 1] = pos[1]
snake_pos_p = snake_pos_arr.ctypes.data_as(ct.POINTER(ct.c_int))
else:
body_len_p = ct.cast(0, ct.POINTER(ct.c_int)) # NULL-Pointer
snake_pos_p = ct.cast(0, ct.POINTER(ct.c_int)) # NULL-Pointer
max_body_len = -1
# snake length
snake_len_arr = np.asarray(self.cfg.init_snake_len, dtype=ct.c_int)
snake_len_p = snake_len_arr.ctypes.data_as(ct.POINTER(ct.c_int))
# snakes alive
snake_alive_arr = np.asarray(self.cfg.init_snakes_alive, dtype=bool)
snake_alive_p = snake_alive_arr.ctypes.data_as(ct.POINTER(ct.c_bool))
# health
snake_health_arr = np.asarray(self.cfg.init_snake_health, dtype=ct.c_int)
snake_health_p = snake_health_arr.ctypes.data_as(ct.POINTER(ct.c_int))
snake_max_health_arr = np.asarray(self.cfg.max_snake_health, dtype=ct.c_int)
snake_max_health_p = snake_max_health_arr.ctypes.data_as(ct.POINTER(ct.c_int))
# food
if self.cfg.init_food_pos is None:
num_init_food = -1 # -1 is signal for random food spawning
food_pos_p = ct.cast(0, ct.POINTER(ct.c_int)) # NULL-Pointer
elif not self.cfg.init_food_pos:
num_init_food = (
-2
) # flag to indicate that no food should be spawned at beginning
food_pos_p = ct.cast(0, ct.POINTER(ct.c_int))
else:
num_init_food = len(self.cfg.init_food_pos)
np_arr = np.asarray(self.cfg.init_food_pos, dtype=ct.c_int)
food_pos_p = np_arr.ctypes.data_as(ct.POINTER(ct.c_int))
# food spawn turns: NULL lets C++ fill in init_turns_played as fallback
food_spawn_turns_p = ct.cast(0, ct.POINTER(ct.c_int))
# hazards, we need to transpose because cpp uses flattened array (this is more efficient)
hazard_arr = np.zeros(shape=(self.cfg.h, self.cfg.w), dtype=bool)
if self.cfg.init_hazards is not None:
for hazard_tile in self.cfg.init_hazards:
hazard_arr[hazard_tile[1], hazard_tile[0]] = True
hazards_p = hazard_arr.ctypes.data_as(ct.POINTER(ct.c_bool))
# c++ call
self.state_p = CPP_LIB.lib.init_cpp(
self.cfg.w,
self.cfg.h,
self.cfg.num_players,
self.cfg.min_food,
self.cfg.food_spawn_chance,
self.cfg.init_turns_played,
spawn_snakes_randomly,
body_len_p,
max_body_len,
snake_pos_p,
num_init_food,
food_pos_p,
food_spawn_turns_p,
snake_alive_p,
snake_health_p,
snake_len_p,
snake_max_health_p,
self.cfg.wrapped,
self.cfg.royale,
self.cfg.shrink_n_turns,
self.cfg.hazard_damage,
hazards_p,
)
[docs]
def reset_saved_properties(self):
"""Clear all intra-turn caches (observations, available actions, player lists).
Called automatically after each :meth:`step` and :meth:`reset`. Only
call this manually if you mutate the game state externally.
"""
self.obs_dict: dict[int, np.ndarray] = dict()
self.available_actions_save: dict[int, list[int]] = dict()
self.players_at_turn_save = None
self.players_at_turn_last = None
self.players_alive_save = None
self.players_alive_last = None
[docs]
def get_obs_shape(self, never_flatten=False) -> tuple[int, ...]:
"""Return the shape of a single-player observation tensor.
Args:
never_flatten: When ``True``, return the spatial shape
``(width, height, channels)`` even if the encoding config has
``flatten=True``.
Returns:
Tuple of ints describing the observation shape. Returns a
1-element tuple ``(n,)`` when the config requests flattening (and
*never_flatten* is ``False``), otherwise ``(width, height, channels)``.
Raises:
ValueError: If the game has been closed.
"""
if self.is_closed:
raise ValueError("Cannot call function on closed game")
# number layers
num_enemies = 1 if self.cfg.ec.compress_enemies else (self.num_players - 1)
offset = num_layers_general(self.cfg.ec) + layers_per_player(self.cfg.ec)
z_dim = offset + num_enemies * layers_per_enemy(self.cfg.ec)
# width and height
width = self.cfg.w
height = self.cfg.h
if self.cfg.ec.centered:
width = 2 * self.cfg.w - 1
height = 2 * self.cfg.h - 1
elif not self.cfg.wrapped: # wrapped does not have a border
# +1 on every side for border of field
width = self.cfg.w + 2
height = self.cfg.h + 2
# number of dim
if self.cfg.ec.flatten and not never_flatten:
dim = width * height * z_dim
return tuple(
[
dim,
]
)
return width, height, z_dim
[docs]
def step(self, actions: tuple[int, ...]) -> tuple[np.ndarray, bool, dict]:
"""Advance the game by one turn.
Args:
actions: Joint action tuple — one action per player currently at
turn, in the order returned by :meth:`players_at_turn`.
Actions must come from :meth:`available_joint_actions`.
Returns:
A 3-tuple ``(rewards, done, info)`` where *rewards* is a
``float64`` array of shape ``(num_players,)``, *done* is ``True``
when the game has reached a terminal state, and *info* is an empty
dict (reserved for future use).
Raises:
Exception: If the game is already in a terminal state.
ValueError: If the length of *actions* does not match the number
of players at turn, or if *actions* is not a legal joint action.
"""
if self.is_terminal():
raise Exception("Cannot call step on terminal state")
if len(actions) != self.num_players_at_turn():
raise ValueError(f"Invalid action length: {actions}")
if actions not in self.available_joint_actions():
raise ValueError(f"Calling step with non-legal actions: {actions}")
reward, done, info = self._step(actions)
self._cum_rewards += reward
self._last_actions = actions
self.turns_played += 1
return reward, done, info
def _step(
self,
actions: tuple[int, ...],
) -> tuple[np.ndarray, bool, dict]:
# test if actions are actually legal to perform
if self.is_closed:
raise ValueError("Cannot call function on closed game")
# fill actions of players not at turn with zeros
action_arr = np.zeros(shape=(self.num_players,), dtype=ct.c_int)
for idx, player in enumerate(self.players_at_turn()):
action_arr[player] = actions[idx]
# shift player alive and at turn
self.players_at_turn_last = self.players_at_turn()
self.players_alive_last = self.players_alive()
# perform step
action_p = action_arr.ctypes.data_as(ct.POINTER(ct.c_int))
CPP_LIB.lib.step_cpp(self.state_p, action_p)
# reset saved properties
self.obs_dict = dict()
self.available_actions_save = dict()
self.players_at_turn_save = None
self.players_alive_save = None
# compute return values
done = self.is_terminal()
rewards = self.reward_func(
done, self.num_players, self.players_at_turn(), self.players_at_turn_last
)
return rewards, done, {}
[docs]
def get_copy(self) -> "BattleSnakeGame":
"""Return an independent deep copy of this environment.
The copy shares the same :attr:`cfg` reference but has its own C++
game-state object, cumulative rewards, and caches. Both the original
and the copy must be closed independently.
Returns:
A new :class:`BattleSnakeGame` with identical state.
"""
cpy = self._get_copy()
cpy._last_actions = self._last_actions
cpy._cum_rewards = self._cum_rewards.copy()
cpy.turns_played = self.turns_played
return cpy
[docs]
def is_player_at_turn(self, player: int) -> bool:
"""Return whether *player* must provide an action this turn.
Args:
player: Zero-based player index.
Returns:
``True`` if the player is alive and has at least one legal action.
"""
return player in self.players_at_turn()
@property
def num_players(self) -> int:
"""Total number of players (snakes) in this game, including dead ones."""
return self.cfg.num_players
[docs]
def num_players_at_turn(self) -> int:
"""Return the number of players who must act this turn."""
return len(self.players_at_turn())
[docs]
def players_not_alive(self) -> list[int]:
"""Return the indices of all dead (eliminated) players.
Returns:
List of player indices that are no longer alive.
"""
result = set(range(self.num_players)) - set(self.players_alive())
return list(result)
[docs]
def num_players_alive(self) -> int:
"""Return the number of players that are still alive."""
return len(self.players_alive())
[docs]
def is_player_alive(self, player: int) -> bool:
"""Return whether *player* is still alive.
Args:
player: Zero-based player index.
"""
return player in self.players_alive()
[docs]
def get_last_actions(self) -> Optional[tuple[int, ...]]:
"""Return the joint action that was passed to the most recent :meth:`step`.
Returns:
Tuple of action indices, or ``None`` if no step has been taken yet.
"""
return self._last_actions
[docs]
def set_last_actions(self, last_actions: Optional[tuple[int, ...]]):
"""Override the recorded last-step joint action.
Args:
last_actions: Tuple of action indices, or ``None`` to clear.
"""
self._last_actions = last_actions
[docs]
def get_cum_rewards(self) -> np.ndarray:
"""Return the cumulative rewards accumulated since the last :meth:`reset`.
Returns:
Float array of shape ``(num_players,)``.
"""
return self._cum_rewards
[docs]
def set_cum_rewards(self, cum_rewards: np.ndarray):
"""Override the cumulative rewards array.
Args:
cum_rewards: Float array of shape ``(num_players,)``.
"""
self._cum_rewards = cum_rewards
[docs]
def play_random_steps(self, steps: int):
"""Advance the game by up to *steps* turns using uniformly random actions.
Stops early if the game reaches a terminal state before all steps are
consumed.
Args:
steps: Maximum number of turns to play.
"""
rndm = random.Random()
while (not self.is_terminal()) and steps > 0:
steps -= 1
self.step(rndm.choice(self.available_joint_actions()))
[docs]
def get_last_action(self) -> Optional[tuple[int, ...]]:
"""Alias for :meth:`get_last_actions`."""
return self._last_actions
def _get_copy(self) -> "BattleSnakeGame":
# clone the c++ env and initialize it on python side
if self.is_closed:
raise ValueError("Cannot call function on closed game")
state_p2 = CPP_LIB.lib.clone_cpp(self.state_p)
cpy = BattleSnakeGame(
cfg=self.cfg,
state_p=state_p2,
)
# copy properties
cpy.available_actions_save = self.available_actions_save.copy()
cpy.obs_dict = self.obs_dict.copy()
if self.players_at_turn_save is not None:
cpy.players_at_turn_save = self.players_at_turn_save.copy()
if self.players_at_turn_last is not None:
cpy.players_at_turn_last = self.players_at_turn_last.copy()
if self.players_alive_save is not None:
cpy.players_alive_save = self.players_alive_save.copy()
if self.players_alive_last is not None:
cpy.players_alive_last = self.players_alive_last.copy()
return cpy
[docs]
def close(self):
"""Free the underlying C++ game-state object.
Must be called when the environment is no longer needed to avoid
memory leaks. Calling any other method after ``close()`` will raise a
``ValueError``.
"""
CPP_LIB.lib.close_cpp(self.state_p)
self.is_closed = True
[docs]
def reset(self):
"""Reinitialise the game to its starting state.
Resets cumulative rewards, turn counter, and last actions, then
re-creates the C++ game state from the original config.
"""
self._cum_rewards = np.zeros(shape=(self.num_players,), dtype=float)
self._last_actions = None
self.turns_played = 0
self._reset()
def _reset(self):
if self.is_closed:
raise ValueError("Cannot call function on closed game")
CPP_LIB.lib.close_cpp(self.state_p)
self._init_cpp()
self.reset_saved_properties()
[docs]
def available_actions(self, player: int) -> list[int]:
"""Return the legal action indices for *player* in the current state.
Results are cached for the lifetime of the current turn.
Args:
player: Zero-based player index.
Returns:
List of legal action indices. Returns an empty list if the player
is dead.
Raises:
ValueError: If the game is closed or *player* is out of range.
"""
if self.is_closed:
raise ValueError("Cannot call function on closed game")
if player < 0 or player >= self.cfg.num_players:
raise ValueError(f"Snake index out of range: {player}")
if not self.is_player_alive(player):
return [] # what is dead cannot move
if self.cfg.all_actions_legal:
return [0, 1, 2, 3]
if (
player in self.available_actions_save
): # use saved actions from last function call
return self.available_actions_save[player]
# ask c++ lib what actions are legal
legal_actions = np.zeros(shape=(self.cfg.num_actions,), dtype=ct.c_int)
CPP_LIB.lib.actions_cpp(self.state_p, player, legal_actions)
action_list = []
for a in range(self.cfg.num_actions):
if legal_actions[a]:
action_list.append(a)
self.available_actions_save[player] = action_list
return action_list
[docs]
def players_at_turn(self) -> list[int]:
"""Return the indices of all players that must act this turn.
A player is at turn if they are alive and have at least one legal
action. Results are cached until the next :meth:`step`.
Returns:
Sorted list of player indices.
Raises:
ValueError: If the game is closed.
"""
if self.is_closed:
raise ValueError("Cannot call function on closed game")
if self.players_at_turn_save is None:
# only snakes with available actions are at turn
self.players_at_turn_save = []
for player in range(self.num_players):
if self.is_player_alive(player) and self.available_actions(player):
self.players_at_turn_save.append(player)
return self.players_at_turn_save
[docs]
def players_alive(self) -> list[int]:
"""Return the indices of all living players.
Results are cached until the next :meth:`step`.
Returns:
List of player indices that are still alive.
Raises:
ValueError: If the game is closed.
"""
if self.is_closed:
raise ValueError("Cannot call function on closed game")
# call c++
if self.players_alive_save is None:
res_arr = np.zeros(shape=(self.num_players,), dtype=bool)
res_p = res_arr.ctypes.data_as(ct.POINTER(ct.c_bool))
CPP_LIB.lib.alive_cpp(self.state_p, res_p)
self.players_alive_save = []
for player in range(self.num_players):
if res_arr[player]:
self.players_alive_save.append(player)
return self.players_alive_save
[docs]
def player_lengths(self) -> list[int]:
"""Return the length (number of segments) of each snake.
Returns:
List of length ``num_players``. Dead snakes retain their last
known length value from C++.
Raises:
ValueError: If the game is closed.
"""
if self.is_closed:
raise ValueError("Cannot call function on closed game")
res_arr = np.zeros(shape=(self.num_players,), dtype=ct.c_int)
res_p = res_arr.ctypes.data_as(ct.POINTER(ct.c_int))
CPP_LIB.lib.snake_length_cpp(self.state_p, res_p)
return list(res_arr)
[docs]
def player_healths(self) -> list[int]:
"""Return the current health value of each snake.
Returns:
List of health values of length ``num_players``.
Raises:
ValueError: If the game is closed.
"""
if self.is_closed:
raise ValueError("Cannot call function on closed game")
res_arr = np.zeros(shape=(self.num_players,), dtype=ct.c_int)
res_p = res_arr.ctypes.data_as(ct.POINTER(ct.c_int))
CPP_LIB.lib.snake_health_cpp(self.state_p, res_p)
return list(res_arr)
[docs]
def player_pos(
self, player: int
) -> list[tuple[int, int]]: # returns list of length BODY_LEN != SNAKE_LEN
"""Return the body positions of *player*'s snake as ``(x, y)`` tuples.
The first element is the head. Note that the body array length
(``BODY_LEN``) may differ from the logical snake length
(``SNAKE_LEN``) during the turn the snake just ate food.
This method calls into C++ on every invocation and should only be
used for debugging or infrequent queries.
Args:
player: Zero-based player index.
Returns:
List of ``(x, y)`` coordinate pairs, head first.
Raises:
ValueError: If the game is closed.
"""
if self.is_closed:
raise ValueError("Cannot call function on closed game")
# this is inefficient and should only be used for debugging
body_len = CPP_LIB.lib.snake_body_length_cpp(self.state_p, player)
res_arr = np.zeros(shape=(body_len, 2), dtype=ct.c_int)
CPP_LIB.lib.snake_pos_cpp(
self.state_p, player, res_arr.ctypes.data_as(ct.POINTER(ct.c_int))
)
res_list = []
for idx in range(res_arr.shape[0]):
res_list.append((res_arr[idx, 0], res_arr[idx, 1]))
return res_list
[docs]
def all_player_pos(self) -> dict[int, list[tuple[int, int]]]:
"""Return body positions for every snake.
This method calls into C++ for each alive snake and should only be
used for debugging or infrequent queries.
Returns:
Dict mapping player index to a list of ``(x, y)`` tuples (head
first). Dead players map to an empty list.
Raises:
ValueError: If the game is closed.
"""
if self.is_closed:
raise ValueError("Cannot call function on closed game")
# this is inefficient and should only be used for debugging
res_dict = {}
for player in range(self.num_players):
if self.is_player_alive(player):
res_dict[player] = self.player_pos(player)
else:
res_dict[player] = []
return res_dict
[docs]
def num_food(self) -> int:
"""Return the number of food items currently on the board.
Raises:
ValueError: If the game is closed.
"""
if self.is_closed:
raise ValueError("Cannot call function on closed game")
return CPP_LIB.lib.num_food_cpp(self.state_p)
[docs]
def get_hazards(self) -> np.ndarray:
"""Return a boolean array indicating hazard tiles.
Returns:
Bool array of shape ``(h, w)`` — ``True`` where a hazard tile
exists (e.g. the shrinking ring in royale mode).
Raises:
ValueError: If the game is closed.
"""
if self.is_closed:
raise ValueError("Cannot call function on closed game")
arr = np.zeros(shape=(self.cfg.w, self.cfg.h), dtype=bool)
arr_p = arr.ctypes.data_as(ct.POINTER(ct.c_bool))
CPP_LIB.lib.hazards_cpp(self.state_p, arr_p)
return arr.T
[docs]
def food_pos(self) -> np.ndarray:
"""Return the coordinates of all food items on the board.
Returns:
Integer array of shape ``(num_food, 2)`` where each row is
``[x, y]``.
Raises:
ValueError: If the game is closed.
"""
# returns array of shape (num_food, 2)
if self.is_closed:
raise ValueError("Cannot call function on closed game")
n = self.num_food()
res_arr = np.zeros(shape=(n, 2), dtype=ct.c_int)
CPP_LIB.lib.food_pos_cpp(
self.state_p, res_arr.ctypes.data_as(ct.POINTER(ct.c_int))
)
return res_arr
[docs]
def food_spawn_turns(self) -> np.ndarray:
"""Return the turn each food item spawned, in the same order as food_pos().
Returns:
Integer array of shape ``(num_food,)`` where each element is the
turn the corresponding food item first appeared on the board.
Raises:
ValueError: If the game is closed.
"""
if self.is_closed:
raise ValueError("Cannot call function on closed game")
n = self.num_food()
res_arr = np.zeros(shape=(n,), dtype=ct.c_int)
CPP_LIB.lib.food_spawn_turns_cpp(
self.state_p, res_arr.ctypes.data_as(ct.POINTER(ct.c_int))
)
return res_arr
[docs]
def is_terminal(self) -> bool:
"""Return whether the game has ended.
The game is terminal when fewer than two players are at turn (in a
multi-player game) or when no player is at turn (in a single-player
game).
Raises:
ValueError: If the game is closed.
"""
if self.is_closed:
raise ValueError("Cannot call function on closed game")
# a game has ended if no / only the last player alive is at turn
if len(self.available_joint_actions()) == 0:
return True
if self.num_players == 1:
return self.num_players_at_turn() == 0
else:
return self.num_players_at_turn() <= 1
[docs]
def area_control(
self,
weight: float = 1.0, # weight of normal tile
food_weight: float = 1.0,
hazard_weight: float = 1.0,
food_in_hazard_weight: float = 1.0,
) -> dict[str, np.ndarray]:
"""
Args:
weight (): weight of normal tile
food_weight (): weight of food tile
hazard_weight (): weight of hazard tile
food_in_hazard_weight (): weight of hazard tile that contains food
Returns:
Dictionary of:
- area control for each player
- food distance for each player (w+h if not reachable)
- tail distance for each player (w+h if not reachable)
- bool array indicating if tail is reachable for each player
- bool array indicating if food is reachable for each player
"""
if self.is_closed:
raise ValueError("Cannot call function on closed game")
ac, fd, td, tr, fr = CPP_LIB.get_area_control(
self.num_players,
self.state_p,
weight,
food_weight,
hazard_weight,
food_in_hazard_weight,
)
res_dict = {
"area_control": ac,
"food_distance": fd,
"tail_distance": td,
"tail_reachable": tr,
"food_reachable": fr,
}
return res_dict
[docs]
def render(self):
"""Print an ASCII representation of the current board to stdout.
Raises:
ValueError: If the game is closed.
"""
if self.is_closed:
raise ValueError("Cannot call function on closed game")
str_repr = self.get_str_repr()
print(str_repr)
[docs]
def get_str_repr(self) -> str:
"""Return an ASCII string representation of the current board.
Returns:
Multi-line string visualising the board (snakes, food, hazards).
Raises:
ValueError: If the game is closed.
"""
if self.is_closed:
raise ValueError("Cannot call function on closed game")
arr = ct.create_string_buffer(self.cfg.w * self.cfg.h * 3)
CPP_LIB.lib.str_cpp(self.state_p, arr)
str_repr = arr.value.decode("utf-8")
return str_repr
def _get_cpp_encoding(
self,
player: int,
temperatures: Optional[list[float]],
single_temperature: Optional[bool],
):
obs_arr = np.zeros(
shape=self.get_obs_shape(never_flatten=True), dtype=np.float32
)
obs_p = obs_arr.ctypes.data_as(ct.POINTER(ct.c_float))
t_arr = np.asarray(temperatures, dtype=ct.c_float)
t_p = t_arr.ctypes.data_as(ct.POINTER(ct.c_float))
CPP_LIB.lib.custom_encode_cpp(
self.state_p,
obs_p,
self.cfg.ec.include_current_food,
self.cfg.ec.include_next_food,
self.cfg.ec.include_board,
self.cfg.ec.include_number_of_turns,
self.cfg.ec.compress_enemies,
player,
self.cfg.ec.include_snake_body_as_one_hot,
self.cfg.ec.include_snake_body,
self.cfg.ec.include_snake_head,
self.cfg.ec.include_snake_tail,
self.cfg.ec.include_snake_health,
self.cfg.ec.include_snake_length,
self.cfg.ec.centered,
self.cfg.ec.include_distance_map,
self.cfg.ec.include_area_control,
self.cfg.ec.include_food_distance,
self.cfg.ec.include_hazards,
self.cfg.ec.include_tail_distance,
self.cfg.ec.include_num_food_on_board,
self.cfg.ec.fixed_food_spawn_chance,
self.cfg.ec.temperature_input,
self.cfg.ec.single_temperature_input
if single_temperature is None
else single_temperature,
t_p,
)
return obs_arr
def _get_custom_state_encoding(
self,
player: int,
perm: Optional[np.ndarray],
temperatures: Optional[list[float]],
single_temperature: Optional[bool],
) -> np.ndarray:
if self.is_closed:
raise ValueError("Cannot call function on closed game")
player_list = self.players_at_turn()
# if the player is dead we give him an arbitrary encoding (just to match shapes, it is discarded later)
if player not in player_list:
raise Exception("Cannot get an encoding for a dead snake (duh)")
# check if we already computed an encoding for this player
if player not in self.obs_dict or temperatures is not None:
obs_arr = self._get_cpp_encoding(
player=player,
temperatures=temperatures,
single_temperature=single_temperature,
)
self.obs_dict[player] = obs_arr.copy()
else:
obs_arr = np.copy(self.obs_dict[player])
# rotate encodings of enemy players according to permutation
if (not self.cfg.ec.compress_enemies) and self.cfg.num_players > 2:
if perm is None:
perm = np.random.permutation(self.cfg.num_players - 1)
left, right = [], []
offset = num_layers_general(self.cfg.ec) + layers_per_player(self.cfg.ec)
num_enemy_layer = layers_per_enemy(self.cfg.ec)
for sub_layer in range(num_enemy_layer):
for enemy in range(0, self.cfg.num_players - 1):
left.append(offset + enemy * num_enemy_layer + sub_layer)
for idx in perm:
right.append(offset + idx * num_enemy_layer + sub_layer)
obs_arr[:, :, left] = obs_arr[:, :, right]
return obs_arr
def __eq__(self, other: Any):
"""Return whether this game state is identical to *other*.
Delegates the comparison to the C++ engine.
Args:
other: Object to compare against.
Returns:
``True`` if *other* is a :class:`BattleSnakeGame` with the same
C++ state, ``False`` otherwise.
Raises:
ValueError: If this game is closed.
"""
if not isinstance(other, BattleSnakeGame):
return False
if self.is_closed:
raise ValueError("Cannot call function on closed game")
equal = CPP_LIB.lib.equals_cpp(self.state_p, other.state_p)
return equal
[docs]
def get_symmetry_count(self):
"""Return the total number of distinct board symmetries.
Symmetries combine 8 spatial transformations (4 rotations × 2 flips)
with permutations of enemy player slots. When ``compress_enemies`` is
enabled there is only one enemy slot, so only 8 symmetries exist.
Otherwise there are ``8 × (num_players − 1)!`` symmetries.
Returns:
Integer count of symmetries.
Raises:
ValueError: If the game is closed.
"""
if self.is_closed:
raise ValueError("Cannot call function on closed game")
if self.cfg.ec.compress_enemies:
return 8
else:
return 8 * math.factorial(self.cfg.num_players - 1)
[docs]
def get_obs(
self,
symmetry: Optional[int] = 0,
) -> tuple[
np.ndarray,
dict[int, int],
dict[int, int],
]:
"""Return observation tensors for all players currently at turn.
The observations are stacked along axis 0 in :meth:`players_at_turn`
order. A symmetry transformation (rotation and/or flip, plus an
optional enemy-slot permutation) may be applied.
Args:
symmetry: Integer index selecting the transformation to apply.
``0`` is the identity. ``None`` samples a transformation
uniformly at random. Valid range is
``[0, get_symmetry_count())``.
Returns:
A 3-tuple ``(obs, perm, inv_perm)`` where:
- *obs* — float32 array of shape
``(num_players_at_turn, width, height, channels)`` or
``(num_players_at_turn, width * height * channels)`` when
flattening is enabled.
- *perm* — dict mapping original action indices to transformed
action indices.
- *inv_perm* — inverse of *perm*.
Raises:
ValueError: If the game is closed or already in a terminal state.
"""
temperatures = None
single_temperature = None
if self.is_closed:
raise ValueError("Cannot call function on closed game")
if self.is_terminal():
raise ValueError("Cannot get encoding on terminal state")
if not self.cfg.ec.temperature_input and temperatures is not None:
raise ValueError("Cannot process temperatures if ec specifies no input")
if self.cfg.ec.temperature_input:
single_temp = (
self.cfg.ec.single_temperature_input
if single_temperature is None
else single_temperature
)
if temperatures is None:
raise ValueError("Need temperatures to generate encoding")
if single_temp and len(temperatures) != 1:
raise ValueError(
"Cannot process multiple temperatures if single temperature input specified"
)
if not single_temp and len(temperatures) != self.num_players:
raise ValueError(f"Invalid temperature length: {temperatures}")
if symmetry is None:
symmetry = np.random.randint(self.get_symmetry_count())
# last 3 bits of symmetry represent rotation and flip
sym_rot = symmetry % 8
flip = sym_rot % 2 == 1 # if symmetry is odd then mirror it
num_rot = math.floor(sym_rot / 2)
# symmetry except last 3 bit describes player permutation
sym_player = math.floor(symmetry / 8)
perm = int_to_perm(sym_player, self.num_players - 1)
# get encoding and stack them
obs_list = []
for player in self.players_at_turn():
obs = self._get_custom_state_encoding(
player=player,
perm=perm,
temperatures=temperatures,
single_temperature=single_temperature,
)
obs_list.append(obs)
obs = np.stack(obs_list)
# apply rotation and flip
obs_res = np.rot90(obs, k=num_rot, axes=(-3, -2))
if flip:
obs_res = np.flip(obs_res, axis=-2)
# calculate action mapping by using offset: counterclockwise rotation of 90 is -1
# flip is offset of 2 (left -> right, up -> down,...)
# original: UP=0, RIGHT=1, DOWN=2, LEFT=3
action_offset: int = -num_rot # + 2*flip
perm, inv_perm = dict(), dict()
for a in range(self.cfg.num_actions):
a_new = (a + action_offset) % self.cfg.num_actions
if flip:
if a_new == 2:
a_new = 0
elif a_new == 0:
a_new = 2
perm[a] = a_new
inv_perm[a_new] = a
# sanity check
if obs_res.shape[0] != self.num_players_at_turn():
raise Exception("Unknown Exception with observation shape")
if self.cfg.ec.flatten:
obs_res = obs_res.reshape(self.num_players_at_turn(), -1)
result = obs_res.copy() # necessary because of negative stride
# view radius
if self.cfg.view_radius is not None:
# Pre-compute newly-spawned food once (same for all players)
new_food_pos: np.ndarray | None = None
if "current_food" in self.layer_explanation:
all_spawn_turns = self.food_spawn_turns()
new_food_pos = self.food_pos()[all_spawn_turns == self.turns_played]
masks = []
for p_self in self.players_at_turn():
scaled_distance = result[
p_self, :, :, self.layer_explanation["distance_map"]
]
distance_map = scaled_distance * (self.cfg.w + self.cfg.h - 2)
cur_mask = (distance_map <= self.cfg.view_radius).astype(float)
masks.append(cur_mask)
if "current_food" in self.layer_explanation:
food_idx = self.layer_explanation["current_food"]
cur_layer = result[p_self, :, :, food_idx]
result[p_self, :, :, food_idx] = cur_layer * cur_mask
# Food that spawned this turn is always visible for one step
if new_food_pos is not None and len(new_food_pos) > 0:
spawn_mask = self._new_food_obs_mask(
new_food_pos, p_self, num_rot, flip
)
result[p_self, :, :, food_idx] = np.maximum(
result[p_self, :, :, food_idx], spawn_mask
)
for p in range(
1, self.num_players
): # do not restrict view on own player
if f"{p}_snake_health" in self.layer_explanation:
result[
p_self, :, :, self.layer_explanation[f"{p}_snake_health"]
] = 0
if f"{p}_snake_length" in self.layer_explanation:
result[
p_self, :, :, self.layer_explanation[f"{p}_snake_length"]
] = 0
if f"{p}_snake_tail_distance" in self.layer_explanation:
result[
p_self,
:,
:,
self.layer_explanation[f"{p}_snake_tail_distance"],
] = 0
if f"{p}_snake_body" in self.layer_explanation:
cur_layer = result[
p_self, :, :, self.layer_explanation[f"{p}_snake_body"]
]
result[
p_self, :, :, self.layer_explanation[f"{p}_snake_body"]
] = cur_layer * cur_mask
if f"{p}_snake_body_as_one_hot" in self.layer_explanation:
cur_layer = result[
p_self,
:,
:,
self.layer_explanation[f"{p}_snake_body_as_one_hot"],
]
result[
p_self,
:,
:,
self.layer_explanation[f"{p}_snake_body_as_one_hot"],
] = cur_layer * cur_mask
if f"{p}_snake_head" in self.layer_explanation:
cur_layer = result[
p_self, :, :, self.layer_explanation[f"{p}_snake_head"]
]
result[
p_self, :, :, self.layer_explanation[f"{p}_snake_head"]
] = cur_layer * cur_mask
if f"{p}_snake_tail" in self.layer_explanation:
cur_layer = result[
p_self, :, :, self.layer_explanation[f"{p}_snake_tail"]
]
result[
p_self, :, :, self.layer_explanation[f"{p}_snake_tail"]
] = cur_layer * cur_mask
# make mask layer
if self.cfg.ec.include_view_mask:
mask_arr = np.asarray(masks)
result = np.concatenate((result, mask_arr[:, :, :, None]), axis=-1)
return result, perm, inv_perm
def _new_food_obs_mask(
self,
new_food_pos: np.ndarray,
player: int,
num_rot: int,
flip: bool,
) -> np.ndarray:
"""Build a 2D mask with 1.0 at newly-spawned food positions in observation space.
Replicates the coordinate transformation applied by the C++ encoder (centering
+ optional tiling for wrapped boards) followed by the same rot90/flip applied
to the full observation, so the result aligns with the food layer in ``result``.
"""
if self.cfg.ec.centered:
pre_w = 2 * self.cfg.w - 1
pre_h = 2 * self.cfg.h - 1
head_x, head_y = self.player_pos(player)[0]
x_off = self.cfg.w - head_x - 1
y_off = self.cfg.h - head_y - 1
else:
if self.cfg.wrapped:
pre_w, pre_h = self.cfg.w, self.cfg.h
x_off, y_off = 0, 0
else:
pre_w, pre_h = self.cfg.w + 2, self.cfg.h + 2
x_off, y_off = 1, 1
mask = np.zeros((pre_w, pre_h), dtype=np.float32)
for pos in new_food_pos:
fx, fy = int(pos[0]), int(pos[1])
xi, yi = fx + x_off, fy + y_off
candidates = [(xi, yi)]
if self.cfg.wrapped and self.cfg.ec.centered:
for dx in (-self.cfg.w, self.cfg.w):
candidates.append((xi + dx, yi))
for dy in (-self.cfg.h, self.cfg.h):
candidates.append((xi, yi + dy))
for dx in (-self.cfg.w, self.cfg.w):
for dy in (-self.cfg.h, self.cfg.h):
candidates.append((xi + dx, yi + dy))
for cx, cy in candidates:
if 0 <= cx < pre_w and 0 <= cy < pre_h:
mask[cx, cy] = 1.0
mask = np.rot90(mask, k=num_rot, axes=(0, 1))
if flip:
mask = np.flip(mask, axis=0)
return mask
def __del__(self):
if not self.is_closed:
self.close()
self.is_closed = True
[docs]
def get_bool_board_matrix(self) -> np.ndarray:
"""Return the board occupancy matrix (constrictor mode only).
Returns:
Int8 array of shape ``(w, h)`` encoding which tiles are occupied.
Raises:
ValueError: If the game is not configured for constrictor mode.
"""
if not self.cfg.constrictor:
raise ValueError("Board matrix currently only supported in constrictor")
arr = np.zeros((self.cfg.w, self.cfg.h), dtype=ct.c_int8)
arr_p = arr.ctypes.data_as(ct.POINTER(ct.c_int8))
CPP_LIB.lib.char_game_matrix_cpp(self.state_p, arr_p)
return arr
[docs]
def get_state(self) -> BattleSnakeState:
"""Capture the current game state as a serialisable snapshot.
The returned :class:`~hisss.game.state.BattleSnakeState` can be
passed to :meth:`set_state` (on any compatible environment) to restore
this exact position.
Returns:
A :class:`~hisss.game.state.BattleSnakeState` describing the
current board.
"""
snakes_alive = self.players_alive()
snakes_alive_bool = [i in snakes_alive for i in range(self.num_players)]
player_pos = {i: self.player_pos(i) for i in range(self.num_players)}
food_pos_arr = self.food_pos()
food_spawn_turns_arr = self.food_spawn_turns()
food_list = [
[food_pos_arr[i, 0], food_pos_arr[i, 1]]
for i in range(food_pos_arr.shape[0])
]
snake_health = self.player_healths()
snake_len = self.player_lengths()
elimination_events: dict[int, EliminationEvent] = {}
for i in range(self.num_players):
if not snakes_alive_bool[i]:
cause_int = CPP_LIB.lib.snake_elim_cause_cpp(self.state_p, i)
killer_int = CPP_LIB.lib.snake_elim_killer_cpp(self.state_p, i)
elim_turn = CPP_LIB.lib.snake_elim_turn_cpp(self.state_p, i)
cause_str = CAUSE_INT_TO_STR.get(cause_int)
if cause_str is not None:
by_str = f"snake-{killer_int}" if killer_int >= 0 else None
elimination_events[i] = EliminationEvent(
cause=cause_str, turn=elim_turn, by=by_str
)
state = BattleSnakeState(
snakes_alive=snakes_alive_bool,
snake_pos=player_pos,
food_pos=food_list,
snake_health=snake_health,
snake_len=snake_len,
turn=self.turns_played,
food_spawn_turns=list(food_spawn_turns_arr),
elimination_events=elimination_events if elimination_events else None,
)
return state
[docs]
def set_state(self, state: BattleSnakeState):
"""Restore the game to a previously captured state.
Resets cumulative rewards and last actions, then re-initialises the
C++ game state from *state*.
Args:
state: Snapshot obtained from :meth:`get_state` (on any
compatible environment with the same config).
"""
self._cum_rewards = np.zeros(shape=(self.cfg.num_players,), dtype=float)
self._last_actions = None
self._set_state(state)
def _set_state(self, state: BattleSnakeState):
# close old state pointer
CPP_LIB.lib.close_cpp(self.state_p)
# find the longest body, this determines the array shape
snake_pos = {} # we need a separate dict to not alter the config object
body_lengths = []
for s in range(self.cfg.num_players):
cur_snake_pos = []
for pos in state.snake_pos[s]:
cur_snake_pos.append((pos[0], pos[1]))
cur_snake_pos = list(dict.fromkeys(cur_snake_pos)) # remove duplicates
snake_pos[s] = cur_snake_pos
body_lengths.append(len(cur_snake_pos))
max_body_len = max(body_lengths)
body_len_arr = np.asarray(body_lengths, dtype=ct.c_int)
body_len_p = body_len_arr.ctypes.data_as(ct.POINTER(ct.c_int))
snake_pos_arr = (
np.zeros(shape=(self.cfg.num_players, max_body_len, 2), dtype=ct.c_int) - 1
)
for s in range(self.cfg.num_players): # convert dictionary to numpy array
for i, pos in enumerate(snake_pos[s]):
snake_pos_arr[s, i, 0] = pos[0]
snake_pos_arr[s, i, 1] = pos[1]
snake_pos_p = snake_pos_arr.ctypes.data_as(ct.POINTER(ct.c_int))
# snake length
snake_len_arr = np.asarray(state.snake_len, dtype=ct.c_int)
snake_len_p = snake_len_arr.ctypes.data_as(ct.POINTER(ct.c_int))
# snakes alive
snake_alive_arr = np.asarray(state.snakes_alive, dtype=bool)
snake_alive_p = snake_alive_arr.ctypes.data_as(ct.POINTER(ct.c_bool))
# health
snake_health_arr = np.asarray(state.snake_health, dtype=ct.c_int)
snake_health_p = snake_health_arr.ctypes.data_as(ct.POINTER(ct.c_int))
snake_max_health_arr = np.asarray(self.cfg.max_snake_health, dtype=ct.c_int)
snake_max_health_p = snake_max_health_arr.ctypes.data_as(ct.POINTER(ct.c_int))
# food
num_init_food = len(state.food_pos)
np_arr = np.asarray(state.food_pos, dtype=ct.c_int)
food_pos_p = np_arr.ctypes.data_as(ct.POINTER(ct.c_int))
if state.food_spawn_turns is not None:
fst_arr = np.asarray(state.food_spawn_turns, dtype=ct.c_int)
else:
fst_arr = np.full(num_init_food, state.turn, dtype=ct.c_int)
food_spawn_turns_p = fst_arr.ctypes.data_as(ct.POINTER(ct.c_int))
# hazards, we need to transpose because cpp uses flattened array (this is more efficient)
hazard_arr = np.zeros(shape=(self.cfg.h, self.cfg.w), dtype=bool)
if self.cfg.init_hazards is not None:
for hazard_tile in self.cfg.init_hazards:
hazard_arr[hazard_tile[1], hazard_tile[0]] = True
hazards_p = hazard_arr.ctypes.data_as(ct.POINTER(ct.c_bool))
self.state_p = CPP_LIB.lib.init_cpp(
self.cfg.w,
self.cfg.h,
self.cfg.num_players,
self.cfg.min_food,
self.cfg.food_spawn_chance,
state.turn,
False,
body_len_p,
max_body_len,
snake_pos_p,
num_init_food,
food_pos_p,
food_spawn_turns_p,
snake_alive_p,
snake_health_p,
snake_len_p,
snake_max_health_p,
self.cfg.wrapped,
self.cfg.royale,
self.cfg.shrink_n_turns,
self.cfg.hazard_damage,
hazards_p,
)
if state.elimination_events:
for snake_id, event in state.elimination_events.items():
cause_int = CAUSE_STR_TO_INT.get(event.cause, 0)
killer_int = int(event.by.split("-")[1]) if event.by is not None else -1
CPP_LIB.lib.set_elim_info_cpp(
self.state_p, snake_id, cause_int, killer_int, event.turn
)
self.reset_saved_properties()