From 42085e36b0b47209ca767a3b8300f689cb6ec0bf Mon Sep 17 00:00:00 2001 From: Eric Kunze Date: Mon, 9 Jan 2023 11:16:51 -0800 Subject: Add TOSA rank requirements to TOSA XML Adds new optional element to argument 'rank' - Must supply minimum and maximum rank - Integer values or the level based "MAX_RANK" - trailing modifiers allowed for "MAX_RANK" - Displays in a new column in the document - Document generation validates rank against specified shape Change-Id: I507dc51bfe012d3230af43103c6c423a6f1e92b5 Signed-off-by: Eric Kunze --- tools/genspec.py | 19 +++++++++++++++---- tools/tosa.py | 38 ++++++++++++++++++++++++++++++++++---- 2 files changed, 49 insertions(+), 8 deletions(-) (limited to 'tools') diff --git a/tools/genspec.py b/tools/genspec.py index f495296..c64f05b 100755 --- a/tools/genspec.py +++ b/tools/genspec.py @@ -3,7 +3,6 @@ import os import tosa - class TOSASpecAsciidocGenerator: def __init__(self, spec): self.spec = spec @@ -19,8 +18,9 @@ class TOSASpecAsciidocGenerator: def generate_operator(self, op, file): file.write("\n*Arguments:*\n") + file.write("[cols='2,1,1,1,2,4']") file.write("\n|===\n") - file.write("|Argument|Type|Name|Shape|Description\n\n") + file.write("|Argument|Type|Name|Shape|Rank|Description\n\n") for arg in op.arguments: cats = arg.categories if len(cats) > 1: @@ -33,8 +33,15 @@ class TOSASpecAsciidocGenerator: sep = " " else: cattext = cats[0].name.title() + if len(arg.rank) > 0: + if (arg.rank[0] == arg.rank[1]): + rank = f'{arg.rank[0]}' + else: + rank = f'{arg.rank[0]} to {arg.rank[1]}' + else: + rank = "" file.write( - f"|{cattext}|{arg.type}|{arg.name}|{arg.shape}|{arg.description}\n" + f"|{cattext}|{arg.type}|{arg.name}|{arg.shape}|{rank}|{arg.description}\n" ) file.write("|===\n") if op.typesupports: @@ -112,7 +119,11 @@ if __name__ == "__main__": parser.add_argument("--outdir", required=True, help="Output directory") args = parser.parse_args() - spec = tosa.TOSASpec(args.xml) + try: + spec = tosa.TOSASpec(args.xml) + except RuntimeError as e: + print(f"Failure reading/validating XML spec: {str(e)}") + exit(1) generator = TOSASpecAsciidocGenerator(spec) generator.generate(args.outdir) diff --git a/tools/tosa.py b/tools/tosa.py index bc6faa6..218412f 100644 --- a/tools/tosa.py +++ b/tools/tosa.py @@ -1,6 +1,22 @@ import re import xml.etree.ElementTree as ET +# possible shapes: shape1, [2], [N,H,W,C] +# returns (checkable, rank) +# checkable is false if shape doesn't contain [] +def get_rank_from_shape(shape): + if '[' not in shape or '[]' in shape: + return (False, -1) + # Check for fixed rank requirement [N] + m = re.match(r'\[(\d+)\]', shape) + if m: + return (True, 1) + # Check for comma separated rank descriptors, return count + m = re.match(r'\[(.*)\]', shape) + if m: + return (True, len(m.group(1).split(','))) + else: + raise RuntimeError(f'Unable to parse shape {shape}') class TOSAOperatorArgumentCategory: def __init__(self, name, profiles=None): @@ -20,13 +36,14 @@ class TOSALevel: self.maximums = maximums class TOSAOperatorArgument: - def __init__(self, name, description, categories, ty, shape, levellimits): + def __init__(self, name, description, categories, ty, shape, levellimits, rank): self.name = name self.description = description self.categories = categories self.type = ty self.shape = shape self.levellimits = levellimits + self.rank = rank class TOSAOperatorDataTypeSupport: @@ -103,7 +120,7 @@ class TOSASpec: types = [] typesupports = [] for arg in op.findall("arguments/argument"): - args.append(self.__load_operator_argument(arg)) + args.append(self.__load_operator_argument(arg, name)) # TODO add pseudo-code to operator object? @@ -122,13 +139,26 @@ class TOSASpec: typesupports.append(TOSAOperatorDataTypeSupport(tsmode, tsmap, tsprofiles)) return TOSAOperator(name, args, types, typesupports) - def __load_operator_argument(self, arg): + def __load_operator_argument(self, arg, op_name): name = arg.get("name") desc = arg.find("description").text.strip() argcats = [] argtype = arg.get("type") shape = arg.get("shape") levellimits = [] + rank = [] + r = arg.find("rank") + if r != None: + if shape == "-": + raise RuntimeError(f"rank is not empty, but shape is '-' for {op_name}: {name}") + rank = [r.get('min'),r.get('max')] + # validate rank against the shape argument + (shape_check, shape_rank) = get_rank_from_shape(shape) + if shape_check and (shape_rank < int(rank[0]) or shape_rank > int(rank[1])): + raise RuntimeError(f"Description of shape rank doesn't match XML rank min/max: {op_name} {name} shape: {shape} shape_rank: {shape_rank} min/max: {rank[0]} {rank[1]}") + else: + if shape != "-": + raise RuntimeError(f"Rank not present for {op_name}: {name} when shape is {shape}") for levellimit in arg.findall("levellimit"): value = levellimit.get("value") limit = levellimit.get("limit") @@ -140,7 +170,7 @@ class TOSASpec: for cat in cats: argcats.append(TOSAOperatorArgumentCategory(cat[0], cat[1].split(","))) - return TOSAOperatorArgument(name, desc, argcats, argtype, shape, levellimits) + return TOSAOperatorArgument(name, desc, argcats, argtype, shape, levellimits, rank) def __load_enum(self, arg): name = arg.get("name") -- cgit v1.2.1