2020-08-01 03:52:59 -07:00
|
|
|
import json
|
|
|
|
|
from itertools import cycle
|
2024-08-24 21:29:39 -07:00
|
|
|
from xml.etree import ElementTree as ET
|
|
|
|
|
|
|
|
|
|
from .api import RLAPI
|
2020-08-01 03:52:59 -07:00
|
|
|
|
2024-08-22 00:18:20 -07:00
|
|
|
|
|
|
|
|
class ZeroAD:
|
|
|
|
|
def __init__(self, uri="http://localhost:6000"):
|
2020-08-01 03:52:59 -07:00
|
|
|
self.api = RLAPI(uri)
|
|
|
|
|
self.current_state = None
|
|
|
|
|
self.cache = {}
|
|
|
|
|
self.player_id = 1
|
|
|
|
|
|
2024-08-24 21:29:39 -07:00
|
|
|
def step(self, actions=None, player=None):
|
|
|
|
|
if actions is None:
|
|
|
|
|
actions = []
|
2020-08-01 03:52:59 -07:00
|
|
|
player_ids = cycle([self.player_id]) if player is None else cycle(player)
|
|
|
|
|
|
2024-09-21 11:54:24 -07:00
|
|
|
cmds = zip(player_ids, actions, strict=False)
|
2020-08-01 03:52:59 -07:00
|
|
|
cmds = ((player, action) for (player, action) in cmds if action is not None)
|
|
|
|
|
state_json = self.api.step(cmds)
|
|
|
|
|
self.current_state = GameState(json.loads(state_json), self)
|
|
|
|
|
return self.current_state
|
|
|
|
|
|
2024-08-22 00:18:20 -07:00
|
|
|
def reset(self, config="", save_replay=False, player_id=1):
|
2020-08-01 03:52:59 -07:00
|
|
|
state_json = self.api.reset(config, player_id, save_replay)
|
|
|
|
|
self.current_state = GameState(json.loads(state_json), self)
|
|
|
|
|
return self.current_state
|
|
|
|
|
|
2021-02-28 04:16:32 -08:00
|
|
|
def evaluate(self, code):
|
|
|
|
|
return self.api.evaluate(code)
|
|
|
|
|
|
2020-08-01 03:52:59 -07:00
|
|
|
def get_template(self, name):
|
|
|
|
|
return self.get_templates([name])[0]
|
|
|
|
|
|
|
|
|
|
def get_templates(self, names):
|
|
|
|
|
templates = self.api.get_templates(names)
|
2024-08-22 00:18:20 -07:00
|
|
|
return [(name, EntityTemplate(content)) for (name, content) in templates]
|
2020-08-01 03:52:59 -07:00
|
|
|
|
2024-08-24 21:29:39 -07:00
|
|
|
def update_templates(self, types=None):
|
|
|
|
|
if types is None:
|
|
|
|
|
types = []
|
|
|
|
|
all_types = list({unit.type() for unit in self.current_state.units()})
|
2020-08-01 03:52:59 -07:00
|
|
|
all_types += types
|
|
|
|
|
template_pairs = self.get_templates(all_types)
|
|
|
|
|
|
|
|
|
|
self.cache = {}
|
2024-08-22 00:18:20 -07:00
|
|
|
for name, tpl in template_pairs:
|
2020-08-01 03:52:59 -07:00
|
|
|
self.cache[name] = tpl
|
|
|
|
|
|
|
|
|
|
return template_pairs
|
|
|
|
|
|
2024-08-22 00:18:20 -07:00
|
|
|
|
|
|
|
|
class GameState:
|
2020-08-01 03:52:59 -07:00
|
|
|
def __init__(self, data, game):
|
|
|
|
|
self.data = data
|
|
|
|
|
self.game = game
|
2024-08-22 00:18:20 -07:00
|
|
|
self.mapSize = self.data["mapSize"]
|
2020-08-01 03:52:59 -07:00
|
|
|
|
2024-08-27 10:28:11 -07:00
|
|
|
def units(self, owner=None, entity_type=None):
|
2024-08-22 00:18:20 -07:00
|
|
|
def filter_fn(e):
|
|
|
|
|
return (owner is None or e["owner"] == owner) and (
|
2024-08-27 10:28:11 -07:00
|
|
|
entity_type is None or entity_type in e["template"]
|
2024-08-22 00:18:20 -07:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return [Entity(e, self.game) for e in self.data["entities"].values() if filter_fn(e)]
|
2020-08-01 03:52:59 -07:00
|
|
|
|
2024-08-27 10:28:11 -07:00
|
|
|
def unit(self, entity_id):
|
|
|
|
|
entity_id = str(entity_id)
|
2024-08-22 00:18:20 -07:00
|
|
|
return (
|
2024-08-27 10:28:11 -07:00
|
|
|
Entity(self.data["entities"][entity_id], self.game)
|
|
|
|
|
if entity_id in self.data["entities"]
|
|
|
|
|
else None
|
2024-08-22 00:18:20 -07:00
|
|
|
)
|
2020-08-01 03:52:59 -07:00
|
|
|
|
|
|
|
|
|
2024-08-22 00:18:20 -07:00
|
|
|
class Entity:
|
2020-08-01 03:52:59 -07:00
|
|
|
def __init__(self, data, game):
|
|
|
|
|
self.data = data
|
|
|
|
|
self.game = game
|
|
|
|
|
self.template = self.game.cache.get(self.type(), None)
|
|
|
|
|
|
|
|
|
|
def type(self):
|
2024-08-22 00:18:20 -07:00
|
|
|
return self.data["template"]
|
2020-08-01 03:52:59 -07:00
|
|
|
|
|
|
|
|
def id(self):
|
2024-08-22 00:18:20 -07:00
|
|
|
return self.data["id"]
|
2020-08-01 03:52:59 -07:00
|
|
|
|
|
|
|
|
def owner(self):
|
2024-08-22 00:18:20 -07:00
|
|
|
return self.data["owner"]
|
2020-08-01 03:52:59 -07:00
|
|
|
|
|
|
|
|
def max_health(self):
|
|
|
|
|
template = self.get_template()
|
2024-08-22 00:18:20 -07:00
|
|
|
return float(template.get("Health/Max"))
|
2020-08-01 03:52:59 -07:00
|
|
|
|
|
|
|
|
def health(self, ratio=False):
|
|
|
|
|
if ratio:
|
2024-08-22 00:18:20 -07:00
|
|
|
return self.data["hitpoints"] / self.max_health()
|
2020-08-01 03:52:59 -07:00
|
|
|
|
2024-08-22 00:18:20 -07:00
|
|
|
return self.data["hitpoints"]
|
2020-08-01 03:52:59 -07:00
|
|
|
|
|
|
|
|
def position(self):
|
2024-08-22 00:18:20 -07:00
|
|
|
return self.data["position"]
|
2020-08-01 03:52:59 -07:00
|
|
|
|
|
|
|
|
def get_template(self):
|
|
|
|
|
if self.template is None:
|
|
|
|
|
self.game.update_templates([self.type()])
|
|
|
|
|
self.template = self.game.cache[self.type()]
|
|
|
|
|
|
|
|
|
|
return self.template
|
|
|
|
|
|
2024-08-22 00:18:20 -07:00
|
|
|
|
|
|
|
|
class EntityTemplate:
|
2020-08-01 03:52:59 -07:00
|
|
|
def __init__(self, xml):
|
2024-08-24 21:29:39 -07:00
|
|
|
self.data = ET.fromstring(f"<Entity>{xml}</Entity>")
|
2020-08-01 03:52:59 -07:00
|
|
|
|
|
|
|
|
def get(self, path):
|
|
|
|
|
node = self.data.find(path)
|
|
|
|
|
return node.text if node is not None else None
|
|
|
|
|
|
|
|
|
|
def set(self, path, value):
|
|
|
|
|
node = self.data.find(path)
|
|
|
|
|
if node:
|
|
|
|
|
node.text = str(value)
|
|
|
|
|
|
|
|
|
|
return node is not None
|
|
|
|
|
|
|
|
|
|
def __str__(self):
|
2024-08-24 21:29:39 -07:00
|
|
|
return ET.tostring(self.data).decode("utf-8")
|