aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDmitrii Agibov <dmitrii.agibov@arm.com>2023-09-21 11:05:58 +0100
committerEric Kunze <eric.kunze@arm.com>2023-11-28 16:42:24 -0800
commit0c0a263bf6742e943bebd42ccf97dcdbd8f4e1c8 (patch)
tree19c1d48d61b1774ccaf94b566900d17e9dbc0fb7
parent4762564da970eb1883a54aa66582e05c0dbd2b81 (diff)
downloadreference_model-0c0a263bf6742e943bebd42ccf97dcdbd8f4e1c8.tar.gz
Pass parameter acc_size to AvgPool2d operator
Signed-off-by: Dmitrii Agibov <dmitrii.agibov@arm.com> Change-Id: I4cd818af0db5e6e8a96641246cd3263ba4878f56
-rw-r--r--reference_model/include/operators.h1
-rw-r--r--reference_model/include/types.h7
-rw-r--r--reference_model/src/operators.cc18
-rw-r--r--reference_model/test/model_runner_tests.cpp4
-rw-r--r--scripts/operator_api/generate_api.py12
-rw-r--r--scripts/operator_api/templates/operators_cc.j213
6 files changed, 51 insertions, 4 deletions
diff --git a/reference_model/include/operators.h b/reference_model/include/operators.h
index 08da277..1399233 100644
--- a/reference_model/include/operators.h
+++ b/reference_model/include/operators.h
@@ -46,6 +46,7 @@ extern "C"
const int32_t client_kernel[2],
const int32_t client_stride[2],
const int32_t client_pad[4],
+ const tosa_acc_size_t client_acc_size,
const int32_t client_input_zp,
const int32_t client_output_zp,
tosa_tensor_t client_output,
diff --git a/reference_model/include/types.h b/reference_model/include/types.h
index 42040bf..a371d04 100644
--- a/reference_model/include/types.h
+++ b/reference_model/include/types.h
@@ -58,6 +58,13 @@ extern "C"
tosa_datatype_fp64_t = 99
};
+ enum tosa_acc_size_t
+ {
+ tosa_acc_size_int32_t = 0,
+ tosa_acc_size_fp16_t = 1,
+ tosa_acc_size_fp32_t = 2
+ };
+
struct tosa_tensor_t
{
const char* name;
diff --git a/reference_model/src/operators.cc b/reference_model/src/operators.cc
index 9b3721b..14065ad 100644
--- a/reference_model/src/operators.cc
+++ b/reference_model/src/operators.cc
@@ -95,6 +95,21 @@ tosa::ResizeMode translate_client_tosa_mode(tosa_mode_t mode)
}
}
+tosa::DType translate_client_acc_size(tosa_acc_size_t acc_size)
+{
+ switch (acc_size)
+ {
+ case tosa_acc_size_int32_t:
+ return tosa::DType::DType_INT32;
+ case tosa_acc_size_fp16_t:
+ return tosa::DType::DType_FP16;
+ case tosa_acc_size_fp32_t:
+ return tosa::DType::DType_FP32;
+ default:
+ return tosa::DType::DType_UNKNOWN;
+ }
+}
+
} // namespace
extern "C"
@@ -138,6 +153,7 @@ extern "C"
const int32_t client_kernel[2],
const int32_t client_stride[2],
const int32_t client_pad[4],
+ const tosa_acc_size_t client_acc_size,
const int32_t client_input_zp,
const int32_t client_output_zp,
tosa_tensor_t client_output,
@@ -147,7 +163,7 @@ extern "C"
const std::vector<int32_t> pad(&client_pad[0], &client_pad[4]);
const std::vector<int32_t> kernel(&client_kernel[0], &client_kernel[2]);
const std::vector<int32_t> stride(&client_stride[0], &client_stride[2]);
- const tosa::DType accum_dtype = tosa::DType::DType_FP32;
+ const tosa::DType accum_dtype = translate_client_acc_size(client_acc_size);
TosaPoolAttribute attr(pad, kernel, stride, client_input_zp, client_output_zp, accum_dtype);
// Create tensors
diff --git a/reference_model/test/model_runner_tests.cpp b/reference_model/test/model_runner_tests.cpp
index 820ed63..7cf9d68 100644
--- a/reference_model/test/model_runner_tests.cpp
+++ b/reference_model/test/model_runner_tests.cpp
@@ -111,8 +111,10 @@ TEST_SUITE("model_runner")
output.data = reinterpret_cast<uint8_t*>(dstData.data());
output.size = dstData.size() * sizeof(float);
+ tosa_acc_size_t acc_size = tosa_acc_size_fp32_t;
+
// Execution
- auto status = tosa_run_avg_pool2d(input, kernel, stride, pad, 0, 0, output, {});
+ auto status = tosa_run_avg_pool2d(input, kernel, stride, pad, acc_size, 0, 0, output, {});
CHECK((status == tosa_status_valid));
// Compare results
diff --git a/scripts/operator_api/generate_api.py b/scripts/operator_api/generate_api.py
index c5c762d..31ee151 100644
--- a/scripts/operator_api/generate_api.py
+++ b/scripts/operator_api/generate_api.py
@@ -102,6 +102,9 @@ def getSerialLibAttsForOp(tosaOpName, allSerialLibAtts, tosaArgs):
if attType in serTosaTypeMap.keys():
init = f"const {attType} {attName} = translate_client_{serTosaTypeMap[att['dType']]}(client_{attName});"
# Initialize Serialization library attributes to their matching function parameter
+ elif tosaOpName == "avg_pool2d" and attName == "accum_dtype":
+ init = f"const tosa::DType {attName} = translate_client_acc_size(client_acc_size);"
+ att["dType"] = "tosa::DType"
elif attName in tosaArgsDict:
if att["SV"] == "V":
if tosaArgsDict[attName]["type"] == "tosa_tensor_t":
@@ -270,10 +273,15 @@ def getTosaArgs(opXml):
argsXml = opXml.getElementsByTagName("argument")
tosaTensorTypes = getTosaArgTypes(tosaXml)
tosaTypeMap = {"bool_t": "bool", "uint6_t": "uint8_t", "mode_t": "tosa_mode_t"}
+ tensorElemTypeMap = {
+ "resize_mode_t": "tosa_mode_t",
+ "acc_size_t": "tosa_acc_size_t",
+ }
for xmlArg in argsXml:
argName = xmlArg.getAttribute("name").lower()
- if xmlArg.getAttribute("tensor-element-type") == "resize_mode_t":
- argType = "tosa_mode_t"
+ tensorElemType = xmlArg.getAttribute("tensor-element-type")
+ if tensorElemType in tensorElemTypeMap:
+ argType = tensorElemTypeMap[tensorElemType]
else:
argType = xmlArg.getAttribute("type")
argShape = xmlArg.getAttribute("shape")
diff --git a/scripts/operator_api/templates/operators_cc.j2 b/scripts/operator_api/templates/operators_cc.j2
index 6b6f864..1de103a 100644
--- a/scripts/operator_api/templates/operators_cc.j2
+++ b/scripts/operator_api/templates/operators_cc.j2
@@ -92,6 +92,19 @@ tosa::ResizeMode translate_client_tosa_mode(tosa_mode_t mode) {
}
}
+tosa::DType translate_client_acc_size(tosa_acc_size_t acc_size) {
+ switch(acc_size) {
+ case tosa_acc_size_int32_t:
+ return tosa::DType::DType_INT32;
+ case tosa_acc_size_fp16_t:
+ return tosa::DType::DType_FP16;
+ case tosa_acc_size_fp32_t:
+ return tosa::DType::DType_FP32;
+ default:
+ return tosa::DType::DType_UNKNOWN;
+ }
+}
+
} // namespace
extern "C"