diff options
Diffstat (limited to 'reference_model/src/operators.cc')
-rw-r--r-- | reference_model/src/operators.cc | 20 |
1 files changed, 9 insertions, 11 deletions
diff --git a/reference_model/src/operators.cc b/reference_model/src/operators.cc index ecebe52..842847e 100644 --- a/reference_model/src/operators.cc +++ b/reference_model/src/operators.cc @@ -2073,34 +2073,32 @@ extern "C" tosa_status_t tosa_run_reshape(tosa_tensor_t client_input1, tosa_tensor_t client_shape, - const int32_t client_new_shape_len, - const int32_t client_new_shape[], tosa_tensor_t client_output, const func_ctx_t& func_ctx) { // Create operator attributes - const std::vector<int32_t> new_shape(&client_new_shape[0], &client_new_shape[0] + client_new_shape_len); - TosaReshapeAttribute attr(new_shape); + std::vector<int32_t> shape; + size_t shape_size = client_shape.size / sizeof(int32_t); + int32_t* shape_data = reinterpret_cast<int32_t*>(client_shape.data); + shape.assign(shape_data, shape_data + shape_size); + TosaReshapeAttribute attr(shape); // Create tensors tosa::TosaSerializationTensor* input1 = translate_client_tensor(client_input1, "input1"); - tosa::TosaSerializationTensor* shape = translate_client_tensor(client_shape, "shape"); tosa::TosaSerializationTensor* output = translate_client_tensor(client_output, "output"); // Create operator - auto op = - new tosa::TosaSerializationOperator(tosa::Op::Op_RESHAPE, tosa::Attribute::Attribute_ReshapeAttribute, - &attr, { input1->GetName(), shape->GetName() }, { output->GetName() }); + auto op = new tosa::TosaSerializationOperator(tosa::Op::Op_RESHAPE, tosa::Attribute::Attribute_ReshapeAttribute, + &attr, { input1->GetName() }, { output->GetName() }); // Create a tosa single-op basic block - tosa::TosaSerializationBasicBlock block("reshape", "main", { op }, { input1, shape, output }, - { input1->GetName(), shape->GetName() }, { output->GetName() }); + tosa::TosaSerializationBasicBlock block("reshape", "main", { op }, { input1, output }, { input1->GetName() }, + { output->GetName() }); // Setup model TosaReference::ModelRunnerImpl runner(func_ctx.func_config, func_ctx.func_debug); TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block)); TOSA_RETURN_ON_ERROR(runner.setInput(input1->GetName(), client_input1.data, client_input1.size)); - TOSA_RETURN_ON_ERROR(runner.setInput(shape->GetName(), client_shape.data, client_shape.size)); // Execute TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.run()); |