From ce53cd103cc2ac09b43b4fdf586249e626bd5627 Mon Sep 17 00:00:00 2001 From: Grant Watson Date: Tue, 31 Oct 2023 19:02:14 +0000 Subject: Fix Reshape in operator API - The API incorrectly requires the new shape to be passed in twice. - This fix changes the name of the attribute from new_shape to shape in the generate_api.py script. - Adds a unit test to verify that the reshape operator works correctly. Signed-off-by: Grant Watson Change-Id: I07dd0ef786c747896b6e54f4eada0e7b97c6cef3 --- reference_model/src/operators.cc | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) (limited to 'reference_model/src/operators.cc') diff --git a/reference_model/src/operators.cc b/reference_model/src/operators.cc index ecebe52..842847e 100644 --- a/reference_model/src/operators.cc +++ b/reference_model/src/operators.cc @@ -2073,34 +2073,32 @@ extern "C" tosa_status_t tosa_run_reshape(tosa_tensor_t client_input1, tosa_tensor_t client_shape, - const int32_t client_new_shape_len, - const int32_t client_new_shape[], tosa_tensor_t client_output, const func_ctx_t& func_ctx) { // Create operator attributes - const std::vector new_shape(&client_new_shape[0], &client_new_shape[0] + client_new_shape_len); - TosaReshapeAttribute attr(new_shape); + std::vector shape; + size_t shape_size = client_shape.size / sizeof(int32_t); + int32_t* shape_data = reinterpret_cast(client_shape.data); + shape.assign(shape_data, shape_data + shape_size); + TosaReshapeAttribute attr(shape); // Create tensors tosa::TosaSerializationTensor* input1 = translate_client_tensor(client_input1, "input1"); - tosa::TosaSerializationTensor* shape = translate_client_tensor(client_shape, "shape"); tosa::TosaSerializationTensor* output = translate_client_tensor(client_output, "output"); // Create operator - auto op = - new tosa::TosaSerializationOperator(tosa::Op::Op_RESHAPE, tosa::Attribute::Attribute_ReshapeAttribute, - &attr, { input1->GetName(), shape->GetName() }, { output->GetName() }); + auto op = new tosa::TosaSerializationOperator(tosa::Op::Op_RESHAPE, tosa::Attribute::Attribute_ReshapeAttribute, + &attr, { input1->GetName() }, { output->GetName() }); // Create a tosa single-op basic block - tosa::TosaSerializationBasicBlock block("reshape", "main", { op }, { input1, shape, output }, - { input1->GetName(), shape->GetName() }, { output->GetName() }); + tosa::TosaSerializationBasicBlock block("reshape", "main", { op }, { input1, output }, { input1->GetName() }, + { output->GetName() }); // Setup model TosaReference::ModelRunnerImpl runner(func_ctx.func_config, func_ctx.func_debug); TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block)); TOSA_RETURN_ON_ERROR(runner.setInput(input1->GetName(), client_input1.data, client_input1.size)); - TOSA_RETURN_ON_ERROR(runner.setInput(shape->GetName(), client_shape.data, client_shape.size)); // Execute TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.run()); -- cgit v1.2.1