From ed1a15b641128e38e6efc9128e7cff97b782446d Mon Sep 17 00:00:00 2001 From: Jiacheng Liang Date: Thu, 27 Jul 2023 16:50:15 +0100 Subject: Fixed missing tensor inputs in fully_connected in model runner Signed-off-by: Jiacheng Liang Change-Id: I473adc1525319b5574ee0e36d10a530277d9215d --- reference_model/include/operators.h | 2 ++ reference_model/src/operators.cc | 17 ++++++++++++----- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/reference_model/include/operators.h b/reference_model/include/operators.h index b12604f..6b4ec58 100644 --- a/reference_model/include/operators.h +++ b/reference_model/include/operators.h @@ -109,6 +109,8 @@ extern "C" tosa_tensor_t client_output); tosa_status_t tosa_run_fully_connected(tosa_tensor_t client_input, + tosa_tensor_t client_weight, + tosa_tensor_t client_bias, const int32_t client_input_zp, const int32_t client_weight_zp, tosa_tensor_t client_output); diff --git a/reference_model/src/operators.cc b/reference_model/src/operators.cc index f3023dd..1ae0683 100644 --- a/reference_model/src/operators.cc +++ b/reference_model/src/operators.cc @@ -324,6 +324,8 @@ extern "C" } tosa_status_t tosa_run_fully_connected(tosa_tensor_t client_input, + tosa_tensor_t client_weight, + tosa_tensor_t client_bias, const int32_t client_input_zp, const int32_t client_weight_zp, tosa_tensor_t client_output) @@ -335,21 +337,26 @@ extern "C" // Create tensors tosa::TosaSerializationTensor* input = translate_client_tensor(client_input, "input"); + tosa::TosaSerializationTensor* weight = translate_client_tensor(client_weight, "weight"); + tosa::TosaSerializationTensor* bias = translate_client_tensor(client_bias, "bias"); tosa::TosaSerializationTensor* output = translate_client_tensor(client_output, "output"); // Create operator - auto op = new tosa::TosaSerializationOperator(tosa::Op::Op_FULLY_CONNECTED, - tosa::Attribute::Attribute_FullyConnectedAttribute, &attr, - { input->GetName() }, { output->GetName() }); + auto op = new tosa::TosaSerializationOperator( + tosa::Op::Op_FULLY_CONNECTED, tosa::Attribute::Attribute_FullyConnectedAttribute, &attr, + { input->GetName(), weight->GetName(), bias->GetName() }, { output->GetName() }); // Create a tosa single-op basic block - tosa::TosaSerializationBasicBlock block("fully_connected", "main", { op }, { input, output }, - { input->GetName() }, { output->GetName() }); + tosa::TosaSerializationBasicBlock block("fully_connected", "main", { op }, { input, weight, bias, output }, + { input->GetName(), weight->GetName(), bias->GetName() }, + { output->GetName() }); // Setup model TosaReference::ModelRunnerImpl runner; TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block)); TOSA_RETURN_ON_ERROR(runner.setInput(input->GetName(), client_input.data, client_input.size)); + TOSA_RETURN_ON_ERROR(runner.setInput(weight->GetName(), client_weight.data, client_weight.size)); + TOSA_RETURN_ON_ERROR(runner.setInput(bias->GetName(), client_bias.data, client_bias.size)); // Execute TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.run()); -- cgit v1.2.1