aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/operators.cc
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/operators.cc')
-rw-r--r--reference_model/src/operators.cc20
1 files changed, 9 insertions, 11 deletions
diff --git a/reference_model/src/operators.cc b/reference_model/src/operators.cc
index ecebe52..842847e 100644
--- a/reference_model/src/operators.cc
+++ b/reference_model/src/operators.cc
@@ -2073,34 +2073,32 @@ extern "C"
tosa_status_t tosa_run_reshape(tosa_tensor_t client_input1,
tosa_tensor_t client_shape,
- const int32_t client_new_shape_len,
- const int32_t client_new_shape[],
tosa_tensor_t client_output,
const func_ctx_t& func_ctx)
{
// Create operator attributes
- const std::vector<int32_t> new_shape(&client_new_shape[0], &client_new_shape[0] + client_new_shape_len);
- TosaReshapeAttribute attr(new_shape);
+ std::vector<int32_t> shape;
+ size_t shape_size = client_shape.size / sizeof(int32_t);
+ int32_t* shape_data = reinterpret_cast<int32_t*>(client_shape.data);
+ shape.assign(shape_data, shape_data + shape_size);
+ TosaReshapeAttribute attr(shape);
// Create tensors
tosa::TosaSerializationTensor* input1 = translate_client_tensor(client_input1, "input1");
- tosa::TosaSerializationTensor* shape = translate_client_tensor(client_shape, "shape");
tosa::TosaSerializationTensor* output = translate_client_tensor(client_output, "output");
// Create operator
- auto op =
- new tosa::TosaSerializationOperator(tosa::Op::Op_RESHAPE, tosa::Attribute::Attribute_ReshapeAttribute,
- &attr, { input1->GetName(), shape->GetName() }, { output->GetName() });
+ auto op = new tosa::TosaSerializationOperator(tosa::Op::Op_RESHAPE, tosa::Attribute::Attribute_ReshapeAttribute,
+ &attr, { input1->GetName() }, { output->GetName() });
// Create a tosa single-op basic block
- tosa::TosaSerializationBasicBlock block("reshape", "main", { op }, { input1, shape, output },
- { input1->GetName(), shape->GetName() }, { output->GetName() });
+ tosa::TosaSerializationBasicBlock block("reshape", "main", { op }, { input1, output }, { input1->GetName() },
+ { output->GetName() });
// Setup model
TosaReference::ModelRunnerImpl runner(func_ctx.func_config, func_ctx.func_debug);
TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block));
TOSA_RETURN_ON_ERROR(runner.setInput(input1->GetName(), client_input1.data, client_input1.size));
- TOSA_RETURN_ON_ERROR(runner.setInput(shape->GetName(), client_shape.data, client_shape.size));
// Execute
TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.run());