0ad/source/tools/rlclient/python/zero_ad/environment.py
Dunedan e36c6a31fe
Enable additional ruff rules
In the ruff config file added in #6954 explicitly selecting the ruff
rules to check was missed, resulting in ruff only checking a very small
subset of its available rules. That hasn't been desired, so this is the
first of a series of commits enabling more rules. In this PR all rules
whose violations can be either automatically fixed by ruff or are
trivial to fix manually get enabled. For the follow up PRs it's intended
to focus on one area of rules per PR to gradually improve the Python
code quality.
2024-08-25 06:29:39 +02:00

128 lines
3.5 KiB
Python

import json
from itertools import cycle
from xml.etree import ElementTree as ET
from .api import RLAPI
class ZeroAD:
def __init__(self, uri="http://localhost:6000"):
self.api = RLAPI(uri)
self.current_state = None
self.cache = {}
self.player_id = 1
def step(self, actions=None, player=None):
if actions is None:
actions = []
player_ids = cycle([self.player_id]) if player is None else cycle(player)
cmds = zip(player_ids, actions)
cmds = ((player, action) for (player, action) in cmds if action is not None)
state_json = self.api.step(cmds)
self.current_state = GameState(json.loads(state_json), self)
return self.current_state
def reset(self, config="", save_replay=False, player_id=1):
state_json = self.api.reset(config, player_id, save_replay)
self.current_state = GameState(json.loads(state_json), self)
return self.current_state
def evaluate(self, code):
return self.api.evaluate(code)
def get_template(self, name):
return self.get_templates([name])[0]
def get_templates(self, names):
templates = self.api.get_templates(names)
return [(name, EntityTemplate(content)) for (name, content) in templates]
def update_templates(self, types=None):
if types is None:
types = []
all_types = list({unit.type() for unit in self.current_state.units()})
all_types += types
template_pairs = self.get_templates(all_types)
self.cache = {}
for name, tpl in template_pairs:
self.cache[name] = tpl
return template_pairs
class GameState:
def __init__(self, data, game):
self.data = data
self.game = game
self.mapSize = self.data["mapSize"]
def units(self, owner=None, type=None):
def filter_fn(e):
return (owner is None or e["owner"] == owner) and (
type is None or type in e["template"]
)
return [Entity(e, self.game) for e in self.data["entities"].values() if filter_fn(e)]
def unit(self, id):
id = str(id)
return (
Entity(self.data["entities"][id], self.game) if id in self.data["entities"] else None
)
class Entity:
def __init__(self, data, game):
self.data = data
self.game = game
self.template = self.game.cache.get(self.type(), None)
def type(self):
return self.data["template"]
def id(self):
return self.data["id"]
def owner(self):
return self.data["owner"]
def max_health(self):
template = self.get_template()
return float(template.get("Health/Max"))
def health(self, ratio=False):
if ratio:
return self.data["hitpoints"] / self.max_health()
return self.data["hitpoints"]
def position(self):
return self.data["position"]
def get_template(self):
if self.template is None:
self.game.update_templates([self.type()])
self.template = self.game.cache[self.type()]
return self.template
class EntityTemplate:
def __init__(self, xml):
self.data = ET.fromstring(f"<Entity>{xml}</Entity>")
def get(self, path):
node = self.data.find(path)
return node.text if node is not None else None
def set(self, path, value):
node = self.data.find(path)
if node:
node.text = str(value)
return node is not None
def __str__(self):
return ET.tostring(self.data).decode("utf-8")