aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/operators.cc
diff options
context:
space:
mode:
authorTai Ly <tai.ly@arm.com>2023-11-13 20:18:14 +0000
committerEric Kunze <eric.kunze@arm.com>2023-11-29 05:19:59 +0000
commitbf59c596b94603d8f534017a5f2c296d7f0efc58 (patch)
treeb9b51a3175305a8d4e385f18c712c5112e18718e /reference_model/src/operators.cc
parent614a911397f138eee6a108d802c0d5c7251a1897 (diff)
downloadreference_model-bf59c596b94603d8f534017a5f2c296d7f0efc58.tar.gz
[reference_model] Add local_bound support
Add support for local_bound attributes. Signed-off-by: Tai Ly <tai.ly@arm.com> Change-Id: Ie1acb65ca2495fb7d1512bf120568c695635d631
Diffstat (limited to 'reference_model/src/operators.cc')
-rw-r--r--reference_model/src/operators.cc13
1 files changed, 9 insertions, 4 deletions
diff --git a/reference_model/src/operators.cc b/reference_model/src/operators.cc
index 9c7f9ef..13e8b12 100644
--- a/reference_model/src/operators.cc
+++ b/reference_model/src/operators.cc
@@ -200,6 +200,7 @@ extern "C"
const int32_t client_dilation[2],
const int32_t client_input_zp,
const int32_t client_weight_zp,
+ const bool client_local_bound,
tosa_tensor_t client_output,
const func_ctx_t& func_ctx)
{
@@ -207,7 +208,7 @@ extern "C"
const std::vector<int32_t> pad(&client_pad[0], &client_pad[4]);
const std::vector<int32_t> stride(&client_stride[0], &client_stride[2]);
const std::vector<int32_t> dilation(&client_dilation[0], &client_dilation[2]);
- TosaConvAttribute attr(pad, stride, dilation, client_input_zp, client_weight_zp);
+ TosaConvAttribute attr(pad, stride, dilation, client_input_zp, client_weight_zp, client_local_bound);
// Create tensors
tosa::TosaSerializationTensor* input = translate_client_tensor(client_input, "input");
@@ -249,6 +250,7 @@ extern "C"
const int32_t client_dilation[3],
const int32_t client_input_zp,
const int32_t client_weight_zp,
+ const bool client_local_bound,
tosa_tensor_t client_output,
const func_ctx_t& func_ctx)
{
@@ -256,7 +258,7 @@ extern "C"
const std::vector<int32_t> pad(&client_pad[0], &client_pad[6]);
const std::vector<int32_t> stride(&client_stride[0], &client_stride[3]);
const std::vector<int32_t> dilation(&client_dilation[0], &client_dilation[3]);
- TosaConvAttribute attr(pad, stride, dilation, client_input_zp, client_weight_zp);
+ TosaConvAttribute attr(pad, stride, dilation, client_input_zp, client_weight_zp, client_local_bound);
// Create tensors
tosa::TosaSerializationTensor* input = translate_client_tensor(client_input, "input");
@@ -298,6 +300,7 @@ extern "C"
const int32_t client_dilation[2],
const int32_t client_input_zp,
const int32_t client_weight_zp,
+ const bool client_local_bound,
tosa_tensor_t client_output,
const func_ctx_t& func_ctx)
{
@@ -305,7 +308,7 @@ extern "C"
const std::vector<int32_t> pad(&client_pad[0], &client_pad[4]);
const std::vector<int32_t> stride(&client_stride[0], &client_stride[2]);
const std::vector<int32_t> dilation(&client_dilation[0], &client_dilation[2]);
- TosaConvAttribute attr(pad, stride, dilation, client_input_zp, client_weight_zp);
+ TosaConvAttribute attr(pad, stride, dilation, client_input_zp, client_weight_zp, client_local_bound);
// Create tensors
tosa::TosaSerializationTensor* input = translate_client_tensor(client_input, "input");
@@ -470,6 +473,7 @@ extern "C"
const int32_t client_out_shape[4],
const int32_t client_input_zp,
const int32_t client_weight_zp,
+ const bool client_local_bound,
tosa_tensor_t client_output,
const func_ctx_t& func_ctx)
{
@@ -477,7 +481,8 @@ extern "C"
const std::vector<int32_t> out_pad(&client_out_pad[0], &client_out_pad[4]);
const std::vector<int32_t> stride(&client_stride[0], &client_stride[2]);
const std::vector<int32_t> out_shape(&client_out_shape[0], &client_out_shape[4]);
- TosaTransposeConvAttribute attr(out_pad, stride, out_shape, client_input_zp, client_weight_zp);
+ TosaTransposeConvAttribute attr(out_pad, stride, out_shape, client_input_zp, client_weight_zp,
+ client_local_bound);
// Create tensors
tosa::TosaSerializationTensor* input = translate_client_tensor(client_input, "input");