aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/operators.cc
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/operators.cc')
-rw-r--r--reference_model/src/operators.cc14
1 files changed, 6 insertions, 8 deletions
diff --git a/reference_model/src/operators.cc b/reference_model/src/operators.cc
index 842847e..9c7f9ef 100644
--- a/reference_model/src/operators.cc
+++ b/reference_model/src/operators.cc
@@ -465,21 +465,19 @@ extern "C"
tosa_status_t tosa_run_transpose_conv2d(tosa_tensor_t client_input,
tosa_tensor_t client_weight,
tosa_tensor_t client_bias,
+ const int32_t client_out_pad[4],
const int32_t client_stride[2],
+ const int32_t client_out_shape[4],
const int32_t client_input_zp,
const int32_t client_weight_zp,
- const int32_t client_pad_len,
- const int32_t client_pad[],
- const int32_t client_dilation_len,
- const int32_t client_dilation[],
tosa_tensor_t client_output,
const func_ctx_t& func_ctx)
{
// Create operator attributes
- const std::vector<int32_t> pad(&client_pad[0], &client_pad[0] + client_pad_len);
+ const std::vector<int32_t> out_pad(&client_out_pad[0], &client_out_pad[4]);
const std::vector<int32_t> stride(&client_stride[0], &client_stride[2]);
- const std::vector<int32_t> dilation(&client_dilation[0], &client_dilation[0] + client_dilation_len);
- TosaConvAttribute attr(pad, stride, dilation, client_input_zp, client_weight_zp);
+ const std::vector<int32_t> out_shape(&client_out_shape[0], &client_out_shape[4]);
+ TosaTransposeConvAttribute attr(out_pad, stride, out_shape, client_input_zp, client_weight_zp);
// Create tensors
tosa::TosaSerializationTensor* input = translate_client_tensor(client_input, "input");
@@ -489,7 +487,7 @@ extern "C"
// Create operator
auto op = new tosa::TosaSerializationOperator(
- tosa::Op::Op_TRANSPOSE_CONV2D, tosa::Attribute::Attribute_ConvAttribute, &attr,
+ tosa::Op::Op_TRANSPOSE_CONV2D, tosa::Attribute::Attribute_TransposeConvAttribute, &attr,
{ input->GetName(), weight->GetName(), bias->GetName() }, { output->GetName() });
// Create a tosa single-op basic block