diff options
author | James Ward <james.ward@arm.com> | 2023-01-18 14:51:25 +0000 |
---|---|---|
committer | Eric Kunze <eric.kunze@arm.com> | 2023-02-02 22:03:25 +0000 |
commit | d34b3fc5eeef48ecc781a02433ce022a28e3373c (patch) | |
tree | 13aa36aa89c618e56eb2f51915a172ff8e4276d9 /reference_model/src/operators.cc | |
parent | 512c1caa8b6d494de81f3ac83a6ebb96e1e0f8e0 (diff) | |
download | reference_model-d34b3fc5eeef48ecc781a02433ce022a28e3373c.tar.gz |
Remove accumulator attributes from all but AVG_POOL2D
Signed-off-by: James Ward <james.ward@arm.com>
Change-Id: If67f503a1848967bc1671646c3011d055b622c52
Diffstat (limited to 'reference_model/src/operators.cc')
-rw-r--r-- | reference_model/src/operators.cc | 42 |
1 files changed, 18 insertions, 24 deletions
diff --git a/reference_model/src/operators.cc b/reference_model/src/operators.cc index af348ca..a627322 100644 --- a/reference_model/src/operators.cc +++ b/reference_model/src/operators.cc @@ -168,10 +168,9 @@ extern "C" const std::vector<int32_t> pad(&client_pad[0], &client_pad[4]); const std::vector<int32_t> stride(&client_stride[0], &client_stride[2]); const std::vector<int32_t> dilation(&client_dilation[0], &client_dilation[2]); - const int32_t input_zp = client_input_zp; - const int32_t weight_zp = client_weight_zp; - const tosa::DType accum_dtype = tosa::DType::DType_FP32; - TosaConvAttribute attr(pad, stride, dilation, input_zp, weight_zp, accum_dtype); + const int32_t input_zp = client_input_zp; + const int32_t weight_zp = client_weight_zp; + TosaConvAttribute attr(pad, stride, dilation, input_zp, weight_zp); // Create tensors tosa::TosaSerializationTensor* input = translate_client_tensor(client_input, "input"); @@ -219,10 +218,9 @@ extern "C" const std::vector<int32_t> pad(&client_pad[0], &client_pad[6]); const std::vector<int32_t> stride(&client_stride[0], &client_stride[3]); const std::vector<int32_t> dilation(&client_dilation[0], &client_dilation[3]); - const int32_t input_zp = client_input_zp; - const int32_t weight_zp = client_weight_zp; - const tosa::DType accum_dtype = tosa::DType::DType_FP32; - TosaConvAttribute attr(pad, stride, dilation, input_zp, weight_zp, accum_dtype); + const int32_t input_zp = client_input_zp; + const int32_t weight_zp = client_weight_zp; + TosaConvAttribute attr(pad, stride, dilation, input_zp, weight_zp); // Create tensors tosa::TosaSerializationTensor* input = translate_client_tensor(client_input, "input"); @@ -270,10 +268,9 @@ extern "C" const std::vector<int32_t> pad(&client_pad[0], &client_pad[4]); const std::vector<int32_t> stride(&client_stride[0], &client_stride[2]); const std::vector<int32_t> dilation(&client_dilation[0], &client_dilation[2]); - const int32_t input_zp = client_input_zp; - const int32_t weight_zp = client_weight_zp; - const tosa::DType accum_dtype = tosa::DType::DType_FP32; - TosaConvAttribute attr(pad, stride, dilation, input_zp, weight_zp, accum_dtype); + const int32_t input_zp = client_input_zp; + const int32_t weight_zp = client_weight_zp; + TosaConvAttribute attr(pad, stride, dilation, input_zp, weight_zp); // Create tensors tosa::TosaSerializationTensor* input = translate_client_tensor(client_input, "input"); @@ -313,10 +310,9 @@ extern "C" tosa_tensor_t client_output) { // Create operator attributes - const int32_t input_zp = client_input_zp; - const int32_t weight_zp = client_weight_zp; - const tosa::DType accum_dtype = tosa::DType::DType_FP32; - TosaFullyConnectedAttribute attr(input_zp, weight_zp, accum_dtype); + const int32_t input_zp = client_input_zp; + const int32_t weight_zp = client_weight_zp; + TosaFullyConnectedAttribute attr(input_zp, weight_zp); // Create tensors tosa::TosaSerializationTensor* input = translate_client_tensor(client_input, "input"); @@ -352,10 +348,9 @@ extern "C" tosa_tensor_t client_output) { // Create operator attributes - const int32_t a_zp = client_a_zp; - const int32_t b_zp = client_b_zp; - const tosa::DType accum_dtype = tosa::DType::DType_FP32; - TosaMatMulAttribute attr(a_zp, b_zp, accum_dtype); + const int32_t a_zp = client_a_zp; + const int32_t b_zp = client_b_zp; + TosaMatMulAttribute attr(a_zp, b_zp); // Create tensors tosa::TosaSerializationTensor* a = translate_client_tensor(client_a, "a"); @@ -446,10 +441,9 @@ extern "C" const std::vector<int32_t> pad(&client_pad[0], &client_pad[0] + client_pad_len); const std::vector<int32_t> stride(&client_stride[0], &client_stride[2]); const std::vector<int32_t> dilation(&client_dilation[0], &client_dilation[0] + client_dilation_len); - const int32_t input_zp = client_input_zp; - const int32_t weight_zp = client_weight_zp; - const tosa::DType accum_dtype = tosa::DType::DType_FP32; - TosaConvAttribute attr(pad, stride, dilation, input_zp, weight_zp, accum_dtype); + const int32_t input_zp = client_input_zp; + const int32_t weight_zp = client_weight_zp; + TosaConvAttribute attr(pad, stride, dilation, input_zp, weight_zp); // Create tensors tosa::TosaSerializationTensor* input = translate_client_tensor(client_input, "input"); |