aboutsummaryrefslogtreecommitdiff
path: root/tools/tosa.py
diff options
context:
space:
mode:
Diffstat (limited to 'tools/tosa.py')
-rw-r--r--tools/tosa.py82
1 files changed, 4 insertions, 78 deletions
diff --git a/tools/tosa.py b/tools/tosa.py
index d01d9c2..74d43d6 100644
--- a/tools/tosa.py
+++ b/tools/tosa.py
@@ -1,50 +1,21 @@
import re
import xml.etree.ElementTree as ET
-# possible shapes: shape1, [2], [N,H,W,C]
-# returns (checkable, rank)
-# checkable is false if shape doesn't contain []
-def get_rank_from_shape(shape):
- if '[' not in shape or '[]' in shape:
- return (False, -1)
- # Check for fixed rank requirement [N]
- m = re.match(r'\[(\d+)\]', shape)
- if m:
- return (True, 1)
- # Check for comma separated rank descriptors, return count
- m = re.match(r'\[(.*)\]', shape)
- if m:
- return (True, len(m.group(1).split(',')))
- else:
- raise RuntimeError(f'Unable to parse shape {shape}')
class TOSAOperatorArgumentCategory:
def __init__(self, name, profiles=None):
self.name = name
self.profiles = profiles
-class TOSAEnum:
- def __init__(self, name, description, values):
- self.name = name
- self.description = description
- self.values = values
-
-class TOSALevel:
- def __init__(self, name, desc, maximums):
- self.name = name
- self.desc = desc
- self.maximums = maximums
class TOSAOperatorArgument:
- def __init__(self, name, description, categories, ty, elty, shape, levellimits, rank):
+ def __init__(self, name, description, categories, ty, shape, levellimits):
self.name = name
self.description = description
self.categories = categories
self.type = ty
- self.tensor_element_type = elty
self.shape = shape
self.levellimits = levellimits
- self.rank = rank
class TOSAOperatorDataTypeSupport:
@@ -72,19 +43,13 @@ class TOSASpec:
def __init__(self, xmlpath):
tree = ET.parse(xmlpath)
self.xmlroot = tree.getroot()
- self.levels = []
self.operatorgroups = []
- self.enums = []
self.__load_spec()
def __load_spec(self):
self.__load_version()
- for level in self.xmlroot.findall("./levels/level"):
- self.levels.append(self.__load_level(level))
for group in self.xmlroot.findall("./operators/operatorgroup"):
self.operatorgroups.append(self.__load_operator_group(group))
- for enum in self.xmlroot.findall("./enum"):
- self.enums.append(self.__load_enum(enum))
def __load_version(self):
version = self.xmlroot.find("./version")
@@ -96,18 +61,6 @@ class TOSASpec:
else:
self.version_is_draft = False
- def __load_level(self, level):
- name = level.get("name")
- desc = level.text.strip()
- maximums = {
- 'MAX_RANK': level.get("max_rank"),
- 'MAX_KERNEL': level.get("max_kernel"),
- 'MAX_STRIDE': level.get("max_stride"),
- 'MAX_SCALE': level.get("max_scale"),
- 'MAX_LOG2_SIZE' : level.get("max_log2_size"),
- }
- return TOSALevel(name, desc, maximums)
-
def __load_operator_group(self, group):
name = group.get("name")
operators = []
@@ -121,7 +74,7 @@ class TOSASpec:
types = []
typesupports = []
for arg in op.findall("arguments/argument"):
- args.append(self.__load_operator_argument(arg, name))
+ args.append(self.__load_operator_argument(arg))
# TODO add pseudo-code to operator object?
@@ -140,27 +93,13 @@ class TOSASpec:
typesupports.append(TOSAOperatorDataTypeSupport(tsmode, tsmap, tsprofiles))
return TOSAOperator(name, args, types, typesupports)
- def __load_operator_argument(self, arg, op_name):
+ def __load_operator_argument(self, arg):
name = arg.get("name")
desc = arg.find("description").text.strip()
argcats = []
argtype = arg.get("type")
- argtelty = arg.get("tensor-element-type")
shape = arg.get("shape")
levellimits = []
- rank = []
- r = arg.find("rank")
- if r != None:
- rank = [r.get('min'),r.get('max')]
- if shape == "-" and (rank[0] != '0' or rank[1] != '0'):
- raise RuntimeError(f"rank is not empty or non-zero, but shape is '-' for {op_name}: {name}")
- # validate rank against the shape argument
- (shape_check, shape_rank) = get_rank_from_shape(shape)
- if shape_check and (shape_rank < int(rank[0]) or shape_rank > int(rank[1])):
- raise RuntimeError(f"Description of shape rank doesn't match XML rank min/max: {op_name} {name} shape: {shape} shape_rank: {shape_rank} min/max: {rank[0]} {rank[1]}")
- else:
- if shape != "-":
- raise RuntimeError(f"Rank not present for {op_name}: {name} when shape is {shape}")
for levellimit in arg.findall("levellimit"):
value = levellimit.get("value")
limit = levellimit.get("limit")
@@ -172,17 +111,4 @@ class TOSASpec:
for cat in cats:
argcats.append(TOSAOperatorArgumentCategory(cat[0], cat[1].split(",")))
- return TOSAOperatorArgument(name, desc, argcats, argtype, argtelty, shape, levellimits, rank)
-
- def __load_enum(self, arg):
- name = arg.get("name")
- desc = arg.get("description").strip()
- values = []
- for val in arg.findall("enumval"):
- values.append((val.get("name"), val.get("value"), val.get("description")))
- return TOSAEnum(name, desc, values)
-
- def get_enum_by_name(self, name):
- for e in self.enums:
- if e.name == name:
- return e
+ return TOSAOperatorArgument(name, desc, argcats, argtype, shape, levellimits)