#!/usr/bin/env python3 # Copyright (c) 2023, ARM Limited. # SPDX-License-Identifier: Apache-2.0 import re import xml.etree.ElementTree as ET # possible shapes: shape1, [2], [N,H,W,C] # returns (checkable, rank) # checkable is false if shape doesn't contain [] def get_rank_from_shape(shape): if "[" not in shape or "[]" in shape: return (False, -1) # Check for fixed rank requirement [N] m = re.match(r"\[(\d+)\]", shape) if m: return (True, 1) # Check for comma separated rank descriptors, return count m = re.match(r"\[(.*)\]", shape) if m: return (True, len(m.group(1).split(","))) else: raise RuntimeError(f"Unable to parse shape {shape}") class TOSAOperatorArgumentCategory: def __init__(self, name, profiles=None): self.name = name self.profiles = profiles class TOSAProfile: def __init__(self, profile, name, description, status): self.profile = profile self.name = name self.description = description self.status = status self.ops = [] class TOSAProfileExtension: def __init__(self, name, description, status, profiles): self.name = name self.description = description self.status = status self.profiles = profiles self.ops = [] class TOSAEnum: def __init__(self, name, description, values): self.name = name self.description = description self.values = values class TOSALevel: def __init__(self, name, desc, maximums): self.name = name self.desc = desc self.maximums = maximums class TOSAOperatorArgument: def __init__( self, name, description, categories, ty, elty, shape, levellimits, rank, optional=False, ): assert isinstance(optional, bool) self.name = name self.description = description self.categories = categories self.type = ty self.tensor_element_type = elty self.shape = shape self.levellimits = levellimits self.rank = rank self.optional = optional class TOSAOperatorDataTypeSupport: def __init__(self, mode, tymap, version_added, profiles): self.mode = mode self.tymap = tymap self.profiles = profiles self.version_added = version_added class TOSAOperator: def __init__(self, name, arguments, types, typesupports): self.name = name self.arguments = arguments self.types = types self.typesupports = typesupports class TOSAOperatorGroup: def __init__(self, name, operators): self.name = name self.operators = operators class TOSASpec: def __init__(self, xmlpath): tree = ET.parse(xmlpath) self.xmlroot = tree.getroot() self.profiles = [] self.profile_extensions = [] self.levels = [] self.operatorgroups = [] self.enums = [] self.__load_spec() def __load_spec(self): self.__load_version() for profile in self.xmlroot.findall("./profiles/profile"): self.profiles.append(self.__load_profile(profile)) for profile_ext in self.xmlroot.findall( "./profile_extensions/profile_extension" ): self.profile_extensions.append(self.__load_profile_extension(profile_ext)) for level in self.xmlroot.findall("./levels/level"): self.levels.append(self.__load_level(level)) for group in self.xmlroot.findall("./operators/operatorgroup"): self.operatorgroups.append(self.__load_operator_group(group)) for enum in self.xmlroot.findall("./enum"): self.enums.append(self.__load_enum(enum)) def __load_version(self): version = self.xmlroot.find("./version") self.version_major = int(version.get("major")) self.version_minor = int(version.get("minor")) self.version_patch = int(version.get("patch")) if version.get("draft") == "true": self.version_is_draft = True else: self.version_is_draft = False def __load_profile(self, xml_profile): profile = xml_profile.get("profile") name = xml_profile.get("name") description = xml_profile.get("description") status = xml_profile.get("status") return TOSAProfile(profile, name, description, status) def __load_profile_extension(self, ext): name = ext.get("name") description = ext.get("description") status = ext.get("status") profiles = [x.text for x in ext] return TOSAProfileExtension(name, description, status, profiles) def __load_level(self, level): name = level.get("name") desc = level.text.strip() maximums = { "MAX_RANK": level.get("max_rank"), "MAX_KERNEL": level.get("max_kernel"), "MAX_STRIDE": level.get("max_stride"), "MAX_SCALE": level.get("max_scale"), "MAX_LOG2_SIZE": level.get("max_log2_size"), "MAX_NESTING": level.get("max_nesting"), "MAX_TENSOR_LIST_SIZE": level.get("max_tensor_list_size"), } return TOSALevel(name, desc, maximums) def __load_operator_group(self, group): name = group.get("name") operators = [] for op in group.findall("operator"): operators.append(self.__load_operator(op)) return TOSAOperatorGroup(name, operators) def __load_operator(self, op): name = op.find("name").text args = [] types = [] typesupports = [] for arg in op.findall("arguments/argument"): args.append(self.__load_operator_argument(arg, name)) # TODO add pseudo-code to operator object? for ty in op.findall("types/type"): types.append(ty.get("name")) for tysup in op.findall("typesupport"): tsmode = tysup.get("mode") tsmap = {} version_added = tysup.get("version_added") profiles = tysup.findall("op_profile") tsprofiles = [] for p in profiles: tsp_name = p.get("name") and_name = p.get("and_name") if and_name is not None: if and_name < tsp_name: tsp_name = f"{and_name} and {tsp_name}" else: tsp_name = f"{tsp_name} and {and_name}" tsprofiles.append(tsp_name) for ty in types: tsmap[ty] = tysup.get(ty) typesupports.append( TOSAOperatorDataTypeSupport(tsmode, tsmap, version_added, tsprofiles) ) return TOSAOperator(name, args, types, typesupports) def __load_operator_argument(self, arg, op_name): name = arg.get("name") desc = arg.find("description").text.strip() argcats = [] argtype = arg.get("type") argtelty = arg.get("tensor-element-type") shape = arg.get("shape") levellimits = [] rank = [] optional = arg.get("optional", "false") == "true" r = arg.find("rank") if r is not None: rank = [r.get("min"), r.get("max")] if shape == "-" and (rank[0] != "0" or rank[1] != "0"): raise RuntimeError( "rank is not empty or non-zero, but shape is '-'" f" for {op_name}: {name}" ) # validate rank against the shape argument (shape_check, shape_rank) = get_rank_from_shape(shape) if shape_check and (shape_rank < int(rank[0]) or shape_rank > int(rank[1])): raise RuntimeError( "Description of shape rank doesn't match XML rank" f" min/max: {op_name} {name} shape: {shape} shape_rank: " f"{shape_rank} min/max: {rank[0]} {rank[1]}" ) else: if shape != "-": raise RuntimeError( f"Rank not present for {op_name}: {name} when shape is {shape}" ) for levellimit in arg.findall("levellimit"): value = levellimit.get("value") limit = levellimit.get("limit") levellimits.append([value, limit]) cats = re.findall( r"(input|output|attribute)\(?([A-Z,]+)?\)?", arg.get("category") ) for cat in cats: argcats.append(TOSAOperatorArgumentCategory(cat[0], cat[1].split(","))) return TOSAOperatorArgument( name, desc, argcats, argtype, argtelty, shape, levellimits, rank, optional ) def __load_enum(self, arg): name = arg.get("name") desc = arg.get("description").strip() values = [] for val in arg.findall("enumval"): values.append((val.get("name"), val.get("value"), val.get("description"))) return TOSAEnum(name, desc, values) def get_enum_by_name(self, name): for e in self.enums: if e.name == name: return e