aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/operators.cc
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/operators.cc')
-rw-r--r--reference_model/src/operators.cc18
1 files changed, 17 insertions, 1 deletions
diff --git a/reference_model/src/operators.cc b/reference_model/src/operators.cc
index 9b3721b..14065ad 100644
--- a/reference_model/src/operators.cc
+++ b/reference_model/src/operators.cc
@@ -95,6 +95,21 @@ tosa::ResizeMode translate_client_tosa_mode(tosa_mode_t mode)
}
}
+tosa::DType translate_client_acc_size(tosa_acc_size_t acc_size)
+{
+ switch (acc_size)
+ {
+ case tosa_acc_size_int32_t:
+ return tosa::DType::DType_INT32;
+ case tosa_acc_size_fp16_t:
+ return tosa::DType::DType_FP16;
+ case tosa_acc_size_fp32_t:
+ return tosa::DType::DType_FP32;
+ default:
+ return tosa::DType::DType_UNKNOWN;
+ }
+}
+
} // namespace
extern "C"
@@ -138,6 +153,7 @@ extern "C"
const int32_t client_kernel[2],
const int32_t client_stride[2],
const int32_t client_pad[4],
+ const tosa_acc_size_t client_acc_size,
const int32_t client_input_zp,
const int32_t client_output_zp,
tosa_tensor_t client_output,
@@ -147,7 +163,7 @@ extern "C"
const std::vector<int32_t> pad(&client_pad[0], &client_pad[4]);
const std::vector<int32_t> kernel(&client_kernel[0], &client_kernel[2]);
const std::vector<int32_t> stride(&client_stride[0], &client_stride[2]);
- const tosa::DType accum_dtype = tosa::DType::DType_FP32;
+ const tosa::DType accum_dtype = translate_client_acc_size(client_acc_size);
TosaPoolAttribute attr(pad, kernel, stride, client_input_zp, client_output_zp, accum_dtype);
// Create tensors