From 4ce101797cfd9fbe287c0217c4da80ea76d51b74 Mon Sep 17 00:00:00 2001 From: Eric Kunze Date: Thu, 7 Sep 2023 01:36:07 +0000 Subject: Add check of operator API to precommit Attempt to avoid API getting out of sync. Signed-off-by: Eric Kunze Change-Id: Ic7b72c3f906e4a38cb26159bb67e9b1c4e22ca96 --- .pre-commit-config.yaml | 11 ++++++++++- reference_model/src/operators.cc | 2 +- scripts/operator_api/generate_api.py | 29 +++++++++++++++++++++-------- 3 files changed, 32 insertions(+), 10 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4a6c9c7..2c72a26 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -37,4 +37,13 @@ repos: language: system entry: clang-format types: ["c++"] - args: ["-i"] \ No newline at end of file + args: ["-i"] + +- repo: local + hooks: + - id: check-operator-api + name: check-operator-api + language: system + entry: python3 scripts/operator_api/generate_api.py + pass_filenames: false + always_run: true diff --git a/reference_model/src/operators.cc b/reference_model/src/operators.cc index 14065ad..ecebe52 100644 --- a/reference_model/src/operators.cc +++ b/reference_model/src/operators.cc @@ -2497,4 +2497,4 @@ extern "C" return tosa_status_valid; } -} // extern "C" +} // extern "C" \ No newline at end of file diff --git a/scripts/operator_api/generate_api.py b/scripts/operator_api/generate_api.py index 31ee151..afe12c1 100644 --- a/scripts/operator_api/generate_api.py +++ b/scripts/operator_api/generate_api.py @@ -4,6 +4,7 @@ import copy import os import subprocess +from pathlib import Path from xml.dom import minidom from jinja2 import Environment @@ -12,6 +13,10 @@ from jinja2 import FileSystemLoader # Note: main script designed to be run from the scripts/operator_api/ directory +def getBasePath(): + return Path(__file__).resolve().parent.parent.parent + + def getTosaArgTypes(tosaXml): """ Returns a list of the TOSA argument types from tosa.xml. @@ -334,7 +339,11 @@ def getSerialLibAtts(): The values are the arguments required by each Serialization library operator. """ serialLibAtts = {} - with open("../../thirdparty/serialization_lib/include/attribute.def") as file: + base_path = getBasePath() + attr_def = ( + base_path / "thirdparty" / "serialization_lib" / "include" / "attribute.def" + ) + with open(attr_def) as file: preamble = True inAtt = False opName = "" @@ -376,15 +385,15 @@ def renderTemplate(environment, dataTypes, operators, template, outfile): clangFormat(outfile) -def generate(environment, dataTypes, operators): +def generate(environment, dataTypes, operators, base_path): # Generate include/operators.h template = environment.get_template("operators_h.j2") - outfile = os.path.join("..", "..", "reference_model", "include", "operators.h") + outfile = base_path / "reference_model/include/operators.h" renderTemplate(environment, dataTypes, operators, template, outfile) # Generate src/operators.cc template = environment.get_template("operators_cc.j2") - outfile = os.path.join("..", "..", "reference_model", "src", "operators.cc") + outfile = base_path / "reference_model/src/operators.cc" renderTemplate(environment, dataTypes, operators, template, outfile) @@ -400,7 +409,8 @@ def getSerializeOpTypeMap(): for name in allSerialLibAtts.keys() ] serAtts = sorted(serAtts, key=len, reverse=True) - tosaXml = minidom.parse("../../thirdparty/specification/tosa.xml") + base_path = getBasePath() + tosaXml = minidom.parse(base_path / "thirdparty/specification/tosa.xml") opsXml = tosaXml.getElementsByTagName("operator") opNames = [ op.getElementsByTagName("name")[0].firstChild.data.lower() for op in opsXml @@ -415,8 +425,11 @@ def getSerializeOpTypeMap(): if __name__ == "__main__": - environment = Environment(loader=FileSystemLoader("templates/")) - tosaXml = minidom.parse("../../thirdparty/specification/tosa.xml") + base_path = getBasePath() + environment = Environment( + loader=FileSystemLoader(Path(__file__).resolve().parent / "templates") + ) + tosaXml = minidom.parse(str(base_path / "thirdparty/specification/tosa.xml")) dataTypes = getTosaDataTypes(tosaXml) operators = getOperators(tosaXml) - generate(environment, dataTypes, operators) + generate(environment, dataTypes, operators, base_path) -- cgit v1.2.1