aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTai Ly <tai.ly@arm.com>2023-11-13 20:18:14 +0000
committerTai Ly <tai.ly@arm.com>2023-11-14 23:10:30 +0000
commitfd8fde80452ba68a21de4de53517ebc4b4aac9ea (patch)
tree42b46a8b185ac984a12f5e834f15d909439dc812
parentaee62afba99a74f772b97356fd4c18f3fdf37073 (diff)
downloadreference_model-fd8fde80452ba68a21de4de53517ebc4b4aac9ea.tar.gz
[reference_model] Add local_bound support
Add support for local_bound attributes. Signed-off-by: Tai Ly <tai.ly@arm.com> Change-Id: Ie1acb65ca2495fb7d1512bf120568c695635d631
-rw-r--r--reference_model/include/operators.h4
-rw-r--r--reference_model/src/operators.cc13
-rw-r--r--reference_model/src/ops/tensor_ops.cc7
-rw-r--r--reference_model/src/ops/tensor_ops.h1
-rw-r--r--reference_model/test/model_runner_tests.cpp14
m---------thirdparty/serialization_lib0
6 files changed, 29 insertions, 10 deletions
diff --git a/reference_model/include/operators.h b/reference_model/include/operators.h
index 1650ea4..d2bcf87 100644
--- a/reference_model/include/operators.h
+++ b/reference_model/include/operators.h
@@ -60,6 +60,7 @@ extern "C"
const int32_t client_dilation[2],
const int32_t client_input_zp,
const int32_t client_weight_zp,
+ const bool client_local_bound,
tosa_tensor_t client_output,
const func_ctx_t& func_ctx);
@@ -71,6 +72,7 @@ extern "C"
const int32_t client_dilation[3],
const int32_t client_input_zp,
const int32_t client_weight_zp,
+ const bool client_local_bound,
tosa_tensor_t client_output,
const func_ctx_t& func_ctx);
@@ -82,6 +84,7 @@ extern "C"
const int32_t client_dilation[2],
const int32_t client_input_zp,
const int32_t client_weight_zp,
+ const bool client_local_bound,
tosa_tensor_t client_output,
const func_ctx_t& func_ctx);
@@ -117,6 +120,7 @@ extern "C"
const int32_t client_out_shape[4],
const int32_t client_input_zp,
const int32_t client_weight_zp,
+ const bool client_local_bound,
tosa_tensor_t client_output,
const func_ctx_t& func_ctx);
diff --git a/reference_model/src/operators.cc b/reference_model/src/operators.cc
index 9c7f9ef..13e8b12 100644
--- a/reference_model/src/operators.cc
+++ b/reference_model/src/operators.cc
@@ -200,6 +200,7 @@ extern "C"
const int32_t client_dilation[2],
const int32_t client_input_zp,
const int32_t client_weight_zp,
+ const bool client_local_bound,
tosa_tensor_t client_output,
const func_ctx_t& func_ctx)
{
@@ -207,7 +208,7 @@ extern "C"
const std::vector<int32_t> pad(&client_pad[0], &client_pad[4]);
const std::vector<int32_t> stride(&client_stride[0], &client_stride[2]);
const std::vector<int32_t> dilation(&client_dilation[0], &client_dilation[2]);
- TosaConvAttribute attr(pad, stride, dilation, client_input_zp, client_weight_zp);
+ TosaConvAttribute attr(pad, stride, dilation, client_input_zp, client_weight_zp, client_local_bound);
// Create tensors
tosa::TosaSerializationTensor* input = translate_client_tensor(client_input, "input");
@@ -249,6 +250,7 @@ extern "C"
const int32_t client_dilation[3],
const int32_t client_input_zp,
const int32_t client_weight_zp,
+ const bool client_local_bound,
tosa_tensor_t client_output,
const func_ctx_t& func_ctx)
{
@@ -256,7 +258,7 @@ extern "C"
const std::vector<int32_t> pad(&client_pad[0], &client_pad[6]);
const std::vector<int32_t> stride(&client_stride[0], &client_stride[3]);
const std::vector<int32_t> dilation(&client_dilation[0], &client_dilation[3]);
- TosaConvAttribute attr(pad, stride, dilation, client_input_zp, client_weight_zp);
+ TosaConvAttribute attr(pad, stride, dilation, client_input_zp, client_weight_zp, client_local_bound);
// Create tensors
tosa::TosaSerializationTensor* input = translate_client_tensor(client_input, "input");
@@ -298,6 +300,7 @@ extern "C"
const int32_t client_dilation[2],
const int32_t client_input_zp,
const int32_t client_weight_zp,
+ const bool client_local_bound,
tosa_tensor_t client_output,
const func_ctx_t& func_ctx)
{
@@ -305,7 +308,7 @@ extern "C"
const std::vector<int32_t> pad(&client_pad[0], &client_pad[4]);
const std::vector<int32_t> stride(&client_stride[0], &client_stride[2]);
const std::vector<int32_t> dilation(&client_dilation[0], &client_dilation[2]);
- TosaConvAttribute attr(pad, stride, dilation, client_input_zp, client_weight_zp);
+ TosaConvAttribute attr(pad, stride, dilation, client_input_zp, client_weight_zp, client_local_bound);
// Create tensors
tosa::TosaSerializationTensor* input = translate_client_tensor(client_input, "input");
@@ -470,6 +473,7 @@ extern "C"
const int32_t client_out_shape[4],
const int32_t client_input_zp,
const int32_t client_weight_zp,
+ const bool client_local_bound,
tosa_tensor_t client_output,
const func_ctx_t& func_ctx)
{
@@ -477,7 +481,8 @@ extern "C"
const std::vector<int32_t> out_pad(&client_out_pad[0], &client_out_pad[4]);
const std::vector<int32_t> stride(&client_stride[0], &client_stride[2]);
const std::vector<int32_t> out_shape(&client_out_shape[0], &client_out_shape[4]);
- TosaTransposeConvAttribute attr(out_pad, stride, out_shape, client_input_zp, client_weight_zp);
+ TosaTransposeConvAttribute attr(out_pad, stride, out_shape, client_input_zp, client_weight_zp,
+ client_local_bound);
// Create tensors
tosa::TosaSerializationTensor* input = translate_client_tensor(client_input, "input");
diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc
index 2d54d8e..3f0e7b2 100644
--- a/reference_model/src/ops/tensor_ops.cc
+++ b/reference_model/src/ops/tensor_ops.cc
@@ -1736,11 +1736,16 @@ OpRFFT2d<Dtype>::OpRFFT2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_
{
setRequiredOperands(1, 2);
setRequiredRank(3, 3);
+
+ INIT_ATTRIBUTE(RFFT);
}
template <TOSA_REF_TYPE Dtype>
OpRFFT2d<Dtype>::~OpRFFT2d()
-{}
+{
+ if (attribute)
+ delete attribute;
+}
template <TOSA_REF_TYPE Dtype>
int OpRFFT2d<Dtype>::checkTensorAttributes()
diff --git a/reference_model/src/ops/tensor_ops.h b/reference_model/src/ops/tensor_ops.h
index f5fcd7f..e2bb811 100644
--- a/reference_model/src/ops/tensor_ops.h
+++ b/reference_model/src/ops/tensor_ops.h
@@ -291,6 +291,7 @@ protected:
TosaReference::TensorTemplate<TIn>* in;
TosaReference::TensorTemplate<TOut>* out_real;
TosaReference::TensorTemplate<TOut>* out_imag;
+ tosa::TosaRFFTAttribute* attribute;
};
template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
diff --git a/reference_model/test/model_runner_tests.cpp b/reference_model/test/model_runner_tests.cpp
index 71e26c9..2037b73 100644
--- a/reference_model/test/model_runner_tests.cpp
+++ b/reference_model/test/model_runner_tests.cpp
@@ -170,9 +170,11 @@ TEST_SUITE("model_runner")
const int32_t input_zp = 0;
const int32_t weight_zp = 0;
+ const bool local_bound = false;
// Execution
- auto status = tosa_run_conv2d(input, weight, bias, pad, stride, dilation, input_zp, weight_zp, output, {});
+ auto status =
+ tosa_run_conv2d(input, weight, bias, pad, stride, dilation, input_zp, weight_zp, local_bound, output, {});
CHECK((status == tosa_status_valid));
// Compare results
@@ -229,10 +231,11 @@ TEST_SUITE("model_runner")
const int32_t input_zp = 0;
const int32_t weight_zp = 0;
+ const bool local_bound = false;
// Execution
- auto status =
- tosa_run_transpose_conv2d(input, weight, bias, out_pad, stride, out_shape, input_zp, weight_zp, output, {});
+ auto status = tosa_run_transpose_conv2d(input, weight, bias, out_pad, stride, out_shape, input_zp, weight_zp,
+ local_bound, output, {});
CHECK((status == tosa_status_valid));
// Compare results
@@ -288,12 +291,13 @@ TEST_SUITE("model_runner")
const int32_t input_zp = 0;
const int32_t weight_zp = 0;
+ const bool local_bound = false;
// Execution
func_ctx_t func_ctx;
func_ctx.func_config.abs_mode = true;
- auto status =
- tosa_run_conv2d(input, weight, bias, pad, stride, dilation, input_zp, weight_zp, output, func_ctx);
+ auto status = tosa_run_conv2d(input, weight, bias, pad, stride, dilation, input_zp, weight_zp, local_bound,
+ output, func_ctx);
CHECK((status == tosa_status_valid));
// Compare results
diff --git a/thirdparty/serialization_lib b/thirdparty/serialization_lib
-Subproject c0a60300951a59c33d2afaea0f6ca0889cabf34
+Subproject 4881c29247d4b411de446b13d9bd58ea93737aa