aboutsummaryrefslogtreecommitdiff
path: root/scripts
diff options
context:
space:
mode:
authorGrant Watson <grant.watson@arm.com>2023-06-23 16:52:12 +0100
committerEric Kunze <eric.kunze@arm.com>2023-06-26 15:44:59 +0000
commiteb74106e1bc52127e5631736e10e8f8b0b7a1d07 (patch)
tree1e5b4286ca1c55eb37bd3ce3a80669b420a4299b /scripts
parentfe36fa9f38824d03250393488fe468b7dacc72ed (diff)
downloadreference_model-eb74106e1bc52127e5631736e10e8f8b0b7a1d07.tar.gz
Upgrade to latest version of TOSA specification
Signed-off-by: Grant Watson <grant.watson@arm.com> Change-Id: I1296f968baca335ea88691bc973e2d01b2aa2c5b
Diffstat (limited to 'scripts')
-rw-r--r--scripts/operator_api/generate_api.py17
-rw-r--r--scripts/operator_api/templates/operators_cc.j22
2 files changed, 16 insertions, 3 deletions
diff --git a/scripts/operator_api/generate_api.py b/scripts/operator_api/generate_api.py
index 5038973..499eadb 100644
--- a/scripts/operator_api/generate_api.py
+++ b/scripts/operator_api/generate_api.py
@@ -16,7 +16,15 @@ def getTosaArgTypes(tosaXml):
"""
Returns a list of the TOSA argument types from tosa.xml.
"""
- argTypes = {"in_t", "out_t", "mul_t", "weight_t", "in_out_t"}
+ argTypes = {
+ "tensor_t",
+ "in_t",
+ "out_t",
+ "mul_t",
+ "weight_t",
+ "in_out_t",
+ "tensor_list_t",
+ }
argTypesXml = tosaXml.getElementsByTagName("type")
for argTypeXml in argTypesXml:
argTypes.add(argTypeXml.getAttribute("name"))
@@ -182,7 +190,7 @@ def getOperators(tosaXml):
Return a list of TOSA operators as defined by tosa.xml.
"""
operators = []
- ignoreOps = ["while_loop", "cond_if", "const", "custom", "fft2d", "rfft2d"]
+ ignoreOps = ["while_loop", "cond_if", "const", "custom", "fft2d", "rfft2d", "erf"]
opsXml = tosaXml.getElementsByTagName("operator")
allSerializeArgs = getSerializeArgs()
for opXml in opsXml:
@@ -227,7 +235,10 @@ def getTosaArgs(opXml):
tosaTypeMap = {"bool_t": "bool", "uint6_t": "uint8_t", "mode_t": "tosa_mode_t"}
for xmlArg in argsXml:
argName = xmlArg.getAttribute("name").lower()
- argType = xmlArg.getAttribute("type")
+ if xmlArg.getAttribute("tensor-element-type") == "resize_mode_t":
+ argType = "tosa_mode_t"
+ else:
+ argType = xmlArg.getAttribute("type")
argShape = xmlArg.getAttribute("shape")
argCategory = xmlArg.getAttribute("category")
# Update argument type
diff --git a/scripts/operator_api/templates/operators_cc.j2 b/scripts/operator_api/templates/operators_cc.j2
index 3f2acb5..37a0af6 100644
--- a/scripts/operator_api/templates/operators_cc.j2
+++ b/scripts/operator_api/templates/operators_cc.j2
@@ -49,6 +49,8 @@ tosa::DType translate_client_datatype(tosa_datatype_t type)
return tosa::DType::DType_FP16;
case tosa_datatype_fp32_t:
return tosa::DType::DType_FP32;
+ case tosa_datatype_bool_t:
+ return tosa::DType::DType_BOOL;
default:
return tosa::DType::DType_UNKNOWN;
}