mirror of
https://gitea.wildfiregames.com/0ad/0ad
synced 2026-06-16 05:13:58 -07:00
114 lines
3.4 KiB
Python
114 lines
3.4 KiB
Python
|
|
from .api import RLAPI
|
||
|
|
import json
|
||
|
|
import math
|
||
|
|
from xml.etree import ElementTree
|
||
|
|
from itertools import cycle
|
||
|
|
|
||
|
|
class ZeroAD():
|
||
|
|
def __init__(self, uri='http://localhost:6000'):
|
||
|
|
self.api = RLAPI(uri)
|
||
|
|
self.current_state = None
|
||
|
|
self.cache = {}
|
||
|
|
self.player_id = 1
|
||
|
|
|
||
|
|
def step(self, actions=[], player=None):
|
||
|
|
player_ids = cycle([self.player_id]) if player is None else cycle(player)
|
||
|
|
|
||
|
|
cmds = zip(player_ids, actions)
|
||
|
|
cmds = ((player, action) for (player, action) in cmds if action is not None)
|
||
|
|
state_json = self.api.step(cmds)
|
||
|
|
self.current_state = GameState(json.loads(state_json), self)
|
||
|
|
return self.current_state
|
||
|
|
|
||
|
|
def reset(self, config='', save_replay=False, player_id=1):
|
||
|
|
state_json = self.api.reset(config, player_id, save_replay)
|
||
|
|
self.current_state = GameState(json.loads(state_json), self)
|
||
|
|
return self.current_state
|
||
|
|
|
||
|
|
def get_template(self, name):
|
||
|
|
return self.get_templates([name])[0]
|
||
|
|
|
||
|
|
def get_templates(self, names):
|
||
|
|
templates = self.api.get_templates(names)
|
||
|
|
return [ (name, EntityTemplate(content)) for (name, content) in templates ]
|
||
|
|
|
||
|
|
def update_templates(self, types=[]):
|
||
|
|
all_types = list(set([unit.type() for unit in self.current_state.units()]))
|
||
|
|
all_types += types
|
||
|
|
template_pairs = self.get_templates(all_types)
|
||
|
|
|
||
|
|
self.cache = {}
|
||
|
|
for (name, tpl) in template_pairs:
|
||
|
|
self.cache[name] = tpl
|
||
|
|
|
||
|
|
return template_pairs
|
||
|
|
|
||
|
|
class GameState():
|
||
|
|
def __init__(self, data, game):
|
||
|
|
self.data = data
|
||
|
|
self.game = game
|
||
|
|
self.mapSize = self.data['mapSize']
|
||
|
|
|
||
|
|
def units(self, owner=None, type=None):
|
||
|
|
filter_fn = lambda e: (owner is None or e['owner'] == owner) and \
|
||
|
|
(type is None or type in e['template'])
|
||
|
|
return [ Entity(e, self.game) for e in self.data['entities'].values() if filter_fn(e) ]
|
||
|
|
|
||
|
|
def unit(self, id):
|
||
|
|
id = str(id)
|
||
|
|
return Entity(self.data['entities'][id], self.game) if id in self.data['entities'] else None
|
||
|
|
|
||
|
|
class Entity():
|
||
|
|
|
||
|
|
def __init__(self, data, game):
|
||
|
|
self.data = data
|
||
|
|
self.game = game
|
||
|
|
self.template = self.game.cache.get(self.type(), None)
|
||
|
|
|
||
|
|
def type(self):
|
||
|
|
return self.data['template']
|
||
|
|
|
||
|
|
def id(self):
|
||
|
|
return self.data['id']
|
||
|
|
|
||
|
|
def owner(self):
|
||
|
|
return self.data['owner']
|
||
|
|
|
||
|
|
def max_health(self):
|
||
|
|
template = self.get_template()
|
||
|
|
return float(template.get('Health/Max'))
|
||
|
|
|
||
|
|
def health(self, ratio=False):
|
||
|
|
if ratio:
|
||
|
|
return self.data['hitpoints']/self.max_health()
|
||
|
|
|
||
|
|
return self.data['hitpoints']
|
||
|
|
|
||
|
|
def position(self):
|
||
|
|
return self.data['position']
|
||
|
|
|
||
|
|
def get_template(self):
|
||
|
|
if self.template is None:
|
||
|
|
self.game.update_templates([self.type()])
|
||
|
|
self.template = self.game.cache[self.type()]
|
||
|
|
|
||
|
|
return self.template
|
||
|
|
|
||
|
|
class EntityTemplate():
|
||
|
|
def __init__(self, xml):
|
||
|
|
self.data = ElementTree.fromstring(f'<Entity>{xml}</Entity>')
|
||
|
|
|
||
|
|
def get(self, path):
|
||
|
|
node = self.data.find(path)
|
||
|
|
return node.text if node is not None else None
|
||
|
|
|
||
|
|
def set(self, path, value):
|
||
|
|
node = self.data.find(path)
|
||
|
|
if node:
|
||
|
|
node.text = str(value)
|
||
|
|
|
||
|
|
return node is not None
|
||
|
|
|
||
|
|
def __str__(self):
|
||
|
|
return ElementTree.tostring(self.data).decode('utf-8')
|