From eb74106e1bc52127e5631736e10e8f8b0b7a1d07 Mon Sep 17 00:00:00 2001 From: Grant Watson Date: Fri, 23 Jun 2023 16:52:12 +0100 Subject: Upgrade to latest version of TOSA specification Signed-off-by: Grant Watson Change-Id: I1296f968baca335ea88691bc973e2d01b2aa2c5b --- reference_model/include/operators.h | 8 ++------ reference_model/src/operators.cc | 8 ++------ scripts/operator_api/generate_api.py | 17 ++++++++++++++--- scripts/operator_api/templates/operators_cc.j2 | 2 ++ thirdparty/specification | 2 +- 5 files changed, 21 insertions(+), 16 deletions(-) diff --git a/reference_model/include/operators.h b/reference_model/include/operators.h index 6e21e95..6efb655 100644 --- a/reference_model/include/operators.h +++ b/reference_model/include/operators.h @@ -130,9 +130,7 @@ extern "C" tosa_status_t tosa_run_transpose_conv2d(tosa_tensor_t client_input, tosa_tensor_t client_weight, tosa_tensor_t client_bias, - const int32_t client_out_pad[4], const int32_t client_stride[2], - const int32_t client_out_shape[4], const int32_t client_input_zp, const int32_t client_weight_zp, const int32_t client_pad_len, @@ -196,7 +194,7 @@ extern "C" tosa_status_t tosa_run_mul(tosa_tensor_t client_input1, tosa_tensor_t client_input2, - const uint8_t client_shift, + const int32_t client_shift, tosa_tensor_t client_output); tosa_status_t tosa_run_pow(tosa_tensor_t client_input1, tosa_tensor_t client_input2, tosa_tensor_t client_output); @@ -288,8 +286,6 @@ extern "C" tosa_tensor_t client_output); tosa_status_t tosa_run_tile(tosa_tensor_t client_input1, - const int32_t client_multiplies_len, - const int32_t client_multiplies[], const int32_t client_multiples_len, const int32_t client_multiples[], tosa_tensor_t client_output); @@ -323,7 +319,7 @@ extern "C" const int32_t client_multiplier_len, const int32_t client_multiplier[], const int32_t client_shift_len, - const uint8_t client_shift[], + const int32_t client_shift[], const bool client_scale32, const bool client_double_round, const bool client_per_channel); diff --git a/reference_model/src/operators.cc b/reference_model/src/operators.cc index e9d6cad..a0b5013 100644 --- a/reference_model/src/operators.cc +++ b/reference_model/src/operators.cc @@ -429,9 +429,7 @@ extern "C" tosa_status_t tosa_run_transpose_conv2d(tosa_tensor_t client_input, tosa_tensor_t client_weight, tosa_tensor_t client_bias, - const int32_t client_out_pad[4], const int32_t client_stride[2], - const int32_t client_out_shape[4], const int32_t client_input_zp, const int32_t client_weight_zp, const int32_t client_pad_len, @@ -1032,7 +1030,7 @@ extern "C" tosa_status_t tosa_run_mul(tosa_tensor_t client_input1, tosa_tensor_t client_input2, - const uint8_t client_shift, + const int32_t client_shift, tosa_tensor_t client_output) { // Create operator attributes @@ -2032,8 +2030,6 @@ extern "C" } tosa_status_t tosa_run_tile(tosa_tensor_t client_input1, - const int32_t client_multiplies_len, - const int32_t client_multiplies[], const int32_t client_multiples_len, const int32_t client_multiples[], tosa_tensor_t client_output) @@ -2256,7 +2252,7 @@ extern "C" const int32_t client_multiplier_len, const int32_t client_multiplier[], const int32_t client_shift_len, - const uint8_t client_shift[], + const int32_t client_shift[], const bool client_scale32, const bool client_double_round, const bool client_per_channel) diff --git a/scripts/operator_api/generate_api.py b/scripts/operator_api/generate_api.py index 5038973..499eadb 100644 --- a/scripts/operator_api/generate_api.py +++ b/scripts/operator_api/generate_api.py @@ -16,7 +16,15 @@ def getTosaArgTypes(tosaXml): """ Returns a list of the TOSA argument types from tosa.xml. """ - argTypes = {"in_t", "out_t", "mul_t", "weight_t", "in_out_t"} + argTypes = { + "tensor_t", + "in_t", + "out_t", + "mul_t", + "weight_t", + "in_out_t", + "tensor_list_t", + } argTypesXml = tosaXml.getElementsByTagName("type") for argTypeXml in argTypesXml: argTypes.add(argTypeXml.getAttribute("name")) @@ -182,7 +190,7 @@ def getOperators(tosaXml): Return a list of TOSA operators as defined by tosa.xml. """ operators = [] - ignoreOps = ["while_loop", "cond_if", "const", "custom", "fft2d", "rfft2d"] + ignoreOps = ["while_loop", "cond_if", "const", "custom", "fft2d", "rfft2d", "erf"] opsXml = tosaXml.getElementsByTagName("operator") allSerializeArgs = getSerializeArgs() for opXml in opsXml: @@ -227,7 +235,10 @@ def getTosaArgs(opXml): tosaTypeMap = {"bool_t": "bool", "uint6_t": "uint8_t", "mode_t": "tosa_mode_t"} for xmlArg in argsXml: argName = xmlArg.getAttribute("name").lower() - argType = xmlArg.getAttribute("type") + if xmlArg.getAttribute("tensor-element-type") == "resize_mode_t": + argType = "tosa_mode_t" + else: + argType = xmlArg.getAttribute("type") argShape = xmlArg.getAttribute("shape") argCategory = xmlArg.getAttribute("category") # Update argument type diff --git a/scripts/operator_api/templates/operators_cc.j2 b/scripts/operator_api/templates/operators_cc.j2 index 3f2acb5..37a0af6 100644 --- a/scripts/operator_api/templates/operators_cc.j2 +++ b/scripts/operator_api/templates/operators_cc.j2 @@ -49,6 +49,8 @@ tosa::DType translate_client_datatype(tosa_datatype_t type) return tosa::DType::DType_FP16; case tosa_datatype_fp32_t: return tosa::DType::DType_FP32; + case tosa_datatype_bool_t: + return tosa::DType::DType_BOOL; default: return tosa::DType::DType_UNKNOWN; } diff --git a/thirdparty/specification b/thirdparty/specification index 0205d99..8e14dcd 160000 --- a/thirdparty/specification +++ b/thirdparty/specification @@ -1 +1 @@ -Subproject commit 0205d99cbff58797bf6602ee5718d50c00d8309b +Subproject commit 8e14dcd2f86e9a3b9c2283fb0f0325088565bbe7 -- cgit v1.2.1