aboutsummaryrefslogtreecommitdiff
path: root/tools/tosa.py
diff options
context:
space:
mode:
Diffstat (limited to 'tools/tosa.py')
-rw-r--r--tools/tosa.py72
1 files changed, 18 insertions, 54 deletions
diff --git a/tools/tosa.py b/tools/tosa.py
index 803e478..d01d9c2 100644
--- a/tools/tosa.py
+++ b/tools/tosa.py
@@ -1,64 +1,42 @@
-#!/usr/bin/env python3
-# Copyright (c) 2023, ARM Limited.
-# SPDX-License-Identifier: Apache-2.0
import re
import xml.etree.ElementTree as ET
-
# possible shapes: shape1, [2], [N,H,W,C]
# returns (checkable, rank)
# checkable is false if shape doesn't contain []
-
-
def get_rank_from_shape(shape):
- if "[" not in shape or "[]" in shape:
+ if '[' not in shape or '[]' in shape:
return (False, -1)
# Check for fixed rank requirement [N]
- m = re.match(r"\[(\d+)\]", shape)
+ m = re.match(r'\[(\d+)\]', shape)
if m:
return (True, 1)
# Check for comma separated rank descriptors, return count
- m = re.match(r"\[(.*)\]", shape)
+ m = re.match(r'\[(.*)\]', shape)
if m:
- return (True, len(m.group(1).split(",")))
+ return (True, len(m.group(1).split(',')))
else:
- raise RuntimeError(f"Unable to parse shape {shape}")
-
+ raise RuntimeError(f'Unable to parse shape {shape}')
class TOSAOperatorArgumentCategory:
def __init__(self, name, profiles=None):
self.name = name
self.profiles = profiles
-
class TOSAEnum:
def __init__(self, name, description, values):
self.name = name
self.description = description
self.values = values
-
class TOSALevel:
def __init__(self, name, desc, maximums):
self.name = name
self.desc = desc
self.maximums = maximums
-
class TOSAOperatorArgument:
- def __init__(
- self,
- name,
- description,
- categories,
- ty,
- elty,
- shape,
- levellimits,
- rank,
- optional=False,
- ):
- assert isinstance(optional, bool)
+ def __init__(self, name, description, categories, ty, elty, shape, levellimits, rank):
self.name = name
self.description = description
self.categories = categories
@@ -67,7 +45,6 @@ class TOSAOperatorArgument:
self.shape = shape
self.levellimits = levellimits
self.rank = rank
- self.optional = optional
class TOSAOperatorDataTypeSupport:
@@ -123,12 +100,11 @@ class TOSASpec:
name = level.get("name")
desc = level.text.strip()
maximums = {
- "MAX_RANK": level.get("max_rank"),
- "MAX_KERNEL": level.get("max_kernel"),
- "MAX_STRIDE": level.get("max_stride"),
- "MAX_SCALE": level.get("max_scale"),
- "MAX_LOG2_SIZE": level.get("max_log2_size"),
- "MAX_NESTING": level.get("max_nesting"),
+ 'MAX_RANK': level.get("max_rank"),
+ 'MAX_KERNEL': level.get("max_kernel"),
+ 'MAX_STRIDE': level.get("max_stride"),
+ 'MAX_SCALE': level.get("max_scale"),
+ 'MAX_LOG2_SIZE' : level.get("max_log2_size"),
}
return TOSALevel(name, desc, maximums)
@@ -173,28 +149,18 @@ class TOSASpec:
shape = arg.get("shape")
levellimits = []
rank = []
- optional = arg.get("optional", "false") == "true"
r = arg.find("rank")
- if r is not None:
- rank = [r.get("min"), r.get("max")]
- if shape == "-" and (rank[0] != "0" or rank[1] != "0"):
- raise RuntimeError(
- "rank is not empty or non-zero, but shape is '-'"
- f" for {op_name}: {name}"
- )
+ if r != None:
+ rank = [r.get('min'),r.get('max')]
+ if shape == "-" and (rank[0] != '0' or rank[1] != '0'):
+ raise RuntimeError(f"rank is not empty or non-zero, but shape is '-' for {op_name}: {name}")
# validate rank against the shape argument
(shape_check, shape_rank) = get_rank_from_shape(shape)
if shape_check and (shape_rank < int(rank[0]) or shape_rank > int(rank[1])):
- raise RuntimeError(
- "Description of shape rank doesn't match XML rank"
- f" min/max: {op_name} {name} shape: {shape} shape_rank: "
- f"{shape_rank} min/max: {rank[0]} {rank[1]}"
- )
+ raise RuntimeError(f"Description of shape rank doesn't match XML rank min/max: {op_name} {name} shape: {shape} shape_rank: {shape_rank} min/max: {rank[0]} {rank[1]}")
else:
if shape != "-":
- raise RuntimeError(
- f"Rank not present for {op_name}: {name} when shape is {shape}"
- )
+ raise RuntimeError(f"Rank not present for {op_name}: {name} when shape is {shape}")
for levellimit in arg.findall("levellimit"):
value = levellimit.get("value")
limit = levellimit.get("limit")
@@ -206,9 +172,7 @@ class TOSASpec:
for cat in cats:
argcats.append(TOSAOperatorArgumentCategory(cat[0], cat[1].split(",")))
- return TOSAOperatorArgument(
- name, desc, argcats, argtype, argtelty, shape, levellimits, rank, optional
- )
+ return TOSAOperatorArgument(name, desc, argcats, argtype, argtelty, shape, levellimits, rank)
def __load_enum(self, arg):
name = arg.get("name")