diff options
Diffstat (limited to 'reference_model/src/operators.cc')
-rw-r--r-- | reference_model/src/operators.cc | 18 |
1 files changed, 17 insertions, 1 deletions
diff --git a/reference_model/src/operators.cc b/reference_model/src/operators.cc index 9b3721b..14065ad 100644 --- a/reference_model/src/operators.cc +++ b/reference_model/src/operators.cc @@ -95,6 +95,21 @@ tosa::ResizeMode translate_client_tosa_mode(tosa_mode_t mode) } } +tosa::DType translate_client_acc_size(tosa_acc_size_t acc_size) +{ + switch (acc_size) + { + case tosa_acc_size_int32_t: + return tosa::DType::DType_INT32; + case tosa_acc_size_fp16_t: + return tosa::DType::DType_FP16; + case tosa_acc_size_fp32_t: + return tosa::DType::DType_FP32; + default: + return tosa::DType::DType_UNKNOWN; + } +} + } // namespace extern "C" @@ -138,6 +153,7 @@ extern "C" const int32_t client_kernel[2], const int32_t client_stride[2], const int32_t client_pad[4], + const tosa_acc_size_t client_acc_size, const int32_t client_input_zp, const int32_t client_output_zp, tosa_tensor_t client_output, @@ -147,7 +163,7 @@ extern "C" const std::vector<int32_t> pad(&client_pad[0], &client_pad[4]); const std::vector<int32_t> kernel(&client_kernel[0], &client_kernel[2]); const std::vector<int32_t> stride(&client_stride[0], &client_stride[2]); - const tosa::DType accum_dtype = tosa::DType::DType_FP32; + const tosa::DType accum_dtype = translate_client_acc_size(client_acc_size); TosaPoolAttribute attr(pad, kernel, stride, client_input_zp, client_output_zp, accum_dtype); // Create tensors |