aboutsummaryrefslogtreecommitdiff
path: root/scripts
diff options
context:
space:
mode:
authorGrant Watson <grant.watson@arm.com>2022-11-16 15:32:39 +0000
committerEric Kunze <eric.kunze@arm.com>2022-12-15 16:41:27 +0000
commit64285a1f25e2c7b85ed1f00b7947403e92baea00 (patch)
tree6d29c54f6497741449339e808508c854ba6a2267 /scripts
parentb45db9a696f5df7b233f374248f329c16ee7ae64 (diff)
downloadreference_model-64285a1f25e2c7b85ed1f00b7947403e92baea00.tar.gz
Extend reference model API with eager operator execution entrypoints
- Adds a script to generate operators.h and operators.cc - Adds jinja2 templates for generating operators.h and operators.cc - Adds unit tests for a subset of the operators generated - Includes the TOSA specification as a submodule - Adds supporting C++ and header files Signed-off-by: Grant Watson <grant.watson@arm.com> Change-Id: I5b60db4c56113110d8e75fe1152525d258233f9c
Diffstat (limited to 'scripts')
-rw-r--r--scripts/operator_api/README.md19
-rw-r--r--scripts/operator_api/generate_api.py349
-rw-r--r--scripts/operator_api/templates/operators_cc.j2176
-rw-r--r--scripts/operator_api/templates/operators_h.j274
4 files changed, 618 insertions, 0 deletions
diff --git a/scripts/operator_api/README.md b/scripts/operator_api/README.md
new file mode 100644
index 0000000..381d90c
--- /dev/null
+++ b/scripts/operator_api/README.md
@@ -0,0 +1,19 @@
+# Generate eager operator execution entrypoints
+
+## Introduction
+
+The generate_api.py script will generate an extended reference model API with eager operator execution entrypoints.
+The following files will be generated: include/operators.h and src/operators.cc
+
+## Requirements
+
+* Python 3.6 or later
+* Jinja2 (install with ```pip install Jinja2```)
+
+## Running from the command line
+
+The script can be run from the scripts/operator-api directory as follows:
+
+```bash
+python generate_api.py
+```
diff --git a/scripts/operator_api/generate_api.py b/scripts/operator_api/generate_api.py
new file mode 100644
index 0000000..1f89f74
--- /dev/null
+++ b/scripts/operator_api/generate_api.py
@@ -0,0 +1,349 @@
+"""Generate extended reference model API with eager operator execution entrypoints"""
+# Copyright (c) 2021-2022, ARM Limited.
+# SPDX-License-Identifier: Apache-2.0
+import copy
+import os
+import subprocess
+from xml.dom import minidom
+
+from jinja2 import Environment
+from jinja2 import FileSystemLoader
+
+
+def getTosaArgTypes(tosaXml):
+ """
+ Returns a list of the TOSA argument types from tosa.xml.
+ """
+ argTypes = {"in_t", "out_t", "mul_t", "weight_t", "in_out_t"}
+ argTypesXml = tosaXml.getElementsByTagName("type")
+ for argTypeXml in argTypesXml:
+ argTypes.add(argTypeXml.getAttribute("name"))
+ argTypes.remove("TABLE_SIZE")
+ return argTypes
+
+
+def getTosaDataTypes(tosaXml):
+ """
+ Returns a list of the TOSA data types from tosa.xml.
+ """
+ argTypes = getTosaArgTypes(tosaXml)
+ dataTypes = set()
+ dataTypesXml = tosaXml.getElementsByTagName("typesupport")
+ for dataTypeXml in dataTypesXml:
+ for argType in argTypes:
+ dataType = dataTypeXml.getAttribute(argType)
+ if dataType != "":
+ dataTypes.add(f"tosa_datatype_{dataType}")
+ return sorted(dataTypes)
+
+
+def getSerializeOpType(tosaOpName):
+ """
+ Returns the Serialization library operator that matches the TOSA operator specified.
+ """
+ map = {
+ "avg_pool2d": "Pool",
+ "conv2d": "Conv",
+ "conv3d": "Conv",
+ "depthwise_conv2d": "Conv",
+ "fully_connected": "FullyConnected",
+ "matmul": "MatMul",
+ "max_pool2d": "Pool",
+ "transpose_conv2d": "Conv",
+ "clamp": "Clamp",
+ "arithmetic_right_shift": "ArithmeticRightShift",
+ "mul": "Mul",
+ "table": "Table",
+ "negate": "Negate",
+ "pad": "Pad",
+ "reshape": "Reshape",
+ "slice": "Slice",
+ "tile": "Tile",
+ "transpose": "Transpose",
+ "resize": "Resize",
+ "rescale": "Rescale",
+ "cond_if": "CondIf",
+ "while_loop": "WhileLoop",
+ }
+ if tosaOpName not in map.keys():
+ return "None"
+ else:
+ return map[tosaOpName]
+
+
+def getSerializeArgsForOp(tosaOpName, allSerializeArgs, tosaArgs):
+ """
+ Returns the arguments required by the Serialization library for the TOSA operator specified.
+ Generates code to initialize Serialization arguments. If a matching TOSA argument exists,
+ that value is used for initialization, otherwise a default value e.g. 0 is used.
+ """
+ serOpType = getSerializeOpType(tosaOpName)
+ if serOpType not in allSerializeArgs.keys():
+ return {}
+ else:
+ serOpArgs = copy.deepcopy(allSerializeArgs[serOpType])
+ tosaArgsDict = {arg["name"]: arg for arg in tosaArgs}
+ serTosaTypeMap = {"ResizeMode": "tosa_mode"}
+ for arg in serOpArgs:
+ argName = arg["name"]
+ init = ""
+ # Translate TOSA data types to Serialization data types for initialization
+ if arg["dType"] in serTosaTypeMap.keys():
+ init = f" = translate_client_{serTosaTypeMap[arg['dType']]}(client_{argName})"
+ # Initialize Serialization arguments to their matching function parameter
+ elif argName in tosaArgsDict:
+ if arg["SV"] == "V":
+ shape = tosaArgsDict[argName]["shape"]
+ if shape == "[]":
+ init = f"(&client_{argName}[0], &client_{argName}[0] + client_{argName}_len)"
+ else:
+ init = f"(&client_{argName}[0], &client_{argName}{shape})"
+ else:
+ init = f" = client_{argName}"
+ else:
+ # Initialize Serialization arguments with no matching fuction parameter
+ if arg["SV"] == "V":
+ init = ""
+ else:
+ if arg["dType"] == "DType":
+ arg["dType"] = "tosa::DType"
+ init = " = tosa::DType::DType_FP32"
+ else:
+ init = " = 0"
+ arg["init"] = init
+ return serOpArgs
+
+
+def updateTosaArgs(tosaArgs, serializeArgs, tosaXml):
+ """
+ Replace TOSA argument data types with their matching Serialization argument data types.
+ Delete TOSA arguments where the type couldn't be determined.
+ Add Serialization arguments that have no matching TOSA argument.
+ """
+ tosaArgTypes = getTosaArgTypes(tosaXml)
+ serArgsDict = {arg["name"]: arg for arg in serializeArgs}
+ tosaArgsNames = [arg["name"] for arg in tosaArgs]
+ delTosaArgs = []
+ # Replace TOSA argument data types with their matching Serialization argument data types.
+ for tosaArg in tosaArgs:
+ if tosaArg["type"] in tosaArgTypes:
+ if tosaArg["name"] in serArgsDict:
+ tosaArg["type"] = serArgsDict[tosaArg["name"]]["dType"]
+ else:
+ # Delete TOSA argument whose data type can't be determined
+ delTosaArgs.append(tosaArgsNames.index(tosaArg["name"]))
+ # Delete corresponding length argument if one exists
+ lenArgName = f"{tosaArg['name']}_len"
+ if lenArgName in tosaArgsNames:
+ delTosaArgs.append(tosaArgsNames.index(lenArgName))
+ # Delete TOSA arguments where the type couldn't be determined
+ for index in sorted(delTosaArgs, key=int, reverse=True):
+ del tosaArgs[index]
+ # Add Serialization arguments that have no matching TOSA argument
+ tosaArgNames = [arg["name"] for arg in tosaArgs]
+ for serArg in serializeArgs:
+ if (serArg["name"] not in tosaArgNames) and (
+ not serArg["dType"] == "tosa::DType"
+ ):
+ serArgName = serArg["name"]
+ if serArg["SV"] == "V":
+ # For vector data types, insert a matching length argument
+ tosaArgs.insert(
+ len(tosaArgs) - 1,
+ {
+ "name": f"{serArgName}_len",
+ "type": "int32_t",
+ "shape": "",
+ "category": "",
+ },
+ )
+ init = f"(&client_{serArgName}[0], &client_{serArgName}[0] + client_{serArgName}_len)"
+ shape = "[]"
+ else:
+ init = f" = client_{serArg['name']}"
+ shape = ""
+ serArg["init"] = init
+ # Insert new argument
+ tosaArgs.insert(
+ len(tosaArgs) - 1,
+ {
+ "name": serArgName,
+ "type": serArg["dType"],
+ "shape": shape,
+ "category": "",
+ },
+ )
+
+
+def getOperators(tosaXml):
+ """
+ Return a list of TOSA operators as defined by tosa.xml.
+ """
+ operators = []
+ ignoreOps = ["while_loop", "cond_if", "const", "custom", "fft2d", "rfft2d"]
+ opsXml = tosaXml.getElementsByTagName("operator")
+ allSerializeArgs = getSerializeArgs()
+ for opXml in opsXml:
+ opName = opXml.getElementsByTagName("name")[0].firstChild.data.lower()
+ if opName not in ignoreOps:
+ operator = {"name": opName}
+ operator["serializeAttType"] = getSerializeOpType(opName)
+ tosaArgs = getTosaArgs(opXml)
+ serializeArgs = getSerializeArgsForOp(opName, allSerializeArgs, tosaArgs)
+ updateTosaArgs(tosaArgs, serializeArgs, tosaXml)
+ operator["arguments"] = tosaArgs
+ operator["serializeArgs"] = serializeArgs
+ operator["inputs"] = [
+ arg["name"] for arg in tosaArgs if arg["category"] == "input"
+ ]
+ operator["outputs"] = [
+ arg["name"] for arg in tosaArgs if arg["category"] == "output"
+ ]
+ operators.append(operator)
+ return operators
+
+
+def getTosaArgs(opXml):
+ """
+ Return the arguments required for the TOSA operator specified.
+ """
+ arguments = []
+ argsXml = opXml.getElementsByTagName("argument")
+ tosaTensorTypes = getTosaArgTypes(tosaXml)
+ tosaTypeMap = {"bool_t": "bool", "uint6_t": "uint8_t", "mode_t": "tosa_mode_t"}
+ for xmlArg in argsXml:
+ argName = xmlArg.getAttribute("name").lower()
+ argType = xmlArg.getAttribute("type")
+ argShape = xmlArg.getAttribute("shape")
+ argCategory = xmlArg.getAttribute("category")
+ # Update argument type
+ if argType[-1:] == "*":
+ argType = argType[:-1]
+ if argCategory in ["input", "output"] and argType in tosaTensorTypes:
+ argType = "tosa_tensor_t"
+ argShape = ""
+ if argType in tosaTypeMap:
+ argType = tosaTypeMap[argType]
+ # Add a length argument for arrays with unknown compile-time size
+ if argShape != "" and argShape[0] == "[" and not argShape[1:-1].isnumeric():
+ argShape = "[]"
+ arguments.append(
+ {
+ "name": f"{argName}_len",
+ "type": "int32_t",
+ "shape": "",
+ "category": "",
+ }
+ )
+ elif argShape == "" or not argShape[0] == "[":
+ argShape = ""
+ # Append argument
+ arguments.append(
+ {
+ "name": argName,
+ "type": argType,
+ "shape": argShape,
+ "category": argCategory,
+ }
+ )
+ return arguments
+
+
+def clangFormat(filename):
+ cmd = ["clang-format", "-i", filename]
+ with open(os.devnull, "w") as devnull:
+ subprocess.check_call(cmd, stdout=devnull)
+
+
+def getSerializeArgs():
+ """
+ Parse attribute.def file and return a dictionary where the keys are Serialization library operator names.
+ The values are the arguments required by each Serialization library operator.
+ """
+ serializeArgs = {}
+ with open("../../thirdparty/serialization_lib/include/attribute.def") as file:
+ preamble = True
+ inAtt = False
+ opName = ""
+ args = []
+ for line in file:
+ if preamble and not line[: len("DEF_ATTRIBUTE(")] == "DEF_ATTRIBUTE(":
+ continue
+ else:
+ preamble = False
+ line = line.lstrip().rstrip()
+ if not inAtt and "DEF_ATTRIBUTE(" in line:
+ opName = line[len("DEF_ATTRIBUTE(") : line.find(",")]
+ inAtt = True
+ elif inAtt:
+ vals = line.split(",")
+ argName = vals[2].lstrip().strip()
+ if ")" in argName:
+ argName = argName[:-1]
+ arg = {
+ "name": argName,
+ "dType": vals[0].lstrip().strip(),
+ "SV": vals[1].lstrip().strip(),
+ }
+ args.append(arg)
+ if ")" in line:
+ serializeArgs[opName] = args
+ opName = ""
+ args = []
+ inAtt = False
+ return serializeArgs
+
+
+def renderTemplate(environment, dataTypes, operators, template, outfile):
+ content = template.render(dataTypes=dataTypes, operators=operators)
+ with open(outfile, mode="w", encoding="utf-8") as output:
+ output.write(content)
+ print(f"Created {outfile}")
+
+ clangFormat(outfile)
+
+
+def generate(environment, dataTypes, operators):
+ # Generate include/operators.h
+ template = environment.get_template("operators_h.j2")
+ outfile = os.path.join("..", "..", "reference_model", "include", "operators.h")
+ renderTemplate(environment, dataTypes, operators, template, outfile)
+
+ # Generate src/operators.cc
+ template = environment.get_template("operators_cc.j2")
+ outfile = os.path.join("..", "..", "reference_model", "src", "operators.cc")
+ renderTemplate(environment, dataTypes, operators, template, outfile)
+
+
+def getSerializeOpTypeMap():
+ """
+ Utility function for generating the map used in getSerializeOpType()
+ """
+ import re
+
+ allSerializeArgs = getSerializeArgs()
+ serArgs = [
+ re.sub(r"(?<!^)(?=[A-Z])", "_", name).lower()
+ for name in allSerializeArgs.keys()
+ ]
+ serArgs = sorted(serArgs, key=len, reverse=True)
+ tosaXml = minidom.parse("../../thirdparty/specification/tosa.xml")
+ opsXml = tosaXml.getElementsByTagName("operator")
+ opNames = [
+ op.getElementsByTagName("name")[0].firstChild.data.lower() for op in opsXml
+ ]
+ map = {}
+ for opName in opNames:
+ for serArg in serArgs:
+ if serArg in opName:
+ components = serArg.split("_")
+ map[opName] = "".join(x.title() for x in components)
+ return map
+
+
+if __name__ == "__main__":
+ environment = Environment(loader=FileSystemLoader("templates/"))
+ tosaXml = minidom.parse("../../thirdparty/specification/tosa.xml")
+ dataTypes = getTosaDataTypes(tosaXml)
+ operators = getOperators(tosaXml)
+ generate(environment, dataTypes, operators)
diff --git a/scripts/operator_api/templates/operators_cc.j2 b/scripts/operator_api/templates/operators_cc.j2
new file mode 100644
index 0000000..6b0ed6e
--- /dev/null
+++ b/scripts/operator_api/templates/operators_cc.j2
@@ -0,0 +1,176 @@
+
+// Copyright (c) 2022, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// THIS FILE IS GENERATED. DO NOT EDIT!
+// See scripts/operator_api/generate_api.py
+
+#include "operators.h"
+#include "model_runner_impl.h"
+#include "ops/op_factory.h"
+
+#define TOSA_RETURN_ON_ERROR(status) \
+ do \
+ { \
+ if (status != 0) \
+ { \
+ return tosa_status_error; \
+ } \
+ } while (false)
+
+#define TOSA_RETURN_ON_GRAPH_STATUS_ERROR(status) \
+ do \
+ { \
+ if (status != GraphStatus::TOSA_VALID) \
+ { \
+ auto ustatus = static_cast<std::underlying_type_t<GraphStatus>>(status); \
+ return static_cast<tosa_status_t>(ustatus); \
+ } \
+ } while (false)
+
+namespace {
+
+tosa::DType translate_client_datatype(tosa_datatype_t type)
+{
+ switch (type)
+ {
+ case tosa_datatype_fp16_t:
+ return tosa::DType::DType_FP16;
+ case tosa_datatype_fp32_t:
+ return tosa::DType::DType_FP32;
+ default:
+ return tosa::DType::DType_UNKNOWN;
+ }
+};
+
+tosa::TosaSerializationTensor* translate_client_tensor(tosa_tensor_t& tensor, const std::string& name)
+{
+ std::vector<int32_t> shape(tensor.shape, tensor.shape + tensor.num_dims);
+ return new tosa::TosaSerializationTensor(name, shape, translate_client_datatype(tensor.data_type), {});
+}
+
+tosa::ResizeMode translate_client_tosa_mode(tosa_mode_t mode) {
+ switch(mode) {
+ case tosa_mode_nearest:
+ return tosa::ResizeMode_NEAREST;
+ case tosa_mode_max:
+ case tosa_mode_bilinear:
+ return tosa::ResizeMode_BILINEAR;
+ default:
+ return tosa::ResizeMode_UNKNOWN;
+ }
+}
+
+} // namespace
+
+extern "C"
+{
+ {% for operator in operators: %}
+ tosa_status_t tosa_run_{{ operator.name }} (
+ {%- for arg in operator.arguments: -%}
+ {% if arg.type != "tosa_tensor_t" -%}const {% endif -%}{{arg.type}} client_{{arg.name}}{{arg.shape}}
+ {% if loop.index < operator.arguments|length %},{% endif %}
+ {%- endfor -%}
+ )
+ {
+ // Create operator attributes
+ {% for arg in operator.serializeArgs: %}
+ {%- if arg.SV == "V": -%}
+ const std::vector<{{arg.dType}}> {{arg.name}}{{arg.init}};
+ {%- else: -%}
+ const {{arg.dType}} {{arg.name}}{{arg.init}};
+ {%- endif -%}
+ {%- endfor -%}
+
+ Tosa{{operator.serializeAttType}}Attribute attr
+ {%- if operator.serializeArgs|length > 0 -%}
+ (
+ {%- for arg in operator.serializeArgs: -%}
+ {{arg.name}}{% if loop.index < operator.serializeArgs|length %}, {% endif %}
+ {%- endfor -%}
+ )
+ {%- endif -%};
+
+ // Create tensors
+ {% for input in operator.inputs: -%}
+ tosa::TosaSerializationTensor* {{input}} = translate_client_tensor(client_{{input}}, "{{input}}");
+ {%- endfor -%}
+ {% for output in operator.outputs: %}
+ tosa::TosaSerializationTensor* {{output}} = translate_client_tensor(client_{{output}}, "{{output}}");
+ {%- endfor %}
+
+ // Create operator
+ auto op = new tosa::TosaSerializationOperator(tosa::Op::Op_{{operator.name|upper}},
+ {%- if operator.serializeAttType != "None" -%}
+ tosa::Attribute::Attribute_{{operator.serializeAttType}}Attribute
+ {%- else -%}
+ tosa::Attribute::Attribute_NONE
+ {%- endif -%},
+ &attr, {
+ {%- for input in operator.inputs: -%}
+ {{input}}->GetName()
+ {%- if loop.index < operator.inputs|length -%},{%- endif -%}
+ {%- endfor -%}
+ },
+ {
+ {%- for output in operator.outputs: -%}
+ {{output}}->GetName()
+ {%- if loop.index < operator.outputs|length -%},{%- endif -%}
+ {%- endfor -%}
+ });
+
+ // Create a tosa single-op basic block
+ tosa::TosaSerializationBasicBlock block("{{operator.name}}", { op },
+ {
+ {%- for input in operator.inputs: -%}
+ {{input}},
+ {%- endfor -%}
+ {%- for output in operator.outputs: -%}
+ {{output}}
+ {%- if loop.index < operator.outputs|length -%},{%- endif -%}
+ {%- endfor -%}
+ },
+ {
+ {%- for input in operator.inputs: -%}
+ {{input}}->GetName()
+ {%- if loop.index < operator.inputs|length -%},{%- endif -%}
+ {%- endfor -%}
+ },
+ {
+ {%- for output in operator.outputs: -%}
+ {{output}}->GetName()
+ {%- if loop.index < operator.outputs|length -%},{%- endif -%}
+ {%- endfor -%}
+ });
+
+ // Setup model
+ TosaReference::ModelRunnerImpl runner;
+ TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block));
+ {% for input in operator.inputs: -%}
+ TOSA_RETURN_ON_ERROR(runner.setInput({{input}}->GetName(), client_{{input}}.data, client_{{input}}.size));
+ {%- endfor %}
+
+ // Execute
+ TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.run());
+
+ // Extract outputs
+ {% for output in operator.outputs: -%}
+ TOSA_RETURN_ON_ERROR(runner.getOutput({{output}}->GetName(), client_{{output}}.data, client_{{output}}.size));
+ {%- endfor %}
+
+ return tosa_status_valid;
+ }
+ {% endfor %}
+
+} // extern "C" \ No newline at end of file
diff --git a/scripts/operator_api/templates/operators_h.j2 b/scripts/operator_api/templates/operators_h.j2
new file mode 100644
index 0000000..803b76a
--- /dev/null
+++ b/scripts/operator_api/templates/operators_h.j2
@@ -0,0 +1,74 @@
+
+// Copyright (c) 2022, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// THIS FILE IS GENERATED. DO NOT EDIT!
+// See scripts/operator_api/generate_api.py
+
+#ifndef OPERATORS_H_
+#define OPERATORS_H_
+
+#include <stddef.h>
+#include <stdint.h>
+
+#ifdef __cplusplus
+extern "C" {
+#endif /* __cplusplus */
+
+ // Note status needs to be aligned with graph_status
+ enum tosa_status_t
+ {
+ tosa_status_valid = 0,
+ tosa_status_unpredictable = 1,
+ tosa_status_error = 2
+ };
+
+ enum tosa_mode_t
+ {
+ tosa_mode_unknown = 0,
+ tosa_mode_nearest = 1,
+ tosa_mode_bilinear = 2,
+ tosa_mode_min = 3,
+ tosa_mode_max = 4
+ };
+
+ enum tosa_datatype_t
+ {
+ {% for dataType in dataTypes: -%}
+ {{dataType}} = {{loop.index-1}},
+ {% endfor -%}
+ };
+
+ struct tosa_tensor_t
+ {
+ int32_t* shape;
+ int32_t num_dims;
+ tosa_datatype_t data_type;
+ uint8_t* data;
+ size_t size;
+ };
+
+ {% for operator in operators: %}
+ tosa_status_t tosa_run_{{ operator.name }} (
+ {%- for arg in operator.arguments: -%}
+ {% if arg.type != "tosa_tensor_t" -%}const {% endif -%}{{arg.type}} client_{{arg.name}}{{arg.shape}}
+ {% if loop.index < operator.arguments|length %},{% endif %}
+ {%- endfor -%});
+ {% endfor %}
+
+#ifdef __cplusplus
+}
+#endif /* __cplusplus */
+
+#endif // OPERATORS_H_ \ No newline at end of file