aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorEric Kunze <eric.kunze@arm.com>2023-08-03 14:25:50 -0700
committerEric Kunze <eric.kunze@arm.com>2023-08-03 14:35:28 -0700
commit318ac6fc76b9efabb60f5d9c1abf84508e1a8a01 (patch)
tree3bea9e8a42bc2b2e63935fa8104fa8525f30ec91
parent7b0f1c9a090fb7e4c39afad5bdb09a2036b389a6 (diff)
downloadspecification-318ac6fc76b9efabb60f5d9c1abf84508e1a8a01.tar.gz
Use black and flake8 to format python (NFC)
No functional changes, cleanup only Add missing copyright notices Signed-off-by: Eric Kunze <eric.kunze@arm.com> Change-Id: I0ddc8738f16aeced28fe7aa2ccc4fb715a84bd18
-rwxr-xr-xtools/genspec.py65
-rw-r--r--tools/tosa.py60
2 files changed, 76 insertions, 49 deletions
diff --git a/tools/genspec.py b/tools/genspec.py
index 11a3e72..b8e961b 100755
--- a/tools/genspec.py
+++ b/tools/genspec.py
@@ -1,8 +1,11 @@
#!/usr/bin/env python3
+# Copyright (c) 2023, ARM Limited.
+# SPDX-License-Identifier: Apache-2.0
import os
import tosa
+
class TOSASpecAsciidocGenerator:
def __init__(self, spec):
self.spec = spec
@@ -36,28 +39,29 @@ class TOSASpecAsciidocGenerator:
cattext = cats[0].name.title()
# Type
- if arg.type == 'tensor_t':
- argtype = 'T<{}>'.format(arg.tensor_element_type)
- elif arg.type == 'tensor_list_t':
- if arg.tensor_element_type == '-':
- argtype = 'tensor_list_t'
+ if arg.type == "tensor_t":
+ argtype = "T<{}>".format(arg.tensor_element_type)
+ elif arg.type == "tensor_list_t":
+ if arg.tensor_element_type == "-":
+ argtype = "tensor_list_t"
else:
- argtype = 'tensor_list_t<T<{}>>'.format(arg.tensor_element_type)
+ argtype = "tensor_list_t<T<{}>>".format(arg.tensor_element_type)
else:
argtype = arg.type
# Rank
if len(arg.rank) > 0:
- if (arg.rank[0] == arg.rank[1]):
- rank = f'{arg.rank[0]}'
+ if arg.rank[0] == arg.rank[1]:
+ rank = f"{arg.rank[0]}"
else:
- rank = f'{arg.rank[0]} to {arg.rank[1]}'
+ rank = f"{arg.rank[0]} to {arg.rank[1]}"
else:
rank = ""
# Format and write line
file.write(
- f"|{cattext}|{argtype}|{arg.name}|{arg.shape}|{rank}|{arg.description}\n"
+ f"|{cattext}|{argtype}|{arg.name}|{arg.shape}"
+ f"|{rank}|{arg.description}\n"
)
file.write("|===\n")
@@ -80,10 +84,10 @@ class TOSASpecAsciidocGenerator:
file.write("\n*Operation Function:*\n\n")
leveltext = ""
for arg in op.arguments:
- if (len(arg.levellimits) > 0):
+ if len(arg.levellimits) > 0:
for limit in arg.levellimits:
- leveltext += "LEVEL_CHECK(" + limit[0] + " <= " + limit[1] + ");\n"
- if (len(leveltext) > 0):
+ leveltext += "LEVEL_CHECK(" + limit[0] + " <= " + limit[1] + ");\n"
+ if len(leveltext) > 0:
file.write(f"[source,c++]\n----\n{leveltext}\n----\n")
def generate(self, outdir):
@@ -93,29 +97,29 @@ class TOSASpecAsciidocGenerator:
major = self.spec.version_major
minor = self.spec.version_minor
patch = self.spec.version_patch
- with open(os.path.join(outdir, "version.adoc"), 'w') as f:
- f.write(':tosa-version-string: {}.{}.{}'.format(major, minor, patch))
+ with open(os.path.join(outdir, "version.adoc"), "w") as f:
+ f.write(":tosa-version-string: {}.{}.{}".format(major, minor, patch))
if self.spec.version_is_draft:
- f.write(' draft')
- f.write('\n')
+ f.write(" draft")
+ f.write("\n")
# Generate level maximums table
- with open(os.path.join(outdir, "levels.adoc"), 'w') as f:
- f.write('|===\n')
- f.write('|tosa_level_t')
+ with open(os.path.join(outdir, "levels.adoc"), "w") as f:
+ f.write("|===\n")
+ f.write("|tosa_level_t")
for level in self.spec.levels:
- f.write('|tosa_level_{}'.format(level.name))
- f.write('\n')
- f.write('|Description')
+ f.write("|tosa_level_{}".format(level.name))
+ f.write("\n")
+ f.write("|Description")
for level in self.spec.levels:
- f.write('|{}'.format(level.desc))
- f.write('\n')
+ f.write("|{}".format(level.desc))
+ f.write("\n")
for param in self.spec.levels[0].maximums:
- f.write('|{}'.format(param))
+ f.write("|{}".format(param))
for level in self.spec.levels:
- f.write('|{}'.format(level.maximums[param]))
- f.write('\n')
- f.write('|===\n')
+ f.write("|{}".format(level.maximums[param]))
+ f.write("\n")
+ f.write("|===\n")
# Generator operators
opdir = os.path.join(outdir, "operators")
@@ -124,10 +128,11 @@ class TOSASpecAsciidocGenerator:
for op in group.operators:
with open(os.path.join(opdir, op.name + ".adoc"), "w") as f:
self.generate_operator(op, f)
- with open(os.path.join(outdir, "enums.adoc"), 'w') as f:
+ with open(os.path.join(outdir, "enums.adoc"), "w") as f:
for enum in self.spec.enums:
self.generate_enum(enum, f)
+
if __name__ == "__main__":
import argparse
diff --git a/tools/tosa.py b/tools/tosa.py
index acea04d..26e501f 100644
--- a/tools/tosa.py
+++ b/tools/tosa.py
@@ -1,42 +1,53 @@
+#!/usr/bin/env python3
+# Copyright (c) 2023, ARM Limited.
+# SPDX-License-Identifier: Apache-2.0
import re
import xml.etree.ElementTree as ET
# possible shapes: shape1, [2], [N,H,W,C]
# returns (checkable, rank)
# checkable is false if shape doesn't contain []
+
+
def get_rank_from_shape(shape):
- if '[' not in shape or '[]' in shape:
+ if "[" not in shape or "[]" in shape:
return (False, -1)
# Check for fixed rank requirement [N]
- m = re.match(r'\[(\d+)\]', shape)
+ m = re.match(r"\[(\d+)\]", shape)
if m:
return (True, 1)
# Check for comma separated rank descriptors, return count
- m = re.match(r'\[(.*)\]', shape)
+ m = re.match(r"\[(.*)\]", shape)
if m:
- return (True, len(m.group(1).split(',')))
+ return (True, len(m.group(1).split(",")))
else:
- raise RuntimeError(f'Unable to parse shape {shape}')
+ raise RuntimeError(f"Unable to parse shape {shape}")
+
class TOSAOperatorArgumentCategory:
def __init__(self, name, profiles=None):
self.name = name
self.profiles = profiles
+
class TOSAEnum:
def __init__(self, name, description, values):
self.name = name
self.description = description
self.values = values
+
class TOSALevel:
def __init__(self, name, desc, maximums):
self.name = name
self.desc = desc
self.maximums = maximums
+
class TOSAOperatorArgument:
- def __init__(self, name, description, categories, ty, elty, shape, levellimits, rank):
+ def __init__(
+ self, name, description, categories, ty, elty, shape, levellimits, rank
+ ):
self.name = name
self.description = description
self.categories = categories
@@ -100,12 +111,12 @@ class TOSASpec:
name = level.get("name")
desc = level.text.strip()
maximums = {
- 'MAX_RANK': level.get("max_rank"),
- 'MAX_KERNEL': level.get("max_kernel"),
- 'MAX_STRIDE': level.get("max_stride"),
- 'MAX_SCALE': level.get("max_scale"),
- 'MAX_LOG2_SIZE' : level.get("max_log2_size"),
- 'MAX_NESTING' : level.get("max_nesting"),
+ "MAX_RANK": level.get("max_rank"),
+ "MAX_KERNEL": level.get("max_kernel"),
+ "MAX_STRIDE": level.get("max_stride"),
+ "MAX_SCALE": level.get("max_scale"),
+ "MAX_LOG2_SIZE": level.get("max_log2_size"),
+ "MAX_NESTING": level.get("max_nesting"),
}
return TOSALevel(name, desc, maximums)
@@ -151,17 +162,26 @@ class TOSASpec:
levellimits = []
rank = []
r = arg.find("rank")
- if r != None:
- rank = [r.get('min'),r.get('max')]
- if shape == "-" and (rank[0] != '0' or rank[1] != '0'):
- raise RuntimeError(f"rank is not empty or non-zero, but shape is '-' for {op_name}: {name}")
+ if r is not None:
+ rank = [r.get("min"), r.get("max")]
+ if shape == "-" and (rank[0] != "0" or rank[1] != "0"):
+ raise RuntimeError(
+ "rank is not empty or non-zero, but shape is '-'"
+ f" for {op_name}: {name}"
+ )
# validate rank against the shape argument
(shape_check, shape_rank) = get_rank_from_shape(shape)
if shape_check and (shape_rank < int(rank[0]) or shape_rank > int(rank[1])):
- raise RuntimeError(f"Description of shape rank doesn't match XML rank min/max: {op_name} {name} shape: {shape} shape_rank: {shape_rank} min/max: {rank[0]} {rank[1]}")
+ raise RuntimeError(
+ "Description of shape rank doesn't match XML rank"
+ f" min/max: {op_name} {name} shape: {shape} shape_rank: "
+ f"{shape_rank} min/max: {rank[0]} {rank[1]}"
+ )
else:
if shape != "-":
- raise RuntimeError(f"Rank not present for {op_name}: {name} when shape is {shape}")
+ raise RuntimeError(
+ f"Rank not present for {op_name}: {name} when shape is {shape}"
+ )
for levellimit in arg.findall("levellimit"):
value = levellimit.get("value")
limit = levellimit.get("limit")
@@ -173,7 +193,9 @@ class TOSASpec:
for cat in cats:
argcats.append(TOSAOperatorArgumentCategory(cat[0], cat[1].split(",")))
- return TOSAOperatorArgument(name, desc, argcats, argtype, argtelty, shape, levellimits, rank)
+ return TOSAOperatorArgument(
+ name, desc, argcats, argtype, argtelty, shape, levellimits, rank
+ )
def __load_enum(self, arg):
name = arg.get("name")