aboutsummaryrefslogtreecommitdiff
path: root/tools/tosa.py
diff options
context:
space:
mode:
Diffstat (limited to 'tools/tosa.py')
-rw-r--r--tools/tosa.py10
1 files changed, 6 insertions, 4 deletions
diff --git a/tools/tosa.py b/tools/tosa.py
index 218412f..d01d9c2 100644
--- a/tools/tosa.py
+++ b/tools/tosa.py
@@ -36,11 +36,12 @@ class TOSALevel:
self.maximums = maximums
class TOSAOperatorArgument:
- def __init__(self, name, description, categories, ty, shape, levellimits, rank):
+ def __init__(self, name, description, categories, ty, elty, shape, levellimits, rank):
self.name = name
self.description = description
self.categories = categories
self.type = ty
+ self.tensor_element_type = elty
self.shape = shape
self.levellimits = levellimits
self.rank = rank
@@ -144,14 +145,15 @@ class TOSASpec:
desc = arg.find("description").text.strip()
argcats = []
argtype = arg.get("type")
+ argtelty = arg.get("tensor-element-type")
shape = arg.get("shape")
levellimits = []
rank = []
r = arg.find("rank")
if r != None:
- if shape == "-":
- raise RuntimeError(f"rank is not empty, but shape is '-' for {op_name}: {name}")
rank = [r.get('min'),r.get('max')]
+ if shape == "-" and (rank[0] != '0' or rank[1] != '0'):
+ raise RuntimeError(f"rank is not empty or non-zero, but shape is '-' for {op_name}: {name}")
# validate rank against the shape argument
(shape_check, shape_rank) = get_rank_from_shape(shape)
if shape_check and (shape_rank < int(rank[0]) or shape_rank > int(rank[1])):
@@ -170,7 +172,7 @@ class TOSASpec:
for cat in cats:
argcats.append(TOSAOperatorArgumentCategory(cat[0], cat[1].split(",")))
- return TOSAOperatorArgument(name, desc, argcats, argtype, shape, levellimits, rank)
+ return TOSAOperatorArgument(name, desc, argcats, argtype, argtelty, shape, levellimits, rank)
def __load_enum(self, arg):
name = arg.get("name")