diff options
author | Grant Watson <grant.watson@arm.com> | 2022-11-16 15:32:39 +0000 |
---|---|---|
committer | Eric Kunze <eric.kunze@arm.com> | 2022-12-15 16:41:27 +0000 |
commit | 64285a1f25e2c7b85ed1f00b7947403e92baea00 (patch) | |
tree | 6d29c54f6497741449339e808508c854ba6a2267 /scripts | |
parent | b45db9a696f5df7b233f374248f329c16ee7ae64 (diff) | |
download | reference_model-64285a1f25e2c7b85ed1f00b7947403e92baea00.tar.gz |
Extend reference model API with eager operator execution entrypoints
- Adds a script to generate operators.h and operators.cc
- Adds jinja2 templates for generating operators.h and operators.cc
- Adds unit tests for a subset of the operators generated
- Includes the TOSA specification as a submodule
- Adds supporting C++ and header files
Signed-off-by: Grant Watson <grant.watson@arm.com>
Change-Id: I5b60db4c56113110d8e75fe1152525d258233f9c
Diffstat (limited to 'scripts')
-rw-r--r-- | scripts/operator_api/README.md | 19 | ||||
-rw-r--r-- | scripts/operator_api/generate_api.py | 349 | ||||
-rw-r--r-- | scripts/operator_api/templates/operators_cc.j2 | 176 | ||||
-rw-r--r-- | scripts/operator_api/templates/operators_h.j2 | 74 |
4 files changed, 618 insertions, 0 deletions
diff --git a/scripts/operator_api/README.md b/scripts/operator_api/README.md new file mode 100644 index 0000000..381d90c --- /dev/null +++ b/scripts/operator_api/README.md @@ -0,0 +1,19 @@ +# Generate eager operator execution entrypoints + +## Introduction + +The generate_api.py script will generate an extended reference model API with eager operator execution entrypoints. +The following files will be generated: include/operators.h and src/operators.cc + +## Requirements + +* Python 3.6 or later +* Jinja2 (install with ```pip install Jinja2```) + +## Running from the command line + +The script can be run from the scripts/operator-api directory as follows: + +```bash +python generate_api.py +``` diff --git a/scripts/operator_api/generate_api.py b/scripts/operator_api/generate_api.py new file mode 100644 index 0000000..1f89f74 --- /dev/null +++ b/scripts/operator_api/generate_api.py @@ -0,0 +1,349 @@ +"""Generate extended reference model API with eager operator execution entrypoints""" +# Copyright (c) 2021-2022, ARM Limited. +# SPDX-License-Identifier: Apache-2.0 +import copy +import os +import subprocess +from xml.dom import minidom + +from jinja2 import Environment +from jinja2 import FileSystemLoader + + +def getTosaArgTypes(tosaXml): + """ + Returns a list of the TOSA argument types from tosa.xml. + """ + argTypes = {"in_t", "out_t", "mul_t", "weight_t", "in_out_t"} + argTypesXml = tosaXml.getElementsByTagName("type") + for argTypeXml in argTypesXml: + argTypes.add(argTypeXml.getAttribute("name")) + argTypes.remove("TABLE_SIZE") + return argTypes + + +def getTosaDataTypes(tosaXml): + """ + Returns a list of the TOSA data types from tosa.xml. + """ + argTypes = getTosaArgTypes(tosaXml) + dataTypes = set() + dataTypesXml = tosaXml.getElementsByTagName("typesupport") + for dataTypeXml in dataTypesXml: + for argType in argTypes: + dataType = dataTypeXml.getAttribute(argType) + if dataType != "": + dataTypes.add(f"tosa_datatype_{dataType}") + return sorted(dataTypes) + + +def getSerializeOpType(tosaOpName): + """ + Returns the Serialization library operator that matches the TOSA operator specified. + """ + map = { + "avg_pool2d": "Pool", + "conv2d": "Conv", + "conv3d": "Conv", + "depthwise_conv2d": "Conv", + "fully_connected": "FullyConnected", + "matmul": "MatMul", + "max_pool2d": "Pool", + "transpose_conv2d": "Conv", + "clamp": "Clamp", + "arithmetic_right_shift": "ArithmeticRightShift", + "mul": "Mul", + "table": "Table", + "negate": "Negate", + "pad": "Pad", + "reshape": "Reshape", + "slice": "Slice", + "tile": "Tile", + "transpose": "Transpose", + "resize": "Resize", + "rescale": "Rescale", + "cond_if": "CondIf", + "while_loop": "WhileLoop", + } + if tosaOpName not in map.keys(): + return "None" + else: + return map[tosaOpName] + + +def getSerializeArgsForOp(tosaOpName, allSerializeArgs, tosaArgs): + """ + Returns the arguments required by the Serialization library for the TOSA operator specified. + Generates code to initialize Serialization arguments. If a matching TOSA argument exists, + that value is used for initialization, otherwise a default value e.g. 0 is used. + """ + serOpType = getSerializeOpType(tosaOpName) + if serOpType not in allSerializeArgs.keys(): + return {} + else: + serOpArgs = copy.deepcopy(allSerializeArgs[serOpType]) + tosaArgsDict = {arg["name"]: arg for arg in tosaArgs} + serTosaTypeMap = {"ResizeMode": "tosa_mode"} + for arg in serOpArgs: + argName = arg["name"] + init = "" + # Translate TOSA data types to Serialization data types for initialization + if arg["dType"] in serTosaTypeMap.keys(): + init = f" = translate_client_{serTosaTypeMap[arg['dType']]}(client_{argName})" + # Initialize Serialization arguments to their matching function parameter + elif argName in tosaArgsDict: + if arg["SV"] == "V": + shape = tosaArgsDict[argName]["shape"] + if shape == "[]": + init = f"(&client_{argName}[0], &client_{argName}[0] + client_{argName}_len)" + else: + init = f"(&client_{argName}[0], &client_{argName}{shape})" + else: + init = f" = client_{argName}" + else: + # Initialize Serialization arguments with no matching fuction parameter + if arg["SV"] == "V": + init = "" + else: + if arg["dType"] == "DType": + arg["dType"] = "tosa::DType" + init = " = tosa::DType::DType_FP32" + else: + init = " = 0" + arg["init"] = init + return serOpArgs + + +def updateTosaArgs(tosaArgs, serializeArgs, tosaXml): + """ + Replace TOSA argument data types with their matching Serialization argument data types. + Delete TOSA arguments where the type couldn't be determined. + Add Serialization arguments that have no matching TOSA argument. + """ + tosaArgTypes = getTosaArgTypes(tosaXml) + serArgsDict = {arg["name"]: arg for arg in serializeArgs} + tosaArgsNames = [arg["name"] for arg in tosaArgs] + delTosaArgs = [] + # Replace TOSA argument data types with their matching Serialization argument data types. + for tosaArg in tosaArgs: + if tosaArg["type"] in tosaArgTypes: + if tosaArg["name"] in serArgsDict: + tosaArg["type"] = serArgsDict[tosaArg["name"]]["dType"] + else: + # Delete TOSA argument whose data type can't be determined + delTosaArgs.append(tosaArgsNames.index(tosaArg["name"])) + # Delete corresponding length argument if one exists + lenArgName = f"{tosaArg['name']}_len" + if lenArgName in tosaArgsNames: + delTosaArgs.append(tosaArgsNames.index(lenArgName)) + # Delete TOSA arguments where the type couldn't be determined + for index in sorted(delTosaArgs, key=int, reverse=True): + del tosaArgs[index] + # Add Serialization arguments that have no matching TOSA argument + tosaArgNames = [arg["name"] for arg in tosaArgs] + for serArg in serializeArgs: + if (serArg["name"] not in tosaArgNames) and ( + not serArg["dType"] == "tosa::DType" + ): + serArgName = serArg["name"] + if serArg["SV"] == "V": + # For vector data types, insert a matching length argument + tosaArgs.insert( + len(tosaArgs) - 1, + { + "name": f"{serArgName}_len", + "type": "int32_t", + "shape": "", + "category": "", + }, + ) + init = f"(&client_{serArgName}[0], &client_{serArgName}[0] + client_{serArgName}_len)" + shape = "[]" + else: + init = f" = client_{serArg['name']}" + shape = "" + serArg["init"] = init + # Insert new argument + tosaArgs.insert( + len(tosaArgs) - 1, + { + "name": serArgName, + "type": serArg["dType"], + "shape": shape, + "category": "", + }, + ) + + +def getOperators(tosaXml): + """ + Return a list of TOSA operators as defined by tosa.xml. + """ + operators = [] + ignoreOps = ["while_loop", "cond_if", "const", "custom", "fft2d", "rfft2d"] + opsXml = tosaXml.getElementsByTagName("operator") + allSerializeArgs = getSerializeArgs() + for opXml in opsXml: + opName = opXml.getElementsByTagName("name")[0].firstChild.data.lower() + if opName not in ignoreOps: + operator = {"name": opName} + operator["serializeAttType"] = getSerializeOpType(opName) + tosaArgs = getTosaArgs(opXml) + serializeArgs = getSerializeArgsForOp(opName, allSerializeArgs, tosaArgs) + updateTosaArgs(tosaArgs, serializeArgs, tosaXml) + operator["arguments"] = tosaArgs + operator["serializeArgs"] = serializeArgs + operator["inputs"] = [ + arg["name"] for arg in tosaArgs if arg["category"] == "input" + ] + operator["outputs"] = [ + arg["name"] for arg in tosaArgs if arg["category"] == "output" + ] + operators.append(operator) + return operators + + +def getTosaArgs(opXml): + """ + Return the arguments required for the TOSA operator specified. + """ + arguments = [] + argsXml = opXml.getElementsByTagName("argument") + tosaTensorTypes = getTosaArgTypes(tosaXml) + tosaTypeMap = {"bool_t": "bool", "uint6_t": "uint8_t", "mode_t": "tosa_mode_t"} + for xmlArg in argsXml: + argName = xmlArg.getAttribute("name").lower() + argType = xmlArg.getAttribute("type") + argShape = xmlArg.getAttribute("shape") + argCategory = xmlArg.getAttribute("category") + # Update argument type + if argType[-1:] == "*": + argType = argType[:-1] + if argCategory in ["input", "output"] and argType in tosaTensorTypes: + argType = "tosa_tensor_t" + argShape = "" + if argType in tosaTypeMap: + argType = tosaTypeMap[argType] + # Add a length argument for arrays with unknown compile-time size + if argShape != "" and argShape[0] == "[" and not argShape[1:-1].isnumeric(): + argShape = "[]" + arguments.append( + { + "name": f"{argName}_len", + "type": "int32_t", + "shape": "", + "category": "", + } + ) + elif argShape == "" or not argShape[0] == "[": + argShape = "" + # Append argument + arguments.append( + { + "name": argName, + "type": argType, + "shape": argShape, + "category": argCategory, + } + ) + return arguments + + +def clangFormat(filename): + cmd = ["clang-format", "-i", filename] + with open(os.devnull, "w") as devnull: + subprocess.check_call(cmd, stdout=devnull) + + +def getSerializeArgs(): + """ + Parse attribute.def file and return a dictionary where the keys are Serialization library operator names. + The values are the arguments required by each Serialization library operator. + """ + serializeArgs = {} + with open("../../thirdparty/serialization_lib/include/attribute.def") as file: + preamble = True + inAtt = False + opName = "" + args = [] + for line in file: + if preamble and not line[: len("DEF_ATTRIBUTE(")] == "DEF_ATTRIBUTE(": + continue + else: + preamble = False + line = line.lstrip().rstrip() + if not inAtt and "DEF_ATTRIBUTE(" in line: + opName = line[len("DEF_ATTRIBUTE(") : line.find(",")] + inAtt = True + elif inAtt: + vals = line.split(",") + argName = vals[2].lstrip().strip() + if ")" in argName: + argName = argName[:-1] + arg = { + "name": argName, + "dType": vals[0].lstrip().strip(), + "SV": vals[1].lstrip().strip(), + } + args.append(arg) + if ")" in line: + serializeArgs[opName] = args + opName = "" + args = [] + inAtt = False + return serializeArgs + + +def renderTemplate(environment, dataTypes, operators, template, outfile): + content = template.render(dataTypes=dataTypes, operators=operators) + with open(outfile, mode="w", encoding="utf-8") as output: + output.write(content) + print(f"Created {outfile}") + + clangFormat(outfile) + + +def generate(environment, dataTypes, operators): + # Generate include/operators.h + template = environment.get_template("operators_h.j2") + outfile = os.path.join("..", "..", "reference_model", "include", "operators.h") + renderTemplate(environment, dataTypes, operators, template, outfile) + + # Generate src/operators.cc + template = environment.get_template("operators_cc.j2") + outfile = os.path.join("..", "..", "reference_model", "src", "operators.cc") + renderTemplate(environment, dataTypes, operators, template, outfile) + + +def getSerializeOpTypeMap(): + """ + Utility function for generating the map used in getSerializeOpType() + """ + import re + + allSerializeArgs = getSerializeArgs() + serArgs = [ + re.sub(r"(?<!^)(?=[A-Z])", "_", name).lower() + for name in allSerializeArgs.keys() + ] + serArgs = sorted(serArgs, key=len, reverse=True) + tosaXml = minidom.parse("../../thirdparty/specification/tosa.xml") + opsXml = tosaXml.getElementsByTagName("operator") + opNames = [ + op.getElementsByTagName("name")[0].firstChild.data.lower() for op in opsXml + ] + map = {} + for opName in opNames: + for serArg in serArgs: + if serArg in opName: + components = serArg.split("_") + map[opName] = "".join(x.title() for x in components) + return map + + +if __name__ == "__main__": + environment = Environment(loader=FileSystemLoader("templates/")) + tosaXml = minidom.parse("../../thirdparty/specification/tosa.xml") + dataTypes = getTosaDataTypes(tosaXml) + operators = getOperators(tosaXml) + generate(environment, dataTypes, operators) diff --git a/scripts/operator_api/templates/operators_cc.j2 b/scripts/operator_api/templates/operators_cc.j2 new file mode 100644 index 0000000..6b0ed6e --- /dev/null +++ b/scripts/operator_api/templates/operators_cc.j2 @@ -0,0 +1,176 @@ + +// Copyright (c) 2022, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// THIS FILE IS GENERATED. DO NOT EDIT! +// See scripts/operator_api/generate_api.py + +#include "operators.h" +#include "model_runner_impl.h" +#include "ops/op_factory.h" + +#define TOSA_RETURN_ON_ERROR(status) \ + do \ + { \ + if (status != 0) \ + { \ + return tosa_status_error; \ + } \ + } while (false) + +#define TOSA_RETURN_ON_GRAPH_STATUS_ERROR(status) \ + do \ + { \ + if (status != GraphStatus::TOSA_VALID) \ + { \ + auto ustatus = static_cast<std::underlying_type_t<GraphStatus>>(status); \ + return static_cast<tosa_status_t>(ustatus); \ + } \ + } while (false) + +namespace { + +tosa::DType translate_client_datatype(tosa_datatype_t type) +{ + switch (type) + { + case tosa_datatype_fp16_t: + return tosa::DType::DType_FP16; + case tosa_datatype_fp32_t: + return tosa::DType::DType_FP32; + default: + return tosa::DType::DType_UNKNOWN; + } +}; + +tosa::TosaSerializationTensor* translate_client_tensor(tosa_tensor_t& tensor, const std::string& name) +{ + std::vector<int32_t> shape(tensor.shape, tensor.shape + tensor.num_dims); + return new tosa::TosaSerializationTensor(name, shape, translate_client_datatype(tensor.data_type), {}); +} + +tosa::ResizeMode translate_client_tosa_mode(tosa_mode_t mode) { + switch(mode) { + case tosa_mode_nearest: + return tosa::ResizeMode_NEAREST; + case tosa_mode_max: + case tosa_mode_bilinear: + return tosa::ResizeMode_BILINEAR; + default: + return tosa::ResizeMode_UNKNOWN; + } +} + +} // namespace + +extern "C" +{ + {% for operator in operators: %} + tosa_status_t tosa_run_{{ operator.name }} ( + {%- for arg in operator.arguments: -%} + {% if arg.type != "tosa_tensor_t" -%}const {% endif -%}{{arg.type}} client_{{arg.name}}{{arg.shape}} + {% if loop.index < operator.arguments|length %},{% endif %} + {%- endfor -%} + ) + { + // Create operator attributes + {% for arg in operator.serializeArgs: %} + {%- if arg.SV == "V": -%} + const std::vector<{{arg.dType}}> {{arg.name}}{{arg.init}}; + {%- else: -%} + const {{arg.dType}} {{arg.name}}{{arg.init}}; + {%- endif -%} + {%- endfor -%} + + Tosa{{operator.serializeAttType}}Attribute attr + {%- if operator.serializeArgs|length > 0 -%} + ( + {%- for arg in operator.serializeArgs: -%} + {{arg.name}}{% if loop.index < operator.serializeArgs|length %}, {% endif %} + {%- endfor -%} + ) + {%- endif -%}; + + // Create tensors + {% for input in operator.inputs: -%} + tosa::TosaSerializationTensor* {{input}} = translate_client_tensor(client_{{input}}, "{{input}}"); + {%- endfor -%} + {% for output in operator.outputs: %} + tosa::TosaSerializationTensor* {{output}} = translate_client_tensor(client_{{output}}, "{{output}}"); + {%- endfor %} + + // Create operator + auto op = new tosa::TosaSerializationOperator(tosa::Op::Op_{{operator.name|upper}}, + {%- if operator.serializeAttType != "None" -%} + tosa::Attribute::Attribute_{{operator.serializeAttType}}Attribute + {%- else -%} + tosa::Attribute::Attribute_NONE + {%- endif -%}, + &attr, { + {%- for input in operator.inputs: -%} + {{input}}->GetName() + {%- if loop.index < operator.inputs|length -%},{%- endif -%} + {%- endfor -%} + }, + { + {%- for output in operator.outputs: -%} + {{output}}->GetName() + {%- if loop.index < operator.outputs|length -%},{%- endif -%} + {%- endfor -%} + }); + + // Create a tosa single-op basic block + tosa::TosaSerializationBasicBlock block("{{operator.name}}", { op }, + { + {%- for input in operator.inputs: -%} + {{input}}, + {%- endfor -%} + {%- for output in operator.outputs: -%} + {{output}} + {%- if loop.index < operator.outputs|length -%},{%- endif -%} + {%- endfor -%} + }, + { + {%- for input in operator.inputs: -%} + {{input}}->GetName() + {%- if loop.index < operator.inputs|length -%},{%- endif -%} + {%- endfor -%} + }, + { + {%- for output in operator.outputs: -%} + {{output}}->GetName() + {%- if loop.index < operator.outputs|length -%},{%- endif -%} + {%- endfor -%} + }); + + // Setup model + TosaReference::ModelRunnerImpl runner; + TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block)); + {% for input in operator.inputs: -%} + TOSA_RETURN_ON_ERROR(runner.setInput({{input}}->GetName(), client_{{input}}.data, client_{{input}}.size)); + {%- endfor %} + + // Execute + TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.run()); + + // Extract outputs + {% for output in operator.outputs: -%} + TOSA_RETURN_ON_ERROR(runner.getOutput({{output}}->GetName(), client_{{output}}.data, client_{{output}}.size)); + {%- endfor %} + + return tosa_status_valid; + } + {% endfor %} + +} // extern "C"
\ No newline at end of file diff --git a/scripts/operator_api/templates/operators_h.j2 b/scripts/operator_api/templates/operators_h.j2 new file mode 100644 index 0000000..803b76a --- /dev/null +++ b/scripts/operator_api/templates/operators_h.j2 @@ -0,0 +1,74 @@ + +// Copyright (c) 2022, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// THIS FILE IS GENERATED. DO NOT EDIT! +// See scripts/operator_api/generate_api.py + +#ifndef OPERATORS_H_ +#define OPERATORS_H_ + +#include <stddef.h> +#include <stdint.h> + +#ifdef __cplusplus +extern "C" { +#endif /* __cplusplus */ + + // Note status needs to be aligned with graph_status + enum tosa_status_t + { + tosa_status_valid = 0, + tosa_status_unpredictable = 1, + tosa_status_error = 2 + }; + + enum tosa_mode_t + { + tosa_mode_unknown = 0, + tosa_mode_nearest = 1, + tosa_mode_bilinear = 2, + tosa_mode_min = 3, + tosa_mode_max = 4 + }; + + enum tosa_datatype_t + { + {% for dataType in dataTypes: -%} + {{dataType}} = {{loop.index-1}}, + {% endfor -%} + }; + + struct tosa_tensor_t + { + int32_t* shape; + int32_t num_dims; + tosa_datatype_t data_type; + uint8_t* data; + size_t size; + }; + + {% for operator in operators: %} + tosa_status_t tosa_run_{{ operator.name }} ( + {%- for arg in operator.arguments: -%} + {% if arg.type != "tosa_tensor_t" -%}const {% endif -%}{{arg.type}} client_{{arg.name}}{{arg.shape}} + {% if loop.index < operator.arguments|length %},{% endif %} + {%- endfor -%}); + {% endfor %} + +#ifdef __cplusplus +} +#endif /* __cplusplus */ + +#endif // OPERATORS_H_
\ No newline at end of file |