diff options
author | Jiacheng Liang <jiacheng.liang@arm.com> | 2023-07-27 16:50:15 +0100 |
---|---|---|
committer | Eric Kunze <eric.kunze@arm.com> | 2023-07-31 15:38:01 +0000 |
commit | ed1a15b641128e38e6efc9128e7cff97b782446d (patch) | |
tree | b4faaa1b57b4c07fc7b34fba44577d6fe4f26850 | |
parent | e5cabbf7528849aac35b498ce0711a144c1a08d5 (diff) | |
download | reference_model-ed1a15b641128e38e6efc9128e7cff97b782446d.tar.gz |
Fixed missing tensor inputs in fully_connected in model runner
Signed-off-by: Jiacheng Liang <jiacheng.liang@arm.com>
Change-Id: I473adc1525319b5574ee0e36d10a530277d9215d
-rw-r--r-- | reference_model/include/operators.h | 2 | ||||
-rw-r--r-- | reference_model/src/operators.cc | 17 |
2 files changed, 14 insertions, 5 deletions
diff --git a/reference_model/include/operators.h b/reference_model/include/operators.h index b12604f..6b4ec58 100644 --- a/reference_model/include/operators.h +++ b/reference_model/include/operators.h @@ -109,6 +109,8 @@ extern "C" tosa_tensor_t client_output); tosa_status_t tosa_run_fully_connected(tosa_tensor_t client_input, + tosa_tensor_t client_weight, + tosa_tensor_t client_bias, const int32_t client_input_zp, const int32_t client_weight_zp, tosa_tensor_t client_output); diff --git a/reference_model/src/operators.cc b/reference_model/src/operators.cc index f3023dd..1ae0683 100644 --- a/reference_model/src/operators.cc +++ b/reference_model/src/operators.cc @@ -324,6 +324,8 @@ extern "C" } tosa_status_t tosa_run_fully_connected(tosa_tensor_t client_input, + tosa_tensor_t client_weight, + tosa_tensor_t client_bias, const int32_t client_input_zp, const int32_t client_weight_zp, tosa_tensor_t client_output) @@ -335,21 +337,26 @@ extern "C" // Create tensors tosa::TosaSerializationTensor* input = translate_client_tensor(client_input, "input"); + tosa::TosaSerializationTensor* weight = translate_client_tensor(client_weight, "weight"); + tosa::TosaSerializationTensor* bias = translate_client_tensor(client_bias, "bias"); tosa::TosaSerializationTensor* output = translate_client_tensor(client_output, "output"); // Create operator - auto op = new tosa::TosaSerializationOperator(tosa::Op::Op_FULLY_CONNECTED, - tosa::Attribute::Attribute_FullyConnectedAttribute, &attr, - { input->GetName() }, { output->GetName() }); + auto op = new tosa::TosaSerializationOperator( + tosa::Op::Op_FULLY_CONNECTED, tosa::Attribute::Attribute_FullyConnectedAttribute, &attr, + { input->GetName(), weight->GetName(), bias->GetName() }, { output->GetName() }); // Create a tosa single-op basic block - tosa::TosaSerializationBasicBlock block("fully_connected", "main", { op }, { input, output }, - { input->GetName() }, { output->GetName() }); + tosa::TosaSerializationBasicBlock block("fully_connected", "main", { op }, { input, weight, bias, output }, + { input->GetName(), weight->GetName(), bias->GetName() }, + { output->GetName() }); // Setup model TosaReference::ModelRunnerImpl runner; TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block)); TOSA_RETURN_ON_ERROR(runner.setInput(input->GetName(), client_input.data, client_input.size)); + TOSA_RETURN_ON_ERROR(runner.setInput(weight->GetName(), client_weight.data, client_weight.size)); + TOSA_RETURN_ON_ERROR(runner.setInput(bias->GetName(), client_bias.data, client_bias.size)); // Execute TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.run()); |