aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorEric Kunze <eric.kunze@arm.com>2023-09-07 01:36:07 +0000
committerEric Kunze <eric.kunze@arm.com>2023-11-29 05:08:44 +0000
commit4ce101797cfd9fbe287c0217c4da80ea76d51b74 (patch)
tree63a409bf97aa820f6258092b2203c09300479bce
parent0c0a263bf6742e943bebd42ccf97dcdbd8f4e1c8 (diff)
downloadreference_model-4ce101797cfd9fbe287c0217c4da80ea76d51b74.tar.gz
Add check of operator API to precommit
Attempt to avoid API getting out of sync. Signed-off-by: Eric Kunze <eric.kunze@arm.com> Change-Id: Ic7b72c3f906e4a38cb26159bb67e9b1c4e22ca96
-rw-r--r--.pre-commit-config.yaml11
-rw-r--r--reference_model/src/operators.cc2
-rw-r--r--scripts/operator_api/generate_api.py29
3 files changed, 32 insertions, 10 deletions
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 4a6c9c7..2c72a26 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -37,4 +37,13 @@ repos:
language: system
entry: clang-format
types: ["c++"]
- args: ["-i"] \ No newline at end of file
+ args: ["-i"]
+
+- repo: local
+ hooks:
+ - id: check-operator-api
+ name: check-operator-api
+ language: system
+ entry: python3 scripts/operator_api/generate_api.py
+ pass_filenames: false
+ always_run: true
diff --git a/reference_model/src/operators.cc b/reference_model/src/operators.cc
index 14065ad..ecebe52 100644
--- a/reference_model/src/operators.cc
+++ b/reference_model/src/operators.cc
@@ -2497,4 +2497,4 @@ extern "C"
return tosa_status_valid;
}
-} // extern "C"
+} // extern "C" \ No newline at end of file
diff --git a/scripts/operator_api/generate_api.py b/scripts/operator_api/generate_api.py
index 31ee151..afe12c1 100644
--- a/scripts/operator_api/generate_api.py
+++ b/scripts/operator_api/generate_api.py
@@ -4,6 +4,7 @@
import copy
import os
import subprocess
+from pathlib import Path
from xml.dom import minidom
from jinja2 import Environment
@@ -12,6 +13,10 @@ from jinja2 import FileSystemLoader
# Note: main script designed to be run from the scripts/operator_api/ directory
+def getBasePath():
+ return Path(__file__).resolve().parent.parent.parent
+
+
def getTosaArgTypes(tosaXml):
"""
Returns a list of the TOSA argument types from tosa.xml.
@@ -334,7 +339,11 @@ def getSerialLibAtts():
The values are the arguments required by each Serialization library operator.
"""
serialLibAtts = {}
- with open("../../thirdparty/serialization_lib/include/attribute.def") as file:
+ base_path = getBasePath()
+ attr_def = (
+ base_path / "thirdparty" / "serialization_lib" / "include" / "attribute.def"
+ )
+ with open(attr_def) as file:
preamble = True
inAtt = False
opName = ""
@@ -376,15 +385,15 @@ def renderTemplate(environment, dataTypes, operators, template, outfile):
clangFormat(outfile)
-def generate(environment, dataTypes, operators):
+def generate(environment, dataTypes, operators, base_path):
# Generate include/operators.h
template = environment.get_template("operators_h.j2")
- outfile = os.path.join("..", "..", "reference_model", "include", "operators.h")
+ outfile = base_path / "reference_model/include/operators.h"
renderTemplate(environment, dataTypes, operators, template, outfile)
# Generate src/operators.cc
template = environment.get_template("operators_cc.j2")
- outfile = os.path.join("..", "..", "reference_model", "src", "operators.cc")
+ outfile = base_path / "reference_model/src/operators.cc"
renderTemplate(environment, dataTypes, operators, template, outfile)
@@ -400,7 +409,8 @@ def getSerializeOpTypeMap():
for name in allSerialLibAtts.keys()
]
serAtts = sorted(serAtts, key=len, reverse=True)
- tosaXml = minidom.parse("../../thirdparty/specification/tosa.xml")
+ base_path = getBasePath()
+ tosaXml = minidom.parse(base_path / "thirdparty/specification/tosa.xml")
opsXml = tosaXml.getElementsByTagName("operator")
opNames = [
op.getElementsByTagName("name")[0].firstChild.data.lower() for op in opsXml
@@ -415,8 +425,11 @@ def getSerializeOpTypeMap():
if __name__ == "__main__":
- environment = Environment(loader=FileSystemLoader("templates/"))
- tosaXml = minidom.parse("../../thirdparty/specification/tosa.xml")
+ base_path = getBasePath()
+ environment = Environment(
+ loader=FileSystemLoader(Path(__file__).resolve().parent / "templates")
+ )
+ tosaXml = minidom.parse(str(base_path / "thirdparty/specification/tosa.xml"))
dataTypes = getTosaDataTypes(tosaXml)
operators = getOperators(tosaXml)
- generate(environment, dataTypes, operators)
+ generate(environment, dataTypes, operators, base_path)