diff options
Diffstat (limited to 'tools/tosa.py')
-rw-r--r-- | tools/tosa.py | 72 |
1 files changed, 18 insertions, 54 deletions
diff --git a/tools/tosa.py b/tools/tosa.py index 803e478..d01d9c2 100644 --- a/tools/tosa.py +++ b/tools/tosa.py @@ -1,64 +1,42 @@ -#!/usr/bin/env python3 -# Copyright (c) 2023, ARM Limited. -# SPDX-License-Identifier: Apache-2.0 import re import xml.etree.ElementTree as ET - # possible shapes: shape1, [2], [N,H,W,C] # returns (checkable, rank) # checkable is false if shape doesn't contain [] - - def get_rank_from_shape(shape): - if "[" not in shape or "[]" in shape: + if '[' not in shape or '[]' in shape: return (False, -1) # Check for fixed rank requirement [N] - m = re.match(r"\[(\d+)\]", shape) + m = re.match(r'\[(\d+)\]', shape) if m: return (True, 1) # Check for comma separated rank descriptors, return count - m = re.match(r"\[(.*)\]", shape) + m = re.match(r'\[(.*)\]', shape) if m: - return (True, len(m.group(1).split(","))) + return (True, len(m.group(1).split(','))) else: - raise RuntimeError(f"Unable to parse shape {shape}") - + raise RuntimeError(f'Unable to parse shape {shape}') class TOSAOperatorArgumentCategory: def __init__(self, name, profiles=None): self.name = name self.profiles = profiles - class TOSAEnum: def __init__(self, name, description, values): self.name = name self.description = description self.values = values - class TOSALevel: def __init__(self, name, desc, maximums): self.name = name self.desc = desc self.maximums = maximums - class TOSAOperatorArgument: - def __init__( - self, - name, - description, - categories, - ty, - elty, - shape, - levellimits, - rank, - optional=False, - ): - assert isinstance(optional, bool) + def __init__(self, name, description, categories, ty, elty, shape, levellimits, rank): self.name = name self.description = description self.categories = categories @@ -67,7 +45,6 @@ class TOSAOperatorArgument: self.shape = shape self.levellimits = levellimits self.rank = rank - self.optional = optional class TOSAOperatorDataTypeSupport: @@ -123,12 +100,11 @@ class TOSASpec: name = level.get("name") desc = level.text.strip() maximums = { - "MAX_RANK": level.get("max_rank"), - "MAX_KERNEL": level.get("max_kernel"), - "MAX_STRIDE": level.get("max_stride"), - "MAX_SCALE": level.get("max_scale"), - "MAX_LOG2_SIZE": level.get("max_log2_size"), - "MAX_NESTING": level.get("max_nesting"), + 'MAX_RANK': level.get("max_rank"), + 'MAX_KERNEL': level.get("max_kernel"), + 'MAX_STRIDE': level.get("max_stride"), + 'MAX_SCALE': level.get("max_scale"), + 'MAX_LOG2_SIZE' : level.get("max_log2_size"), } return TOSALevel(name, desc, maximums) @@ -173,28 +149,18 @@ class TOSASpec: shape = arg.get("shape") levellimits = [] rank = [] - optional = arg.get("optional", "false") == "true" r = arg.find("rank") - if r is not None: - rank = [r.get("min"), r.get("max")] - if shape == "-" and (rank[0] != "0" or rank[1] != "0"): - raise RuntimeError( - "rank is not empty or non-zero, but shape is '-'" - f" for {op_name}: {name}" - ) + if r != None: + rank = [r.get('min'),r.get('max')] + if shape == "-" and (rank[0] != '0' or rank[1] != '0'): + raise RuntimeError(f"rank is not empty or non-zero, but shape is '-' for {op_name}: {name}") # validate rank against the shape argument (shape_check, shape_rank) = get_rank_from_shape(shape) if shape_check and (shape_rank < int(rank[0]) or shape_rank > int(rank[1])): - raise RuntimeError( - "Description of shape rank doesn't match XML rank" - f" min/max: {op_name} {name} shape: {shape} shape_rank: " - f"{shape_rank} min/max: {rank[0]} {rank[1]}" - ) + raise RuntimeError(f"Description of shape rank doesn't match XML rank min/max: {op_name} {name} shape: {shape} shape_rank: {shape_rank} min/max: {rank[0]} {rank[1]}") else: if shape != "-": - raise RuntimeError( - f"Rank not present for {op_name}: {name} when shape is {shape}" - ) + raise RuntimeError(f"Rank not present for {op_name}: {name} when shape is {shape}") for levellimit in arg.findall("levellimit"): value = levellimit.get("value") limit = levellimit.get("limit") @@ -206,9 +172,7 @@ class TOSASpec: for cat in cats: argcats.append(TOSAOperatorArgumentCategory(cat[0], cat[1].split(","))) - return TOSAOperatorArgument( - name, desc, argcats, argtype, argtelty, shape, levellimits, rank, optional - ) + return TOSAOperatorArgument(name, desc, argcats, argtype, argtelty, shape, levellimits, rank) def __load_enum(self, arg): name = arg.get("name") |