From bf59c596b94603d8f534017a5f2c296d7f0efc58 Mon Sep 17 00:00:00 2001 From: Tai Ly Date: Mon, 13 Nov 2023 20:18:14 +0000 Subject: [reference_model] Add local_bound support Add support for local_bound attributes. Signed-off-by: Tai Ly Change-Id: Ie1acb65ca2495fb7d1512bf120568c695635d631 --- reference_model/include/operators.h | 4 ++++ reference_model/src/operators.cc | 13 +++++++++---- reference_model/src/ops/tensor_ops.cc | 7 ++++++- reference_model/src/ops/tensor_ops.h | 1 + reference_model/test/model_runner_tests.cpp | 14 +++++++++----- thirdparty/serialization_lib | 2 +- 6 files changed, 30 insertions(+), 11 deletions(-) diff --git a/reference_model/include/operators.h b/reference_model/include/operators.h index 1650ea4..d2bcf87 100644 --- a/reference_model/include/operators.h +++ b/reference_model/include/operators.h @@ -60,6 +60,7 @@ extern "C" const int32_t client_dilation[2], const int32_t client_input_zp, const int32_t client_weight_zp, + const bool client_local_bound, tosa_tensor_t client_output, const func_ctx_t& func_ctx); @@ -71,6 +72,7 @@ extern "C" const int32_t client_dilation[3], const int32_t client_input_zp, const int32_t client_weight_zp, + const bool client_local_bound, tosa_tensor_t client_output, const func_ctx_t& func_ctx); @@ -82,6 +84,7 @@ extern "C" const int32_t client_dilation[2], const int32_t client_input_zp, const int32_t client_weight_zp, + const bool client_local_bound, tosa_tensor_t client_output, const func_ctx_t& func_ctx); @@ -117,6 +120,7 @@ extern "C" const int32_t client_out_shape[4], const int32_t client_input_zp, const int32_t client_weight_zp, + const bool client_local_bound, tosa_tensor_t client_output, const func_ctx_t& func_ctx); diff --git a/reference_model/src/operators.cc b/reference_model/src/operators.cc index 9c7f9ef..13e8b12 100644 --- a/reference_model/src/operators.cc +++ b/reference_model/src/operators.cc @@ -200,6 +200,7 @@ extern "C" const int32_t client_dilation[2], const int32_t client_input_zp, const int32_t client_weight_zp, + const bool client_local_bound, tosa_tensor_t client_output, const func_ctx_t& func_ctx) { @@ -207,7 +208,7 @@ extern "C" const std::vector pad(&client_pad[0], &client_pad[4]); const std::vector stride(&client_stride[0], &client_stride[2]); const std::vector dilation(&client_dilation[0], &client_dilation[2]); - TosaConvAttribute attr(pad, stride, dilation, client_input_zp, client_weight_zp); + TosaConvAttribute attr(pad, stride, dilation, client_input_zp, client_weight_zp, client_local_bound); // Create tensors tosa::TosaSerializationTensor* input = translate_client_tensor(client_input, "input"); @@ -249,6 +250,7 @@ extern "C" const int32_t client_dilation[3], const int32_t client_input_zp, const int32_t client_weight_zp, + const bool client_local_bound, tosa_tensor_t client_output, const func_ctx_t& func_ctx) { @@ -256,7 +258,7 @@ extern "C" const std::vector pad(&client_pad[0], &client_pad[6]); const std::vector stride(&client_stride[0], &client_stride[3]); const std::vector dilation(&client_dilation[0], &client_dilation[3]); - TosaConvAttribute attr(pad, stride, dilation, client_input_zp, client_weight_zp); + TosaConvAttribute attr(pad, stride, dilation, client_input_zp, client_weight_zp, client_local_bound); // Create tensors tosa::TosaSerializationTensor* input = translate_client_tensor(client_input, "input"); @@ -298,6 +300,7 @@ extern "C" const int32_t client_dilation[2], const int32_t client_input_zp, const int32_t client_weight_zp, + const bool client_local_bound, tosa_tensor_t client_output, const func_ctx_t& func_ctx) { @@ -305,7 +308,7 @@ extern "C" const std::vector pad(&client_pad[0], &client_pad[4]); const std::vector stride(&client_stride[0], &client_stride[2]); const std::vector dilation(&client_dilation[0], &client_dilation[2]); - TosaConvAttribute attr(pad, stride, dilation, client_input_zp, client_weight_zp); + TosaConvAttribute attr(pad, stride, dilation, client_input_zp, client_weight_zp, client_local_bound); // Create tensors tosa::TosaSerializationTensor* input = translate_client_tensor(client_input, "input"); @@ -470,6 +473,7 @@ extern "C" const int32_t client_out_shape[4], const int32_t client_input_zp, const int32_t client_weight_zp, + const bool client_local_bound, tosa_tensor_t client_output, const func_ctx_t& func_ctx) { @@ -477,7 +481,8 @@ extern "C" const std::vector out_pad(&client_out_pad[0], &client_out_pad[4]); const std::vector stride(&client_stride[0], &client_stride[2]); const std::vector out_shape(&client_out_shape[0], &client_out_shape[4]); - TosaTransposeConvAttribute attr(out_pad, stride, out_shape, client_input_zp, client_weight_zp); + TosaTransposeConvAttribute attr(out_pad, stride, out_shape, client_input_zp, client_weight_zp, + client_local_bound); // Create tensors tosa::TosaSerializationTensor* input = translate_client_tensor(client_input, "input"); diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc index d9608b7..3aa7830 100644 --- a/reference_model/src/ops/tensor_ops.cc +++ b/reference_model/src/ops/tensor_ops.cc @@ -1786,11 +1786,16 @@ OpRFFT2d::OpRFFT2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_ { setRequiredOperands(1, 2); setRequiredRank(3, 3); + + INIT_ATTRIBUTE(RFFT); } template OpRFFT2d::~OpRFFT2d() -{} +{ + if (attribute) + delete attribute; +} template int OpRFFT2d::checkTensorAttributes() diff --git a/reference_model/src/ops/tensor_ops.h b/reference_model/src/ops/tensor_ops.h index f5fcd7f..e2bb811 100644 --- a/reference_model/src/ops/tensor_ops.h +++ b/reference_model/src/ops/tensor_ops.h @@ -291,6 +291,7 @@ protected: TosaReference::TensorTemplate* in; TosaReference::TensorTemplate* out_real; TosaReference::TensorTemplate* out_imag; + tosa::TosaRFFTAttribute* attribute; }; template diff --git a/reference_model/test/model_runner_tests.cpp b/reference_model/test/model_runner_tests.cpp index 71e26c9..2037b73 100644 --- a/reference_model/test/model_runner_tests.cpp +++ b/reference_model/test/model_runner_tests.cpp @@ -170,9 +170,11 @@ TEST_SUITE("model_runner") const int32_t input_zp = 0; const int32_t weight_zp = 0; + const bool local_bound = false; // Execution - auto status = tosa_run_conv2d(input, weight, bias, pad, stride, dilation, input_zp, weight_zp, output, {}); + auto status = + tosa_run_conv2d(input, weight, bias, pad, stride, dilation, input_zp, weight_zp, local_bound, output, {}); CHECK((status == tosa_status_valid)); // Compare results @@ -229,10 +231,11 @@ TEST_SUITE("model_runner") const int32_t input_zp = 0; const int32_t weight_zp = 0; + const bool local_bound = false; // Execution - auto status = - tosa_run_transpose_conv2d(input, weight, bias, out_pad, stride, out_shape, input_zp, weight_zp, output, {}); + auto status = tosa_run_transpose_conv2d(input, weight, bias, out_pad, stride, out_shape, input_zp, weight_zp, + local_bound, output, {}); CHECK((status == tosa_status_valid)); // Compare results @@ -288,12 +291,13 @@ TEST_SUITE("model_runner") const int32_t input_zp = 0; const int32_t weight_zp = 0; + const bool local_bound = false; // Execution func_ctx_t func_ctx; func_ctx.func_config.abs_mode = true; - auto status = - tosa_run_conv2d(input, weight, bias, pad, stride, dilation, input_zp, weight_zp, output, func_ctx); + auto status = tosa_run_conv2d(input, weight, bias, pad, stride, dilation, input_zp, weight_zp, local_bound, + output, func_ctx); CHECK((status == tosa_status_valid)); // Compare results diff --git a/thirdparty/serialization_lib b/thirdparty/serialization_lib index 2d33385..cce787e 160000 --- a/thirdparty/serialization_lib +++ b/thirdparty/serialization_lib @@ -1 +1 @@ -Subproject commit 2d3338562530f903bf03391077015914514e2274 +Subproject commit cce787e55f2f4e9fa14fd4643ddfd34de754c500 -- cgit v1.2.1