From 0c0a263bf6742e943bebd42ccf97dcdbd8f4e1c8 Mon Sep 17 00:00:00 2001 From: Dmitrii Agibov Date: Thu, 21 Sep 2023 11:05:58 +0100 Subject: Pass parameter acc_size to AvgPool2d operator Signed-off-by: Dmitrii Agibov Change-Id: I4cd818af0db5e6e8a96641246cd3263ba4878f56 --- reference_model/include/operators.h | 1 + reference_model/include/types.h | 7 +++++++ reference_model/src/operators.cc | 18 +++++++++++++++++- reference_model/test/model_runner_tests.cpp | 4 +++- scripts/operator_api/generate_api.py | 12 ++++++++++-- scripts/operator_api/templates/operators_cc.j2 | 13 +++++++++++++ 6 files changed, 51 insertions(+), 4 deletions(-) diff --git a/reference_model/include/operators.h b/reference_model/include/operators.h index 08da277..1399233 100644 --- a/reference_model/include/operators.h +++ b/reference_model/include/operators.h @@ -46,6 +46,7 @@ extern "C" const int32_t client_kernel[2], const int32_t client_stride[2], const int32_t client_pad[4], + const tosa_acc_size_t client_acc_size, const int32_t client_input_zp, const int32_t client_output_zp, tosa_tensor_t client_output, diff --git a/reference_model/include/types.h b/reference_model/include/types.h index 42040bf..a371d04 100644 --- a/reference_model/include/types.h +++ b/reference_model/include/types.h @@ -58,6 +58,13 @@ extern "C" tosa_datatype_fp64_t = 99 }; + enum tosa_acc_size_t + { + tosa_acc_size_int32_t = 0, + tosa_acc_size_fp16_t = 1, + tosa_acc_size_fp32_t = 2 + }; + struct tosa_tensor_t { const char* name; diff --git a/reference_model/src/operators.cc b/reference_model/src/operators.cc index 9b3721b..14065ad 100644 --- a/reference_model/src/operators.cc +++ b/reference_model/src/operators.cc @@ -95,6 +95,21 @@ tosa::ResizeMode translate_client_tosa_mode(tosa_mode_t mode) } } +tosa::DType translate_client_acc_size(tosa_acc_size_t acc_size) +{ + switch (acc_size) + { + case tosa_acc_size_int32_t: + return tosa::DType::DType_INT32; + case tosa_acc_size_fp16_t: + return tosa::DType::DType_FP16; + case tosa_acc_size_fp32_t: + return tosa::DType::DType_FP32; + default: + return tosa::DType::DType_UNKNOWN; + } +} + } // namespace extern "C" @@ -138,6 +153,7 @@ extern "C" const int32_t client_kernel[2], const int32_t client_stride[2], const int32_t client_pad[4], + const tosa_acc_size_t client_acc_size, const int32_t client_input_zp, const int32_t client_output_zp, tosa_tensor_t client_output, @@ -147,7 +163,7 @@ extern "C" const std::vector pad(&client_pad[0], &client_pad[4]); const std::vector kernel(&client_kernel[0], &client_kernel[2]); const std::vector stride(&client_stride[0], &client_stride[2]); - const tosa::DType accum_dtype = tosa::DType::DType_FP32; + const tosa::DType accum_dtype = translate_client_acc_size(client_acc_size); TosaPoolAttribute attr(pad, kernel, stride, client_input_zp, client_output_zp, accum_dtype); // Create tensors diff --git a/reference_model/test/model_runner_tests.cpp b/reference_model/test/model_runner_tests.cpp index 820ed63..7cf9d68 100644 --- a/reference_model/test/model_runner_tests.cpp +++ b/reference_model/test/model_runner_tests.cpp @@ -111,8 +111,10 @@ TEST_SUITE("model_runner") output.data = reinterpret_cast(dstData.data()); output.size = dstData.size() * sizeof(float); + tosa_acc_size_t acc_size = tosa_acc_size_fp32_t; + // Execution - auto status = tosa_run_avg_pool2d(input, kernel, stride, pad, 0, 0, output, {}); + auto status = tosa_run_avg_pool2d(input, kernel, stride, pad, acc_size, 0, 0, output, {}); CHECK((status == tosa_status_valid)); // Compare results diff --git a/scripts/operator_api/generate_api.py b/scripts/operator_api/generate_api.py index c5c762d..31ee151 100644 --- a/scripts/operator_api/generate_api.py +++ b/scripts/operator_api/generate_api.py @@ -102,6 +102,9 @@ def getSerialLibAttsForOp(tosaOpName, allSerialLibAtts, tosaArgs): if attType in serTosaTypeMap.keys(): init = f"const {attType} {attName} = translate_client_{serTosaTypeMap[att['dType']]}(client_{attName});" # Initialize Serialization library attributes to their matching function parameter + elif tosaOpName == "avg_pool2d" and attName == "accum_dtype": + init = f"const tosa::DType {attName} = translate_client_acc_size(client_acc_size);" + att["dType"] = "tosa::DType" elif attName in tosaArgsDict: if att["SV"] == "V": if tosaArgsDict[attName]["type"] == "tosa_tensor_t": @@ -270,10 +273,15 @@ def getTosaArgs(opXml): argsXml = opXml.getElementsByTagName("argument") tosaTensorTypes = getTosaArgTypes(tosaXml) tosaTypeMap = {"bool_t": "bool", "uint6_t": "uint8_t", "mode_t": "tosa_mode_t"} + tensorElemTypeMap = { + "resize_mode_t": "tosa_mode_t", + "acc_size_t": "tosa_acc_size_t", + } for xmlArg in argsXml: argName = xmlArg.getAttribute("name").lower() - if xmlArg.getAttribute("tensor-element-type") == "resize_mode_t": - argType = "tosa_mode_t" + tensorElemType = xmlArg.getAttribute("tensor-element-type") + if tensorElemType in tensorElemTypeMap: + argType = tensorElemTypeMap[tensorElemType] else: argType = xmlArg.getAttribute("type") argShape = xmlArg.getAttribute("shape") diff --git a/scripts/operator_api/templates/operators_cc.j2 b/scripts/operator_api/templates/operators_cc.j2 index 6b6f864..1de103a 100644 --- a/scripts/operator_api/templates/operators_cc.j2 +++ b/scripts/operator_api/templates/operators_cc.j2 @@ -92,6 +92,19 @@ tosa::ResizeMode translate_client_tosa_mode(tosa_mode_t mode) { } } +tosa::DType translate_client_acc_size(tosa_acc_size_t acc_size) { + switch(acc_size) { + case tosa_acc_size_int32_t: + return tosa::DType::DType_INT32; + case tosa_acc_size_fp16_t: + return tosa::DType::DType_FP16; + case tosa_acc_size_fp32_t: + return tosa::DType::DType_FP32; + default: + return tosa::DType::DType_UNKNOWN; + } +} + } // namespace extern "C" -- cgit v1.2.1