diff options
Diffstat (limited to 'tools/tosa.py')
-rw-r--r-- | tools/tosa.py | 82 |
1 files changed, 4 insertions, 78 deletions
diff --git a/tools/tosa.py b/tools/tosa.py index d01d9c2..74d43d6 100644 --- a/tools/tosa.py +++ b/tools/tosa.py @@ -1,50 +1,21 @@ import re import xml.etree.ElementTree as ET -# possible shapes: shape1, [2], [N,H,W,C] -# returns (checkable, rank) -# checkable is false if shape doesn't contain [] -def get_rank_from_shape(shape): - if '[' not in shape or '[]' in shape: - return (False, -1) - # Check for fixed rank requirement [N] - m = re.match(r'\[(\d+)\]', shape) - if m: - return (True, 1) - # Check for comma separated rank descriptors, return count - m = re.match(r'\[(.*)\]', shape) - if m: - return (True, len(m.group(1).split(','))) - else: - raise RuntimeError(f'Unable to parse shape {shape}') class TOSAOperatorArgumentCategory: def __init__(self, name, profiles=None): self.name = name self.profiles = profiles -class TOSAEnum: - def __init__(self, name, description, values): - self.name = name - self.description = description - self.values = values - -class TOSALevel: - def __init__(self, name, desc, maximums): - self.name = name - self.desc = desc - self.maximums = maximums class TOSAOperatorArgument: - def __init__(self, name, description, categories, ty, elty, shape, levellimits, rank): + def __init__(self, name, description, categories, ty, shape, levellimits): self.name = name self.description = description self.categories = categories self.type = ty - self.tensor_element_type = elty self.shape = shape self.levellimits = levellimits - self.rank = rank class TOSAOperatorDataTypeSupport: @@ -72,19 +43,13 @@ class TOSASpec: def __init__(self, xmlpath): tree = ET.parse(xmlpath) self.xmlroot = tree.getroot() - self.levels = [] self.operatorgroups = [] - self.enums = [] self.__load_spec() def __load_spec(self): self.__load_version() - for level in self.xmlroot.findall("./levels/level"): - self.levels.append(self.__load_level(level)) for group in self.xmlroot.findall("./operators/operatorgroup"): self.operatorgroups.append(self.__load_operator_group(group)) - for enum in self.xmlroot.findall("./enum"): - self.enums.append(self.__load_enum(enum)) def __load_version(self): version = self.xmlroot.find("./version") @@ -96,18 +61,6 @@ class TOSASpec: else: self.version_is_draft = False - def __load_level(self, level): - name = level.get("name") - desc = level.text.strip() - maximums = { - 'MAX_RANK': level.get("max_rank"), - 'MAX_KERNEL': level.get("max_kernel"), - 'MAX_STRIDE': level.get("max_stride"), - 'MAX_SCALE': level.get("max_scale"), - 'MAX_LOG2_SIZE' : level.get("max_log2_size"), - } - return TOSALevel(name, desc, maximums) - def __load_operator_group(self, group): name = group.get("name") operators = [] @@ -121,7 +74,7 @@ class TOSASpec: types = [] typesupports = [] for arg in op.findall("arguments/argument"): - args.append(self.__load_operator_argument(arg, name)) + args.append(self.__load_operator_argument(arg)) # TODO add pseudo-code to operator object? @@ -140,27 +93,13 @@ class TOSASpec: typesupports.append(TOSAOperatorDataTypeSupport(tsmode, tsmap, tsprofiles)) return TOSAOperator(name, args, types, typesupports) - def __load_operator_argument(self, arg, op_name): + def __load_operator_argument(self, arg): name = arg.get("name") desc = arg.find("description").text.strip() argcats = [] argtype = arg.get("type") - argtelty = arg.get("tensor-element-type") shape = arg.get("shape") levellimits = [] - rank = [] - r = arg.find("rank") - if r != None: - rank = [r.get('min'),r.get('max')] - if shape == "-" and (rank[0] != '0' or rank[1] != '0'): - raise RuntimeError(f"rank is not empty or non-zero, but shape is '-' for {op_name}: {name}") - # validate rank against the shape argument - (shape_check, shape_rank) = get_rank_from_shape(shape) - if shape_check and (shape_rank < int(rank[0]) or shape_rank > int(rank[1])): - raise RuntimeError(f"Description of shape rank doesn't match XML rank min/max: {op_name} {name} shape: {shape} shape_rank: {shape_rank} min/max: {rank[0]} {rank[1]}") - else: - if shape != "-": - raise RuntimeError(f"Rank not present for {op_name}: {name} when shape is {shape}") for levellimit in arg.findall("levellimit"): value = levellimit.get("value") limit = levellimit.get("limit") @@ -172,17 +111,4 @@ class TOSASpec: for cat in cats: argcats.append(TOSAOperatorArgumentCategory(cat[0], cat[1].split(","))) - return TOSAOperatorArgument(name, desc, argcats, argtype, argtelty, shape, levellimits, rank) - - def __load_enum(self, arg): - name = arg.get("name") - desc = arg.get("description").strip() - values = [] - for val in arg.findall("enumval"): - values.append((val.get("name"), val.get("value"), val.get("description"))) - return TOSAEnum(name, desc, values) - - def get_enum_by_name(self, name): - for e in self.enums: - if e.name == name: - return e + return TOSAOperatorArgument(name, desc, argcats, argtype, shape, levellimits) |