aboutsummaryrefslogtreecommitdiff
path: root/scripts
diff options
context:
space:
mode:
authorGrant Watson <grant.watson@arm.com>2023-09-12 10:46:36 +0100
committerEric Kunze <eric.kunze@arm.com>2023-09-14 17:02:30 +0000
commiteff703839274d44743e837300bb273046ca2a639 (patch)
tree3322b948eaeb863bbd13d22b4a014bd71e6ae55b /scripts
parent48df8c7509f51b145e97619a45aa25836e702767 (diff)
downloadreference_model-eff703839274d44743e837300bb273046ca2a639.tar.gz
Upgrade to latest version of TOSA specification
- Updates TOSA specification to the latest version - Updates generate_api.py to generate the operator API correctly for ops with additional tensor inputs. - Removes default arguments for func_debug and func_config to make the API C compliant again. - Updates model_runner_tests.cpp for operators that have changed. - Adds a unit test for the Tile operator to check that generated code for additional tensor inputs works correctly. Signed-off-by: Grant Watson <grant.watson@arm.com> Change-Id: I1e26065c6ed333b2ca4b3da39972d30f896fa6e5
Diffstat (limited to 'scripts')
-rw-r--r--scripts/operator_api/generate_api.py171
-rw-r--r--scripts/operator_api/templates/operators_cc.j226
-rw-r--r--scripts/operator_api/templates/operators_h.j237
3 files changed, 123 insertions, 111 deletions
diff --git a/scripts/operator_api/generate_api.py b/scripts/operator_api/generate_api.py
index f1cb6e0..c5c762d 100644
--- a/scripts/operator_api/generate_api.py
+++ b/scripts/operator_api/generate_api.py
@@ -81,64 +81,87 @@ def getSerializeOpType(tosaOpName):
return map[tosaOpName]
-def getSerializeArgsForOp(tosaOpName, allSerializeArgs, tosaArgs):
+def getSerialLibAttsForOp(tosaOpName, allSerialLibAtts, tosaArgs):
"""
- Returns the arguments required by the Serialization library for the TOSA operator specified.
- Generates code to initialize Serialization arguments. If a matching TOSA argument exists,
+ Returns the attributes required by the Serialization library for the TOSA operator specified.
+ Generates code to initialize Serialization library attributes. If a matching TOSA argument exists,
that value is used for initialization, otherwise a default value e.g. 0 is used.
"""
- serOpType = getSerializeOpType(tosaOpName)
- if serOpType not in allSerializeArgs.keys():
+ serLibOpType = getSerializeOpType(tosaOpName)
+ if serLibOpType not in allSerialLibAtts.keys():
return {}
else:
- serOpArgs = copy.deepcopy(allSerializeArgs[serOpType])
+ serLibOpAtts = copy.deepcopy(allSerialLibAtts[serLibOpType])
tosaArgsDict = {arg["name"]: arg for arg in tosaArgs}
serTosaTypeMap = {"ResizeMode": "tosa_mode"}
- for arg in serOpArgs:
- argName = arg["name"]
+ for att in serLibOpAtts:
+ attName = att["name"]
+ attType = att["dType"]
init = ""
- # Translate TOSA data types to Serialization data types for initialization
- if arg["dType"] in serTosaTypeMap.keys():
- init = f" = translate_client_{serTosaTypeMap[arg['dType']]}(client_{argName})"
- # Initialize Serialization arguments to their matching function parameter
- elif argName in tosaArgsDict:
- if arg["SV"] == "V":
- shape = tosaArgsDict[argName]["shape"]
- if shape == "[]":
- init = f"(&client_{argName}[0], &client_{argName}[0] + client_{argName}_len)"
+ # Translate TOSA data types to Serialization library data types for initialization
+ if attType in serTosaTypeMap.keys():
+ init = f"const {attType} {attName} = translate_client_{serTosaTypeMap[att['dType']]}(client_{attName});"
+ # Initialize Serialization library attributes to their matching function parameter
+ elif attName in tosaArgsDict:
+ if att["SV"] == "V":
+ if tosaArgsDict[attName]["type"] == "tosa_tensor_t":
+ init = f"std::vector<{attType}> {attName};"
+ init = (
+ init
+ + f"size_t {attName}_size = client_{attName}.size / sizeof({attType});"
+ )
+ init = (
+ init
+ + f"{attType}* {attName}_data = reinterpret_cast<{attType}*>(client_{attName}.data);"
+ )
+ init = (
+ init
+ + f"{attName}.assign({attName}_data, {attName}_data + {attName}_size);"
+ )
else:
- init = f"(&client_{argName}[0], &client_{argName}{shape})"
+ init = f"const std::vector<{attType}> {attName}"
+ shape = tosaArgsDict[attName]["shape"]
+ if shape == "[]":
+ init = (
+ init
+ + f"(&client_{attName}[0], &client_{attName}[0] + client_{attName}_len);"
+ )
+ else:
+ init = (
+ init
+ + f"(&client_{attName}[0], &client_{attName}{shape});"
+ )
else:
- init = f" = client_{argName}"
- else:
- # Initialize Serialization arguments with no matching fuction parameter
- if arg["SV"] == "V":
init = ""
+ else:
+ # Initialize Serialization library attributes with no matching fuction parameter
+ if att["SV"] == "V":
+ init = f"std::vector<int32_t> {attName};"
else:
- if arg["dType"] == "DType":
- arg["dType"] = "tosa::DType"
- init = " = tosa::DType::DType_FP32"
+ if att["dType"] == "DType":
+ att["dType"] = "tosa::DType"
+ init = f"const tosa::DType {attName} = tosa::DType::DType_FP32;"
else:
- init = " = 0"
- arg["init"] = init
- return serOpArgs
+ init = f"const {attType} {attName} = 0;"
+ att["init"] = init
+ return serLibOpAtts
-def updateTosaArgs(tosaArgs, serializeArgs, tosaXml):
+def updateTosaArgs(tosaArgs, serialLibAtts, tosaXml):
"""
- Replace TOSA argument data types with their matching Serialization argument data types.
+ Replace TOSA argument data types with their matching Serialization attribute data types.
Delete TOSA arguments where the type couldn't be determined.
- Add Serialization arguments that have no matching TOSA argument.
+ Add Serialization attributes that have no matching TOSA argument.
"""
tosaArgTypes = getTosaArgTypes(tosaXml)
- serArgsDict = {arg["name"]: arg for arg in serializeArgs}
+ serAttsDict = {att["name"]: att for att in serialLibAtts}
tosaArgsNames = [arg["name"] for arg in tosaArgs]
delTosaArgs = []
- # Replace TOSA argument data types with their matching Serialization argument data types.
+ # Replace TOSA argument data types with their matching Serialization attribute data types.
for tosaArg in tosaArgs:
if tosaArg["type"] in tosaArgTypes:
- if tosaArg["name"] in serArgsDict:
- tosaArg["type"] = serArgsDict[tosaArg["name"]]["dType"]
+ if tosaArg["name"] in serAttsDict:
+ tosaArg["type"] = serAttsDict[tosaArg["name"]]["dType"]
else:
# Delete TOSA argument whose data type can't be determined
delTosaArgs.append(tosaArgsNames.index(tosaArg["name"]))
@@ -149,36 +172,36 @@ def updateTosaArgs(tosaArgs, serializeArgs, tosaXml):
# Delete TOSA arguments where the type couldn't be determined
for index in sorted(delTosaArgs, key=int, reverse=True):
del tosaArgs[index]
- # Add Serialization arguments that have no matching TOSA argument
+ # Add Serialization attributes that have no matching TOSA argument
tosaArgNames = [arg["name"] for arg in tosaArgs]
- for serArg in serializeArgs:
- if (serArg["name"] not in tosaArgNames) and (
- not serArg["dType"] == "tosa::DType"
- ):
- serArgName = serArg["name"]
- if serArg["SV"] == "V":
+ for serAtt in serialLibAtts:
+ attName = serAtt["name"]
+ attType = serAtt["dType"]
+ if (attName not in tosaArgNames) and (not attType == "tosa::DType"):
+ serAttName = serAtt["name"]
+ if serAtt["SV"] == "V":
# For vector data types, insert a matching length argument
tosaArgs.insert(
len(tosaArgs) - 1,
{
- "name": f"{serArgName}_len",
+ "name": f"{serAttName}_len",
"type": "int32_t",
"shape": "",
"category": "",
},
)
- init = f"(&client_{serArgName}[0], &client_{serArgName}[0] + client_{serArgName}_len)"
+ init = f"const std::vector<{attType}> {attName}(&client_{serAttName}[0], &client_{serAttName}[0] + client_{serAttName}_len);"
shape = "[]"
else:
- init = f" = client_{serArg['name']}"
+ init = ""
shape = ""
- serArg["init"] = init
+ serAtt["init"] = init
# Insert new argument
tosaArgs.insert(
len(tosaArgs) - 1,
{
- "name": serArgName,
- "type": serArg["dType"],
+ "name": serAttName,
+ "type": serAtt["dType"],
"shape": shape,
"category": "",
},
@@ -190,33 +213,47 @@ def getOperators(tosaXml):
Return a list of TOSA operators as defined by tosa.xml.
"""
operators = []
- ignoreOps = ["while_loop", "cond_if", "const", "custom", "fft2d", "rfft2d"]
+ ignoreOps = [
+ "while_loop",
+ "cond_if",
+ "const",
+ "custom",
+ "fft2d",
+ "rfft2d",
+ "variable",
+ "variable_read",
+ "variable_write",
+ ]
opsXml = tosaXml.getElementsByTagName("operator")
- allSerializeArgs = getSerializeArgs()
+ allSerialLibAtts = getSerialLibAtts()
for opXml in opsXml:
opName = opXml.getElementsByTagName("name")[0].firstChild.data.lower()
if opName not in ignoreOps:
operator = {"name": opName}
operator["serializeAttType"] = getSerializeOpType(opName)
tosaArgs = getTosaArgs(opXml)
- serializeArgs = getSerializeArgsForOp(opName, allSerializeArgs, tosaArgs)
+ serialLibAtts = getSerialLibAttsForOp(opName, allSerialLibAtts, tosaArgs)
# Handle "axis" arguments
axisList = [arg["name"] for arg in tosaArgs if arg["name"] == "axis"]
if operator["serializeAttType"] == "None" and len(axisList) > 0:
operator["serializeAttType"] = "Axis"
- serializeArgs = [
+ serialLibAtts = [
{
"name": "axis",
"dType": "int32_t",
"SV": "S",
- "init": "= client_axis",
+ "init": "",
}
]
- updateTosaArgs(tosaArgs, serializeArgs, tosaXml)
+ updateTosaArgs(tosaArgs, serialLibAtts, tosaXml)
operator["arguments"] = tosaArgs
- operator["serializeArgs"] = serializeArgs
+ operator["serialLibAtts"] = serialLibAtts
+ serializationAttNames = [att["name"] for att in serialLibAtts]
operator["inputs"] = [
- arg["name"] for arg in tosaArgs if arg["category"] == "input"
+ arg["name"]
+ for arg in tosaArgs
+ if arg["category"] == "input"
+ and arg["name"] not in serializationAttNames
]
operator["outputs"] = [
arg["name"] for arg in tosaArgs if arg["category"] == "output"
@@ -283,12 +320,12 @@ def clangFormat(filename):
subprocess.check_call(cmd, stdout=devnull)
-def getSerializeArgs():
+def getSerialLibAtts():
"""
Parse attribute.def file and return a dictionary where the keys are Serialization library operator names.
The values are the arguments required by each Serialization library operator.
"""
- serializeArgs = {}
+ serialLibAtts = {}
with open("../../thirdparty/serialization_lib/include/attribute.def") as file:
preamble = True
inAtt = False
@@ -315,11 +352,11 @@ def getSerializeArgs():
}
args.append(arg)
if ")" in line:
- serializeArgs[opName] = args
+ serialLibAtts[opName] = args
opName = ""
args = []
inAtt = False
- return serializeArgs
+ return serialLibAtts
def renderTemplate(environment, dataTypes, operators, template, outfile):
@@ -349,12 +386,12 @@ def getSerializeOpTypeMap():
"""
import re
- allSerializeArgs = getSerializeArgs()
- serArgs = [
+ allSerialLibAtts = getSerialLibAtts()
+ serAtts = [
re.sub(r"(?<!^)(?=[A-Z])", "_", name).lower()
- for name in allSerializeArgs.keys()
+ for name in allSerialLibAtts.keys()
]
- serArgs = sorted(serArgs, key=len, reverse=True)
+ serAtts = sorted(serAtts, key=len, reverse=True)
tosaXml = minidom.parse("../../thirdparty/specification/tosa.xml")
opsXml = tosaXml.getElementsByTagName("operator")
opNames = [
@@ -362,9 +399,9 @@ def getSerializeOpTypeMap():
]
map = {}
for opName in opNames:
- for serArg in serArgs:
- if serArg in opName:
- components = serArg.split("_")
+ for serAtt in serAtts:
+ if serAtt in opName:
+ components = serAtt.split("_")
map[opName] = "".join(x.title() for x in components)
return map
diff --git a/scripts/operator_api/templates/operators_cc.j2 b/scripts/operator_api/templates/operators_cc.j2
index a8f1c24..6b6f864 100644
--- a/scripts/operator_api/templates/operators_cc.j2
+++ b/scripts/operator_api/templates/operators_cc.j2
@@ -67,6 +67,8 @@ tosa::DType translate_client_datatype(tosa_datatype_t type)
return tosa::DType::DType_UINT16;
case tosa_datatype_uint8_t:
return tosa::DType::DType_UINT8;
+ case tosa_datatype_shape_t:
+ return tosa::DType::DType_SHAPE;
default:
return tosa::DType::DType_UNKNOWN;
}
@@ -99,24 +101,24 @@ extern "C"
{%- for arg in operator.arguments: -%}
{% if arg.type != "tosa_tensor_t" -%}const {% endif -%}{{arg.type}} client_{{arg.name}}{{arg.shape}}
{% if loop.index < operator.arguments|length %},{% endif %}
- {%- endfor -%}, const func_config_t &func_config,
- const func_debug_t &func_debug
+ {%- endfor -%},const func_ctx_t& func_ctx
)
{
// Create operator attributes
- {% for arg in operator.serializeArgs: %}
- {%- if arg.SV == "V": -%}
- const std::vector<{{arg.dType}}> {{arg.name}}{{arg.init}};
- {%- else: -%}
- const {{arg.dType}} {{arg.name}}{{arg.init}};
- {%- endif -%}
+ {% for att in operator.serialLibAtts: -%}
+ {{att.init}}
{%- endfor -%}
Tosa{{operator.serializeAttType}}Attribute attr
- {%- if operator.serializeArgs|length > 0 -%}
+ {%- if operator.serialLibAtts|length > 0 -%}
(
- {%- for arg in operator.serializeArgs: -%}
- {{arg.name}}{% if loop.index < operator.serializeArgs|length %}, {% endif %}
+ {%- for att in operator.serialLibAtts: -%}
+ {%- if att.init == "" -%}
+ client_{{att.name}}
+ {%- else -%}
+ {{att.name}}
+ {%- endif -%}
+ {% if loop.index < operator.serialLibAtts|length %}, {% endif %}
{%- endfor -%}
)
{%- endif -%};
@@ -174,7 +176,7 @@ extern "C"
});
// Setup model
- TosaReference::ModelRunnerImpl runner(func_config, func_debug);
+ TosaReference::ModelRunnerImpl runner(func_ctx.func_config, func_ctx.func_debug);
TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block));
{% for input in operator.inputs: -%}
TOSA_RETURN_ON_ERROR(runner.setInput({{input}}->GetName(), client_{{input}}.data, client_{{input}}.size));
diff --git a/scripts/operator_api/templates/operators_h.j2 b/scripts/operator_api/templates/operators_h.j2
index 042d7a5..0c98da8 100644
--- a/scripts/operator_api/templates/operators_h.j2
+++ b/scripts/operator_api/templates/operators_h.j2
@@ -21,6 +21,7 @@
#include "func_config.h"
#include "func_debug.h"
+#include "types.h"
#include <stddef.h>
#include <stdint.h>
@@ -29,37 +30,10 @@
extern "C" {
#endif /* __cplusplus */
- // Note status needs to be aligned with graph_status
- enum tosa_status_t
+ struct func_ctx_t
{
- tosa_status_valid = 0,
- tosa_status_unpredictable = 1,
- tosa_status_error = 2
- };
-
- enum tosa_mode_t
- {
- tosa_mode_unknown = 0,
- tosa_mode_nearest = 1,
- tosa_mode_bilinear = 2,
- tosa_mode_min = 3,
- tosa_mode_max = 4
- };
-
- enum tosa_datatype_t
- {
- {% for dataType in dataTypes: -%}
- {{dataType}} = {{loop.index-1}},
- {% endfor -%}
- };
-
- struct tosa_tensor_t
- {
- int32_t* shape;
- int32_t num_dims;
- tosa_datatype_t data_type;
- uint8_t* data;
- size_t size;
+ func_config_t func_config = func_config_t{};
+ func_debug_t func_debug = func_debug_t{};
};
{% for operator in operators: %}
@@ -67,8 +41,7 @@ extern "C" {
{%- for arg in operator.arguments: -%}
{% if arg.type != "tosa_tensor_t" -%}const {% endif -%}{{arg.type}} client_{{arg.name}}{{arg.shape}}
{% if loop.index < operator.arguments|length %},{% endif %}
- {%- endfor -%}, const func_config_t &func_config = func_config_t{},
- const func_debug_t &func_debug = func_debug_t{});
+ {%- endfor -%},const func_ctx_t& func_ctx);
{% endfor %}
#ifdef __cplusplus