aboutsummaryrefslogtreecommitdiff
path: root/scripts
diff options
context:
space:
mode:
authorDmitrii Agibov <dmitrii.agibov@arm.com>2024-01-02 15:41:01 +0000
committerEric Kunze <eric.kunze@arm.com>2024-01-05 19:29:29 +0000
commit2936f13d0e26c394333495ce909740eaf58a45cc (patch)
tree0b602d9389c93e1e1152b6abd18c66e8140f00a8 /scripts
parent54bb61effee583239d30ec6d4fda32c1a710050c (diff)
downloadreference_model-2936f13d0e26c394333495ce909740eaf58a45cc.tar.gz
Remove operators API
The operators API generated by the script is no longer used and could be removed from the project. Signed-off-by: Dmitrii Agibov <dmitrii.agibov@arm.com> Change-Id: Ia611b069463b3aded7d6546987c2323674184673
Diffstat (limited to 'scripts')
-rw-r--r--scripts/operator_api/README.md19
-rw-r--r--scripts/operator_api/generate_api.py413
-rw-r--r--scripts/operator_api/templates/operators_cc.j2248
-rw-r--r--scripts/operator_api/templates/operators_h.j251
4 files changed, 0 insertions, 731 deletions
diff --git a/scripts/operator_api/README.md b/scripts/operator_api/README.md
deleted file mode 100644
index 381d90c..0000000
--- a/scripts/operator_api/README.md
+++ /dev/null
@@ -1,19 +0,0 @@
-# Generate eager operator execution entrypoints
-
-## Introduction
-
-The generate_api.py script will generate an extended reference model API with eager operator execution entrypoints.
-The following files will be generated: include/operators.h and src/operators.cc
-
-## Requirements
-
-* Python 3.6 or later
-* Jinja2 (install with ```pip install Jinja2```)
-
-## Running from the command line
-
-The script can be run from the scripts/operator-api directory as follows:
-
-```bash
-python generate_api.py
-```
diff --git a/scripts/operator_api/generate_api.py b/scripts/operator_api/generate_api.py
deleted file mode 100644
index e511f19..0000000
--- a/scripts/operator_api/generate_api.py
+++ /dev/null
@@ -1,413 +0,0 @@
-"""Generate extended reference model API with eager operator execution entrypoints"""
-# Copyright (c) 2021-2023, ARM Limited.
-# SPDX-License-Identifier: Apache-2.0
-import copy
-import os
-import subprocess
-from pathlib import Path
-from xml.dom import minidom
-
-from jinja2 import Environment
-from jinja2 import FileSystemLoader
-
-
-def getBasePath():
- return Path(__file__).resolve().parent.parent.parent
-
-
-def getTosaArgTypes(tosaXml):
- """
- Returns a list of the TOSA argument types from tosa.xml.
- """
- argTypes = {
- "tensor_t",
- "in_t",
- "out_t",
- "mul_t",
- "weight_t",
- "in_out_t",
- "tensor_list_t",
- }
- argTypesXml = tosaXml.getElementsByTagName("type")
- for argTypeXml in argTypesXml:
- argTypes.add(argTypeXml.getAttribute("name"))
- argTypes.remove("TABLE_SIZE")
- return argTypes
-
-
-def getTosaDataTypes(tosaXml):
- """
- Returns a list of the TOSA data types from tosa.xml.
- """
- argTypes = getTosaArgTypes(tosaXml)
- dataTypes = set()
- dataTypesXml = tosaXml.getElementsByTagName("typesupport")
- for dataTypeXml in dataTypesXml:
- for argType in argTypes:
- dataType = dataTypeXml.getAttribute(argType)
- if dataType != "":
- dataTypes.add(f"tosa_datatype_{dataType}")
- return sorted(dataTypes)
-
-
-def getSerializeOpType(tosaOpName):
- """
- Returns the Serialization library operator that matches the TOSA operator specified.
- """
- map = {
- "avg_pool2d": "Pool",
- "conv2d": "Conv",
- "conv3d": "Conv",
- "depthwise_conv2d": "Conv",
- "fully_connected": "FullyConnected",
- "fft2d": "FFT",
- "rfft2d": "RFFT",
- "matmul": "MatMul",
- "max_pool2d": "Pool",
- "transpose_conv2d": "TransposeConv",
- "clamp": "Clamp",
- "arithmetic_right_shift": "ArithmeticRightShift",
- "mul": "Mul",
- "table": "Table",
- "negate": "Negate",
- "pad": "Pad",
- "reshape": "Reshape",
- "slice": "Slice",
- "tile": "Tile",
- "transpose": "Transpose",
- "resize": "Resize",
- "rescale": "Rescale",
- "cond_if": "CondIf",
- "while_loop": "WhileLoop",
- }
- return map.get(tosaOpName, "None")
-
-
-def getSerialLibAttsForOp(tosaOpName, allSerialLibAtts, tosaArgs):
- """
- Returns the attributes required by the Serialization library for the TOSA operator specified.
- Generates code to initialize Serialization library attributes. If a matching TOSA argument exists,
- that value is used for initialization, otherwise a default value e.g. 0 is used.
- """
- serLibOpType = getSerializeOpType(tosaOpName)
- if serLibOpType not in allSerialLibAtts.keys():
- return {}
- else:
- serLibOpAtts = copy.deepcopy(allSerialLibAtts[serLibOpType])
- tosaArgsDict = {arg["name"]: arg for arg in tosaArgs}
- serTosaTypeMap = {"ResizeMode": "tosa_mode"}
- serAttsToFix = {
- "reshape": {"new_shape": "shape"},
- "transpose_conv2d": {"output_shape": "out_shape"},
- }
- if tosaOpName in serAttsToFix:
- # Fix attributes names to match with tosa.xml
- for attDefName, tosaSpecName in serAttsToFix[tosaOpName].items():
- for opAtts in serLibOpAtts:
- if opAtts["name"] == attDefName:
- opAtts["name"] = tosaSpecName
- for att in serLibOpAtts:
- attName = att["name"]
- attType = att["dType"]
- init = ""
- # Translate TOSA data types to Serialization library data types for initialization
- if attType in serTosaTypeMap.keys():
- init = f"const {attType} {attName} = translate_client_{serTosaTypeMap[att['dType']]}(client_{attName});"
- # Initialize Serialization library attributes to their matching function parameter
- elif tosaOpName == "avg_pool2d" and attName == "accum_dtype":
- init = f"const tosa::DType {attName} = translate_client_acc_size(client_acc_size);"
- att["dType"] = "tosa::DType"
- elif attName in tosaArgsDict:
- if att["SV"] == "V":
- if tosaArgsDict[attName]["type"] == "tosa_tensor_t":
- init = f"std::vector<{attType}> {attName};"
- init = (
- init
- + f"size_t {attName}_size = client_{attName}.size / sizeof({attType});"
- )
- init = (
- init
- + f"{attType}* {attName}_data = reinterpret_cast<{attType}*>(client_{attName}.data);"
- )
- init = (
- init
- + f"{attName}.assign({attName}_data, {attName}_data + {attName}_size);"
- )
- else:
- init = f"const std::vector<{attType}> {attName}"
- shape = tosaArgsDict[attName]["shape"]
- if shape == "[]":
- init = (
- init
- + f"(&client_{attName}[0], &client_{attName}[0] + client_{attName}_len);"
- )
- else:
- init = (
- init
- + f"(&client_{attName}[0], &client_{attName}{shape});"
- )
- else:
- init = ""
- else:
- # Initialize Serialization library attributes with no matching fuction parameter
- if att["SV"] == "V":
- init = f"std::vector<int32_t> {attName};"
- else:
- if att["dType"] == "DType":
- att["dType"] = "tosa::DType"
- init = f"const tosa::DType {attName} = tosa::DType::DType_FP32;"
- else:
- init = f"const {attType} {attName} = 0;"
- att["init"] = init
- return serLibOpAtts
-
-
-def updateTosaArgs(tosaArgs, serialLibAtts, tosaXml):
- """
- Replace TOSA argument data types with their matching Serialization attribute data types.
- Delete TOSA arguments where the type couldn't be determined.
- Add Serialization attributes that have no matching TOSA argument.
- """
- tosaArgTypes = getTosaArgTypes(tosaXml)
- serAttsDict = {att["name"]: att for att in serialLibAtts}
- tosaArgsNames = [arg["name"] for arg in tosaArgs]
- delTosaArgs = []
- # Replace TOSA argument data types with their matching Serialization attribute data types.
- for tosaArg in tosaArgs:
- if tosaArg["type"] in tosaArgTypes:
- if tosaArg["name"] in serAttsDict:
- tosaArg["type"] = serAttsDict[tosaArg["name"]]["dType"]
- else:
- # Delete TOSA argument whose data type can't be determined
- delTosaArgs.append(tosaArgsNames.index(tosaArg["name"]))
- # Delete corresponding length argument if one exists
- lenArgName = f"{tosaArg['name']}_len"
- if lenArgName in tosaArgsNames:
- delTosaArgs.append(tosaArgsNames.index(lenArgName))
- # Delete TOSA arguments where the type couldn't be determined
- for index in sorted(delTosaArgs, key=int, reverse=True):
- del tosaArgs[index]
- # Add Serialization attributes that have no matching TOSA argument
- tosaArgNames = [arg["name"] for arg in tosaArgs]
- for serAtt in serialLibAtts:
- attName = serAtt["name"]
- attType = serAtt["dType"]
- if (attName not in tosaArgNames) and (not attType == "tosa::DType"):
- serAttName = serAtt["name"]
- if serAtt["SV"] == "V":
- # For vector data types, insert a matching length argument
- tosaArgs.insert(
- len(tosaArgs) - 1,
- {
- "name": f"{serAttName}_len",
- "type": "int32_t",
- "shape": "",
- "category": "",
- },
- )
- init = f"const std::vector<{attType}> {attName}(&client_{serAttName}[0], &client_{serAttName}[0] + client_{serAttName}_len);"
- shape = "[]"
- else:
- init = ""
- shape = ""
- serAtt["init"] = init
- # Insert new argument
- tosaArgs.insert(
- len(tosaArgs) - 1,
- {
- "name": serAttName,
- "type": serAtt["dType"],
- "shape": shape,
- "category": "",
- },
- )
-
-
-def getOperators(tosaXml):
- """
- Return a list of TOSA operators as defined by tosa.xml.
- """
- operators = []
- ignoreOps = [
- "while_loop",
- "cond_if",
- "const",
- "custom",
- "variable",
- "variable_read",
- "variable_write",
- ]
- opsXml = tosaXml.getElementsByTagName("operator")
- allSerialLibAtts = getSerialLibAtts()
- for opXml in opsXml:
- opName = opXml.getElementsByTagName("name")[0].firstChild.data.lower()
- if opName not in ignoreOps:
- operator = {"name": opName}
- operator["serializeAttType"] = getSerializeOpType(opName)
- tosaArgs = getTosaArgs(opXml)
- serialLibAtts = getSerialLibAttsForOp(opName, allSerialLibAtts, tosaArgs)
- # Handle "axis" arguments
- axisList = [arg["name"] for arg in tosaArgs if arg["name"] == "axis"]
- if operator["serializeAttType"] == "None" and len(axisList) > 0:
- operator["serializeAttType"] = "Axis"
- serialLibAtts = [
- {
- "name": "axis",
- "dType": "int32_t",
- "SV": "S",
- "init": "",
- }
- ]
- updateTosaArgs(tosaArgs, serialLibAtts, tosaXml)
- operator["arguments"] = tosaArgs
- operator["serialLibAtts"] = serialLibAtts
- serializationAttNames = [att["name"] for att in serialLibAtts]
- operator["inputs"] = [
- {"name": arg["name"], "type": arg["type"]}
- for arg in tosaArgs
- if arg["category"] == "input"
- and arg["name"] not in serializationAttNames
- ]
- operator["outputs"] = [
- arg["name"] for arg in tosaArgs if arg["category"] == "output"
- ]
- operators.append(operator)
- return operators
-
-
-def getTosaArgs(opXml):
- """
- Return the arguments required for the TOSA operator specified.
- """
- arguments = []
- argsXml = opXml.getElementsByTagName("argument")
- tosaTensorTypes = getTosaArgTypes(tosaXml)
- tosaTypeMap = {"bool_t": "bool", "uint6_t": "uint8_t", "mode_t": "tosa_mode_t"}
- tensorElemTypeMap = {
- "resize_mode_t": "tosa_mode_t",
- "acc_size_t": "tosa_acc_size_t",
- }
- for xmlArg in argsXml:
- argName = xmlArg.getAttribute("name").lower()
- tensorElemType = xmlArg.getAttribute("tensor-element-type")
- if tensorElemType in tensorElemTypeMap:
- argType = tensorElemTypeMap[tensorElemType]
- else:
- argType = xmlArg.getAttribute("type")
- argShape = xmlArg.getAttribute("shape")
- argCategory = xmlArg.getAttribute("category")
- # FullyConnected workaround
- if (argName == "weight" or argName == "bias") and (argCategory == "attribute"):
- argCategory = "input"
- # Update argument type
- if argType[-1:] == "*":
- argType = argType[:-1]
- if argCategory in ["input", "output"] and argType in tosaTensorTypes:
- argType = f"tosa_{argType}"
- argShape = ""
- if argType in tosaTypeMap:
- argType = tosaTypeMap[argType]
- # Add a length argument for arrays with unknown compile-time size
- if argShape != "" and argShape[0] == "[" and not argShape[1:-1].isnumeric():
- argShape = "[]"
- arguments.append(
- {
- "name": f"{argName}_len",
- "type": "int32_t",
- "shape": "",
- "category": "",
- }
- )
- elif argShape == "" or not argShape[0] == "[":
- argShape = ""
- # Append argument
- arguments.append(
- {
- "name": argName,
- "type": argType,
- "shape": argShape,
- "category": argCategory,
- }
- )
- return arguments
-
-
-def clangFormat(filename):
- cmd = ["clang-format", "-i", filename]
- with open(os.devnull, "w") as devnull:
- subprocess.check_call(cmd, stdout=devnull)
-
-
-def getSerialLibAtts():
- """
- Parse attribute.def file and return a dictionary where the keys are Serialization library operator names.
- The values are the arguments required by each Serialization library operator.
- """
- serialLibAtts = {}
- base_path = getBasePath()
- attr_def = (
- base_path / "thirdparty" / "serialization_lib" / "include" / "attribute.def"
- )
- with open(attr_def) as file:
- preamble = True
- inAtt = False
- opName = ""
- args = []
- for line in file:
- if preamble and not line[: len("DEF_ATTRIBUTE(")] == "DEF_ATTRIBUTE(":
- continue
- else:
- preamble = False
- line = line.lstrip().rstrip()
- if not inAtt and "DEF_ATTRIBUTE(" in line:
- opName = line[len("DEF_ATTRIBUTE(") : line.find(",")]
- inAtt = True
- elif inAtt:
- vals = line.split(",")
- argName = vals[2].lstrip().strip()
- if ")" in argName:
- argName = argName[:-1]
- arg = {
- "name": argName,
- "dType": vals[0].lstrip().strip(),
- "SV": vals[1].lstrip().strip(),
- }
- args.append(arg)
- if ")" in line:
- serialLibAtts[opName] = args
- opName = ""
- args = []
- inAtt = False
- return serialLibAtts
-
-
-def renderTemplate(dataTypes, operators, template, outfile):
- content = template.render(dataTypes=dataTypes, operators=operators)
- with open(outfile, mode="w", encoding="utf-8") as output:
- output.write(content)
- print(f"Created {outfile}")
-
- clangFormat(outfile)
-
-
-def generate(environment, dataTypes, operators, base_path):
- # Generate include/operators.h
- template = environment.get_template("operators_h.j2")
- outfile = base_path / "reference_model/include/operators.h"
- renderTemplate(dataTypes, operators, template, outfile)
-
- # Generate src/operators.cc
- template = environment.get_template("operators_cc.j2")
- outfile = base_path / "reference_model/src/operators.cc"
- renderTemplate(dataTypes, operators, template, outfile)
-
-
-if __name__ == "__main__":
- base_path = getBasePath()
- environment = Environment(
- loader=FileSystemLoader(Path(__file__).resolve().parent / "templates")
- )
- tosaXml = minidom.parse(str(base_path / "thirdparty/specification/tosa.xml"))
- dataTypes = getTosaDataTypes(tosaXml)
- operators = getOperators(tosaXml)
- generate(environment, dataTypes, operators, base_path)
diff --git a/scripts/operator_api/templates/operators_cc.j2 b/scripts/operator_api/templates/operators_cc.j2
deleted file mode 100644
index 0fc52ab..0000000
--- a/scripts/operator_api/templates/operators_cc.j2
+++ /dev/null
@@ -1,248 +0,0 @@
-
-// Copyright (c) 2022-2023, ARM Limited.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// THIS FILE IS GENERATED. DO NOT EDIT!
-// See scripts/operator_api/generate_api.py
-
-#include "operators.h"
-#include "model_runner_impl.h"
-#include "ops/op_factory.h"
-
-
-#define TOSA_PROPAGATE_ERROR(status) \
- do \
- { \
- if (status != 0) \
- { \
- return status; \
- } \
- } while (false)
-
-
-#define TOSA_RETURN_ON_ERROR(status) \
- do \
- { \
- if (status != 0) \
- { \
- return tosa_status_error; \
- } \
- } while (false)
-
-#define TOSA_RETURN_ON_GRAPH_STATUS_ERROR(status) \
- do \
- { \
- if (status != GraphStatus::TOSA_VALID) \
- { \
- auto ustatus = static_cast<std::underlying_type_t<GraphStatus>>(status); \
- return static_cast<tosa_status_t>(ustatus); \
- } \
- } while (false)
-
-namespace {
-
-tosa::DType translate_client_datatype(tosa_datatype_t type)
-{
- switch (type)
- {
- case tosa_datatype_bf16_t:
- return tosa::DType::DType_BF16;
- case tosa_datatype_bool_t:
- return tosa::DType::DType_BOOL;
- case tosa_datatype_fp16_t:
- return tosa::DType::DType_FP16;
- case tosa_datatype_fp32_t:
- return tosa::DType::DType_FP32;
- case tosa_datatype_int16_t:
- return tosa::DType::DType_INT16;
- case tosa_datatype_int32_t:
- return tosa::DType::DType_INT32;
- case tosa_datatype_int48_t:
- return tosa::DType::DType_INT48;
- case tosa_datatype_int4_t:
- return tosa::DType::DType_INT4;
- case tosa_datatype_int8_t:
- return tosa::DType::DType_INT8;
- case tosa_datatype_uint16_t:
- return tosa::DType::DType_UINT16;
- case tosa_datatype_uint8_t:
- return tosa::DType::DType_UINT8;
- case tosa_datatype_shape_t:
- return tosa::DType::DType_SHAPE;
- default:
- return tosa::DType::DType_UNKNOWN;
- }
-};
-
-using TosaTensorInfo = std::pair<tosa::TosaSerializationTensor*, tosa_tensor_t*>;
-
-tosa::TosaSerializationTensor* translate_client_tensor(tosa_tensor_t& tensor, const std::string& name)
-{
- std::vector<int32_t> shape(tensor.shape, tensor.shape + tensor.num_dims);
- return new tosa::TosaSerializationTensor(name, shape, translate_client_datatype(tensor.data_type), {});
-}
-
-void addTensor(std::vector<TosaTensorInfo> &tensors, tosa_tensor_t& tensor, std::string tensorName) {
- auto tensorDescr = translate_client_tensor(tensor, tensorName);
- tensors.push_back(std::make_pair(tensorDescr, &tensor));
-}
-
-int setInputTensors(TosaReference::ModelRunnerImpl& runner, std::vector<TosaTensorInfo>& inputTensors)
-{
- for (const auto& [tensorDescr, tensorData] : inputTensors)
- {
- auto status = runner.setInput(tensorDescr->GetName(), tensorData->data, tensorData->size);
- TOSA_PROPAGATE_ERROR(status);
- }
-
- return 0;
-}
-
-int getOutputTensors(TosaReference::ModelRunnerImpl& runner, std::vector<TosaTensorInfo>& outputTensors)
-{
- for (const auto& [tensorDescr, tensorData] : outputTensors)
- {
- auto status = runner.getOutput(tensorDescr->GetName(), tensorData->data, tensorData->size);
- TOSA_PROPAGATE_ERROR(status);
- }
-
- return 0;
-}
-
-std::vector<std::string> getTensorNames(std::vector<TosaTensorInfo>& tensors)
-{
- std::vector<std::string> tensorNames;
- const auto mapping = [](const TosaTensorInfo &info){ return info.first->GetName(); };
-
- std::transform(tensors.cbegin(), tensors.cend(), std::back_inserter(tensorNames), mapping);
- return tensorNames;
-}
-
-std::vector<TosaSerializationTensor*> allTensors(std::vector<TosaTensorInfo> &inputTensors, std::vector<TosaTensorInfo> &outputTensors) {
- std::vector<TosaSerializationTensor*> result;
- const auto mapping = [](const TosaTensorInfo &info){ return info.first; };
-
- std::transform(inputTensors.cbegin(), inputTensors.cend(), std::back_inserter(result), mapping);
- std::transform(outputTensors.cbegin(), outputTensors.cend(), std::back_inserter(result), mapping);
-
- return result;
-}
-
-tosa::ResizeMode translate_client_tosa_mode(tosa_mode_t mode) {
- switch(mode) {
- case tosa_mode_nearest:
- return tosa::ResizeMode_NEAREST;
- case tosa_mode_max:
- case tosa_mode_bilinear:
- return tosa::ResizeMode_BILINEAR;
- default:
- return tosa::ResizeMode_UNKNOWN;
- }
-}
-
-tosa::DType translate_client_acc_size(tosa_acc_size_t acc_size) {
- switch(acc_size) {
- case tosa_acc_size_int32_t:
- return tosa::DType::DType_INT32;
- case tosa_acc_size_fp16_t:
- return tosa::DType::DType_FP16;
- case tosa_acc_size_fp32_t:
- return tosa::DType::DType_FP32;
- default:
- return tosa::DType::DType_UNKNOWN;
- }
-}
-
-} // namespace
-
-extern "C"
-{
- {% for operator in operators: %}
- tosa_status_t tosa_run_{{ operator.name }} (
- {%- for arg in operator.arguments: -%}
- {% if arg.type != "tosa_tensor_t" -%}const {% endif -%}{{arg.type}} client_{{arg.name}}{{arg.shape}}
- {% if not loop.last %},{% endif %}
- {%- endfor -%},const func_ctx_t& func_ctx
- )
- {
- // Create operator attributes
- {% for att in operator.serialLibAtts: -%}
- {{att.init}}
- {%- endfor -%}
-
- Tosa{{operator.serializeAttType}}Attribute attr
- {%- if operator.serialLibAtts|length > 0 -%}
- (
- {%- for att in operator.serialLibAtts: -%}
- {%- if att.init == "" -%}
- client_{{att.name}}
- {%- else -%}
- {{att.name}}
- {%- endif -%}
- {% if not loop.last %}, {% endif %}
- {%- endfor -%}
- )
- {%- endif -%};
-
- // Create tensors
- std::vector<TosaTensorInfo> inputTensors;
- {% for input in operator.inputs: -%}
- {%- if input.type == "tosa_tensor_list_t" -%}
- for (int i = 0; i < client_{{input.name}}.size; i++) {
- addTensor(inputTensors, client_{{input.name}}.tensors[i], "{{input.name}}-" + std::to_string(i));
- }
- {%- else -%}
- addTensor(inputTensors, client_{{input.name}}, "{{input.name}}");
- {%- endif -%}
- {%- endfor %}
-
- std::vector<TosaTensorInfo> outputTensors;
- {% for output in operator.outputs: -%}
- addTensor(outputTensors, client_{{output}}, "{{output}}");
- {%- endfor %}
-
- // Create operator
- auto op = new tosa::TosaSerializationOperator(tosa::Op::Op_{{operator.name|upper}},
- {%- if operator.serializeAttType != "None" -%}
- tosa::Attribute::Attribute_{{operator.serializeAttType}}Attribute
- {%- else -%}
- tosa::Attribute::Attribute_NONE
- {%- endif -%},
- &attr,
- getTensorNames(inputTensors),
- getTensorNames(outputTensors));
-
- // Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("{{operator.name}}", "main", { op },
- allTensors(inputTensors, outputTensors),
- op->GetInputTensorNames(),
- op->GetOutputTensorNames());
-
- // Setup model
- TosaReference::ModelRunnerImpl runner(func_ctx.func_config, func_ctx.func_debug);
- TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block));
-
- TOSA_RETURN_ON_ERROR(setInputTensors(runner, inputTensors));
-
- // Execute
- TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.run());
-
- // Extract outputs
- TOSA_RETURN_ON_ERROR(getOutputTensors(runner, outputTensors));
-
- return tosa_status_valid;
- }
- {% endfor %}
-
-} // extern "C" \ No newline at end of file
diff --git a/scripts/operator_api/templates/operators_h.j2 b/scripts/operator_api/templates/operators_h.j2
deleted file mode 100644
index 0c98da8..0000000
--- a/scripts/operator_api/templates/operators_h.j2
+++ /dev/null
@@ -1,51 +0,0 @@
-
-// Copyright (c) 2022-2023, ARM Limited.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// THIS FILE IS GENERATED. DO NOT EDIT!
-// See scripts/operator_api/generate_api.py
-
-#ifndef OPERATORS_H_
-#define OPERATORS_H_
-
-#include "func_config.h"
-#include "func_debug.h"
-#include "types.h"
-
-#include <stddef.h>
-#include <stdint.h>
-
-#ifdef __cplusplus
-extern "C" {
-#endif /* __cplusplus */
-
- struct func_ctx_t
- {
- func_config_t func_config = func_config_t{};
- func_debug_t func_debug = func_debug_t{};
- };
-
- {% for operator in operators: %}
- tosa_status_t tosa_run_{{ operator.name }} (
- {%- for arg in operator.arguments: -%}
- {% if arg.type != "tosa_tensor_t" -%}const {% endif -%}{{arg.type}} client_{{arg.name}}{{arg.shape}}
- {% if loop.index < operator.arguments|length %},{% endif %}
- {%- endfor -%},const func_ctx_t& func_ctx);
- {% endfor %}
-
-#ifdef __cplusplus
-}
-#endif /* __cplusplus */
-
-#endif // OPERATORS_H_ \ No newline at end of file